init_db.py•3.74 kB
"""Database initialization and setup utilities."""
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from ..utils.errors import DatabaseError
from .models.base import Base
# Get database URL from environment or use default SQLite
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///mcp_server.db")
# Create engine with appropriate configuration for environment
def create_db_engine():
"""Create database engine with environment-appropriate settings."""
is_test = os.getenv("TESTING", "").lower() == "true"
if is_test or DATABASE_URL.startswith("sqlite"):
# SQLite configuration (including tests)
engine = create_engine(
DATABASE_URL, echo=False, connect_args={"check_same_thread": False}
)
# Enable foreign keys at runtime
from sqlalchemy import event
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine
else:
# Production PostgreSQL configuration
return create_engine(
DATABASE_URL,
echo=True,
pool_size=20,
max_overflow=2,
pool_timeout=30,
pool_recycle=3600,
)
engine = create_db_engine()
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db(force: bool = False):
"""Initialize the database, creating all tables.
Args:
force: If True, drop existing tables before creation.
Only use in development/testing!
Raises:
DatabaseError: If database initialization fails
"""
try:
# Import all models to ensure they're registered
from .models import (
entities,
relationships,
observations,
providers,
arguments,
ansible,
parameters,
)
# Verify all models are imported
required_models = {
'Entity', 'Relationship', 'Observation',
'Provider', 'Argument', 'AnsibleCollection',
'Parameter'
}
registered_models = set(Base.metadata.tables.keys())
missing_models = required_models - registered_models
if missing_models:
raise DatabaseError(
message="Missing required models",
details={"missing_models": list(missing_models)}
)
if force:
if os.getenv("TESTING", "").lower() == "true":
Base.metadata.drop_all(bind=engine)
else:
raise DatabaseError(
message="Cannot force drop tables outside of testing environment"
)
Base.metadata.create_all(bind=engine)
except Exception as e:
raise DatabaseError(
message=f"Failed to initialize database: {str(e)}",
details={"error": str(e)}
)
def get_db():
"""Get a database session with proper cleanup."""
db = SessionLocal()
try:
# Set timeouts if not SQLite
if not db.bind.dialect.name == "sqlite":
from sqlalchemy import text
db.execute(text("SET statement_timeout = 10000")) # 10s
db.execute(text("SET idle_in_transaction_session_timeout = 60000")) # 60s
yield db
db.commit()
except Exception as e:
db.rollback()
raise DatabaseError(f"Database operation failed: {str(e)}")
finally:
db.close()