test_relation_repository.py•12 kB
"""Tests for the RelationRepository."""
from datetime import datetime, timezone
import pytest
import pytest_asyncio
import sqlalchemy
from basic_memory import db
from basic_memory.models import Entity, Relation, Project
from basic_memory.repository.relation_repository import RelationRepository
@pytest_asyncio.fixture
async def source_entity(session_maker, test_project: Project):
"""Create a source entity for testing relations."""
entity = Entity(
project_id=test_project.id,
title="test_source",
entity_type="test",
permalink="source/test-source",
file_path="source/test_source.md",
content_type="text/markdown",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
async with db.scoped_session(session_maker) as session:
session.add(entity)
await session.flush()
return entity
@pytest_asyncio.fixture
async def target_entity(session_maker, test_project: Project):
"""Create a target entity for testing relations."""
entity = Entity(
project_id=test_project.id,
title="test_target",
entity_type="test",
permalink="target/test-target",
file_path="target/test_target.md",
content_type="text/markdown",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
async with db.scoped_session(session_maker) as session:
session.add(entity)
await session.flush()
return entity
@pytest_asyncio.fixture
async def test_relations(session_maker, source_entity, target_entity):
"""Create test relations."""
relations = [
Relation(
from_id=source_entity.id,
to_id=target_entity.id,
to_name=target_entity.title,
relation_type="connects_to",
),
Relation(
from_id=source_entity.id,
to_id=target_entity.id,
to_name=target_entity.title,
relation_type="depends_on",
),
]
async with db.scoped_session(session_maker) as session:
session.add_all(relations)
await session.flush()
return relations
@pytest_asyncio.fixture(scope="function")
async def related_entity(entity_repository):
"""Create a second entity for testing relations"""
entity_data = {
"title": "Related Entity",
"entity_type": "test",
"permalink": "test/related-entity",
"file_path": "test/related_entity.md",
"summary": "A related test entity",
"content_type": "text/markdown",
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
return await entity_repository.create(entity_data)
@pytest_asyncio.fixture(scope="function")
async def sample_relation(
relation_repository: RelationRepository, sample_entity: Entity, related_entity: Entity
):
"""Create a sample relation for testing"""
relation_data = {
"from_id": sample_entity.id,
"to_id": related_entity.id,
"to_name": related_entity.title,
"relation_type": "test_relation",
"context": "test-context",
}
return await relation_repository.create(relation_data)
@pytest_asyncio.fixture(scope="function")
async def multiple_relations(
relation_repository: RelationRepository, sample_entity: Entity, related_entity: Entity
):
"""Create multiple relations for testing"""
relations_data = [
{
"from_id": sample_entity.id,
"to_id": related_entity.id,
"to_name": related_entity.title,
"relation_type": "relation_one",
"context": "context_one",
},
{
"from_id": sample_entity.id,
"to_id": related_entity.id,
"to_name": related_entity.title,
"relation_type": "relation_two",
"context": "context_two",
},
{
"from_id": related_entity.id,
"to_id": sample_entity.id,
"to_name": related_entity.title,
"relation_type": "relation_one",
"context": "context_three",
},
]
return [await relation_repository.create(data) for data in relations_data]
@pytest.mark.asyncio
async def test_create_relation(
relation_repository: RelationRepository, sample_entity: Entity, related_entity: Entity
):
"""Test creating a new relation"""
relation_data = {
"from_id": sample_entity.id,
"to_id": related_entity.id,
"to_name": related_entity.title,
"relation_type": "test_relation",
"context": "test-context",
}
relation = await relation_repository.create(relation_data)
assert relation.from_id == sample_entity.id
assert relation.to_id == related_entity.id
assert relation.relation_type == "test_relation"
assert relation.id is not None # Should be auto-generated
@pytest.mark.asyncio
async def test_create_relation_entity_does_not_exist(
relation_repository: RelationRepository, sample_entity: Entity, related_entity: Entity
):
"""Test creating a new relation"""
relation_data = {
"from_id": "not_exist",
"to_id": related_entity.id,
"to_name": related_entity.title,
"relation_type": "test_relation",
"context": "test-context",
}
with pytest.raises(sqlalchemy.exc.IntegrityError):
await relation_repository.create(relation_data)
@pytest.mark.asyncio
async def test_find_by_entities(
relation_repository: RelationRepository,
sample_relation: Relation,
sample_entity: Entity,
related_entity: Entity,
):
"""Test finding relations between specific entities"""
relations = await relation_repository.find_by_entities(sample_entity.id, related_entity.id)
assert len(relations) == 1
assert relations[0].id == sample_relation.id
assert relations[0].relation_type == sample_relation.relation_type
@pytest.mark.asyncio
async def test_find_relation(relation_repository: RelationRepository, sample_relation: Relation):
"""Test finding relations by type"""
relation = await relation_repository.find_relation(
from_permalink=sample_relation.from_entity.permalink,
to_permalink=sample_relation.to_entity.permalink,
relation_type=sample_relation.relation_type,
)
assert relation.id == sample_relation.id
@pytest.mark.asyncio
async def test_find_by_type(relation_repository: RelationRepository, sample_relation: Relation):
"""Test finding relations by type"""
relations = await relation_repository.find_by_type("test_relation")
assert len(relations) == 1
assert relations[0].id == sample_relation.id
@pytest.mark.asyncio
async def test_find_unresolved_relations(
relation_repository: RelationRepository, sample_entity: Entity, related_entity: Entity
):
"""Test creating a new relation"""
relation_data = {
"from_id": sample_entity.id,
"to_id": None,
"to_name": related_entity.title,
"relation_type": "test_relation",
"context": "test-context",
}
relation = await relation_repository.create(relation_data)
assert relation.from_id == sample_entity.id
assert relation.to_id is None
unresolved = await relation_repository.find_unresolved_relations()
assert len(unresolved) == 1
assert unresolved[0].id == relation.id
@pytest.mark.asyncio
async def test_delete_by_fields_single_field(
relation_repository: RelationRepository, multiple_relations: list[Relation]
):
"""Test deleting relations by a single field."""
# Delete all relations of type 'relation_one'
result = await relation_repository.delete_by_fields(relation_type="relation_one") # pyright: ignore [reportArgumentType]
assert result is True
# Verify deletion
remaining = await relation_repository.find_by_type("relation_one")
assert len(remaining) == 0
# Other relations should still exist
others = await relation_repository.find_by_type("relation_two")
assert len(others) == 1
@pytest.mark.asyncio
async def test_delete_by_fields_multiple_fields(
relation_repository: RelationRepository,
multiple_relations: list[Relation],
sample_entity: Entity,
related_entity: Entity,
):
"""Test deleting relations by multiple fields."""
# Delete specific relation matching both from_id and relation_type
result = await relation_repository.delete_by_fields(
from_id=sample_entity.id, # pyright: ignore [reportArgumentType]
relation_type="relation_one", # pyright: ignore [reportArgumentType]
)
assert result is True
# Verify correct relation was deleted
remaining = await relation_repository.find_by_entities(sample_entity.id, related_entity.id)
assert len(remaining) == 1 # Only relation_two should remain
assert remaining[0].relation_type == "relation_two"
@pytest.mark.asyncio
async def test_delete_by_fields_no_match(
relation_repository: RelationRepository, multiple_relations: list[Relation]
):
"""Test delete_by_fields when no relations match."""
result = await relation_repository.delete_by_fields(
relation_type="nonexistent_type" # pyright: ignore [reportArgumentType]
)
assert result is False
@pytest.mark.asyncio
async def test_delete_by_fields_all_fields(
relation_repository: RelationRepository,
multiple_relations: list[Relation],
sample_entity: Entity,
related_entity: Entity,
):
"""Test deleting relation by matching all fields."""
# Get first relation's data
relation = multiple_relations[0]
# Delete using all fields
result = await relation_repository.delete_by_fields(
from_id=relation.from_id, # pyright: ignore [reportArgumentType]
to_id=relation.to_id, # pyright: ignore [reportArgumentType]
relation_type=relation.relation_type, # pyright: ignore [reportArgumentType]
)
assert result is True
# Verify only exact match was deleted
remaining = await relation_repository.find_by_type(relation.relation_type)
assert len(remaining) == 1 # One other relation_one should remain
@pytest.mark.asyncio
async def test_delete_relation_by_id(relation_repository, test_relations):
"""Test deleting a relation by ID."""
relation = test_relations[0]
result = await relation_repository.delete(relation.id)
assert result is True
# Verify deletion
remaining = await relation_repository.find_one(
relation_repository.select(Relation).filter(Relation.id == relation.id)
)
assert remaining is None
@pytest.mark.asyncio
async def test_delete_relations_by_type(relation_repository, test_relations):
"""Test deleting relations by type."""
result = await relation_repository.delete_by_fields(relation_type="connects_to")
assert result is True
# Verify specific type was deleted
remaining = await relation_repository.find_by_type("connects_to")
assert len(remaining) == 0
# Verify other type still exists
others = await relation_repository.find_by_type("depends_on")
assert len(others) == 1
@pytest.mark.asyncio
async def test_delete_relations_by_entities(
relation_repository, test_relations, source_entity, target_entity
):
"""Test deleting relations between specific entities."""
result = await relation_repository.delete_by_fields(
from_id=source_entity.id, to_id=target_entity.id
)
assert result is True
# Verify all relations between entities were deleted
remaining = await relation_repository.find_by_entities(source_entity.id, target_entity.id)
assert len(remaining) == 0
@pytest.mark.asyncio
async def test_delete_nonexistent_relation(relation_repository):
"""Test deleting a relation that doesn't exist."""
result = await relation_repository.delete_by_fields(relation_type="nonexistent")
assert result is False