"""Async database session management for sync service.
Provides async SQLAlchemy session factory for use with FastAPI.
"""
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from service.db.models import Base
logger = logging.getLogger(__name__)
# Module-level engine and session factory
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
def get_database_url() -> str:
"""Get async database URL from environment.
Converts postgresql:// to postgresql+asyncpg:// for async driver.
"""
url = os.environ.get(
"CONTEXTFS_POSTGRES_URL",
"postgresql://contextfs:contextfs@localhost:5433/contextfs_sync",
)
# Convert to async driver URL
if url.startswith("postgresql://"):
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
elif url.startswith("postgres://"):
url = url.replace("postgres://", "postgresql+asyncpg://", 1)
return url
async def init_db(url: str | None = None) -> AsyncEngine:
"""Initialize database engine and session factory.
Args:
url: Database URL (uses environment variable if not provided)
Returns:
Async database engine
"""
global _engine, _session_factory
if _engine is None:
db_url = url or get_database_url()
_engine = create_async_engine(
db_url,
echo=os.environ.get("CONTEXTFS_DB_ECHO", "").lower() == "true",
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
_session_factory = async_sessionmaker(
bind=_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
return _engine
async def create_tables() -> None:
"""Create all database tables.
NOTE: For production, use Alembic migrations instead.
This function is kept for testing/development only.
"""
global _engine
if _engine is None:
await init_db()
async with _engine.begin() as conn:
# Enable required PostgreSQL extensions
await conn.execute(
__import__("sqlalchemy").text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
)
await conn.execute(
__import__("sqlalchemy").text('CREATE EXTENSION IF NOT EXISTS "pg_trgm";')
)
await conn.execute(
__import__("sqlalchemy").text('CREATE EXTENSION IF NOT EXISTS "vector";')
)
# Create tables
await conn.run_sync(Base.metadata.create_all)
async def drop_tables() -> None:
"""Drop all database tables (use with caution!)."""
global _engine
if _engine is None:
await init_db()
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@asynccontextmanager
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Get async database session.
Usage:
async with get_session() as session:
result = await session.execute(query)
await session.commit()
"""
global _session_factory
if _session_factory is None:
await init_db()
async with _session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def get_session_dependency() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency for database session.
Usage:
@app.get("/items")
async def get_items(session: AsyncSession = Depends(get_session_dependency)):
...
"""
async with get_session() as session:
yield session
async def close_db() -> None:
"""Close database engine and cleanup connections."""
global _engine, _session_factory
if _engine is not None:
await _engine.dispose()
_engine = None
_session_factory = None