test_embeddings.py•13.9 kB
"""
Tests for embedding generation and vector storage
"""
import os
import shutil
import tempfile
import pytest
from scout.chunker import CodeChunk
from scout.embeddings import EmbeddingStore, SearchResult
@pytest.fixture
def temp_db_path():
"""Create temporary database path"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Cleanup
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def embedding_store(temp_db_path):
"""Create embedding store with temporary database"""
store = EmbeddingStore(db_path=temp_db_path, collection_name="test_collection")
yield store
# Cleanup
store.reset()
@pytest.fixture
def sample_chunks():
"""Create sample code chunks for testing"""
return [
CodeChunk(
chunk_id="test-repo:chunk:1",
file_path="auth.py",
repo_name="test-repo",
language="python",
chunk_type="function",
code_text="def authenticate_user(username, password):\n return check_credentials(username, password)",
start_line=10,
end_line=11,
symbol_name="authenticate_user",
parent_class=None,
imports=["bcrypt", "jwt"],
context_text="File: auth.py | Function: authenticate_user\n\ndef authenticate_user(username, password):\n return check_credentials(username, password)",
token_count=50,
),
CodeChunk(
chunk_id="test-repo:chunk:2",
file_path="user.py",
repo_name="test-repo",
language="python",
chunk_type="class",
code_text="class User:\n def __init__(self, name):\n self.name = name",
start_line=5,
end_line=7,
symbol_name="User",
parent_class=None,
imports=["dataclasses"],
context_text="File: user.py | Class: User\n\nclass User:\n def __init__(self, name):\n self.name = name",
token_count=40,
),
CodeChunk(
chunk_id="other-repo:chunk:1",
file_path="login.js",
repo_name="other-repo",
language="javascript",
chunk_type="function",
code_text="function login(username, password) {\n return authenticateUser(username, password);\n}",
start_line=1,
end_line=3,
symbol_name="login",
parent_class=None,
imports=["auth"],
context_text="File: login.js | Function: login\n\nfunction login(username, password) {\n return authenticateUser(username, password);\n}",
token_count=45,
),
]
class TestEmbeddingStore:
"""Tests for EmbeddingStore class"""
def test_store_initialization(self, embedding_store):
"""Test embedding store initializes correctly"""
assert embedding_store.collection is not None
assert embedding_store.model is not None
assert embedding_store.collection_name == "test_collection"
def test_add_chunks(self, embedding_store, sample_chunks):
"""Test adding chunks to the store"""
embedding_store.add_chunks(sample_chunks)
# Verify chunks were added by checking collection size
results = embedding_store.collection.get()
assert len(results["ids"]) == 3
def test_add_empty_chunks(self, embedding_store):
"""Test adding empty list of chunks"""
embedding_store.add_chunks([])
results = embedding_store.collection.get()
assert len(results["ids"]) == 0
def test_semantic_search(self, embedding_store, sample_chunks):
"""Test semantic search functionality"""
embedding_store.add_chunks(sample_chunks)
# Search for authentication-related code
results = embedding_store.semantic_search("user authentication", n_results=5)
assert len(results) > 0
assert all(isinstance(r, SearchResult) for r in results)
# Should find the authenticate_user function
auth_results = [r for r in results if "authenticate" in r.code_text.lower()]
assert len(auth_results) > 0
def test_semantic_search_with_repo_filter(self, embedding_store, sample_chunks):
"""Test semantic search with repository filter"""
embedding_store.add_chunks(sample_chunks)
# Search only in test-repo
results = embedding_store.semantic_search(
"authentication", n_results=10, repo_filter=["test-repo"]
)
assert len(results) > 0
assert all(r.repo_name == "test-repo" for r in results)
def test_semantic_search_with_language_filter(self, embedding_store, sample_chunks):
"""Test semantic search with language filter"""
embedding_store.add_chunks(sample_chunks)
# Search only Python code
results = embedding_store.semantic_search(
"user", n_results=10, language_filter="python"
)
assert len(results) > 0
assert all(r.metadata["language"] == "python" for r in results)
def test_find_by_symbol(self, embedding_store, sample_chunks):
"""Test finding code by exact symbol name"""
embedding_store.add_chunks(sample_chunks)
results = embedding_store.find_by_symbol("authenticate_user")
assert len(results) == 1
assert results[0].symbol_name == "authenticate_user"
assert "authenticate_user" in results[0].code_text
def test_find_by_symbol_with_repo_filter(self, embedding_store, sample_chunks):
"""Test finding symbol with repository filter"""
embedding_store.add_chunks(sample_chunks)
results = embedding_store.find_by_symbol("login", repo_filter=["other-repo"])
assert len(results) == 1
assert results[0].repo_name == "other-repo"
def test_find_related_by_file(self, embedding_store, sample_chunks):
"""Test finding related code by file"""
embedding_store.add_chunks(sample_chunks)
results = embedding_store.find_related_by_file(
file_path="auth.py", repo_name="test-repo", n_results=5
)
# Should find related code, excluding the same file
assert all(r.file_path != "auth.py" for r in results)
def test_list_repos(self, embedding_store, sample_chunks):
"""Test listing all repositories"""
embedding_store.add_chunks(sample_chunks)
repos = embedding_store.list_repos()
assert "test-repo" in repos
assert "other-repo" in repos
assert len(repos) == 2
def test_get_repo_stats(self, embedding_store, sample_chunks):
"""Test getting repository statistics"""
embedding_store.add_chunks(sample_chunks)
stats = embedding_store.get_repo_stats("test-repo")
assert stats["repo_name"] == "test-repo"
assert stats["chunk_count"] == 2
assert "auth.py" in stats["files"]
assert "user.py" in stats["files"]
assert "python" in stats["languages"]
def test_get_repo_stats_nonexistent(self, embedding_store):
"""Test getting stats for nonexistent repo"""
stats = embedding_store.get_repo_stats("nonexistent-repo")
assert stats["repo_name"] == "nonexistent-repo"
assert stats["chunk_count"] == 0
assert len(stats["files"]) == 0
def test_delete_repo(self, embedding_store, sample_chunks):
"""Test deleting a repository"""
embedding_store.add_chunks(sample_chunks)
deleted_count = embedding_store.delete_repo("test-repo")
assert deleted_count == 2 # Two chunks from test-repo
# Verify only other-repo remains
remaining = embedding_store.collection.get()
assert len(remaining["ids"]) == 1
assert remaining["metadatas"][0]["repo_name"] == "other-repo"
def test_delete_nonexistent_repo(self, embedding_store):
"""Test deleting nonexistent repository"""
deleted_count = embedding_store.delete_repo("nonexistent")
assert deleted_count == 0
def test_delete_file(self, embedding_store, sample_chunks):
"""Test deleting chunks for a specific file"""
embedding_store.add_chunks(sample_chunks)
# Delete auth.py file
deleted_count = embedding_store.delete_file("test-repo", "auth.py")
assert deleted_count == 1 # One chunk from auth.py
# Verify only auth.py chunks were deleted
remaining = embedding_store.collection.get(where={"repo_name": "test-repo"})
assert len(remaining["ids"]) == 1 # Only user.py remaining
remaining_files = {meta["file_path"] for meta in remaining["metadatas"]}
assert "auth.py" not in remaining_files
assert "user.py" in remaining_files
def test_delete_file_nonexistent(self, embedding_store):
"""Test deleting nonexistent file"""
deleted_count = embedding_store.delete_file("test-repo", "nonexistent.py")
assert deleted_count == 0
def test_delete_file_with_multiple_chunks(self, embedding_store):
"""Test deleting file with multiple chunks"""
from scout.chunker import CodeChunk
# Create multiple chunks for same file
chunks = [
CodeChunk(
chunk_id=f"test:file:chunk{i}",
file_path="multi_chunk.py",
repo_name="test-repo",
language="python",
chunk_type="function",
code_text=f"def function_{i}(): pass",
start_line=i * 10,
end_line=i * 10 + 5,
symbol_name=f"function_{i}",
parent_class=None,
imports=[],
context_text=f"def function_{i}(): pass",
token_count=10,
)
for i in range(5)
]
embedding_store.add_chunks(chunks)
# Delete all chunks for this file
deleted_count = embedding_store.delete_file("test-repo", "multi_chunk.py")
assert deleted_count == 5
# Verify all chunks deleted
results = embedding_store.collection.get(where={"file_path": "multi_chunk.py"})
assert len(results["ids"]) == 0
def test_reset(self, embedding_store, sample_chunks):
"""Test resetting the collection"""
embedding_store.add_chunks(sample_chunks)
# Verify chunks exist
results = embedding_store.collection.get()
assert len(results["ids"]) == 3
# Reset
embedding_store.reset()
# Verify collection is empty
results = embedding_store.collection.get()
assert len(results["ids"]) == 0
def test_search_result_attributes(self, embedding_store, sample_chunks):
"""Test SearchResult has all required attributes"""
embedding_store.add_chunks(sample_chunks)
results = embedding_store.semantic_search("test", n_results=1)
if results:
result = results[0]
assert hasattr(result, "chunk_id")
assert hasattr(result, "file_path")
assert hasattr(result, "repo_name")
assert hasattr(result, "symbol_name")
assert hasattr(result, "code_text")
assert hasattr(result, "score")
assert hasattr(result, "metadata")
def test_persistence(self, temp_db_path, sample_chunks):
"""Test that data persists across store instances"""
# Create store and add chunks
store1 = EmbeddingStore(db_path=temp_db_path, collection_name="persist_test")
store1.add_chunks(sample_chunks)
# Close and recreate store
del store1
store2 = EmbeddingStore(db_path=temp_db_path, collection_name="persist_test")
# Data should still be there
results = store2.collection.get()
assert len(results["ids"]) == 3
# Cleanup
store2.reset()
def test_model_selection_env_var(self, temp_db_path):
"""Test that SCOUT_MODEL environment variable is respected"""
# Set environment variable
os.environ["SCOUT_MODEL"] = "sentence-transformers/all-MiniLM-L6-v2"
try:
store = EmbeddingStore(db_path=temp_db_path, collection_name="model_test")
assert store.model_name == "sentence-transformers/all-MiniLM-L6-v2"
# Cleanup
store.reset()
finally:
# Clean up environment variable
if "SCOUT_MODEL" in os.environ:
del os.environ["SCOUT_MODEL"]
def test_model_selection_parameter_override(self, temp_db_path):
"""Test that explicit model_name parameter overrides environment variable"""
os.environ["SCOUT_MODEL"] = "sentence-transformers/all-MiniLM-L6-v2"
try:
store = EmbeddingStore(
db_path=temp_db_path,
collection_name="override_test",
model_name="sentence-transformers/all-mpnet-base-v2",
)
# Parameter should override env var
assert store.model_name == "sentence-transformers/all-mpnet-base-v2"
# Cleanup
store.reset()
finally:
if "SCOUT_MODEL" in os.environ:
del os.environ["SCOUT_MODEL"]
def test_default_model_selection(self, temp_db_path):
"""Test that default model is used when no override specified"""
# Clear env var if set
if "SCOUT_MODEL" in os.environ:
del os.environ["SCOUT_MODEL"]
store = EmbeddingStore(db_path=temp_db_path, collection_name="default_test")
# Should use the new default
assert (
store.model_name == "sentence-transformers/multi-qa-mpnet-base-dot-v1"
)
# Cleanup
store.reset()
if __name__ == "__main__":
pytest.main([__file__, "-v"])