test_project_repository.py•10.5 kB
"""Tests for the ProjectRepository."""
from datetime import datetime, timezone
from pathlib import Path
import pytest
import pytest_asyncio
from sqlalchemy import select
from basic_memory import db
from basic_memory.models.project import Project
from basic_memory.repository.project_repository import ProjectRepository
@pytest_asyncio.fixture
async def sample_project(project_repository: ProjectRepository) -> Project:
"""Create a sample project for testing."""
project_data = {
"name": "Sample Project",
"description": "A sample project",
"path": "/sample/project/path",
"is_active": True,
"is_default": False,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
return await project_repository.create(project_data)
@pytest.mark.asyncio
async def test_create_project(project_repository: ProjectRepository):
"""Test creating a new project."""
project_data = {
"name": "Sample Project",
"description": "A sample project",
"path": "/sample/project/path",
"is_active": True,
"is_default": False,
}
project = await project_repository.create(project_data)
# Verify returned object
assert project.id is not None
assert project.name == "Sample Project"
assert project.description == "A sample project"
assert project.path == "/sample/project/path"
assert project.is_active is True
assert project.is_default is False
assert isinstance(project.created_at, datetime)
assert isinstance(project.updated_at, datetime)
# Verify permalink was generated correctly
assert project.permalink == "sample-project"
# Verify in database
found = await project_repository.find_by_id(project.id)
assert found is not None
assert found.id == project.id
assert found.name == project.name
assert found.description == project.description
assert found.path == project.path
assert found.permalink == "sample-project"
assert found.is_active is True
assert found.is_default is False
@pytest.mark.asyncio
async def test_get_by_name(project_repository: ProjectRepository, sample_project: Project):
"""Test getting a project by name."""
# Test exact match
found = await project_repository.get_by_name(sample_project.name)
assert found is not None
assert found.id == sample_project.id
assert found.name == sample_project.name
# Test non-existent name
found = await project_repository.get_by_name("Non-existent Project")
assert found is None
@pytest.mark.asyncio
async def test_get_by_permalink(project_repository: ProjectRepository, sample_project: Project):
"""Test getting a project by permalink."""
# Verify the permalink value
assert sample_project.permalink == "sample-project"
# Test exact match
found = await project_repository.get_by_permalink(sample_project.permalink)
assert found is not None
assert found.id == sample_project.id
assert found.permalink == sample_project.permalink
# Test non-existent permalink
found = await project_repository.get_by_permalink("non-existent-project")
assert found is None
@pytest.mark.asyncio
async def test_get_by_path(project_repository: ProjectRepository, sample_project: Project):
"""Test getting a project by path."""
# Test exact match
found = await project_repository.get_by_path(sample_project.path)
assert found is not None
assert found.id == sample_project.id
assert found.path == sample_project.path
# Test with Path object
found = await project_repository.get_by_path(Path(sample_project.path))
assert found is not None
assert found.id == sample_project.id
assert found.path == sample_project.path
# Test non-existent path
found = await project_repository.get_by_path("/non/existent/path")
assert found is None
@pytest.mark.asyncio
async def test_get_default_project(project_repository: ProjectRepository):
"""Test getting the default project."""
# We already have a default project from the test_project fixture
# So just create a non-default project
non_default_project_data = {
"name": "Non-Default Project",
"description": "A non-default project",
"path": "/non-default/project/path",
"is_active": True,
"is_default": None, # Not the default project
}
await project_repository.create(non_default_project_data)
# Get default project
default_project = await project_repository.get_default_project()
assert default_project is not None
assert default_project.is_default is True
@pytest.mark.asyncio
async def test_get_active_projects(project_repository: ProjectRepository):
"""Test getting all active projects."""
# Create active and inactive projects
active_project_data = {
"name": "Active Project",
"description": "An active project",
"path": "/active/project/path",
"is_active": True,
}
inactive_project_data = {
"name": "Inactive Project",
"description": "An inactive project",
"path": "/inactive/project/path",
"is_active": False,
}
await project_repository.create(active_project_data)
await project_repository.create(inactive_project_data)
# Get active projects
active_projects = await project_repository.get_active_projects()
assert len(active_projects) >= 1 # Could be more from other tests
# Verify that all returned projects are active
for project in active_projects:
assert project.is_active is True
# Verify active project is included
active_names = [p.name for p in active_projects]
assert "Active Project" in active_names
# Verify inactive project is not included
assert "Inactive Project" not in active_names
@pytest.mark.asyncio
async def test_set_as_default(project_repository: ProjectRepository, test_project: Project):
"""Test setting a project as default."""
# The test_project fixture is already the default
# Create a non-default project
project2_data = {
"name": "Project 2",
"description": "Project 2",
"path": "/project2/path",
"is_active": True,
"is_default": None, # Not default
}
# Get the existing default project
project1 = test_project
project2 = await project_repository.create(project2_data)
# Verify initial state
assert project1.is_default is True
assert project2.is_default is None
# Set project2 as default
updated_project2 = await project_repository.set_as_default(project2.id)
assert updated_project2 is not None
assert updated_project2.is_default is True
# Verify project1 is no longer default
project1_updated = await project_repository.find_by_id(project1.id)
assert project1_updated is not None
assert project1_updated.is_default is None
# Verify project2 is now default
project2_updated = await project_repository.find_by_id(project2.id)
assert project2_updated is not None
assert project2_updated.is_default is True
@pytest.mark.asyncio
async def test_update_project(project_repository: ProjectRepository, sample_project: Project):
"""Test updating a project."""
# Update project
updated_data = {
"name": "Updated Project Name",
"description": "Updated description",
"path": "/updated/path",
}
updated_project = await project_repository.update(sample_project.id, updated_data)
# Verify returned object
assert updated_project is not None
assert updated_project.id == sample_project.id
assert updated_project.name == "Updated Project Name"
assert updated_project.description == "Updated description"
assert updated_project.path == "/updated/path"
# Verify permalink was updated based on new name
assert updated_project.permalink == "updated-project-name"
# Verify in database
found = await project_repository.find_by_id(sample_project.id)
assert found is not None
assert found.name == "Updated Project Name"
assert found.description == "Updated description"
assert found.path == "/updated/path"
assert found.permalink == "updated-project-name"
# Verify we can find by the new permalink
found_by_permalink = await project_repository.get_by_permalink("updated-project-name")
assert found_by_permalink is not None
assert found_by_permalink.id == sample_project.id
@pytest.mark.asyncio
async def test_delete_project(project_repository: ProjectRepository, sample_project: Project):
"""Test deleting a project."""
# Delete project
result = await project_repository.delete(sample_project.id)
assert result is True
# Verify deletion
deleted = await project_repository.find_by_id(sample_project.id)
assert deleted is None
# Verify with direct database query
async with db.scoped_session(project_repository.session_maker) as session:
query = select(Project).filter(Project.id == sample_project.id)
result = await session.execute(query)
assert result.scalar_one_or_none() is None
@pytest.mark.asyncio
async def test_delete_nonexistent_project(project_repository: ProjectRepository):
"""Test deleting a project that doesn't exist."""
result = await project_repository.delete(999) # Non-existent ID
assert result is False
@pytest.mark.asyncio
async def test_update_path(project_repository: ProjectRepository, sample_project: Project):
"""Test updating a project's path."""
new_path = "/new/project/path"
# Update the project path
updated_project = await project_repository.update_path(sample_project.id, new_path)
# Verify returned object
assert updated_project is not None
assert updated_project.id == sample_project.id
assert updated_project.path == new_path
assert updated_project.name == sample_project.name # Other fields unchanged
# Verify in database
found = await project_repository.find_by_id(sample_project.id)
assert found is not None
assert found.path == new_path
assert found.name == sample_project.name
@pytest.mark.asyncio
async def test_update_path_nonexistent_project(project_repository: ProjectRepository):
"""Test updating path for a project that doesn't exist."""
result = await project_repository.update_path(999, "/some/path") # Non-existent ID
assert result is None