conftest.py•15.3 kB
"""Common test fixtures."""
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from textwrap import dedent
from typing import AsyncGenerator
import os
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
from basic_memory import db
from basic_memory.config import ProjectConfig, BasicMemoryConfig, ConfigManager
from basic_memory.db import DatabaseType
from basic_memory.markdown import EntityParser
from basic_memory.markdown.markdown_processor import MarkdownProcessor
from basic_memory.models import Base
from basic_memory.models.knowledge import Entity
from basic_memory.models.project import Project
from basic_memory.repository.entity_repository import EntityRepository
from basic_memory.repository.observation_repository import ObservationRepository
from basic_memory.repository.project_repository import ProjectRepository
from basic_memory.repository.relation_repository import RelationRepository
from basic_memory.repository.search_repository import SearchRepository
from basic_memory.schemas.base import Entity as EntitySchema
from basic_memory.services import (
    EntityService,
    ProjectService,
)
from basic_memory.services.directory_service import DirectoryService
from basic_memory.services.file_service import FileService
from basic_memory.services.link_resolver import LinkResolver
from basic_memory.services.search_service import SearchService
from basic_memory.sync.sync_service import SyncService
from basic_memory.sync.watch_service import WatchService
@pytest.fixture
def anyio_backend():
    return "asyncio"
@pytest.fixture
def project_root() -> Path:
    return Path(__file__).parent.parent
@pytest.fixture
def config_home(tmp_path, monkeypatch) -> Path:
    # Patch HOME environment variable for the duration of the test
    monkeypatch.setenv("HOME", str(tmp_path))
    # On Windows, also set USERPROFILE
    if os.name == "nt":
        monkeypatch.setenv("USERPROFILE", str(tmp_path))
    # Set BASIC_MEMORY_HOME to the test directory
    monkeypatch.setenv("BASIC_MEMORY_HOME", str(tmp_path / "basic-memory"))
    return tmp_path
@pytest.fixture(scope="function", autouse=True)
def app_config(config_home, tmp_path, monkeypatch) -> BasicMemoryConfig:
    """Create test app configuration."""
    # Create a basic config without depending on test_project to avoid circular dependency
    projects = {"test-project": str(config_home)}
    app_config = BasicMemoryConfig(
        env="test",
        projects=projects,
        default_project="test-project",
        update_permalinks_on_move=True,
    )
    return app_config
@pytest.fixture(autouse=True)
def config_manager(
    app_config: BasicMemoryConfig, project_config: ProjectConfig, config_home: Path, monkeypatch
) -> ConfigManager:
    # Invalidate config cache to ensure clean state for each test
    from basic_memory import config as config_module
    config_module._CONFIG_CACHE = None
    # Create a new ConfigManager that uses the test home directory
    config_manager = ConfigManager()
    # Update its paths to use the test directory
    config_manager.config_dir = config_home / ".basic-memory"
    config_manager.config_file = config_manager.config_dir / "config.json"
    config_manager.config_dir.mkdir(parents=True, exist_ok=True)
    # Ensure the config file is written to disk
    config_manager.save_config(app_config)
    return config_manager
@pytest.fixture(scope="function", autouse=True)
def project_config(test_project):
    """Create test project configuration."""
    project_config = ProjectConfig(
        name=test_project.name,
        home=Path(test_project.path),
    )
    return project_config
@dataclass
class TestConfig:
    config_home: Path
    project_config: ProjectConfig
    app_config: BasicMemoryConfig
    config_manager: ConfigManager
@pytest.fixture
def test_config(config_home, project_config, app_config, config_manager) -> TestConfig:
    """All test configuration fixtures"""
    return TestConfig(config_home, project_config, app_config, config_manager)
@pytest_asyncio.fixture(scope="function")
async def engine_factory(
    app_config,
) -> AsyncGenerator[tuple[AsyncEngine, async_sessionmaker[AsyncSession]], None]:
    """Create an engine and session factory using an in-memory SQLite database."""
    async with db.engine_session_factory(
        db_path=app_config.database_path, db_type=DatabaseType.MEMORY
    ) as (engine, session_maker):
        # Create all tables for the DB the engine is connected to
        async with engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)
        yield engine, session_maker
@pytest_asyncio.fixture
async def session_maker(engine_factory) -> async_sessionmaker[AsyncSession]:
    """Get session maker for tests."""
    _, session_maker = engine_factory
    return session_maker
## Repositories
@pytest_asyncio.fixture(scope="function")
async def entity_repository(
    session_maker: async_sessionmaker[AsyncSession], test_project: Project
) -> EntityRepository:
    """Create an EntityRepository instance with project context."""
    return EntityRepository(session_maker, project_id=test_project.id)
@pytest_asyncio.fixture(scope="function")
async def observation_repository(
    session_maker: async_sessionmaker[AsyncSession], test_project: Project
) -> ObservationRepository:
    """Create an ObservationRepository instance with project context."""
    return ObservationRepository(session_maker, project_id=test_project.id)
@pytest_asyncio.fixture(scope="function")
async def relation_repository(
    session_maker: async_sessionmaker[AsyncSession], test_project: Project
) -> RelationRepository:
    """Create a RelationRepository instance with project context."""
    return RelationRepository(session_maker, project_id=test_project.id)
@pytest_asyncio.fixture(scope="function")
async def project_repository(
    session_maker: async_sessionmaker[AsyncSession],
) -> ProjectRepository:
    """Create a ProjectRepository instance."""
    return ProjectRepository(session_maker)
@pytest_asyncio.fixture(scope="function")
async def test_project(config_home, engine_factory) -> Project:
    """Create a test project to be used as context for other repositories."""
    project_data = {
        "name": "test-project",
        "description": "Project used as context for tests",
        "path": str(config_home),
        "is_active": True,
        "is_default": True,  # Explicitly set as the default project (for cli operations)
    }
    engine, session_maker = engine_factory
    project_repository = ProjectRepository(session_maker)
    project = await project_repository.create(project_data)
    return project
## Services
@pytest_asyncio.fixture
async def entity_service(
    entity_repository: EntityRepository,
    observation_repository: ObservationRepository,
    relation_repository: RelationRepository,
    entity_parser: EntityParser,
    file_service: FileService,
    link_resolver: LinkResolver,
    app_config: BasicMemoryConfig,
) -> EntityService:
    """Create EntityService."""
    return EntityService(
        entity_parser=entity_parser,
        entity_repository=entity_repository,
        observation_repository=observation_repository,
        relation_repository=relation_repository,
        file_service=file_service,
        link_resolver=link_resolver,
        app_config=app_config,
    )
@pytest.fixture
def file_service(
    project_config: ProjectConfig, markdown_processor: MarkdownProcessor
) -> FileService:
    """Create FileService instance."""
    return FileService(project_config.home, markdown_processor)
@pytest.fixture
def markdown_processor(entity_parser: EntityParser) -> MarkdownProcessor:
    """Create writer instance."""
    return MarkdownProcessor(entity_parser)
@pytest.fixture
def link_resolver(entity_repository: EntityRepository, search_service: SearchService):
    """Create parser instance."""
    return LinkResolver(entity_repository, search_service)
@pytest.fixture
def entity_parser(project_config):
    """Create parser instance."""
    return EntityParser(project_config.home)
@pytest_asyncio.fixture
async def sync_service(
    app_config: BasicMemoryConfig,
    entity_service: EntityService,
    entity_parser: EntityParser,
    project_repository: ProjectRepository,
    entity_repository: EntityRepository,
    relation_repository: RelationRepository,
    search_service: SearchService,
    file_service: FileService,
) -> SyncService:
    """Create sync service for testing."""
    return SyncService(
        app_config=app_config,
        entity_service=entity_service,
        project_repository=project_repository,
        entity_repository=entity_repository,
        relation_repository=relation_repository,
        entity_parser=entity_parser,
        search_service=search_service,
        file_service=file_service,
    )
@pytest_asyncio.fixture
async def directory_service(entity_repository, project_config) -> DirectoryService:
    """Create directory service for testing."""
    return DirectoryService(
        entity_repository=entity_repository,
    )
@pytest_asyncio.fixture
async def search_repository(session_maker, test_project: Project):
    """Create SearchRepository instance with project context"""
    return SearchRepository(session_maker, project_id=test_project.id)
@pytest_asyncio.fixture(autouse=True)
async def init_search_index(search_service):
    await search_service.init_search_index()
@pytest_asyncio.fixture
async def search_service(
    search_repository: SearchRepository,
    entity_repository: EntityRepository,
    file_service: FileService,
) -> SearchService:
    """Create and initialize search service"""
    service = SearchService(search_repository, entity_repository, file_service)
    await service.init_search_index()
    return service
@pytest_asyncio.fixture(scope="function")
async def sample_entity(entity_repository: EntityRepository) -> Entity:
    """Create a sample entity for testing."""
    entity_data = {
        "project_id": entity_repository.project_id,
        "title": "Test Entity",
        "entity_type": "test",
        "permalink": "test/test-entity",
        "file_path": "test/test_entity.md",
        "content_type": "text/markdown",
        "created_at": datetime.now(timezone.utc),
        "updated_at": datetime.now(timezone.utc),
    }
    return await entity_repository.create(entity_data)
@pytest_asyncio.fixture
async def project_service(
    project_repository: ProjectRepository,
) -> ProjectService:
    """Create ProjectService with repository."""
    return ProjectService(repository=project_repository)
@pytest_asyncio.fixture
async def full_entity(sample_entity, entity_repository, file_service, entity_service) -> Entity:
    """Create a search test entity."""
    # Create test entity
    entity, created = await entity_service.create_or_update_entity(
        EntitySchema(
            title="Search_Entity",
            folder="test",
            entity_type="test",
            content=dedent("""
                ## Observations
                - [tech] Tech note
                - [design] Design note
                ## Relations
                - out1 [[Test Entity]]
                - out2 [[Test Entity]]
                """),
        )
    )
    return entity
@pytest_asyncio.fixture
async def test_graph(
    entity_repository,
    relation_repository,
    observation_repository,
    search_service,
    file_service,
    entity_service,
):
    """Create a test knowledge graph with entities, relations and observations."""
    # Create some test entities in reverse order so they will be linked
    deeper, _ = await entity_service.create_or_update_entity(
        EntitySchema(
            title="Deeper Entity",
            entity_type="deeper",
            folder="test",
            content=dedent("""
                # Deeper Entity
                """),
        )
    )
    deep, _ = await entity_service.create_or_update_entity(
        EntitySchema(
            title="Deep Entity",
            entity_type="deep",
            folder="test",
            content=dedent("""
                # Deep Entity
                - deeper_connection [[Deeper Entity]]
                """),
        )
    )
    connected_2, _ = await entity_service.create_or_update_entity(
        EntitySchema(
            title="Connected Entity 2",
            entity_type="test",
            folder="test",
            content=dedent("""
                # Connected Entity 2
                - deep_connection [[Deep Entity]]
                """),
        )
    )
    connected_1, _ = await entity_service.create_or_update_entity(
        EntitySchema(
            title="Connected Entity 1",
            entity_type="test",
            folder="test",
            content=dedent("""
                # Connected Entity 1
                - [note] Connected 1 note
                - connected_to [[Connected Entity 2]]
                """),
        )
    )
    root, _ = await entity_service.create_or_update_entity(
        EntitySchema(
            title="Root",
            entity_type="test",
            folder="test",
            content=dedent("""
                # Root Entity
                - [note] Root note 1
                - [tech] Root tech note
                - connects_to [[Connected Entity 1]]
                """),
        )
    )
    # get latest
    entities = await entity_repository.find_all()
    relations = await relation_repository.find_all()
    # Index everything for search
    for entity in entities:
        await search_service.index_entity(entity)
    return {
        "root": root,
        "connected1": connected_1,
        "connected2": connected_2,
        "deep": deep,
        "observations": [e.observations for e in entities],
        "relations": relations,
    }
@pytest.fixture
def watch_service(app_config: BasicMemoryConfig, project_repository) -> WatchService:
    return WatchService(app_config=app_config, project_repository=project_repository)
@pytest.fixture
def test_files(project_config, project_root) -> dict[str, Path]:
    """Copy test files into the project directory.
    Returns a dict mapping file names to their paths in the project dir.
    """
    # Source files relative to tests directory
    source_files = {
        "pdf": Path(project_root / "tests/Non-MarkdownFileSupport.pdf"),
        "image": Path(project_root / "tests/Screenshot.png"),
    }
    # Create copies in temp project directory
    project_files = {}
    for name, src_path in source_files.items():
        # Read source file
        content = src_path.read_bytes()
        # Create destination path and ensure parent dirs exist
        dest_path = project_config.home / src_path.name
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        # Write file
        dest_path.write_bytes(content)
        project_files[name] = dest_path
    return project_files
@pytest_asyncio.fixture
async def synced_files(sync_service, project_config, test_files):
    # Initial sync - should create forward reference
    await sync_service.sync(project_config.home)
    return test_files