db.py•11.7 kB
import asyncio
import os
from contextlib import asynccontextmanager
from enum import Enum, auto
from pathlib import Path
from typing import AsyncGenerator, Optional
from basic_memory.config import BasicMemoryConfig, ConfigManager
from alembic import command
from alembic.config import Config
from loguru import logger
from sqlalchemy import text, event
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
AsyncSession,
AsyncEngine,
async_scoped_session,
)
from sqlalchemy.pool import NullPool
from basic_memory.repository.search_repository import SearchRepository
# Module level state
_engine: Optional[AsyncEngine] = None
_session_maker: Optional[async_sessionmaker[AsyncSession]] = None
_migrations_completed: bool = False
class DatabaseType(Enum):
"""Types of supported databases."""
MEMORY = auto()
FILESYSTEM = auto()
@classmethod
def get_db_url(cls, db_path: Path, db_type: "DatabaseType") -> str:
"""Get SQLAlchemy URL for database path."""
if db_type == cls.MEMORY:
logger.info("Using in-memory SQLite database")
return "sqlite+aiosqlite://"
return f"sqlite+aiosqlite:///{db_path}" # pragma: no cover
def get_scoped_session_factory(
session_maker: async_sessionmaker[AsyncSession],
) -> async_scoped_session:
"""Create a scoped session factory scoped to current task."""
return async_scoped_session(session_maker, scopefunc=asyncio.current_task)
@asynccontextmanager
async def scoped_session(
session_maker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[AsyncSession, None]:
"""
Get a scoped session with proper lifecycle management.
Args:
session_maker: Session maker to create scoped sessions from
"""
factory = get_scoped_session_factory(session_maker)
session = factory()
try:
await session.execute(text("PRAGMA foreign_keys=ON"))
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
await factory.remove()
def _configure_sqlite_connection(dbapi_conn, enable_wal: bool = True) -> None:
"""Configure SQLite connection with WAL mode and optimizations.
Args:
dbapi_conn: Database API connection object
enable_wal: Whether to enable WAL mode (should be False for in-memory databases)
"""
cursor = dbapi_conn.cursor()
try:
# Enable WAL mode for better concurrency (not supported for in-memory databases)
if enable_wal:
cursor.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to handle locked databases
cursor.execute("PRAGMA busy_timeout=10000") # 10 seconds
# Optimize for performance
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA cache_size=-64000") # 64MB cache
cursor.execute("PRAGMA temp_store=MEMORY")
# Windows-specific optimizations
if os.name == "nt":
cursor.execute("PRAGMA locking_mode=NORMAL") # Ensure normal locking on Windows
except Exception as e:
# Log but don't fail - some PRAGMAs may not be supported
logger.warning(f"Failed to configure SQLite connection: {e}")
finally:
cursor.close()
def _create_engine_and_session(
db_path: Path, db_type: DatabaseType = DatabaseType.FILESYSTEM
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
"""Internal helper to create engine and session maker."""
db_url = DatabaseType.get_db_url(db_path, db_type)
logger.debug(f"Creating engine for db_url: {db_url}")
# Configure connection args with Windows-specific settings
connect_args: dict[str, bool | float | None] = {"check_same_thread": False}
# Add Windows-specific parameters to improve reliability
if os.name == "nt": # Windows
connect_args.update(
{
"timeout": 30.0, # Increase timeout to 30 seconds for Windows
"isolation_level": None, # Use autocommit mode
}
)
# Use NullPool for Windows filesystem databases to avoid connection pooling issues
# Important: Do NOT use NullPool for in-memory databases as it will destroy the database
# between connections
if db_type == DatabaseType.FILESYSTEM:
engine = create_async_engine(
db_url,
connect_args=connect_args,
poolclass=NullPool, # Disable connection pooling on Windows
echo=False,
)
else:
# In-memory databases need connection pooling to maintain state
engine = create_async_engine(db_url, connect_args=connect_args)
else:
engine = create_async_engine(db_url, connect_args=connect_args)
# Enable WAL mode for better concurrency and reliability
# Note: WAL mode is not supported for in-memory databases
enable_wal = db_type != DatabaseType.MEMORY
@event.listens_for(engine.sync_engine, "connect")
def enable_wal_mode(dbapi_conn, connection_record):
"""Enable WAL mode on each connection."""
_configure_sqlite_connection(dbapi_conn, enable_wal=enable_wal)
session_maker = async_sessionmaker(engine, expire_on_commit=False)
return engine, session_maker
async def get_or_create_db(
db_path: Path,
db_type: DatabaseType = DatabaseType.FILESYSTEM,
ensure_migrations: bool = True,
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: # pragma: no cover
"""Get or create database engine and session maker."""
global _engine, _session_maker
if _engine is None:
_engine, _session_maker = _create_engine_and_session(db_path, db_type)
# Run migrations automatically unless explicitly disabled
if ensure_migrations:
app_config = ConfigManager().config
await run_migrations(app_config, db_type)
# These checks should never fail since we just created the engine and session maker
# if they were None, but we'll check anyway for the type checker
if _engine is None:
logger.error("Failed to create database engine", db_path=str(db_path))
raise RuntimeError("Database engine initialization failed")
if _session_maker is None:
logger.error("Failed to create session maker", db_path=str(db_path))
raise RuntimeError("Session maker initialization failed")
return _engine, _session_maker
async def shutdown_db() -> None: # pragma: no cover
"""Clean up database connections."""
global _engine, _session_maker, _migrations_completed
if _engine:
await _engine.dispose()
_engine = None
_session_maker = None
_migrations_completed = False
@asynccontextmanager
async def engine_session_factory(
db_path: Path,
db_type: DatabaseType = DatabaseType.MEMORY,
) -> AsyncGenerator[tuple[AsyncEngine, async_sessionmaker[AsyncSession]], None]:
"""Create engine and session factory.
Note: This is primarily used for testing where we want a fresh database
for each test. For production use, use get_or_create_db() instead.
"""
global _engine, _session_maker, _migrations_completed
db_url = DatabaseType.get_db_url(db_path, db_type)
logger.debug(f"Creating engine for db_url: {db_url}")
# Configure connection args with Windows-specific settings
connect_args: dict[str, bool | float | None] = {"check_same_thread": False}
# Add Windows-specific parameters to improve reliability
if os.name == "nt": # Windows
connect_args.update(
{
"timeout": 30.0, # Increase timeout to 30 seconds for Windows
"isolation_level": None, # Use autocommit mode
}
)
# Use NullPool for Windows filesystem databases to avoid connection pooling issues
# Important: Do NOT use NullPool for in-memory databases as it will destroy the database
# between connections
if db_type == DatabaseType.FILESYSTEM:
_engine = create_async_engine(
db_url,
connect_args=connect_args,
poolclass=NullPool, # Disable connection pooling on Windows
echo=False,
)
else:
# In-memory databases need connection pooling to maintain state
_engine = create_async_engine(db_url, connect_args=connect_args)
else:
_engine = create_async_engine(db_url, connect_args=connect_args)
# Enable WAL mode for better concurrency and reliability
# Note: WAL mode is not supported for in-memory databases
enable_wal = db_type != DatabaseType.MEMORY
@event.listens_for(_engine.sync_engine, "connect")
def enable_wal_mode(dbapi_conn, connection_record):
"""Enable WAL mode on each connection."""
_configure_sqlite_connection(dbapi_conn, enable_wal=enable_wal)
try:
_session_maker = async_sessionmaker(_engine, expire_on_commit=False)
# Verify that engine and session maker are initialized
if _engine is None: # pragma: no cover
logger.error("Database engine is None in engine_session_factory")
raise RuntimeError("Database engine initialization failed")
if _session_maker is None: # pragma: no cover
logger.error("Session maker is None in engine_session_factory")
raise RuntimeError("Session maker initialization failed")
yield _engine, _session_maker
finally:
if _engine:
await _engine.dispose()
_engine = None
_session_maker = None
_migrations_completed = False
async def run_migrations(
app_config: BasicMemoryConfig, database_type=DatabaseType.FILESYSTEM, force: bool = False
): # pragma: no cover
"""Run any pending alembic migrations."""
global _migrations_completed
# Skip if migrations already completed unless forced
if _migrations_completed and not force:
logger.debug("Migrations already completed in this session, skipping")
return
logger.info("Running database migrations...")
try:
# Get the absolute path to the alembic directory relative to this file
alembic_dir = Path(__file__).parent / "alembic"
config = Config()
# Set required Alembic config options programmatically
config.set_main_option("script_location", str(alembic_dir))
config.set_main_option(
"file_template",
"%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s",
)
config.set_main_option("timezone", "UTC")
config.set_main_option("revision_environment", "false")
config.set_main_option(
"sqlalchemy.url", DatabaseType.get_db_url(app_config.database_path, database_type)
)
command.upgrade(config, "head")
logger.info("Migrations completed successfully")
# Get session maker - ensure we don't trigger recursive migration calls
if _session_maker is None:
_, session_maker = _create_engine_and_session(app_config.database_path, database_type)
else:
session_maker = _session_maker
# initialize the search Index schema
# the project_id is not used for init_search_index, so we pass a dummy value
await SearchRepository(session_maker, 1).init_search_index()
# Mark migrations as completed
_migrations_completed = True
except Exception as e: # pragma: no cover
logger.error(f"Error running migrations: {e}")
raise