from __future__ import annotations
from typing import Any, Dict, Optional, Tuple, List
import warnings
import numpy as np
import pandas as pd
from ..interface import ForecastMethod, ForecastResult
from ..registry import ForecastRegistry
try:
from statsmodels.tsa.holtwinters import SimpleExpSmoothing as _SES, ExponentialSmoothing as _ETS # type: ignore
_SM_ETS_AVAILABLE = True
except Exception:
_SM_ETS_AVAILABLE = False
try:
from statsmodels.tsa.statespace.sarimax import SARIMAX as _SARIMAX # type: ignore
_SM_SARIMAX_AVAILABLE = True
except Exception:
_SM_SARIMAX_AVAILABLE = False
class ETSArimaMethod(ForecastMethod):
"""Base class for ETS and ARIMA methods."""
@property
def category(self) -> str:
return "ets_arima"
@property
def required_packages(self) -> List[str]:
return ["statsmodels"]
@property
def supports_features(self) -> Dict[str, bool]:
return {"price": True, "return": True, "volatility": True, "ci": True}
@ForecastRegistry.register("ses")
class SESMethod(ETSArimaMethod):
@property
def name(self) -> str:
return "ses"
def forecast(
self,
series: pd.Series,
horizon: int,
seasonality: int,
params: Dict[str, Any],
exog_future: Optional[pd.DataFrame] = None,
**kwargs
) -> ForecastResult:
if not _SM_ETS_AVAILABLE:
raise RuntimeError("SES requires statsmodels")
vals = series.values
alpha = params.get('alpha')
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if alpha is None:
res = _SES(vals, initialization_method='heuristic').fit(optimized=True)
else:
res = _SES(vals, initialization_method='heuristic').fit(smoothing_level=float(alpha), optimized=False)
f_vals = np.asarray(res.forecast(int(horizon)), dtype=float)
# Recover effective alpha
alpha_used = alpha
try:
par = getattr(res, 'params', None)
if par is not None and hasattr(par, 'get'):
alpha_used = par.get('smoothing_level', alpha)
except Exception:
pass
return ForecastResult(forecast=f_vals, params_used={"alpha": alpha_used})
@ForecastRegistry.register("holt")
class HoltMethod(ETSArimaMethod):
@property
def name(self) -> str:
return "holt"
def forecast(
self,
series: pd.Series,
horizon: int,
seasonality: int,
params: Dict[str, Any],
exog_future: Optional[pd.DataFrame] = None,
**kwargs
) -> ForecastResult:
if not _SM_ETS_AVAILABLE:
raise RuntimeError("Holt requires statsmodels")
vals = series.values
damped = bool(params.get('damped', False))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = _ETS(vals, trend='add', damped_trend=damped, initialization_method='heuristic')
res = model.fit(optimized=True)
f_vals = np.asarray(res.forecast(int(horizon)), dtype=float)
return ForecastResult(forecast=f_vals, params_used={"damped": damped})
@ForecastRegistry.register("holt_winters_add")
class HoltWintersAddMethod(ETSArimaMethod):
@property
def name(self) -> str:
return "holt_winters_add"
def forecast(
self,
series: pd.Series,
horizon: int,
seasonality: int,
params: Dict[str, Any],
exog_future: Optional[pd.DataFrame] = None,
**kwargs
) -> ForecastResult:
return self._forecast_hw(series, horizon, seasonality, params, 'add')
def _forecast_hw(self, series, horizon, seasonality, params, seasonal_type):
if not _SM_ETS_AVAILABLE:
raise RuntimeError("Holt-Winters requires statsmodels")
m = int(seasonality)
if m <= 0:
raise ValueError("Holt-Winters requires positive seasonality")
vals = series.values
damped = bool(params.get('damped', False))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = _ETS(vals, trend='add', seasonal=seasonal_type, seasonal_periods=m, damped_trend=damped, initialization_method='heuristic')
res = model.fit(optimized=True)
f_vals = np.asarray(res.forecast(int(horizon)), dtype=float)
return ForecastResult(forecast=f_vals, params_used={"seasonal": seasonal_type, "m": m, "damped": damped})
@ForecastRegistry.register("holt_winters_mul")
class HoltWintersMulMethod(HoltWintersAddMethod):
@property
def name(self) -> str:
return "holt_winters_mul"
def forecast(
self,
series: pd.Series,
horizon: int,
seasonality: int,
params: Dict[str, Any],
exog_future: Optional[pd.DataFrame] = None,
**kwargs
) -> ForecastResult:
return self._forecast_hw(series, horizon, seasonality, params, 'mul')
@ForecastRegistry.register("arima")
class ARIMAMethod(ETSArimaMethod):
@property
def name(self) -> str:
return "arima"
@property
def category(self) -> str:
return "arima"
def forecast(
self,
series: pd.Series,
horizon: int,
seasonality: int,
params: Dict[str, Any],
exog_future: Optional[pd.DataFrame] = None,
**kwargs
) -> ForecastResult:
return self._forecast_sarimax(series, horizon, seasonality, params, seasonal=False, exog_future=exog_future, **kwargs)
def _forecast_sarimax(self, series, horizon, seasonality, params, seasonal, exog_future=None, **kwargs):
if not _SM_SARIMAX_AVAILABLE:
raise RuntimeError("SARIMAX requires statsmodels")
vals = series.values.astype(float)
order = params.get('order', (1, 1, 1))
seasonal_order = params.get('seasonal_order', (0, 0, 0, 0))
if seasonal and seasonality > 1 and seasonal_order == (0, 0, 0, 0):
# Auto-guess seasonal order if not provided but requested
seasonal_order = (0, 1, 1, seasonality)
trend = params.get('trend', 'c')
ci_alpha = params.get('alpha', 0.05)
exog_used = kwargs.get('exog_used')
exog_future_arr = kwargs.get('exog_future') # This might come from kwargs or explicit arg
# If exog_future was passed as explicit arg, use it (it might be DataFrame)
# The interface defines exog_future as Optional[pd.DataFrame]
# But legacy wrapper passes numpy array.
# We need to handle both.
exog_u = exog_used
exog_f = exog_future_arr if exog_future_arr is not None else exog_future
if isinstance(exog_f, pd.DataFrame):
exog_f = exog_f.values
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = _SARIMAX(
vals,
order=order,
seasonal_order=seasonal_order,
trend=str(trend),
enforce_stationarity=True,
enforce_invertibility=True,
exog=exog_u
)
res = model.fit(method='lbfgs', disp=False, maxiter=100)
if exog_f is not None:
pred = res.get_forecast(steps=int(horizon), exog=exog_f)
else:
pred = res.get_forecast(steps=int(horizon))
pm = pred.predicted_mean
f_vals = np.asarray(pm, dtype=float)
ci = None
try:
_alpha = float(ci_alpha) if ci_alpha is not None else 0.05
ci_df = pred.conf_int(alpha=_alpha)
ci_arr = np.asarray(ci_df)
if ci_arr.ndim == 2 and ci_arr.shape[1] >= 2:
ci = (ci_arr[:, 0], ci_arr[:, 1])
except Exception:
pass
params_used = {"order": tuple(order), "seasonal_order": tuple(seasonal_order), "trend": str(trend)}
if exog_u is not None:
params_used["exog"] = {"n_features": int(exog_u.shape[1])}
return ForecastResult(forecast=f_vals, ci_values=ci, params_used=params_used)
@ForecastRegistry.register("sarima")
class SARIMAMethod(ARIMAMethod):
@property
def name(self) -> str:
return "sarima"
def forecast(
self,
series: pd.Series,
horizon: int,
seasonality: int,
params: Dict[str, Any],
exog_future: Optional[pd.DataFrame] = None,
**kwargs
) -> ForecastResult:
return self._forecast_sarimax(series, horizon, seasonality, params, seasonal=True, exog_future=exog_future, **kwargs)
# Backward compatibility wrappers
def forecast_ses(series: np.ndarray, fh: int, alpha: Optional[float] = None) -> Tuple[np.ndarray, Dict[str, Any], Optional[np.ndarray]]:
res = ForecastRegistry.get("ses").forecast(pd.Series(series), fh, 0, {"alpha": alpha})
# Note: original returned fitted values as 3rd element. New interface doesn't strictly require it but we can add to metadata if needed.
# For now, returning None for fitted to match signature
return res.forecast, res.params_used, None
def forecast_holt(series: np.ndarray, fh: int, damped: bool = True) -> Tuple[np.ndarray, Dict[str, Any], Optional[np.ndarray]]:
res = ForecastRegistry.get("holt").forecast(pd.Series(series), fh, 0, {"damped": damped})
return res.forecast, res.params_used, None
def forecast_holt_winters(series: np.ndarray, fh: int, m: int, seasonal: str = 'add') -> Tuple[np.ndarray, Dict[str, Any], Optional[np.ndarray]]:
method_name = "holt_winters_add" if seasonal == 'add' else "holt_winters_mul"
res = ForecastRegistry.get(method_name).forecast(pd.Series(series), fh, m, {"damped": False}) # Original wrapper didn't expose damped param?
return res.forecast, res.params_used, None
def forecast_sarimax(
series: np.ndarray,
fh: int,
order: Tuple[int, int, int],
seasonal_order: Tuple[int, int, int, int] = (0, 0, 0, 0),
trend: str = 'c',
exog_used: Optional[np.ndarray] = None,
exog_future: Optional[np.ndarray] = None,
ci_alpha: Optional[float] = 0.05,
) -> Tuple[np.ndarray, Dict[str, Any], Optional[Tuple[np.ndarray, np.ndarray]]]:
# Determine if it's ARIMA or SARIMA based on seasonal_order
method_name = "sarima" if sum(seasonal_order) > 0 else "arima"
params = {
"order": order,
"seasonal_order": seasonal_order,
"trend": trend,
"alpha": ci_alpha
}
res = ForecastRegistry.get(method_name).forecast(
pd.Series(series),
fh,
seasonal_order[3] if len(seasonal_order) > 3 else 0,
params,
exog_used=exog_used,
exog_future=exog_future
)
return res.forecast, res.params_used, res.ci_values