# models/arima_model.py - FIXED VERSION
import pandas as pd
import numpy as np
import yfinance as yf
import json
import warnings
from datetime import datetime, timedelta
from pathlib import Path
from typing import Tuple, Dict, Any, Optional
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.stattools import adfuller
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import itertools
warnings.filterwarnings('ignore')
class ARIMAModelManager:
"""Fixed ARIMA model with proper ticker-specific caching"""
def __init__(self, cache_dir: str = "cache/arima_params"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _get_cache_file_path(self, ticker: str) -> Path:
"""Get ticker-specific cache file path"""
return self.cache_dir / f"{ticker}_arima_params.json"
def _load_cached_parameters(self, ticker: str) -> Optional[Dict[str, Any]]:
"""Load cached parameters ONLY for the specific ticker"""
cache_file = self._get_cache_file_path(ticker)
try:
if cache_file.exists():
with open(cache_file, 'r') as f:
cache_data = json.load(f)
# Validate cache is recent (within 30 days)
cache_date = datetime.fromisoformat(cache_data['timestamp'])
if (datetime.now() - cache_date).days < 30:
print(f"ā
{ticker}: Using cached ARIMA parameters {cache_data['optimal_params']}")
return cache_data
else:
print(f"ā° {ticker}: Cache expired ({(datetime.now() - cache_date).days} days old)")
return None
else:
print(f"š {ticker}: No cache file found")
return None
except Exception as e:
print(f"ā ļø {ticker}: Cache read error: {e}")
return None
def _save_optimized_parameters(self, ticker: str, optimal_params: Tuple[int, int, int],
model_metrics: Dict[str, float]) -> None:
"""Save ticker-specific optimized parameters"""
cache_file = self._get_cache_file_path(ticker)
cache_data = {
'ticker': ticker,
'optimal_params': optimal_params,
'aic': model_metrics.get('aic', None),
'bic': model_metrics.get('bic', None),
'timestamp': datetime.now().isoformat(),
'data_period': '4y',
'optimization_date': datetime.now().strftime('%Y-%m-%d')
}
try:
with open(cache_file, 'w') as f:
json.dump(cache_data, f, indent=2)
print(f"š¾ {ticker}: Saved optimized parameters {optimal_params}")
except Exception as e:
print(f"ā {ticker}: Failed to save cache: {e}")
def _download_stock_data(self, ticker: str, period: str = "4y") -> pd.DataFrame:
"""Download 4 years of stock data"""
print(f"š„ {ticker}: Downloading {period} of data...")
try:
data = yf.download(ticker, period=period, progress=False)
if data.empty:
raise ValueError(f"No data available for {ticker}")
print(f"ā
{ticker}: Downloaded {len(data)} data points ({data.index[0].date()} to {data.index[-1].date()})")
return data
except Exception as e:
print(f"ā {ticker}: Data download failed: {e}")
raise
def _optimize_arima_parameters(self, data: pd.DataFrame, ticker: str) -> Tuple[Tuple[int, int, int], Dict[str, float]]:
"""Find optimal ARIMA parameters using grid search"""
print(f"š {ticker}: Optimizing ARIMA parameters...")
# Prepare data
close_prices = data['Close'].dropna()
# Parameter ranges for grid search
p_values = range(0, 4) # AR terms
d_values = range(0, 3) # Differencing
q_values = range(0, 4) # MA terms
best_aic = float('inf')
best_params = None
best_model = None
optimization_results = []
total_combinations = len(p_values) * len(d_values) * len(q_values)
current_combination = 0
for p, d, q in itertools.product(p_values, d_values, q_values):
current_combination += 1
try:
# Fit ARIMA model
model = ARIMA(close_prices, order=(p, d, q))
fitted_model = model.fit()
# Collect metrics
aic = fitted_model.aic
bic = fitted_model.bic
optimization_results.append({
'params': (p, d, q),
'aic': aic,
'bic': bic
})
# Check if this is the best model so far
if aic < best_aic:
best_aic = aic
best_params = (p, d, q)
best_model = fitted_model
if current_combination % 10 == 0:
print(f" Progress: {current_combination}/{total_combinations} combinations tested...")
except Exception as e:
# Skip problematic parameter combinations
continue
if best_params is None:
# Fallback to simple ARIMA(1,1,1)
print(f"ā ļø {ticker}: Optimization failed, using fallback ARIMA(1,1,1)")
best_params = (1, 1, 1)
model = ARIMA(close_prices, order=best_params)
best_model = model.fit()
best_aic = best_model.aic
metrics = {
'aic': best_aic,
'bic': best_model.bic,
'optimization_combinations_tested': len(optimization_results)
}
print(f"ā
{ticker}: Optimal ARIMA{best_params} found (AIC: {best_aic:.2f})")
return best_params, metrics
def get_arima_forecast(self, ticker: str, use_optimized_params: bool = True,
force_recalculate: bool = False) -> Tuple[float, np.ndarray, pd.DataFrame]:
"""
Get ARIMA forecast with proper ticker-specific caching
Returns:
forecast_price: Next day's predicted price
residuals: Model residuals for XGBoost enhancement
data: Historical price data used
"""
print(f"\nš {ticker}: Starting ARIMA analysis...")
# Download fresh data (always download to ensure we have latest data)
data = self._download_stock_data(ticker)
close_prices = data['Close'].dropna()
# Determine optimal parameters
if use_optimized_params and not force_recalculate:
# Try to load cached parameters for THIS specific ticker
cached_data = self._load_cached_parameters(ticker)
if cached_data:
optimal_params = tuple(cached_data['optimal_params'])
print(f"š {ticker}: Using cached parameters ARIMA{optimal_params}")
else:
# No cache for this ticker, optimize
print(f"š {ticker}: No cached parameters found, optimizing...")
optimal_params, metrics = self._optimize_arima_parameters(data, ticker)
self._save_optimized_parameters(ticker, optimal_params, metrics)
else:
# Force fresh optimization
if force_recalculate:
print(f"š {ticker}: Force recalculating parameters...")
else:
print(f"š {ticker}: Optimizing parameters...")
optimal_params, metrics = self._optimize_arima_parameters(data, ticker)
self._save_optimized_parameters(ticker, optimal_params, metrics)
# Fit final model with optimal parameters
print(f"šÆ {ticker}: Fitting final ARIMA{optimal_params} model...")
try:
model = ARIMA(close_prices, order=optimal_params)
fitted_model = model.fit()
# Generate forecast
forecast_result = fitted_model.forecast(steps=1)
forecast_price = float(forecast_result.iloc[0])
# Get residuals for XGBoost enhancement
residuals = fitted_model.resid
# Model diagnostics
last_price = float(close_prices.iloc[-1])
expected_change = (forecast_price - last_price) / last_price
print(f"ā
{ticker}: ARIMA forecast complete")
print(f" Last price: ${last_price:.2f}")
print(f" Forecast: ${forecast_price:.2f} ({expected_change:+.2%})")
print(f" Model: ARIMA{optimal_params}, AIC: {fitted_model.aic:.2f}")
return forecast_price, residuals, data
except Exception as e:
print(f"ā {ticker}: ARIMA model fitting failed: {e}")
raise
# Global instance for backward compatibility
arima_manager = ARIMAModelManager()
def get_arima_forecast(ticker: str, use_optimized_params: bool = True,
force_recalculate: bool = False) -> Tuple[float, np.ndarray, pd.DataFrame]:
"""
FIXED: Get ARIMA forecast with proper ticker-specific caching
This function now ensures:
1. Each ticker has its own parameter cache
2. Parameters are never shared between tickers
3. Fresh 4-year data is always downloaded
4. Optimization results are ticker-specific
"""
return arima_manager.get_arima_forecast(ticker, use_optimized_params, force_recalculate)
def get_enhanced_arima_forecast(ticker: str) -> Tuple[float, np.ndarray, pd.DataFrame]:
"""Enhanced ARIMA forecast with forced parameter optimization"""
return get_arima_forecast(ticker, use_optimized_params=False, force_recalculate=True)
def clear_cache_for_ticker(ticker: str) -> bool:
"""Clear cached parameters for specific ticker"""
cache_file = arima_manager._get_cache_file_path(ticker)
try:
if cache_file.exists():
cache_file.unlink()
print(f"šļø {ticker}: Cache cleared")
return True
else:
print(f"ā¹ļø {ticker}: No cache to clear")
return False
except Exception as e:
print(f"ā {ticker}: Failed to clear cache: {e}")
return False
def clear_all_cache() -> int:
"""Clear all cached parameters"""
cache_dir = arima_manager.cache_dir
cleared_count = 0
try:
for cache_file in cache_dir.glob("*_arima_params.json"):
cache_file.unlink()
cleared_count += 1
print(f"šļø Cleared {cleared_count} cache files")
return cleared_count
except Exception as e:
print(f"ā Failed to clear cache: {e}")
return 0