volatility.py•13.7 kB
from __future__ import annotations
import datetime as _dt
import math
from typing import Annotated, Any, Callable, Mapping, cast
import numpy as np
import pandas as pd
from mcp.server.fastmcp import FastMCP
from schwab_mcp.context import SchwabContext
from schwab_mcp.tools._registration import register_tool
from schwab_mcp.tools.utils import JSONType, call
from .base import ensure_columns, fetch_price_frame
__all__ = ["register"]
_WEEK_DAYS = 5
_MONTH_DAYS = 21
def _volatility_regime(annualized_pct: float) -> str:
    if annualized_pct < 10:
        return "very_low"
    if annualized_pct < 15:
        return "low"
    if annualized_pct < 20:
        return "normal"
    if annualized_pct < 30:
        return "elevated"
    if annualized_pct < 50:
        return "high"
    return "extreme"
def _compute_percentile(vol_series: pd.Series, latest: float) -> float:
    if vol_series.empty:
        return 50.0
    below = (vol_series < latest).sum()
    return float(below / len(vol_series) * 100.0)
def _round(value: float, digits: int = 2) -> float:
    return float(round(value, digits))
async def historical_volatility(
    ctx: SchwabContext,
    symbol: Annotated[str, "Symbol of the security"],
    period: Annotated[int, "Rolling window size for volatility"] = 20,
    interval: Annotated[
        str,
        ("Price interval. Supported values: 1m, 5m, 10m, 15m, 30m, 1d, 1w."),
    ] = "1d",
    start: Annotated[
        str | None,
        (
            "Optional ISO-8601 timestamp for the first candle used in the calculation. "
            "Defaults to enough history based on the requested period."
        ),
    ] = None,
    end: Annotated[
        str | None,
        "Optional ISO-8601 timestamp for the final candle (defaults to now in UTC).",
    ] = None,
    bars: Annotated[
        int | None,
        (
            "Override the number of candles fetched. Defaults to fetching a padded "
            "window sized for the requested period."
        ),
    ] = None,
    annualize_factor: Annotated[
        int,
        (
            "Trading sessions per year used for annualization (default 252 for US equities)."
        ),
    ] = 252,
    method: Annotated[
        str,
        (
            "Volatility calculation method: close_to_close (default), log_returns, or parkinson."
        ),
    ] = "close_to_close",
) -> JSONType:
    """Compute historical volatility statistics for Schwab price history."""
    if period <= 1:
        raise ValueError("period must be greater than 1")
    if annualize_factor <= 0:
        raise ValueError("annualize_factor must be positive")
    method_key = method.strip().lower()
    valid_methods = {"close_to_close", "log_returns", "parkinson"}
    if method_key not in valid_methods:
        raise ValueError(
            "Invalid method. Choose from close_to_close, log_returns, or parkinson."
        )
    required_points = period + 1 if method_key != "parkinson" else period
    padding = max(period // 2, 10)
    window = max(required_points + padding, period * 2)
    if bars is not None:
        window = max(bars, required_points)
    frame, metadata = await fetch_price_frame(
        ctx,
        symbol,
        interval=interval,
        start=start,
        end=end,
        bars=window,
    )
    if frame.empty:
        raise ValueError("Price history request returned no data.")
    if method_key == "parkinson":
        ensure_columns(frame, ("high", "low"))
        working = frame[["high", "low"]].dropna()
        if len(working) < required_points:
            raise ValueError(
                "Not enough high/low data to compute Parkinson volatility for the requested period."
            )
        hl_ratio = np.log(working["high"] / working["low"])
        hl_ratio_sq = hl_ratio.pow(2)
        rolling_sum = hl_ratio_sq.rolling(window=period, min_periods=period).sum()
        vol_series = (rolling_sum / (period * 4.0 * math.log(2.0))).pow(0.5)
        vol_series = cast(pd.Series, vol_series)
    else:
        ensure_columns(frame, ("close",))
        closes = frame["close"].dropna()
        if len(closes) < required_points:
            raise ValueError(
                "Not enough closing prices to compute historical volatility for the requested period."
            )
        if method_key == "log_returns":
            returns = np.log(closes / closes.shift(1))
        else:
            returns = closes.pct_change()
        returns = returns.dropna()
        if len(returns) < period:
            raise ValueError(
                "Not enough return values to compute historical volatility for the requested period."
            )
        vol_series = returns.rolling(window=period, min_periods=period).std()
        vol_series = cast(pd.Series, vol_series)
    vol_series = vol_series.dropna()
    if vol_series.empty:
        raise ValueError(
            "Unable to compute historical volatility with the provided inputs."
        )
    daily_vol = float(vol_series.iloc[-1])
    daily_vol_pct = daily_vol * 100.0
    weekly_vol_pct = daily_vol * math.sqrt(_WEEK_DAYS) * 100.0
    monthly_vol_pct = daily_vol * math.sqrt(_MONTH_DAYS) * 100.0
    annualized_vol_pct = daily_vol * math.sqrt(annualize_factor) * 100.0
    percentile_rank = _compute_percentile(vol_series, daily_vol)
    if vol_series.empty:
        min_vol = max_vol = mean_vol = 0.0
    else:
        scaled = vol_series * math.sqrt(annualize_factor) * 100.0
        min_vol = float(scaled.min())
        max_vol = float(scaled.max())
        mean_vol = float(scaled.mean())
    return {
        "symbol": metadata["symbol"],
        "interval": metadata["interval"],
        "start": metadata["start"],
        "end": metadata["end"],
        "candles": metadata["candles_returned"],
        "period": period,
        "annualize_factor": annualize_factor,
        "method": method_key,
        "daily_vol": _round(daily_vol_pct),
        "weekly_vol": _round(weekly_vol_pct),
        "monthly_vol": _round(monthly_vol_pct),
        "annualized_vol": _round(annualized_vol_pct),
        "percentile_rank": _round(percentile_rank, 1),
        "regime": _volatility_regime(annualized_vol_pct),
        "min_vol": _round(min_vol),
        "max_vol": _round(max_vol),
        "mean_vol": _round(mean_vol),
    }
async def expected_move(
    ctx: SchwabContext,
    symbol: Annotated[str, "Symbol of the underlying security"],
    call_price: Annotated[float | None, "At-the-money call option premium"] = None,
    put_price: Annotated[float | None, "At-the-money put option premium"] = None,
    interval: Annotated[
        str,
        (
            "Price interval used when fetching underlying data if needed. "
            "Supported values: 1m, 5m, 10m, 15m, 30m, 1d, 1w."
        ),
    ] = "1d",
    underlying_price: Annotated[
        float | None,
        (
            "Optional underlying price to use for the calculation. If omitted, the tool "
            "fetches the most recent close for the requested interval."
        ),
    ] = None,
    multiplier: Annotated[
        float,
        (
            "Statistical adjustment multiplier applied to the raw straddle. "
            "Defaults to 0.85 to approximate a one standard deviation move."
        ),
    ] = 0.85,
) -> JSONType:
    """Calculate the option-priced ±1 standard deviation move."""
    chain: Mapping[str, Any] | None = None
    metadata: dict[str, Any] | None = None
    if call_price is not None and call_price <= 0:
        raise ValueError("call_price must be a positive value")
    if put_price is not None and put_price <= 0:
        raise ValueError("put_price must be a positive value")
    if multiplier <= 0:
        raise ValueError("multiplier must be a positive value")
    underlying = underlying_price
    needs_chain = call_price is None or put_price is None
    if needs_chain:
        chain = await _fetch_option_chain(ctx, symbol)
        chain_underlying = chain.get("underlyingPrice") if chain else None
        if underlying is None and chain_underlying is not None:
            underlying = float(chain_underlying)
    if underlying is None:
        frame, metadata = await fetch_price_frame(
            ctx,
            symbol,
            interval=interval,
            bars=1,
        )
        if frame.empty or "close" not in frame.columns:
            raise ValueError(
                "Unable to determine the underlying price from price history."
            )
        underlying = float(frame["close"].iloc[-1])
    if underlying <= 0:
        raise ValueError("underlying_price must be positive")
    if needs_chain:
        call_contract, put_contract = _select_atm_contracts(chain, underlying)
        if call_price is None:
            call_price = _option_price(call_contract)
        if put_price is None:
            put_price = _option_price(put_contract)
    if call_price is None or put_price is None:
        raise ValueError("Unable to determine ATM call and put premiums.")
    if call_price <= 0 or put_price <= 0:
        raise ValueError("call_price and put_price must be positive values")
    straddle_price = float(call_price) + float(put_price)
    move_percent = straddle_price / float(underlying)
    adjusted_move = straddle_price * float(multiplier)
    adjusted_move_percent = adjusted_move / float(underlying)
    boundaries = {
        "upper_1x": float(underlying) + adjusted_move,
        "lower_1x": float(underlying) - adjusted_move,
        "upper_2x": float(underlying) + (adjusted_move * 2.0),
        "lower_2x": float(underlying) - (adjusted_move * 2.0),
    }
    response: dict[str, JSONType] = {
        "symbol": symbol.upper(),
        "call_price": float(call_price),
        "put_price": float(put_price),
        "underlying_price": float(underlying),
        "expected_move": straddle_price,
        "expected_move_percent": move_percent,
        "multiplier": float(multiplier),
        "adjusted_move": adjusted_move,
        "adjusted_move_percent": adjusted_move_percent,
        "boundaries": boundaries,
    }
    if metadata is not None:
        response.update(
            {
                "interval": metadata["interval"],
                "start": metadata["start"],
                "end": metadata["end"],
                "candles": metadata["candles_returned"],
            }
        )
    else:
        response["interval"] = interval
    return response
def register(
    server: FastMCP,
    *,
    allow_write: bool,
    result_transform: Callable[[Any], Any] | None = None,
) -> None:
    _ = allow_write
    register_tool(server, expected_move, result_transform=result_transform)
    register_tool(server, historical_volatility, result_transform=result_transform)
async def _fetch_option_chain(ctx: SchwabContext, symbol: str) -> Mapping[str, Any]:
    response = await call(
        ctx.options.get_option_chain,
        symbol,
        contract_type=None,
        strike_count=10,
        include_underlying_quote=True,
    )
    if not isinstance(response, Mapping):
        raise TypeError("Unexpected option chain response type")
    return cast(Mapping[str, Any], response)
def _select_atm_contracts(
    chain: Mapping[str, Any] | None, underlying: float
) -> tuple[Mapping[str, Any], Mapping[str, Any]]:
    if not chain:
        raise ValueError("Option chain response missing")
    call_map = chain.get("callExpDateMap") or {}
    put_map = chain.get("putExpDateMap") or {}
    best: tuple[float, _dt.date, float, Mapping[str, Any], Mapping[str, Any]] | None = (
        None
    )
    for exp_key, strikes in call_map.items():
        exp_date = _parse_expiration(exp_key)
        for strike_key, contracts in strikes.items():
            if not contracts:
                continue
            strike = _to_float(strike_key)
            call_contract = contracts[0]
            put_contract = _get_contract(put_map, exp_key, strike_key)
            if put_contract is None:
                continue
            diff = abs(strike - underlying)
            if best is None or (diff, exp_date, strike) < (best[0], best[1], best[2]):
                best = (diff, exp_date, strike, call_contract, put_contract)
    if best is None:
        raise ValueError("Unable to locate at-the-money call and put contracts.")
    return best[3], best[4]
def _get_contract(
    exp_map: Mapping[str, Any], exp_key: str, strike_key: str
) -> Mapping[str, Any] | None:
    strikes = exp_map.get(exp_key)
    if not strikes:
        return None
    contracts = strikes.get(strike_key)
    if not contracts:
        return None
    return contracts[0]
def _option_price(contract: Mapping[str, Any]) -> float:
    for key in ("mark", "markPrice", "mark_price"):
        value = contract.get(key)
        if _is_positive_number(value):
            return _to_float(value)
    bid = contract.get("bid")
    ask = contract.get("ask")
    if _is_positive_number(bid) and _is_positive_number(ask):
        return (_to_float(bid) + _to_float(ask)) / 2.0
    for key in ("last", "lastPrice", "closePrice"):
        value = contract.get(key)
        if _is_positive_number(value):
            return _to_float(value)
    raise ValueError("Option contract missing price information")
def _parse_expiration(value: str) -> _dt.date:
    date_part = value.split(":", 1)[0]
    return _dt.date.fromisoformat(date_part)
def _to_float(value: Any) -> float:
    if isinstance(value, (int, float)):
        return float(value)
    return float(str(value))
def _is_positive_number(value: Any) -> bool:
    try:
        return value is not None and float(value) > 0
    except (TypeError, ValueError):
        return False