"""
IRIS Database Connection and Session Management
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import StaticPool
from contextlib import contextmanager
from typing import Generator
import logging
from src.config.settings import settings
from .models import Base
logger = logging.getLogger(__name__)
# Database engine
engine = create_engine(
settings.database_url,
poolclass=StaticPool,
pool_pre_ping=True,
pool_recycle=300,
echo=settings.debug # Log SQL queries in debug mode
)
# Session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def create_tables():
"""Create all database tables"""
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully")
except Exception as e:
logger.error(f"Failed to create database tables: {e}")
raise
def drop_tables():
"""Drop all database tables (use with caution!)"""
try:
Base.metadata.drop_all(bind=engine)
logger.warning("All database tables dropped")
except Exception as e:
logger.error(f"Failed to drop database tables: {e}")
raise
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
"""
Context manager for database sessions
Ensures proper cleanup and error handling
"""
session = SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database session error: {e}")
raise
finally:
session.close()
def get_db() -> Generator[Session, None, None]:
"""
Dependency for FastAPI to get database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
class DatabaseManager:
"""Database management utilities"""
@staticmethod
def health_check() -> bool:
"""Check database connectivity"""
try:
from sqlalchemy import text
with get_db_session() as db:
db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
@staticmethod
def get_stats() -> dict:
"""Get database statistics"""
try:
with get_db_session() as db:
from .models import User, Session, File, Message, SystemLog
stats = {
"users": db.query(User).count(),
"sessions": db.query(Session).count(),
"messages": db.query(Message).count(),
"files": db.query(File).count(),
"logs": db.query(SystemLog).count(),
"active_users": db.query(User).filter(User.is_active == True).count(),
"active_sessions": db.query(Session).filter(Session.state == "active").count()
}
return stats
except Exception as e:
logger.error(f"Failed to get database stats: {e}")
return {}
@staticmethod
def cleanup_old_data(days: int = 30):
"""Clean up old data (logs, inactive sessions, etc.)"""
try:
from datetime import datetime, timedelta
from .models import SystemLog, Session
cutoff_date = datetime.utcnow() - timedelta(days=days)
with get_db_session() as db:
# Clean old logs
old_logs = db.query(SystemLog).filter(SystemLog.created_at < cutoff_date).count()
db.query(SystemLog).filter(SystemLog.created_at < cutoff_date).delete()
# Clean inactive sessions
old_sessions = db.query(Session).filter(
Session.last_message_at < cutoff_date,
Session.state != "active"
).count()
db.query(Session).filter(
Session.last_message_at < cutoff_date,
Session.state != "active"
).delete()
logger.info(f"Cleaned up {old_logs} old logs and {old_sessions} old sessions")
except Exception as e:
logger.error(f"Failed to cleanup old data: {e}")
raise