test_repository.py•6 kB
"""Test repository implementation."""
from datetime import datetime
import pytest
from sqlalchemy import String, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from basic_memory.models import Base
from basic_memory.repository.repository import Repository
class ModelTest(Base):
"""Test model for repository tests."""
__tablename__ = "test_model"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
@pytest.fixture
def repository(session_maker):
"""Create a test repository."""
return Repository(session_maker, ModelTest)
@pytest.mark.asyncio
async def test_add(repository):
"""Test bulk creation of entities."""
# Create test instances
instance = ModelTest(id="test_add", name="Test Add")
await repository.add(instance)
# Verify we can find in db
found = await repository.find_by_id("test_add")
assert found is not None
assert found.name == "Test Add"
@pytest.mark.asyncio
async def test_add_all(repository):
"""Test bulk creation of entities."""
# Create test instances
instances = [ModelTest(id=f"test_{i}", name=f"Test {i}") for i in range(3)]
await repository.add_all(instances)
# Verify we can find them in db
found = await repository.find_by_id("test_0")
assert found is not None
assert found.name == "Test 0"
@pytest.mark.asyncio
async def test_bulk_create(repository):
"""Test bulk creation of entities."""
# Create test instances
instances = [ModelTest(id=f"test_{i}", name=f"Test {i}") for i in range(3)]
# Bulk create
await repository.create_all([instance.__dict__ for instance in instances])
# Verify we can find them in db
found = await repository.find_by_id("test_0")
assert found is not None
assert found.name == "Test 0"
@pytest.mark.asyncio
async def test_find_all(repository):
"""Test finding multiple entities by IDs."""
# Create test data
instances = [ModelTest(id=f"test_{i}", name=f"Test {i}") for i in range(5)]
await repository.create_all([instance.__dict__ for instance in instances])
found = await repository.find_all(limit=3)
assert len(found) == 3
@pytest.mark.asyncio
async def test_find_by_ids(repository):
"""Test finding multiple entities by IDs."""
# Create test data
instances = [ModelTest(id=f"test_{i}", name=f"Test {i}") for i in range(5)]
await repository.create_all([instance.__dict__ for instance in instances])
# Test finding subset of entities
ids_to_find = ["test_0", "test_2", "test_4"]
found = await repository.find_by_ids(ids_to_find)
assert len(found) == 3
assert sorted([e.id for e in found]) == sorted(ids_to_find)
# Test finding with some non-existent IDs
mixed_ids = ["test_0", "nonexistent", "test_4"]
partial_found = await repository.find_by_ids(mixed_ids)
assert len(partial_found) == 2
assert sorted([e.id for e in partial_found]) == ["test_0", "test_4"]
# Test with empty list
empty_found = await repository.find_by_ids([])
assert len(empty_found) == 0
# Test with all non-existent IDs
not_found = await repository.find_by_ids(["fake1", "fake2"])
assert len(not_found) == 0
@pytest.mark.asyncio
async def test_delete_by_ids(repository):
"""Test finding multiple entities by IDs."""
# Create test data
instances = [ModelTest(id=f"test_{i}", name=f"Test {i}") for i in range(5)]
await repository.create_all([instance.__dict__ for instance in instances])
# Test delete subset of entities
ids_to_delete = ["test_0", "test_2", "test_4"]
deleted_count = await repository.delete_by_ids(ids_to_delete)
assert deleted_count == 3
# Test finding subset of entities
ids_to_find = ["test_1", "test_3"]
found = await repository.find_by_ids(ids_to_find)
assert len(found) == 2
assert sorted([e.id for e in found]) == sorted(ids_to_find)
assert await repository.find_by_id(ids_to_delete[0]) is None
assert await repository.find_by_id(ids_to_delete[1]) is None
assert await repository.find_by_id(ids_to_delete[2]) is None
@pytest.mark.asyncio
async def test_update(repository):
"""Test finding entities modified since a timestamp."""
# Create initial test data
instance = ModelTest(id="test_add", name="Test Add")
await repository.add(instance)
instance = ModelTest(id="test_add", name="Updated")
# Find recently modified
modified = await repository.update(instance.id, {"name": "Updated"})
assert modified is not None
assert modified.name == "Updated"
@pytest.mark.asyncio
async def test_update_model(repository):
"""Test finding entities modified since a timestamp."""
# Create initial test data
instance = ModelTest(id="test_add", name="Test Add")
await repository.add(instance)
instance.name = "Updated"
# Find recently modified
modified = await repository.update(instance.id, instance)
assert modified is not None
assert modified.name == "Updated"
@pytest.mark.asyncio
async def test_update_model_not_found(repository):
"""Test finding entities modified since a timestamp."""
# Create initial test data
instance = ModelTest(id="test_add", name="Test Add")
await repository.add(instance)
modified = await repository.update(0, {})
assert modified is None
@pytest.mark.asyncio
async def test_count(repository):
"""Test bulk creation of entities."""
# Create test instances
instance = ModelTest(id="test_add", name="Test Add")
await repository.add(instance)
# Verify we can count in db
count = await repository.count()
assert count == 1