"""
Database management module for IBKR MCP.
This module provides database connection management, session handling,
and initialization for the IBKR trading system.
"""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional, Union
from loguru import logger
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
# Database URL types
DatabaseURL = str
class DatabaseManager:
"""
Manages database connections, sessions, and initialization.
Supports:
- SQLite (default, recommended for single-user trading)
- PostgreSQL (for production multi-user)
- MySQL (alternative to PostgreSQL)
"""
def __init__(self, db_url: Optional[DatabaseURL] = None):
"""
Initialize database manager.
Args:
db_url: Database connection URL. If None, will be built from environment variables.
Format: sqlite+aiosqlite:///path/to/database.db
postgresql+asyncpg://user:password@host:port/dbname
mysql+aiomysql://user:password@host:port/dbname
"""
self.db_url = db_url or self._build_db_url()
self.engine: Optional[AsyncEngine] = None
self._session_factory: Optional[sessionmaker] = None
# Extract database type for logging
self.db_type = self._extract_db_type(self.db_url)
logger.info(
"DatabaseManager initialized | type={db_type} url={db_url}",
db_type=self.db_type,
db_url=self._safe_url(self.db_url),
)
def _build_db_url(self) -> DatabaseURL:
"""
Build SQLite database URL from environment variables.
Returns:
SQLite database URL: sqlite+aiosqlite:///path/to/file.db
"""
# Check for full URL first
full_url = os.getenv("DB_URL")
if full_url:
return full_url
# Build SQLite URL
db_path = os.getenv("DB_PATH", "ibkr_data.db")
# Ensure .db extension
if not db_path.endswith(".db"):
db_path = f"{db_path}.db"
return f"sqlite+aiosqlite:///{db_path}"
@staticmethod
def _extract_db_type(db_url: DatabaseURL) -> str:
"""Extract database type from URL."""
if db_url.startswith("sqlite"):
return "sqlite"
elif db_url.startswith("postgresql"):
return "postgresql"
elif db_url.startswith("mysql"):
return "mysql"
return "unknown"
@staticmethod
def _safe_url(db_url: DatabaseURL) -> str:
"""
Sanitize database URL for logging (hide password).
Example:
postgresql+asyncpg://user:***@host:port/dbname
"""
if "://" in db_url:
scheme, rest = db_url.split("://", 1)
if "@" in rest:
creds, host_part = rest.rsplit("@", 1)
if ":" in creds:
user, _ = creds.split(":", 1)
creds = f"{user}:***"
return f"{scheme}://{creds}@{host_part}"
return db_url
async def initialize(self) -> None:
"""
Initialize database connection and session factory.
This method:
1. Creates async engine with connection pooling
2. Configures session factory
3. Creates all tables (if using Base.metadata.create_all)
4. Runs initial migrations/seed data if needed
"""
if self.engine is not None:
logger.warning("Database already initialized, skipping")
return
# Create async engine
self.engine = create_async_engine(
self.db_url,
echo=os.getenv("DB_ECHO", "false").lower() == "true",
# Connection pool settings
pool_size=int(os.getenv("DB_POOL_SIZE", "10")),
max_overflow=int(os.getenv("DB_MAX_OVERFLOW", "20")),
pool_timeout=int(os.getenv("DB_POOL_TIMEOUT", "30")),
pool_recycle=int(os.getenv("DB_POOL_RECYCLE", "3600")),
# Pool pre-ping to check connections
pool_pre_ping=True,
)
# Create session factory
self._session_factory = sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
# Initialize database schema
await self._initialize_schema()
logger.info(
"Database initialized successfully | type={db_type} pool_size={pool_size}",
db_type=self.db_type,
pool_size=os.getenv("DB_POOL_SIZE", "10"),
)
async def _initialize_schema(self) -> None:
"""
Initialize database schema.
Creates all tables defined in SQLAlchemy models.
For production use, consider using Alembic migrations instead.
"""
try:
# Import models to ensure they're registered with Base
from . import models # noqa: F401
async with self.engine.begin() as conn:
# Create all tables
await conn.run_sync(self._create_all_tables)
logger.info("Database schema initialized successfully")
except Exception as e:
logger.error(
"Failed to initialize database schema | error={error}",
error=e,
)
raise
def _create_all_tables(self, connection) -> None:
"""
Create all tables using synchronous metadata.
This is a helper for running sync metadata.create_all
within an async context.
"""
from .models import Base
Base.metadata.create_all(bind=connection)
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
Get a database session with automatic cleanup.
Usage:
async with db_manager.get_session() as session:
result = await session.execute(select(MyModel))
session.add(new_model)
await session.commit()
Yields:
AsyncSession: SQLAlchemy async session
"""
if not self._session_factory:
raise RuntimeError(
"Database not initialized. Call await db_manager.initialize() first."
)
async with self._session_factory() as session:
try:
yield session
await session.commit() # 正常退出时自动 commit
except Exception:
await session.rollback()
logger.exception("Database session error, rolling back")
raise
async def execute_raw(self, query: str, params: Optional[dict] = None) -> Any:
"""
Execute raw SQL query.
Args:
query: SQL query string
params: Query parameters
Returns:
Query result
"""
async with self.get_session() as session:
result = await session.execute(text(query), params or {})
await session.commit()
return result
async def health_check(self) -> bool:
"""
Perform database health check.
Returns:
True if database is reachable and responsive
"""
try:
async with self.get_session() as session:
await session.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error("Database health check failed | error={error}", error=e)
return False
async def close(self) -> None:
"""
Close database connections and cleanup resources.
Should be called when shutting down the application.
"""
if self.engine:
await self.engine.dispose()
self.engine = None
self._session_factory = None
logger.info("Database connections closed")
async def drop_all(self) -> None:
"""
Drop all tables (DANGEROUS!).
WARNING: This will delete ALL data!
Only use in development/testing.
Consider using database migrations (Alembic) for schema changes.
"""
logger.warning(
"DROPPING ALL TABLES - This will delete all data!"
)
if not self.engine:
raise RuntimeError("Database not initialized")
async with self.engine.begin() as conn:
from .models import Base
await conn.run_sync(Base.metadata.drop_all)
logger.warning("All tables dropped")
@property
def is_initialized(self) -> bool:
"""Check if database is initialized."""
return self.engine is not None and self._session_factory is not None
# Global database manager instance
db_manager = DatabaseManager()
# Convenience function to get database URL
def get_database_url() -> str:
"""
Get the current database URL.
Returns:
The database URL string used by the DatabaseManager
"""
return db_manager.db_url
# Convenience decorator for automatic session management
def ensure_db_session(func):
"""
Decorator to provide database session to service methods.
Usage:
class MyService:
@ensure_db_session
async def my_method(self, session: AsyncSession):
result = await session.execute(select(MyModel))
return result.scalar()
Args:
func: Async function to wrap
Returns:
Wrapped function with automatic session injection
"""
from functools import wraps
@wraps(func)
async def wrapper(*args, **kwargs):
async with db_manager.get_session() as session:
# Inject session if not provided
if 'session' not in kwargs:
kwargs['session'] = session
return await func(*args, **kwargs)
return wrapper
# Database configuration defaults
DEFAULT_CONFIG = {
# SQLite settings
"DB_PATH": "ibkr_data.db",
# Logging
"DB_ECHO": "false", # Set to "true" to log all SQL queries
}
__all__ = [
"DatabaseManager",
"db_manager",
"ensure_db_session",
"get_database_url",
"DEFAULT_CONFIG",
]