Skip to main content
Glama

basic-memory

test_entity_repository.py29.5 kB
"""Tests for the EntityRepository.""" from datetime import datetime, timezone import pytest import pytest_asyncio from sqlalchemy import select from basic_memory import db from basic_memory.models import Entity, Observation, Relation, Project from basic_memory.repository.entity_repository import EntityRepository from basic_memory.utils import generate_permalink @pytest_asyncio.fixture async def entity_with_observations(session_maker, sample_entity): """Create an entity with observations.""" async with db.scoped_session(session_maker) as session: observations = [ Observation( entity_id=sample_entity.id, content="First observation", ), Observation( entity_id=sample_entity.id, content="Second observation", ), ] session.add_all(observations) return sample_entity @pytest_asyncio.fixture async def related_results(session_maker, test_project: Project): """Create entities with relations between them.""" async with db.scoped_session(session_maker) as session: source = Entity( project_id=test_project.id, title="source", entity_type="test", permalink="source/source", file_path="source/source.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) target = Entity( project_id=test_project.id, title="target", entity_type="test", permalink="target/target", file_path="target/target.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) session.add(source) session.add(target) await session.flush() relation = Relation( from_id=source.id, to_id=target.id, to_name=target.title, relation_type="connects_to", ) session.add(relation) return source, target, relation @pytest.mark.asyncio async def test_create_entity(entity_repository: EntityRepository): """Test creating a new entity""" entity_data = { "project_id": entity_repository.project_id, "title": "Test", "entity_type": "test", "permalink": "test/test", "file_path": "test/test.md", "content_type": "text/markdown", "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } entity = await entity_repository.create(entity_data) # Verify returned object assert entity.id is not None assert entity.title == "Test" assert isinstance(entity.created_at, datetime) assert isinstance(entity.updated_at, datetime) # Verify in database found = await entity_repository.find_by_id(entity.id) assert found is not None assert found.id is not None assert found.id == entity.id assert found.title == entity.title # assert relations are eagerly loaded assert len(entity.observations) == 0 assert len(entity.relations) == 0 @pytest.mark.asyncio async def test_create_all(entity_repository: EntityRepository): """Test creating a new entity""" entity_data = [ { "project_id": entity_repository.project_id, "title": "Test_1", "entity_type": "test", "permalink": "test/test-1", "file_path": "test/test_1.md", "content_type": "text/markdown", "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), }, { "project_id": entity_repository.project_id, "title": "Test-2", "entity_type": "test", "permalink": "test/test-2", "file_path": "test/test_2.md", "content_type": "text/markdown", "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), }, ] entities = await entity_repository.create_all(entity_data) assert len(entities) == 2 entity = entities[0] # Verify in database found = await entity_repository.find_by_id(entity.id) assert found is not None assert found.id is not None assert found.id == entity.id assert found.title == entity.title # assert relations are eagerly loaded assert len(entity.observations) == 0 assert len(entity.relations) == 0 @pytest.mark.asyncio async def test_find_by_id(entity_repository: EntityRepository, sample_entity: Entity): """Test finding an entity by ID""" found = await entity_repository.find_by_id(sample_entity.id) assert found is not None assert found.id == sample_entity.id assert found.title == sample_entity.title # Verify against direct database query async with db.scoped_session(entity_repository.session_maker) as session: stmt = select(Entity).where(Entity.id == sample_entity.id) result = await session.execute(stmt) db_entity = result.scalar_one() assert db_entity.id == found.id assert db_entity.title == found.title @pytest.mark.asyncio async def test_update_entity(entity_repository: EntityRepository, sample_entity: Entity): """Test updating an entity""" updated = await entity_repository.update(sample_entity.id, {"title": "Updated title"}) assert updated is not None assert updated.title == "Updated title" # Verify in database async with db.scoped_session(entity_repository.session_maker) as session: stmt = select(Entity).where(Entity.id == sample_entity.id) result = await session.execute(stmt) db_entity = result.scalar_one() assert db_entity.title == "Updated title" @pytest.mark.asyncio async def test_delete_entity(entity_repository: EntityRepository, sample_entity): """Test deleting an entity.""" result = await entity_repository.delete(sample_entity.id) assert result is True # Verify deletion deleted = await entity_repository.find_by_id(sample_entity.id) assert deleted is None @pytest.mark.asyncio async def test_delete_entity_with_observations( entity_repository: EntityRepository, entity_with_observations ): """Test deleting an entity cascades to its observations.""" entity = entity_with_observations result = await entity_repository.delete(entity.id) assert result is True # Verify entity deletion deleted = await entity_repository.find_by_id(entity.id) assert deleted is None # Verify observations were cascaded async with db.scoped_session(entity_repository.session_maker) as session: query = select(Observation).filter(Observation.entity_id == entity.id) result = await session.execute(query) remaining_observations = result.scalars().all() assert len(remaining_observations) == 0 @pytest.mark.asyncio async def test_delete_entities_by_type(entity_repository: EntityRepository, sample_entity): """Test deleting entities by type.""" result = await entity_repository.delete_by_fields(entity_type=sample_entity.entity_type) assert result is True # Verify deletion async with db.scoped_session(entity_repository.session_maker) as session: query = select(Entity).filter(Entity.entity_type == sample_entity.entity_type) result = await session.execute(query) remaining = result.scalars().all() assert len(remaining) == 0 @pytest.mark.asyncio async def test_delete_entity_with_relations(entity_repository: EntityRepository, related_results): """Test deleting an entity cascades to its relations.""" source, target, relation = related_results # Delete source entity result = await entity_repository.delete(source.id) assert result is True # Verify relation was cascaded async with db.scoped_session(entity_repository.session_maker) as session: query = select(Relation).filter(Relation.from_id == source.id) result = await session.execute(query) remaining_relations = result.scalars().all() assert len(remaining_relations) == 0 # Verify target entity still exists target_exists = await entity_repository.find_by_id(target.id) assert target_exists is not None @pytest.mark.asyncio async def test_delete_nonexistent_entity(entity_repository: EntityRepository): """Test deleting an entity that doesn't exist.""" result = await entity_repository.delete(0) assert result is False @pytest_asyncio.fixture async def test_entities(session_maker, test_project: Project): """Create multiple test entities.""" async with db.scoped_session(session_maker) as session: entities = [ Entity( project_id=test_project.id, title="entity1", entity_type="test", permalink="type1/entity1", file_path="type1/entity1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=test_project.id, title="entity2", entity_type="test", permalink="type1/entity2", file_path="type1/entity2.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=test_project.id, title="entity3", entity_type="test", permalink="type2/entity3", file_path="type2/entity3.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), ] session.add_all(entities) return entities @pytest.mark.asyncio async def test_find_by_permalinks(entity_repository: EntityRepository, test_entities): """Test finding multiple entities by their type/name pairs.""" # Test finding multiple entities permalinks = [e.permalink for e in test_entities] found = await entity_repository.find_by_permalinks(permalinks) assert len(found) == 3 names = {e.title for e in found} assert names == {"entity1", "entity2", "entity3"} # Test finding subset of entities permalinks = [e.permalink for e in test_entities if e.title != "entity2"] found = await entity_repository.find_by_permalinks(permalinks) assert len(found) == 2 names = {e.title for e in found} assert names == {"entity1", "entity3"} # Test with non-existent entities permalinks = ["type1/entity1", "type3/nonexistent"] found = await entity_repository.find_by_permalinks(permalinks) assert len(found) == 1 assert found[0].title == "entity1" # Test empty input found = await entity_repository.find_by_permalinks([]) assert len(found) == 0 @pytest.mark.asyncio async def test_generate_permalink_from_file_path(): """Test permalink generation from different file paths.""" test_cases = [ ("docs/My Feature.md", "docs/my-feature"), ("specs/API (v2).md", "specs/api-v2"), ("notes/2024/Q1 Planning!!!.md", "notes/2024/q1-planning"), ("test/Über File.md", "test/uber-file"), ("docs/my_feature_name.md", "docs/my-feature-name"), ("specs/multiple--dashes.md", "specs/multiple-dashes"), ("notes/trailing/space/ file.md", "notes/trailing/space/file"), ] for input_path, expected in test_cases: result = generate_permalink(input_path) assert result == expected, f"Failed for {input_path}" # Verify the result passes validation Entity( title="test", entity_type="test", permalink=result, file_path=input_path, content_type="text/markdown", ) # This will raise ValueError if invalid @pytest.mark.asyncio async def test_get_by_title(entity_repository: EntityRepository, session_maker): """Test getting an entity by title.""" # Create test entities async with db.scoped_session(session_maker) as session: entities = [ Entity( project_id=entity_repository.project_id, title="Unique Title", entity_type="test", permalink="test/unique-title", file_path="test/unique-title.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="Another Title", entity_type="test", permalink="test/another-title", file_path="test/another-title.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="Another Title", entity_type="test", permalink="test/another-title-1", file_path="test/another-title-1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), ] session.add_all(entities) await session.flush() # Test getting by exact title found = await entity_repository.get_by_title("Unique Title") assert found is not None assert len(found) == 1 assert found[0].title == "Unique Title" # Test case sensitivity found = await entity_repository.get_by_title("unique title") assert not found # Should be case-sensitive # Test non-existent title found = await entity_repository.get_by_title("Non Existent") assert not found # Test multiple rows found found = await entity_repository.get_by_title("Another Title") assert len(found) == 2 @pytest.mark.asyncio async def test_get_by_file_path(entity_repository: EntityRepository, session_maker): """Test getting an entity by title.""" # Create test entities async with db.scoped_session(session_maker) as session: entities = [ Entity( project_id=entity_repository.project_id, title="Unique Title", entity_type="test", permalink="test/unique-title", file_path="test/unique-title.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), ] session.add_all(entities) await session.flush() # Test getting by file_path found = await entity_repository.get_by_file_path("test/unique-title.md") assert found is not None assert found.title == "Unique Title" # Test non-existent file_path found = await entity_repository.get_by_file_path("not/a/real/file.md") assert found is None @pytest.mark.asyncio async def test_get_distinct_directories(entity_repository: EntityRepository, session_maker): """Test getting distinct directory paths from entity file paths.""" # Create test entities with various directory structures async with db.scoped_session(session_maker) as session: entities = [ Entity( project_id=entity_repository.project_id, title="File 1", entity_type="test", permalink="docs/guides/file1", file_path="docs/guides/file1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 2", entity_type="test", permalink="docs/guides/file2", file_path="docs/guides/file2.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 3", entity_type="test", permalink="docs/api/file3", file_path="docs/api/file3.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 4", entity_type="test", permalink="specs/file4", file_path="specs/file4.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 5", entity_type="test", permalink="notes/2024/q1/file5", file_path="notes/2024/q1/file5.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), ] session.add_all(entities) await session.flush() # Get distinct directories directories = await entity_repository.get_distinct_directories() # Verify directories are extracted correctly assert isinstance(directories, list) assert len(directories) > 0 # Should include all parent directories but not filenames expected_dirs = { "docs", "docs/guides", "docs/api", "notes", "notes/2024", "notes/2024/q1", "specs", } assert set(directories) == expected_dirs # Verify results are sorted assert directories == sorted(directories) # Verify no file paths are included for dir_path in directories: assert not dir_path.endswith(".md") @pytest.mark.asyncio async def test_get_distinct_directories_empty_db(entity_repository: EntityRepository): """Test getting distinct directories when database is empty.""" directories = await entity_repository.get_distinct_directories() assert directories == [] @pytest.mark.asyncio async def test_find_by_directory_prefix(entity_repository: EntityRepository, session_maker): """Test finding entities by directory prefix.""" # Create test entities in various directories async with db.scoped_session(session_maker) as session: entities = [ Entity( project_id=entity_repository.project_id, title="File 1", entity_type="test", permalink="docs/file1", file_path="docs/file1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 2", entity_type="test", permalink="docs/guides/file2", file_path="docs/guides/file2.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 3", entity_type="test", permalink="docs/api/file3", file_path="docs/api/file3.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 4", entity_type="test", permalink="specs/file4", file_path="specs/file4.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), ] session.add_all(entities) await session.flush() # Test finding all entities in "docs" directory and subdirectories docs_entities = await entity_repository.find_by_directory_prefix("docs") assert len(docs_entities) == 3 file_paths = {e.file_path for e in docs_entities} assert file_paths == {"docs/file1.md", "docs/guides/file2.md", "docs/api/file3.md"} # Test finding entities in "docs/guides" subdirectory guides_entities = await entity_repository.find_by_directory_prefix("docs/guides") assert len(guides_entities) == 1 assert guides_entities[0].file_path == "docs/guides/file2.md" # Test finding entities in "specs" directory specs_entities = await entity_repository.find_by_directory_prefix("specs") assert len(specs_entities) == 1 assert specs_entities[0].file_path == "specs/file4.md" # Test with root directory (empty string) all_entities = await entity_repository.find_by_directory_prefix("") assert len(all_entities) == 4 # Test with root directory (slash) all_entities = await entity_repository.find_by_directory_prefix("/") assert len(all_entities) == 4 # Test with non-existent directory nonexistent = await entity_repository.find_by_directory_prefix("nonexistent") assert len(nonexistent) == 0 @pytest.mark.asyncio async def test_find_by_directory_prefix_basic_fields_only( entity_repository: EntityRepository, session_maker ): """Test that find_by_directory_prefix returns basic entity fields. Note: This method uses use_query_options=False for performance, so it doesn't eager load relationships. Directory trees only need basic entity fields. """ # Create test entity async with db.scoped_session(session_maker) as session: entity = Entity( project_id=entity_repository.project_id, title="Test Entity", entity_type="test", permalink="docs/test", file_path="docs/test.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) session.add(entity) await session.flush() # Query entity by directory prefix entities = await entity_repository.find_by_directory_prefix("docs") assert len(entities) == 1 # Verify basic fields are present (all we need for directory trees) entity = entities[0] assert entity.title == "Test Entity" assert entity.file_path == "docs/test.md" assert entity.permalink == "docs/test" assert entity.entity_type == "test" assert entity.content_type == "text/markdown" assert entity.updated_at is not None @pytest.mark.asyncio async def test_get_all_file_paths(entity_repository: EntityRepository, session_maker): """Test getting all file paths for deletion detection during sync.""" # Create test entities with various file paths async with db.scoped_session(session_maker) as session: entities = [ Entity( project_id=entity_repository.project_id, title="File 1", entity_type="test", permalink="docs/file1", file_path="docs/file1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 2", entity_type="test", permalink="specs/file2", file_path="specs/file2.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), Entity( project_id=entity_repository.project_id, title="File 3", entity_type="test", permalink="notes/file3", file_path="notes/file3.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ), ] session.add_all(entities) await session.flush() # Get all file paths file_paths = await entity_repository.get_all_file_paths() # Verify results assert isinstance(file_paths, list) assert len(file_paths) == 3 assert set(file_paths) == {"docs/file1.md", "specs/file2.md", "notes/file3.md"} @pytest.mark.asyncio async def test_get_all_file_paths_empty_db(entity_repository: EntityRepository): """Test getting all file paths when database is empty.""" file_paths = await entity_repository.get_all_file_paths() assert file_paths == [] @pytest.mark.asyncio async def test_get_all_file_paths_performance(entity_repository: EntityRepository, session_maker): """Test that get_all_file_paths doesn't load entities or relationships. This method is optimized for deletion detection during streaming sync. It should only query file_path strings, not full entity objects. """ # Create test entity with observations and relations async with db.scoped_session(session_maker) as session: # Create entities entity1 = Entity( project_id=entity_repository.project_id, title="Entity 1", entity_type="test", permalink="test/entity1", file_path="test/entity1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) entity2 = Entity( project_id=entity_repository.project_id, title="Entity 2", entity_type="test", permalink="test/entity2", file_path="test/entity2.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) session.add_all([entity1, entity2]) await session.flush() # Add observations to entity1 observation = Observation( entity_id=entity1.id, content="Test observation", category="note", ) session.add(observation) # Add relation between entities relation = Relation( from_id=entity1.id, to_id=entity2.id, to_name=entity2.title, relation_type="relates_to", ) session.add(relation) await session.flush() # Get all file paths - should be fast and not load relationships file_paths = await entity_repository.get_all_file_paths() # Verify results - just file paths, no entities or relationships loaded assert len(file_paths) == 2 assert set(file_paths) == {"test/entity1.md", "test/entity2.md"} # Result should be list of strings, not entity objects for path in file_paths: assert isinstance(path, str) @pytest.mark.asyncio async def test_get_all_file_paths_project_isolation( entity_repository: EntityRepository, session_maker ): """Test that get_all_file_paths only returns paths from the current project.""" # Create entities in the repository's project async with db.scoped_session(session_maker) as session: entity1 = Entity( project_id=entity_repository.project_id, title="Project 1 File", entity_type="test", permalink="test/file1", file_path="test/file1.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) session.add(entity1) await session.flush() # Create a second project project2 = Project(name="other-project", path="/tmp/other") session.add(project2) await session.flush() # Create entity in different project entity2 = Entity( project_id=project2.id, title="Project 2 File", entity_type="test", permalink="test/file2", file_path="test/file2.md", content_type="text/markdown", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) session.add(entity2) await session.flush() # Get all file paths for project 1 file_paths = await entity_repository.get_all_file_paths() # Should only include files from project 1 assert len(file_paths) == 1 assert file_paths == ["test/file1.md"]

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/basicmachines-co/basic-memory'

If you have feedback or need assistance with the MCP directory API, please join our Discord server