from __future__ import annotations
import datetime as _dt
from dataclasses import dataclass
from typing import Any, Iterable, Mapping, cast
import pandas as pd
from schwab_mcp.context import SchwabContext
from schwab_mcp.tools.utils import JSONType, call
from . import pandas_ta as _pandas_ta
__all__ = [
"normalize_interval",
"fetch_price_frame",
"series_to_json",
"frame_to_json",
"ensure_columns",
"pandas_ta",
]
@dataclass(frozen=True)
class _IntervalConfig:
method_name: str
bar_size: _dt.timedelta
_INTERVAL_CONFIGS: dict[str, _IntervalConfig] = {
"1m": _IntervalConfig(
method_name="get_price_history_every_minute",
bar_size=_dt.timedelta(minutes=1),
),
"5m": _IntervalConfig(
method_name="get_price_history_every_five_minutes",
bar_size=_dt.timedelta(minutes=5),
),
"10m": _IntervalConfig(
method_name="get_price_history_every_ten_minutes",
bar_size=_dt.timedelta(minutes=10),
),
"15m": _IntervalConfig(
method_name="get_price_history_every_fifteen_minutes",
bar_size=_dt.timedelta(minutes=15),
),
"30m": _IntervalConfig(
method_name="get_price_history_every_thirty_minutes",
bar_size=_dt.timedelta(minutes=30),
),
"1d": _IntervalConfig(
method_name="get_price_history_every_day",
bar_size=_dt.timedelta(days=1),
),
"1w": _IntervalConfig(
method_name="get_price_history_every_week",
bar_size=_dt.timedelta(days=7),
),
}
def normalize_interval(value: str) -> str:
"""Return canonical short form (e.g., 1d, 15m) for the supplied interval."""
normalized = value.strip().lower()
if normalized in _INTERVAL_CONFIGS:
return normalized
raise ValueError(
f"Unsupported interval '{value}'. "
f"Choose from: {', '.join(sorted(_INTERVAL_CONFIGS))}"
)
def _add_utc_timezone(value: _dt.datetime) -> _dt.datetime:
if value.tzinfo is None:
return value.replace(tzinfo=_dt.timezone.utc)
return value.astimezone(_dt.timezone.utc)
def _parse_timestamp(value: str | _dt.datetime | None) -> _dt.datetime | None:
if value is None:
return None
if isinstance(value, _dt.datetime):
return _add_utc_timezone(value)
return _add_utc_timezone(_dt.datetime.fromisoformat(value))
def _default_start(
*, end: _dt.datetime, interval: _IntervalConfig, bars: int | None
) -> _dt.datetime | None:
if bars is None or bars <= 0:
return None
return end - (interval.bar_size * bars)
def _candles_to_dataframe(candles: Iterable[Mapping[str, Any]]) -> pd.DataFrame:
frame = pd.DataFrame.from_records(candles)
if frame.empty:
return frame
if "datetime" in frame.columns:
frame["datetime"] = pd.to_datetime(
frame["datetime"], unit="ms", utc=True, errors="coerce"
)
frame = frame.dropna(subset=["datetime"]).set_index("datetime")
numeric_columns = [
column
for column in ("open", "high", "low", "close", "volume")
if column in frame.columns
]
if numeric_columns:
frame[numeric_columns] = frame[numeric_columns].apply(
pd.to_numeric, errors="coerce"
)
return frame.sort_index().dropna(how="all")
def ensure_columns(frame: pd.DataFrame, columns: Iterable[str]) -> None:
missing = [column for column in columns if column not in frame.columns]
if missing:
raise ValueError(
"Price history missing required columns: " + ", ".join(sorted(missing))
)
async def fetch_price_frame(
ctx: SchwabContext,
symbol: str,
*,
interval: str,
start: str | None = None,
end: str | None = None,
bars: int | None = None,
) -> tuple[pd.DataFrame, dict[str, Any]]:
"""Fetch OHLCV data for the requested interval and return a pandas DataFrame."""
interval_key = normalize_interval(interval)
config = _INTERVAL_CONFIGS[interval_key]
end_dt = _parse_timestamp(end) or _dt.datetime.now(tz=_dt.timezone.utc)
start_dt = _parse_timestamp(start) or _default_start(
end=end_dt, interval=config, bars=bars
)
fetcher = getattr(ctx.price_history, config.method_name)
response: JSONType = await call(
fetcher,
symbol,
start_datetime=start_dt,
end_datetime=end_dt,
)
if not isinstance(response, Mapping):
raise TypeError("Unexpected response type for price history payload")
candles = response.get("candles", [])
frame = _candles_to_dataframe(candles)
empty = bool(response.get("empty")) or frame.empty
metadata = {
"symbol": str(response.get("symbol", symbol)).upper(),
"interval": interval_key,
"start": start_dt.isoformat() if start_dt else None,
"end": end_dt.isoformat(),
"bars_requested": bars,
"empty": empty,
"candles_returned": len(frame),
}
return frame, metadata
def series_to_json(
series: pd.Series,
*,
limit: int | None = None,
value_key: str | None = None,
) -> list[dict[str, Any]]:
"""Convert a pandas Series indexed by timestamps into JSON serializable rows."""
if series.empty:
return []
series = series.dropna()
if series.empty:
return []
if limit is not None and limit > 0:
series = series.tail(limit)
value_key = value_key or (str(series.name) if series.name else "value")
index = _normalize_index(series.index)
values = series.to_numpy()
rows: list[dict[str, Any]] = []
for timestamp, value in zip(index, values):
if pd.isna(timestamp) or pd.isna(value):
continue
rows.append({"timestamp": timestamp.isoformat(), value_key: float(value)})
return rows
def frame_to_json(
frame: pd.DataFrame,
*,
limit: int | None = None,
) -> list[dict[str, Any]]:
"""Convert a pandas DataFrame indexed by timestamps into JSON rows."""
if frame.empty:
return []
numeric = frame.apply(pd.to_numeric, errors="coerce")
numeric = numeric.dropna(how="all")
if numeric.empty:
return []
if limit is not None and limit > 0:
numeric = numeric.tail(limit)
index = _normalize_index(numeric.index)
rows: list[dict[str, Any]] = []
for timestamp, (_, row) in zip(index, numeric.iterrows()):
valid_items = {
str(column): float(value)
for column, value in row.items()
if pd.notna(value)
}
if not valid_items:
continue
rows.append({"timestamp": timestamp.isoformat(), **valid_items})
return rows
def _normalize_index(index: pd.Index) -> pd.DatetimeIndex:
if isinstance(index, pd.DatetimeIndex):
if index.tz is None:
return index.tz_localize("UTC")
return index.tz_convert("UTC")
converted = pd.to_datetime(index, utc=True, errors="coerce")
if not isinstance(converted, pd.DatetimeIndex):
converted = pd.DatetimeIndex(converted)
return converted
# Re-export the optional dependency so submodules can share the import guard.
pandas_ta = cast(Any, _pandas_ta)