db.py•11.7 kB
import asyncio
import os
from contextlib import asynccontextmanager
from enum import Enum, auto
from pathlib import Path
from typing import AsyncGenerator, Optional
from basic_memory.config import BasicMemoryConfig, ConfigManager
from alembic import command
from alembic.config import Config
from loguru import logger
from sqlalchemy import text, event
from sqlalchemy.ext.asyncio import (
    create_async_engine,
    async_sessionmaker,
    AsyncSession,
    AsyncEngine,
    async_scoped_session,
)
from sqlalchemy.pool import NullPool
from basic_memory.repository.search_repository import SearchRepository
# Module level state
_engine: Optional[AsyncEngine] = None
_session_maker: Optional[async_sessionmaker[AsyncSession]] = None
_migrations_completed: bool = False
class DatabaseType(Enum):
    """Types of supported databases."""
    MEMORY = auto()
    FILESYSTEM = auto()
    @classmethod
    def get_db_url(cls, db_path: Path, db_type: "DatabaseType") -> str:
        """Get SQLAlchemy URL for database path."""
        if db_type == cls.MEMORY:
            logger.info("Using in-memory SQLite database")
            return "sqlite+aiosqlite://"
        return f"sqlite+aiosqlite:///{db_path}"  # pragma: no cover
def get_scoped_session_factory(
    session_maker: async_sessionmaker[AsyncSession],
) -> async_scoped_session:
    """Create a scoped session factory scoped to current task."""
    return async_scoped_session(session_maker, scopefunc=asyncio.current_task)
@asynccontextmanager
async def scoped_session(
    session_maker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[AsyncSession, None]:
    """
    Get a scoped session with proper lifecycle management.
    Args:
        session_maker: Session maker to create scoped sessions from
    """
    factory = get_scoped_session_factory(session_maker)
    session = factory()
    try:
        await session.execute(text("PRAGMA foreign_keys=ON"))
        yield session
        await session.commit()
    except Exception:
        await session.rollback()
        raise
    finally:
        await session.close()
        await factory.remove()
def _configure_sqlite_connection(dbapi_conn, enable_wal: bool = True) -> None:
    """Configure SQLite connection with WAL mode and optimizations.
    Args:
        dbapi_conn: Database API connection object
        enable_wal: Whether to enable WAL mode (should be False for in-memory databases)
    """
    cursor = dbapi_conn.cursor()
    try:
        # Enable WAL mode for better concurrency (not supported for in-memory databases)
        if enable_wal:
            cursor.execute("PRAGMA journal_mode=WAL")
        # Set busy timeout to handle locked databases
        cursor.execute("PRAGMA busy_timeout=10000")  # 10 seconds
        # Optimize for performance
        cursor.execute("PRAGMA synchronous=NORMAL")
        cursor.execute("PRAGMA cache_size=-64000")  # 64MB cache
        cursor.execute("PRAGMA temp_store=MEMORY")
        # Windows-specific optimizations
        if os.name == "nt":
            cursor.execute("PRAGMA locking_mode=NORMAL")  # Ensure normal locking on Windows
    except Exception as e:
        # Log but don't fail - some PRAGMAs may not be supported
        logger.warning(f"Failed to configure SQLite connection: {e}")
    finally:
        cursor.close()
def _create_engine_and_session(
    db_path: Path, db_type: DatabaseType = DatabaseType.FILESYSTEM
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
    """Internal helper to create engine and session maker."""
    db_url = DatabaseType.get_db_url(db_path, db_type)
    logger.debug(f"Creating engine for db_url: {db_url}")
    # Configure connection args with Windows-specific settings
    connect_args: dict[str, bool | float | None] = {"check_same_thread": False}
    # Add Windows-specific parameters to improve reliability
    if os.name == "nt":  # Windows
        connect_args.update(
            {
                "timeout": 30.0,  # Increase timeout to 30 seconds for Windows
                "isolation_level": None,  # Use autocommit mode
            }
        )
        # Use NullPool for Windows filesystem databases to avoid connection pooling issues
        # Important: Do NOT use NullPool for in-memory databases as it will destroy the database
        # between connections
        if db_type == DatabaseType.FILESYSTEM:
            engine = create_async_engine(
                db_url,
                connect_args=connect_args,
                poolclass=NullPool,  # Disable connection pooling on Windows
                echo=False,
            )
        else:
            # In-memory databases need connection pooling to maintain state
            engine = create_async_engine(db_url, connect_args=connect_args)
    else:
        engine = create_async_engine(db_url, connect_args=connect_args)
    # Enable WAL mode for better concurrency and reliability
    # Note: WAL mode is not supported for in-memory databases
    enable_wal = db_type != DatabaseType.MEMORY
    @event.listens_for(engine.sync_engine, "connect")
    def enable_wal_mode(dbapi_conn, connection_record):
        """Enable WAL mode on each connection."""
        _configure_sqlite_connection(dbapi_conn, enable_wal=enable_wal)
    session_maker = async_sessionmaker(engine, expire_on_commit=False)
    return engine, session_maker
async def get_or_create_db(
    db_path: Path,
    db_type: DatabaseType = DatabaseType.FILESYSTEM,
    ensure_migrations: bool = True,
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:  # pragma: no cover
    """Get or create database engine and session maker."""
    global _engine, _session_maker
    if _engine is None:
        _engine, _session_maker = _create_engine_and_session(db_path, db_type)
        # Run migrations automatically unless explicitly disabled
        if ensure_migrations:
            app_config = ConfigManager().config
            await run_migrations(app_config, db_type)
    # These checks should never fail since we just created the engine and session maker
    # if they were None, but we'll check anyway for the type checker
    if _engine is None:
        logger.error("Failed to create database engine", db_path=str(db_path))
        raise RuntimeError("Database engine initialization failed")
    if _session_maker is None:
        logger.error("Failed to create session maker", db_path=str(db_path))
        raise RuntimeError("Session maker initialization failed")
    return _engine, _session_maker
async def shutdown_db() -> None:  # pragma: no cover
    """Clean up database connections."""
    global _engine, _session_maker, _migrations_completed
    if _engine:
        await _engine.dispose()
        _engine = None
        _session_maker = None
        _migrations_completed = False
@asynccontextmanager
async def engine_session_factory(
    db_path: Path,
    db_type: DatabaseType = DatabaseType.MEMORY,
) -> AsyncGenerator[tuple[AsyncEngine, async_sessionmaker[AsyncSession]], None]:
    """Create engine and session factory.
    Note: This is primarily used for testing where we want a fresh database
    for each test. For production use, use get_or_create_db() instead.
    """
    global _engine, _session_maker, _migrations_completed
    db_url = DatabaseType.get_db_url(db_path, db_type)
    logger.debug(f"Creating engine for db_url: {db_url}")
    # Configure connection args with Windows-specific settings
    connect_args: dict[str, bool | float | None] = {"check_same_thread": False}
    # Add Windows-specific parameters to improve reliability
    if os.name == "nt":  # Windows
        connect_args.update(
            {
                "timeout": 30.0,  # Increase timeout to 30 seconds for Windows
                "isolation_level": None,  # Use autocommit mode
            }
        )
        # Use NullPool for Windows filesystem databases to avoid connection pooling issues
        # Important: Do NOT use NullPool for in-memory databases as it will destroy the database
        # between connections
        if db_type == DatabaseType.FILESYSTEM:
            _engine = create_async_engine(
                db_url,
                connect_args=connect_args,
                poolclass=NullPool,  # Disable connection pooling on Windows
                echo=False,
            )
        else:
            # In-memory databases need connection pooling to maintain state
            _engine = create_async_engine(db_url, connect_args=connect_args)
    else:
        _engine = create_async_engine(db_url, connect_args=connect_args)
    # Enable WAL mode for better concurrency and reliability
    # Note: WAL mode is not supported for in-memory databases
    enable_wal = db_type != DatabaseType.MEMORY
    @event.listens_for(_engine.sync_engine, "connect")
    def enable_wal_mode(dbapi_conn, connection_record):
        """Enable WAL mode on each connection."""
        _configure_sqlite_connection(dbapi_conn, enable_wal=enable_wal)
    try:
        _session_maker = async_sessionmaker(_engine, expire_on_commit=False)
        # Verify that engine and session maker are initialized
        if _engine is None:  # pragma: no cover
            logger.error("Database engine is None in engine_session_factory")
            raise RuntimeError("Database engine initialization failed")
        if _session_maker is None:  # pragma: no cover
            logger.error("Session maker is None in engine_session_factory")
            raise RuntimeError("Session maker initialization failed")
        yield _engine, _session_maker
    finally:
        if _engine:
            await _engine.dispose()
            _engine = None
            _session_maker = None
            _migrations_completed = False
async def run_migrations(
    app_config: BasicMemoryConfig, database_type=DatabaseType.FILESYSTEM, force: bool = False
):  # pragma: no cover
    """Run any pending alembic migrations."""
    global _migrations_completed
    # Skip if migrations already completed unless forced
    if _migrations_completed and not force:
        logger.debug("Migrations already completed in this session, skipping")
        return
    logger.info("Running database migrations...")
    try:
        # Get the absolute path to the alembic directory relative to this file
        alembic_dir = Path(__file__).parent / "alembic"
        config = Config()
        # Set required Alembic config options programmatically
        config.set_main_option("script_location", str(alembic_dir))
        config.set_main_option(
            "file_template",
            "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s",
        )
        config.set_main_option("timezone", "UTC")
        config.set_main_option("revision_environment", "false")
        config.set_main_option(
            "sqlalchemy.url", DatabaseType.get_db_url(app_config.database_path, database_type)
        )
        command.upgrade(config, "head")
        logger.info("Migrations completed successfully")
        # Get session maker - ensure we don't trigger recursive migration calls
        if _session_maker is None:
            _, session_maker = _create_engine_and_session(app_config.database_path, database_type)
        else:
            session_maker = _session_maker
        # initialize the search Index schema
        # the project_id is not used for init_search_index, so we pass a dummy value
        await SearchRepository(session_maker, 1).init_search_index()
        # Mark migrations as completed
        _migrations_completed = True
    except Exception as e:  # pragma: no cover
        logger.error(f"Error running migrations: {e}")
        raise