test_context_service.py•13.1 kB
"""Tests for context service."""
from datetime import datetime, timedelta, UTC
import pytest
import pytest_asyncio
from basic_memory.repository.search_repository import SearchIndexRow
from basic_memory.schemas.memory import memory_url, memory_url_path
from basic_memory.schemas.search import SearchItemType
from basic_memory.services.context_service import ContextService
from basic_memory.models.knowledge import Entity, Relation
from basic_memory.models.project import Project
@pytest_asyncio.fixture
async def context_service(search_repository, entity_repository, observation_repository):
"""Create context service for testing."""
return ContextService(search_repository, entity_repository, observation_repository)
@pytest.mark.asyncio
async def test_find_connected_depth_limit(context_service, test_graph):
"""Test depth limiting works.
Our traversal path is:
- Depth 0: Root
- Depth 1: Relations + directly connected entities (Connected1, Connected2)
- Depth 2: Relations + next level entities (Deep)
"""
type_id_pairs = [("entity", test_graph["root"].id)]
# With depth=1, we get direct connections
# shallow_results = await context_service.find_related(type_id_pairs, max_depth=1)
# shallow_entities = {(r.id, r.type) for r in shallow_results if r.type == "entity"}
#
# assert (test_graph["deep"].id, "entity") not in shallow_entities
# search deeper
deep_results = await context_service.find_related(type_id_pairs, max_depth=3, max_results=100)
deep_entities = {(r.id, r.type) for r in deep_results if r.type == "entity"}
print(deep_entities)
# Should now include Deep entity
assert (test_graph["deep"].id, "entity") in deep_entities
@pytest.mark.asyncio
async def test_find_connected_timeframe(
context_service, test_graph, search_repository, entity_repository
):
"""Test timeframe filtering.
This tests how traversal is affected by the item dates.
When we filter by date, items are only included if:
1. They match the timeframe
2. There is a valid path to them through other items in the timeframe
"""
now = datetime.now(UTC)
old_date = now - timedelta(days=10)
recent_date = now - timedelta(days=1)
# Update entity table timestamps directly
# Root entity uses old date
root_entity = test_graph["root"]
await entity_repository.update(root_entity.id, {"created_at": old_date, "updated_at": old_date})
# Connected entity uses recent date
connected_entity = test_graph["connected1"]
await entity_repository.update(
connected_entity.id, {"created_at": recent_date, "updated_at": recent_date}
)
# Also update search_index for test consistency
await search_repository.index_item(
SearchIndexRow(
project_id=entity_repository.project_id,
id=test_graph["root"].id,
title=test_graph["root"].title,
content_snippet="Root content",
permalink=test_graph["root"].permalink,
file_path=test_graph["root"].file_path,
type=SearchItemType.ENTITY,
metadata={"created_at": old_date.isoformat()},
created_at=old_date.isoformat(),
updated_at=old_date.isoformat(),
)
)
await search_repository.index_item(
SearchIndexRow(
project_id=entity_repository.project_id,
id=test_graph["relations"][0].id,
title="Root Entity → Connected Entity 1",
content_snippet="",
permalink=f"{test_graph['root'].permalink}/connects_to/{test_graph['connected1'].permalink}",
file_path=test_graph["root"].file_path,
type=SearchItemType.RELATION,
from_id=test_graph["root"].id,
to_id=test_graph["connected1"].id,
relation_type="connects_to",
metadata={"created_at": old_date.isoformat()},
created_at=old_date.isoformat(),
updated_at=old_date.isoformat(),
)
)
await search_repository.index_item(
SearchIndexRow(
project_id=entity_repository.project_id,
id=test_graph["connected1"].id,
title=test_graph["connected1"].title,
content_snippet="Connected 1 content",
permalink=test_graph["connected1"].permalink,
file_path=test_graph["connected1"].file_path,
type=SearchItemType.ENTITY,
metadata={"created_at": recent_date.isoformat()},
created_at=recent_date.isoformat(),
updated_at=recent_date.isoformat(),
)
)
type_id_pairs = [("entity", test_graph["root"].id)]
# Search with a 7-day cutoff
since_date = now - timedelta(days=7)
results = await context_service.find_related(type_id_pairs, since=since_date)
# Only connected1 is recent, but we can't get to it
# because its connecting relation is too old and is filtered out
# (we can only reach connected1 through a relation starting from root)
entity_ids = {r.id for r in results if r.type == "entity"}
assert len(entity_ids) == 0 # No accessible entities within timeframe
@pytest.mark.asyncio
async def test_build_context(context_service, test_graph):
"""Test exact permalink lookup."""
url = memory_url.validate_strings("memory://test/root")
context_result = await context_service.build_context(url)
# Check metadata
assert context_result.metadata.uri == memory_url_path(url)
assert context_result.metadata.depth == 1
assert context_result.metadata.primary_count == 1
assert context_result.metadata.related_count > 0
assert context_result.metadata.generated_at is not None
# Check results
assert len(context_result.results) == 1
context_item = context_result.results[0]
# Check primary result
primary_result = context_item.primary_result
assert primary_result.id == test_graph["root"].id
assert primary_result.type == "entity"
assert primary_result.title == "Root"
assert primary_result.permalink == "test/root"
assert primary_result.file_path == "test/Root.md"
assert primary_result.created_at is not None
# Check related results
assert len(context_item.related_results) > 0
# Find related relation
relation = next((r for r in context_item.related_results if r.type == "relation"), None)
assert relation is not None
assert relation.relation_type == "connects_to"
assert relation.from_id == test_graph["root"].id
assert relation.to_id == test_graph["connected1"].id
# Find related entity
related_entity = next((r for r in context_item.related_results if r.type == "entity"), None)
assert related_entity is not None
assert related_entity.id == test_graph["connected1"].id
assert related_entity.title == test_graph["connected1"].title
assert related_entity.permalink == test_graph["connected1"].permalink
@pytest.mark.asyncio
async def test_build_context_with_observations(context_service, test_graph):
"""Test context building with observations."""
# The test_graph fixture already creates observations for root entity
# Let's use those existing observations
# Build context
url = memory_url.validate_strings("memory://test/root")
context_result = await context_service.build_context(url, include_observations=True)
# Check the metadata
assert context_result.metadata.total_observations > 0
assert len(context_result.results) == 1
# Check that observations were included
context_item = context_result.results[0]
assert len(context_item.observations) > 0
# Check observation properties
for observation in context_item.observations:
assert observation.type == "observation"
assert observation.category in ["note", "tech"] # Categories from test_graph fixture
assert observation.entity_id == test_graph["root"].id
# Verify at least one observation has the correct category and content
note_observation = next((o for o in context_item.observations if o.category == "note"), None)
assert note_observation is not None
assert "Root note" in note_observation.content
@pytest.mark.asyncio
async def test_build_context_not_found(context_service):
"""Test handling non-existent permalinks."""
context = await context_service.build_context("memory://does/not/exist")
assert len(context.results) == 0
assert context.metadata.primary_count == 0
assert context.metadata.related_count == 0
@pytest.mark.asyncio
async def test_context_metadata(context_service, test_graph):
"""Test metadata is correctly populated."""
context = await context_service.build_context("memory://test/root", depth=2)
metadata = context.metadata
assert metadata.uri == "test/root"
assert metadata.depth == 2
assert metadata.generated_at is not None
assert metadata.primary_count > 0
@pytest.mark.asyncio
async def test_project_isolation_in_find_related(session_maker):
"""Test that find_related respects project boundaries and doesn't leak data."""
from basic_memory.repository.entity_repository import EntityRepository
from basic_memory.repository.observation_repository import ObservationRepository
from basic_memory.repository.search_repository import SearchRepository
from basic_memory import db
# Create database session
async with db.scoped_session(session_maker) as db_session:
# Create two separate projects
project1 = Project(name="project1", path="/test1")
project2 = Project(name="project2", path="/test2")
db_session.add(project1)
db_session.add(project2)
await db_session.flush()
# Create entities in project1
entity1_p1 = Entity(
title="Entity1_P1",
entity_type="document",
content_type="text/markdown",
project_id=project1.id,
permalink="project1/entity1",
file_path="project1/entity1.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
entity2_p1 = Entity(
title="Entity2_P1",
entity_type="document",
content_type="text/markdown",
project_id=project1.id,
permalink="project1/entity2",
file_path="project1/entity2.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
# Create entities in project2
entity1_p2 = Entity(
title="Entity1_P2",
entity_type="document",
content_type="text/markdown",
project_id=project2.id,
permalink="project2/entity1",
file_path="project2/entity1.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
db_session.add_all([entity1_p1, entity2_p1, entity1_p2])
await db_session.flush()
# Create relation in project1 (between entities of project1)
relation_p1 = Relation(
from_id=entity1_p1.id,
to_id=entity2_p1.id,
to_name="Entity2_P1",
relation_type="connects_to",
)
db_session.add(relation_p1)
await db_session.commit()
# Create repositories for project1
search_repo_p1 = SearchRepository(session_maker, project1.id)
entity_repo_p1 = EntityRepository(session_maker, project1.id)
obs_repo_p1 = ObservationRepository(session_maker, project1.id)
context_service_p1 = ContextService(search_repo_p1, entity_repo_p1, obs_repo_p1)
# Create repositories for project2
search_repo_p2 = SearchRepository(session_maker, project2.id)
entity_repo_p2 = EntityRepository(session_maker, project2.id)
obs_repo_p2 = ObservationRepository(session_maker, project2.id)
context_service_p2 = ContextService(search_repo_p2, entity_repo_p2, obs_repo_p2)
# Test: find_related for project1 should only return project1 entities
type_id_pairs_p1 = [("entity", entity1_p1.id)]
related_p1 = await context_service_p1.find_related(type_id_pairs_p1, max_depth=2)
# Verify only project1 entities are returned
related_entity_ids = [r.id for r in related_p1 if r.type == "entity"]
assert entity2_p1.id in related_entity_ids # Should find connected entity2 in project1
assert entity1_p2.id not in related_entity_ids # Should NOT find entity from project2
# Test: find_related for project2 should return empty (no relations)
type_id_pairs_p2 = [("entity", entity1_p2.id)]
related_p2 = await context_service_p2.find_related(type_id_pairs_p2, max_depth=2)
# Project2 has no relations, so should return empty
assert len(related_p2) == 0
# Double-check: verify entities exist in their respective projects
assert entity1_p1.project_id == project1.id
assert entity2_p1.project_id == project1.id
assert entity1_p2.project_id == project2.id