# models/__init__.py
"""
Trading Models Package
This package contains machine learning models for stock price prediction:
- ARIMA models for time series forecasting
- XGBoost models for residual correction
- Hybrid models combining both approaches
"""
import logging
import warnings
# Set up package logging
logger = logging.getLogger(__name__)
# Version information
__version__ = "1.0.0"
__author__ = "Trading System"
__email__ = "your.email@example.com"
# Import main functions with error handling
try:
# ARIMA model imports
from .arima_model import (
get_arima_forecast,
ARIMAForecaster,
get_enhanced_arima_forecast
)
logger.info("Successfully imported ARIMA models")
except ImportError as e:
logger.warning(f"Could not import ARIMA models: {e}")
# Fallback imports or dummy functions
get_arima_forecast = None
ARIMAForecaster = None
get_enhanced_arima_forecast = None
try:
# Hybrid model imports
from .hybrid_model import (
train_xgboost_on_residuals,
HybridARIMAXGBoost,
train_xgboost_on_residuals_enhanced
)
logger.info("Successfully imported hybrid models")
except ImportError as e:
logger.warning(f"Could not import hybrid models: {e}")
# Fallback imports
train_xgboost_on_residuals = None
HybridARIMAXGBoost = None
train_xgboost_on_residuals_enhanced = None
# Define what gets imported with "from models import *"
__all__ = [
# ARIMA models
'get_arima_forecast',
'ARIMAForecaster',
'get_enhanced_arima_forecast',
# Hybrid models
'train_xgboost_on_residuals',
'HybridARIMAXGBoost',
'train_xgboost_on_residuals_enhanced',
# Package info
'__version__',
'__author__'
]
# Package initialization
def initialize_models():
"""Initialize and validate all models."""
logger.info(f"Initializing trading models package v{__version__}")
# Suppress common warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning, module='yfinance')
# Validate imports
missing_models = []
if get_arima_forecast is None:
missing_models.append("ARIMA models")
if train_xgboost_on_residuals is None:
missing_models.append("Hybrid models")
if missing_models:
logger.error(f"Missing models: {', '.join(missing_models)}")
return False
logger.info("All models successfully initialized")
return True
# Convenience functions
def get_available_models():
"""Return list of available model functions."""
available = []
if get_arima_forecast is not None:
available.extend(['get_arima_forecast', 'ARIMAForecaster'])
if train_xgboost_on_residuals is not None:
available.extend(['train_xgboost_on_residuals', 'HybridARIMAXGBoost'])
return available
def check_dependencies():
"""Check if required dependencies are installed."""
required_packages = [
'yfinance', 'pmdarima', 'xgboost',
'sklearn', 'pandas', 'numpy'
]
missing = []
for package in required_packages:
try:
__import__(package)
except ImportError:
missing.append(package)
if missing:
logger.error(f"Missing required packages: {', '.join(missing)}")
return False
logger.info("All dependencies satisfied")
return True
# Auto-initialize when package is imported
initialize_models()