test_entity_repository.py•29.5 kB
"""Tests for the EntityRepository."""
from datetime import datetime, timezone
import pytest
import pytest_asyncio
from sqlalchemy import select
from basic_memory import db
from basic_memory.models import Entity, Observation, Relation, Project
from basic_memory.repository.entity_repository import EntityRepository
from basic_memory.utils import generate_permalink
@pytest_asyncio.fixture
async def entity_with_observations(session_maker, sample_entity):
    """Create an entity with observations."""
    async with db.scoped_session(session_maker) as session:
        observations = [
            Observation(
                entity_id=sample_entity.id,
                content="First observation",
            ),
            Observation(
                entity_id=sample_entity.id,
                content="Second observation",
            ),
        ]
        session.add_all(observations)
        return sample_entity
@pytest_asyncio.fixture
async def related_results(session_maker, test_project: Project):
    """Create entities with relations between them."""
    async with db.scoped_session(session_maker) as session:
        source = Entity(
            project_id=test_project.id,
            title="source",
            entity_type="test",
            permalink="source/source",
            file_path="source/source.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        target = Entity(
            project_id=test_project.id,
            title="target",
            entity_type="test",
            permalink="target/target",
            file_path="target/target.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        session.add(source)
        session.add(target)
        await session.flush()
        relation = Relation(
            from_id=source.id,
            to_id=target.id,
            to_name=target.title,
            relation_type="connects_to",
        )
        session.add(relation)
        return source, target, relation
@pytest.mark.asyncio
async def test_create_entity(entity_repository: EntityRepository):
    """Test creating a new entity"""
    entity_data = {
        "project_id": entity_repository.project_id,
        "title": "Test",
        "entity_type": "test",
        "permalink": "test/test",
        "file_path": "test/test.md",
        "content_type": "text/markdown",
        "created_at": datetime.now(timezone.utc),
        "updated_at": datetime.now(timezone.utc),
    }
    entity = await entity_repository.create(entity_data)
    # Verify returned object
    assert entity.id is not None
    assert entity.title == "Test"
    assert isinstance(entity.created_at, datetime)
    assert isinstance(entity.updated_at, datetime)
    # Verify in database
    found = await entity_repository.find_by_id(entity.id)
    assert found is not None
    assert found.id is not None
    assert found.id == entity.id
    assert found.title == entity.title
    # assert relations are eagerly loaded
    assert len(entity.observations) == 0
    assert len(entity.relations) == 0
@pytest.mark.asyncio
async def test_create_all(entity_repository: EntityRepository):
    """Test creating a new entity"""
    entity_data = [
        {
            "project_id": entity_repository.project_id,
            "title": "Test_1",
            "entity_type": "test",
            "permalink": "test/test-1",
            "file_path": "test/test_1.md",
            "content_type": "text/markdown",
            "created_at": datetime.now(timezone.utc),
            "updated_at": datetime.now(timezone.utc),
        },
        {
            "project_id": entity_repository.project_id,
            "title": "Test-2",
            "entity_type": "test",
            "permalink": "test/test-2",
            "file_path": "test/test_2.md",
            "content_type": "text/markdown",
            "created_at": datetime.now(timezone.utc),
            "updated_at": datetime.now(timezone.utc),
        },
    ]
    entities = await entity_repository.create_all(entity_data)
    assert len(entities) == 2
    entity = entities[0]
    # Verify in database
    found = await entity_repository.find_by_id(entity.id)
    assert found is not None
    assert found.id is not None
    assert found.id == entity.id
    assert found.title == entity.title
    # assert relations are eagerly loaded
    assert len(entity.observations) == 0
    assert len(entity.relations) == 0
@pytest.mark.asyncio
async def test_find_by_id(entity_repository: EntityRepository, sample_entity: Entity):
    """Test finding an entity by ID"""
    found = await entity_repository.find_by_id(sample_entity.id)
    assert found is not None
    assert found.id == sample_entity.id
    assert found.title == sample_entity.title
    # Verify against direct database query
    async with db.scoped_session(entity_repository.session_maker) as session:
        stmt = select(Entity).where(Entity.id == sample_entity.id)
        result = await session.execute(stmt)
        db_entity = result.scalar_one()
        assert db_entity.id == found.id
        assert db_entity.title == found.title
@pytest.mark.asyncio
async def test_update_entity(entity_repository: EntityRepository, sample_entity: Entity):
    """Test updating an entity"""
    updated = await entity_repository.update(sample_entity.id, {"title": "Updated title"})
    assert updated is not None
    assert updated.title == "Updated title"
    # Verify in database
    async with db.scoped_session(entity_repository.session_maker) as session:
        stmt = select(Entity).where(Entity.id == sample_entity.id)
        result = await session.execute(stmt)
        db_entity = result.scalar_one()
        assert db_entity.title == "Updated title"
@pytest.mark.asyncio
async def test_delete_entity(entity_repository: EntityRepository, sample_entity):
    """Test deleting an entity."""
    result = await entity_repository.delete(sample_entity.id)
    assert result is True
    # Verify deletion
    deleted = await entity_repository.find_by_id(sample_entity.id)
    assert deleted is None
@pytest.mark.asyncio
async def test_delete_entity_with_observations(
    entity_repository: EntityRepository, entity_with_observations
):
    """Test deleting an entity cascades to its observations."""
    entity = entity_with_observations
    result = await entity_repository.delete(entity.id)
    assert result is True
    # Verify entity deletion
    deleted = await entity_repository.find_by_id(entity.id)
    assert deleted is None
    # Verify observations were cascaded
    async with db.scoped_session(entity_repository.session_maker) as session:
        query = select(Observation).filter(Observation.entity_id == entity.id)
        result = await session.execute(query)
        remaining_observations = result.scalars().all()
        assert len(remaining_observations) == 0
@pytest.mark.asyncio
async def test_delete_entities_by_type(entity_repository: EntityRepository, sample_entity):
    """Test deleting entities by type."""
    result = await entity_repository.delete_by_fields(entity_type=sample_entity.entity_type)
    assert result is True
    # Verify deletion
    async with db.scoped_session(entity_repository.session_maker) as session:
        query = select(Entity).filter(Entity.entity_type == sample_entity.entity_type)
        result = await session.execute(query)
        remaining = result.scalars().all()
        assert len(remaining) == 0
@pytest.mark.asyncio
async def test_delete_entity_with_relations(entity_repository: EntityRepository, related_results):
    """Test deleting an entity cascades to its relations."""
    source, target, relation = related_results
    # Delete source entity
    result = await entity_repository.delete(source.id)
    assert result is True
    # Verify relation was cascaded
    async with db.scoped_session(entity_repository.session_maker) as session:
        query = select(Relation).filter(Relation.from_id == source.id)
        result = await session.execute(query)
        remaining_relations = result.scalars().all()
        assert len(remaining_relations) == 0
        # Verify target entity still exists
        target_exists = await entity_repository.find_by_id(target.id)
        assert target_exists is not None
@pytest.mark.asyncio
async def test_delete_nonexistent_entity(entity_repository: EntityRepository):
    """Test deleting an entity that doesn't exist."""
    result = await entity_repository.delete(0)
    assert result is False
@pytest_asyncio.fixture
async def test_entities(session_maker, test_project: Project):
    """Create multiple test entities."""
    async with db.scoped_session(session_maker) as session:
        entities = [
            Entity(
                project_id=test_project.id,
                title="entity1",
                entity_type="test",
                permalink="type1/entity1",
                file_path="type1/entity1.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=test_project.id,
                title="entity2",
                entity_type="test",
                permalink="type1/entity2",
                file_path="type1/entity2.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=test_project.id,
                title="entity3",
                entity_type="test",
                permalink="type2/entity3",
                file_path="type2/entity3.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
        ]
        session.add_all(entities)
        return entities
@pytest.mark.asyncio
async def test_find_by_permalinks(entity_repository: EntityRepository, test_entities):
    """Test finding multiple entities by their type/name pairs."""
    # Test finding multiple entities
    permalinks = [e.permalink for e in test_entities]
    found = await entity_repository.find_by_permalinks(permalinks)
    assert len(found) == 3
    names = {e.title for e in found}
    assert names == {"entity1", "entity2", "entity3"}
    # Test finding subset of entities
    permalinks = [e.permalink for e in test_entities if e.title != "entity2"]
    found = await entity_repository.find_by_permalinks(permalinks)
    assert len(found) == 2
    names = {e.title for e in found}
    assert names == {"entity1", "entity3"}
    # Test with non-existent entities
    permalinks = ["type1/entity1", "type3/nonexistent"]
    found = await entity_repository.find_by_permalinks(permalinks)
    assert len(found) == 1
    assert found[0].title == "entity1"
    # Test empty input
    found = await entity_repository.find_by_permalinks([])
    assert len(found) == 0
@pytest.mark.asyncio
async def test_generate_permalink_from_file_path():
    """Test permalink generation from different file paths."""
    test_cases = [
        ("docs/My Feature.md", "docs/my-feature"),
        ("specs/API (v2).md", "specs/api-v2"),
        ("notes/2024/Q1 Planning!!!.md", "notes/2024/q1-planning"),
        ("test/Über File.md", "test/uber-file"),
        ("docs/my_feature_name.md", "docs/my-feature-name"),
        ("specs/multiple--dashes.md", "specs/multiple-dashes"),
        ("notes/trailing/space/ file.md", "notes/trailing/space/file"),
    ]
    for input_path, expected in test_cases:
        result = generate_permalink(input_path)
        assert result == expected, f"Failed for {input_path}"
        # Verify the result passes validation
        Entity(
            title="test",
            entity_type="test",
            permalink=result,
            file_path=input_path,
            content_type="text/markdown",
        )  # This will raise ValueError if invalid
@pytest.mark.asyncio
async def test_get_by_title(entity_repository: EntityRepository, session_maker):
    """Test getting an entity by title."""
    # Create test entities
    async with db.scoped_session(session_maker) as session:
        entities = [
            Entity(
                project_id=entity_repository.project_id,
                title="Unique Title",
                entity_type="test",
                permalink="test/unique-title",
                file_path="test/unique-title.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="Another Title",
                entity_type="test",
                permalink="test/another-title",
                file_path="test/another-title.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="Another Title",
                entity_type="test",
                permalink="test/another-title-1",
                file_path="test/another-title-1.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
        ]
        session.add_all(entities)
        await session.flush()
    # Test getting by exact title
    found = await entity_repository.get_by_title("Unique Title")
    assert found is not None
    assert len(found) == 1
    assert found[0].title == "Unique Title"
    # Test case sensitivity
    found = await entity_repository.get_by_title("unique title")
    assert not found  # Should be case-sensitive
    # Test non-existent title
    found = await entity_repository.get_by_title("Non Existent")
    assert not found
    # Test multiple rows found
    found = await entity_repository.get_by_title("Another Title")
    assert len(found) == 2
@pytest.mark.asyncio
async def test_get_by_file_path(entity_repository: EntityRepository, session_maker):
    """Test getting an entity by title."""
    # Create test entities
    async with db.scoped_session(session_maker) as session:
        entities = [
            Entity(
                project_id=entity_repository.project_id,
                title="Unique Title",
                entity_type="test",
                permalink="test/unique-title",
                file_path="test/unique-title.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
        ]
        session.add_all(entities)
        await session.flush()
    # Test getting by file_path
    found = await entity_repository.get_by_file_path("test/unique-title.md")
    assert found is not None
    assert found.title == "Unique Title"
    # Test non-existent file_path
    found = await entity_repository.get_by_file_path("not/a/real/file.md")
    assert found is None
@pytest.mark.asyncio
async def test_get_distinct_directories(entity_repository: EntityRepository, session_maker):
    """Test getting distinct directory paths from entity file paths."""
    # Create test entities with various directory structures
    async with db.scoped_session(session_maker) as session:
        entities = [
            Entity(
                project_id=entity_repository.project_id,
                title="File 1",
                entity_type="test",
                permalink="docs/guides/file1",
                file_path="docs/guides/file1.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 2",
                entity_type="test",
                permalink="docs/guides/file2",
                file_path="docs/guides/file2.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 3",
                entity_type="test",
                permalink="docs/api/file3",
                file_path="docs/api/file3.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 4",
                entity_type="test",
                permalink="specs/file4",
                file_path="specs/file4.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 5",
                entity_type="test",
                permalink="notes/2024/q1/file5",
                file_path="notes/2024/q1/file5.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
        ]
        session.add_all(entities)
        await session.flush()
    # Get distinct directories
    directories = await entity_repository.get_distinct_directories()
    # Verify directories are extracted correctly
    assert isinstance(directories, list)
    assert len(directories) > 0
    # Should include all parent directories but not filenames
    expected_dirs = {
        "docs",
        "docs/guides",
        "docs/api",
        "notes",
        "notes/2024",
        "notes/2024/q1",
        "specs",
    }
    assert set(directories) == expected_dirs
    # Verify results are sorted
    assert directories == sorted(directories)
    # Verify no file paths are included
    for dir_path in directories:
        assert not dir_path.endswith(".md")
@pytest.mark.asyncio
async def test_get_distinct_directories_empty_db(entity_repository: EntityRepository):
    """Test getting distinct directories when database is empty."""
    directories = await entity_repository.get_distinct_directories()
    assert directories == []
@pytest.mark.asyncio
async def test_find_by_directory_prefix(entity_repository: EntityRepository, session_maker):
    """Test finding entities by directory prefix."""
    # Create test entities in various directories
    async with db.scoped_session(session_maker) as session:
        entities = [
            Entity(
                project_id=entity_repository.project_id,
                title="File 1",
                entity_type="test",
                permalink="docs/file1",
                file_path="docs/file1.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 2",
                entity_type="test",
                permalink="docs/guides/file2",
                file_path="docs/guides/file2.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 3",
                entity_type="test",
                permalink="docs/api/file3",
                file_path="docs/api/file3.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 4",
                entity_type="test",
                permalink="specs/file4",
                file_path="specs/file4.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
        ]
        session.add_all(entities)
        await session.flush()
    # Test finding all entities in "docs" directory and subdirectories
    docs_entities = await entity_repository.find_by_directory_prefix("docs")
    assert len(docs_entities) == 3
    file_paths = {e.file_path for e in docs_entities}
    assert file_paths == {"docs/file1.md", "docs/guides/file2.md", "docs/api/file3.md"}
    # Test finding entities in "docs/guides" subdirectory
    guides_entities = await entity_repository.find_by_directory_prefix("docs/guides")
    assert len(guides_entities) == 1
    assert guides_entities[0].file_path == "docs/guides/file2.md"
    # Test finding entities in "specs" directory
    specs_entities = await entity_repository.find_by_directory_prefix("specs")
    assert len(specs_entities) == 1
    assert specs_entities[0].file_path == "specs/file4.md"
    # Test with root directory (empty string)
    all_entities = await entity_repository.find_by_directory_prefix("")
    assert len(all_entities) == 4
    # Test with root directory (slash)
    all_entities = await entity_repository.find_by_directory_prefix("/")
    assert len(all_entities) == 4
    # Test with non-existent directory
    nonexistent = await entity_repository.find_by_directory_prefix("nonexistent")
    assert len(nonexistent) == 0
@pytest.mark.asyncio
async def test_find_by_directory_prefix_basic_fields_only(
    entity_repository: EntityRepository, session_maker
):
    """Test that find_by_directory_prefix returns basic entity fields.
    Note: This method uses use_query_options=False for performance,
    so it doesn't eager load relationships. Directory trees only need
    basic entity fields.
    """
    # Create test entity
    async with db.scoped_session(session_maker) as session:
        entity = Entity(
            project_id=entity_repository.project_id,
            title="Test Entity",
            entity_type="test",
            permalink="docs/test",
            file_path="docs/test.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        session.add(entity)
        await session.flush()
    # Query entity by directory prefix
    entities = await entity_repository.find_by_directory_prefix("docs")
    assert len(entities) == 1
    # Verify basic fields are present (all we need for directory trees)
    entity = entities[0]
    assert entity.title == "Test Entity"
    assert entity.file_path == "docs/test.md"
    assert entity.permalink == "docs/test"
    assert entity.entity_type == "test"
    assert entity.content_type == "text/markdown"
    assert entity.updated_at is not None
@pytest.mark.asyncio
async def test_get_all_file_paths(entity_repository: EntityRepository, session_maker):
    """Test getting all file paths for deletion detection during sync."""
    # Create test entities with various file paths
    async with db.scoped_session(session_maker) as session:
        entities = [
            Entity(
                project_id=entity_repository.project_id,
                title="File 1",
                entity_type="test",
                permalink="docs/file1",
                file_path="docs/file1.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 2",
                entity_type="test",
                permalink="specs/file2",
                file_path="specs/file2.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
            Entity(
                project_id=entity_repository.project_id,
                title="File 3",
                entity_type="test",
                permalink="notes/file3",
                file_path="notes/file3.md",
                content_type="text/markdown",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            ),
        ]
        session.add_all(entities)
        await session.flush()
    # Get all file paths
    file_paths = await entity_repository.get_all_file_paths()
    # Verify results
    assert isinstance(file_paths, list)
    assert len(file_paths) == 3
    assert set(file_paths) == {"docs/file1.md", "specs/file2.md", "notes/file3.md"}
@pytest.mark.asyncio
async def test_get_all_file_paths_empty_db(entity_repository: EntityRepository):
    """Test getting all file paths when database is empty."""
    file_paths = await entity_repository.get_all_file_paths()
    assert file_paths == []
@pytest.mark.asyncio
async def test_get_all_file_paths_performance(entity_repository: EntityRepository, session_maker):
    """Test that get_all_file_paths doesn't load entities or relationships.
    This method is optimized for deletion detection during streaming sync.
    It should only query file_path strings, not full entity objects.
    """
    # Create test entity with observations and relations
    async with db.scoped_session(session_maker) as session:
        # Create entities
        entity1 = Entity(
            project_id=entity_repository.project_id,
            title="Entity 1",
            entity_type="test",
            permalink="test/entity1",
            file_path="test/entity1.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        entity2 = Entity(
            project_id=entity_repository.project_id,
            title="Entity 2",
            entity_type="test",
            permalink="test/entity2",
            file_path="test/entity2.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        session.add_all([entity1, entity2])
        await session.flush()
        # Add observations to entity1
        observation = Observation(
            entity_id=entity1.id,
            content="Test observation",
            category="note",
        )
        session.add(observation)
        # Add relation between entities
        relation = Relation(
            from_id=entity1.id,
            to_id=entity2.id,
            to_name=entity2.title,
            relation_type="relates_to",
        )
        session.add(relation)
        await session.flush()
    # Get all file paths - should be fast and not load relationships
    file_paths = await entity_repository.get_all_file_paths()
    # Verify results - just file paths, no entities or relationships loaded
    assert len(file_paths) == 2
    assert set(file_paths) == {"test/entity1.md", "test/entity2.md"}
    # Result should be list of strings, not entity objects
    for path in file_paths:
        assert isinstance(path, str)
@pytest.mark.asyncio
async def test_get_all_file_paths_project_isolation(
    entity_repository: EntityRepository, session_maker
):
    """Test that get_all_file_paths only returns paths from the current project."""
    # Create entities in the repository's project
    async with db.scoped_session(session_maker) as session:
        entity1 = Entity(
            project_id=entity_repository.project_id,
            title="Project 1 File",
            entity_type="test",
            permalink="test/file1",
            file_path="test/file1.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        session.add(entity1)
        await session.flush()
        # Create a second project
        project2 = Project(name="other-project", path="/tmp/other")
        session.add(project2)
        await session.flush()
        # Create entity in different project
        entity2 = Entity(
            project_id=project2.id,
            title="Project 2 File",
            entity_type="test",
            permalink="test/file2",
            file_path="test/file2.md",
            content_type="text/markdown",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        session.add(entity2)
        await session.flush()
    # Get all file paths for project 1
    file_paths = await entity_repository.get_all_file_paths()
    # Should only include files from project 1
    assert len(file_paths) == 1
    assert file_paths == ["test/file1.md"]