"""Tests for LLM response caching."""
import hashlib
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from local_deepwiki.config import LLMCacheConfig
from local_deepwiki.core.llm_cache import LLMCache
from local_deepwiki.providers.base import LLMProvider
from local_deepwiki.providers.llm.cached import CachingLLMProvider
class MockEmbeddingProvider:
"""Mock embedding provider for testing."""
def __init__(self, dimension: int = 384):
self._dimension = dimension
async def embed(self, texts: list[str]) -> list[list[float]]:
"""Return deterministic embeddings based on text hash."""
embeddings = []
for text in texts:
# Create deterministic embedding from text hash
hash_bytes = hashlib.md5(text.encode()).digest()
# Expand to dimension size
values = []
for i in range(self._dimension):
byte_idx = i % len(hash_bytes)
values.append((hash_bytes[byte_idx] - 128) / 128.0)
embeddings.append(values)
return embeddings
def get_dimension(self) -> int:
return self._dimension
@property
def name(self) -> str:
return "mock-embedding"
class MockLLMProvider(LLMProvider):
"""Mock LLM provider for testing."""
def __init__(self):
self.calls: list[dict] = []
async def generate(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> str:
self.calls.append(
{
"prompt": prompt,
"system_prompt": system_prompt,
"max_tokens": max_tokens,
"temperature": temperature,
}
)
return f"Response to: {prompt[:50]}"
async def generate_stream(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
):
self.calls.append(
{
"prompt": prompt,
"system_prompt": system_prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": True,
}
)
response = f"Streamed response to: {prompt[:50]}"
for chunk in response.split():
yield chunk + " "
@property
def name(self) -> str:
return "mock-llm"
class TestLLMCache:
"""Tests for the LLMCache class."""
@pytest.fixture
def cache_path(self, tmp_path: Path) -> Path:
"""Create a temporary cache path."""
return tmp_path / "test_cache.lance"
@pytest.fixture
def embedding_provider(self) -> MockEmbeddingProvider:
"""Create a mock embedding provider."""
return MockEmbeddingProvider()
@pytest.fixture
def config(self) -> LLMCacheConfig:
"""Create a cache config with default settings."""
return LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.95,
max_cacheable_temperature=0.3,
)
@pytest.fixture
def cache(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider, config: LLMCacheConfig
) -> LLMCache:
"""Create an LLMCache instance."""
return LLMCache(cache_path, embedding_provider, config)
@pytest.mark.asyncio
async def test_cache_miss_on_empty_cache(self, cache: LLMCache):
"""Test that empty cache returns None."""
result = await cache.get(
prompt="test prompt",
system_prompt="test system",
temperature=0.1,
model_name="test-model",
)
assert result is None
assert cache.stats["misses"] == 1
@pytest.mark.asyncio
async def test_cache_set_and_get_exact_match(self, cache: LLMCache):
"""Test that exact same prompt returns cached response."""
prompt = "What is the meaning of life?"
system_prompt = "You are a philosopher"
response = "42"
# Set cache entry
await cache.set(
prompt=prompt,
response=response,
system_prompt=system_prompt,
temperature=0.1,
model_name="test-model",
)
# Get cache entry
result = await cache.get(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.1,
model_name="test-model",
)
assert result == response
assert cache.stats["hits"] == 1
@pytest.mark.asyncio
async def test_high_temperature_not_cached(self, cache: LLMCache):
"""Test that high temperature responses are not cached."""
prompt = "Be creative"
response = "Random creative output"
# Try to cache with high temperature
await cache.set(
prompt=prompt,
response=response,
temperature=0.8, # Above max_cacheable_temperature
model_name="test-model",
)
# Should not be in cache
result = await cache.get(
prompt=prompt,
temperature=0.8,
model_name="test-model",
)
assert result is None
assert cache.stats["skipped"] == 1
@pytest.mark.asyncio
async def test_high_temperature_get_skipped(self, cache: LLMCache):
"""Test that cache lookup is skipped for high temperature requests."""
# First, cache something at low temperature
await cache.set(
prompt="test",
response="cached",
temperature=0.1,
model_name="test-model",
)
# Try to get with high temperature - should skip cache
result = await cache.get(
prompt="test",
temperature=0.8, # Above threshold
model_name="test-model",
)
assert result is None
assert cache.stats["skipped"] == 1
@pytest.mark.asyncio
async def test_different_system_prompts_different_cache_entries(self, cache: LLMCache):
"""Test that different system prompts result in different cache entries."""
prompt = "Hello"
await cache.set(
prompt=prompt,
response="Response A",
system_prompt="You are A",
temperature=0.1,
model_name="test-model",
)
await cache.set(
prompt=prompt,
response="Response B",
system_prompt="You are B",
temperature=0.1,
model_name="test-model",
)
result_a = await cache.get(
prompt=prompt,
system_prompt="You are A",
temperature=0.1,
model_name="test-model",
)
result_b = await cache.get(
prompt=prompt,
system_prompt="You are B",
temperature=0.1,
model_name="test-model",
)
assert result_a == "Response A"
assert result_b == "Response B"
@pytest.mark.asyncio
async def test_clear_cache(self, cache: LLMCache):
"""Test clearing the cache."""
# Add some entries
await cache.set(prompt="p1", response="r1", temperature=0.1, model_name="m")
await cache.set(prompt="p2", response="r2", temperature=0.1, model_name="m")
assert cache.get_entry_count() == 2
# Clear cache
cleared = await cache.clear()
assert cleared == 2
assert cache.get_entry_count() == 0
@pytest.mark.asyncio
async def test_cache_stats(self, cache: LLMCache):
"""Test that cache statistics are tracked correctly."""
# Miss on empty cache
await cache.get(prompt="p1", temperature=0.1, model_name="m")
# Set and hit
await cache.set(prompt="p2", response="r2", temperature=0.1, model_name="m")
await cache.get(prompt="p2", temperature=0.1, model_name="m")
# Skip due to high temperature
await cache.get(prompt="p3", temperature=0.9, model_name="m")
stats = cache.stats
assert stats["misses"] == 1
assert stats["hits"] == 1
assert stats["skipped"] == 1
class TestCachingLLMProvider:
"""Tests for the CachingLLMProvider class."""
@pytest.fixture
def cache_path(self, tmp_path: Path) -> Path:
"""Create a temporary cache path."""
return tmp_path / "test_cache.lance"
@pytest.fixture
def mock_llm(self) -> MockLLMProvider:
"""Create a mock LLM provider."""
return MockLLMProvider()
@pytest.fixture
def cache(self, cache_path: Path) -> LLMCache:
"""Create an LLMCache instance."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.95,
max_cacheable_temperature=0.3,
)
embedding_provider = MockEmbeddingProvider()
return LLMCache(cache_path, embedding_provider, config)
@pytest.fixture
def cached_provider(self, mock_llm: MockLLMProvider, cache: LLMCache) -> CachingLLMProvider:
"""Create a CachingLLMProvider instance."""
return CachingLLMProvider(mock_llm, cache)
def test_name_includes_cache_prefix(self, cached_provider: CachingLLMProvider):
"""Test that provider name includes cache prefix."""
assert cached_provider.name == "cached:mock-llm"
@pytest.mark.asyncio
async def test_first_call_goes_to_provider(
self, cached_provider: CachingLLMProvider, mock_llm: MockLLMProvider
):
"""Test that first call goes to underlying provider."""
result = await cached_provider.generate(
prompt="test prompt",
system_prompt="test system",
temperature=0.1,
)
assert len(mock_llm.calls) == 1
assert "Response to: test prompt" in result
@pytest.mark.asyncio
async def test_second_call_uses_cache(
self, cached_provider: CachingLLMProvider, mock_llm: MockLLMProvider
):
"""Test that second identical call uses cache."""
prompt = "What is 2+2?"
system_prompt = "You are a calculator"
# First call
result1 = await cached_provider.generate(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.1,
)
# Second call (should hit cache)
result2 = await cached_provider.generate(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.1,
)
# Should only have one call to underlying provider
assert len(mock_llm.calls) == 1
assert result1 == result2
@pytest.mark.asyncio
async def test_high_temperature_bypasses_cache(
self, cached_provider: CachingLLMProvider, mock_llm: MockLLMProvider
):
"""Test that high temperature calls don't use cache."""
prompt = "Be creative"
# First call with high temp
await cached_provider.generate(prompt=prompt, temperature=0.8)
# Second call with high temp (should not hit cache)
await cached_provider.generate(prompt=prompt, temperature=0.8)
# Both calls should go to provider
assert len(mock_llm.calls) == 2
@pytest.mark.asyncio
async def test_stats_accessible(
self, cached_provider: CachingLLMProvider, mock_llm: MockLLMProvider
):
"""Test that cache stats are accessible through provider."""
await cached_provider.generate(prompt="p1", temperature=0.1)
await cached_provider.generate(prompt="p1", temperature=0.1)
stats = cached_provider.stats
assert stats["hits"] == 1
assert stats["misses"] == 1
@pytest.mark.asyncio
async def test_stream_first_call_caches(
self, cached_provider: CachingLLMProvider, mock_llm: MockLLMProvider
):
"""Test that streaming call caches the complete response."""
prompt = "Tell me a story"
# First streaming call
chunks1 = []
async for chunk in cached_provider.generate_stream(
prompt=prompt,
temperature=0.1,
):
chunks1.append(chunk)
# Should have called provider
assert len(mock_llm.calls) == 1
# Second streaming call (should hit cache)
chunks2 = []
async for chunk in cached_provider.generate_stream(
prompt=prompt,
temperature=0.1,
):
chunks2.append(chunk)
# Should not have made another call
assert len(mock_llm.calls) == 1
# Both should produce same content
assert "".join(chunks1) == "".join(chunks2)
@pytest.mark.asyncio
async def test_different_prompts_different_cache_entries(
self, cached_provider: CachingLLMProvider, mock_llm: MockLLMProvider
):
"""Test that different prompts get different cache entries."""
await cached_provider.generate(prompt="Question 1", temperature=0.1)
await cached_provider.generate(prompt="Question 2", temperature=0.1)
await cached_provider.generate(prompt="Question 1", temperature=0.1) # Cache hit
await cached_provider.generate(prompt="Question 2", temperature=0.1) # Cache hit
# Only 2 calls to provider (first two)
assert len(mock_llm.calls) == 2
# 2 cache hits
assert cached_provider.stats["hits"] == 2
class TestLLMCacheConfig:
"""Tests for LLMCacheConfig validation."""
def test_default_values(self):
"""Test default configuration values."""
config = LLMCacheConfig()
assert config.enabled is True
assert config.ttl_seconds == 604800 # 7 days
assert config.max_entries == 10000
assert config.similarity_threshold == 0.95
assert config.max_cacheable_temperature == 0.3
def test_ttl_validation_min(self):
"""Test that TTL has minimum bound."""
with pytest.raises(ValueError):
LLMCacheConfig(ttl_seconds=30) # Below 60
def test_ttl_validation_max(self):
"""Test that TTL has maximum bound."""
with pytest.raises(ValueError):
LLMCacheConfig(ttl_seconds=3000000) # Above 30 days
def test_similarity_threshold_bounds(self):
"""Test similarity threshold bounds."""
# Valid values
LLMCacheConfig(similarity_threshold=0.0)
LLMCacheConfig(similarity_threshold=1.0)
# Invalid values
with pytest.raises(ValueError):
LLMCacheConfig(similarity_threshold=-0.1)
with pytest.raises(ValueError):
LLMCacheConfig(similarity_threshold=1.1)
def test_max_entries_bounds(self):
"""Test max_entries bounds."""
with pytest.raises(ValueError):
LLMCacheConfig(max_entries=50) # Below 100
with pytest.raises(ValueError):
LLMCacheConfig(max_entries=200000) # Above 100000
class TestLLMCacheEdgeCases:
"""Tests for LLMCache edge cases and exception handling."""
@pytest.fixture
def cache_path(self, tmp_path: Path) -> Path:
"""Create a temporary cache path."""
return tmp_path / "test_cache.lance"
@pytest.fixture
def embedding_provider(self) -> MockEmbeddingProvider:
"""Create a mock embedding provider."""
return MockEmbeddingProvider()
@pytest.fixture
def config(self) -> LLMCacheConfig:
"""Create a cache config with default settings."""
return LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.95,
max_cacheable_temperature=0.3,
)
@pytest.fixture
def cache(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider, config: LLMCacheConfig
) -> LLMCache:
"""Create an LLMCache instance."""
return LLMCache(cache_path, embedding_provider, config)
@pytest.mark.asyncio
async def test_ensure_table_returns_existing_table(self, cache: LLMCache):
"""Test _ensure_table returns table when it already exists."""
# First, create an entry to create the table
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# Reset _table reference to force lookup
cache._table = None
# Call _ensure_table - should find existing table
table = cache._ensure_table(embedding_dim=384)
assert table is not None
@pytest.mark.asyncio
async def test_ensure_table_returns_none_when_no_table(self, cache: LLMCache):
"""Test _ensure_table returns None when table doesn't exist."""
# Never created any entries, so table doesn't exist
table = cache._ensure_table(embedding_dim=384)
assert table is None
@pytest.mark.asyncio
async def test_ensure_table_uses_cached_table(self, cache: LLMCache):
"""Test _ensure_table returns cached _table if already set."""
# Create entry to create table
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# _table should be set now
cached_table = cache._table
# Call _ensure_table again - should return same table
table = cache._ensure_table(embedding_dim=384)
assert table is cached_table
@pytest.mark.asyncio
async def test_get_table_when_table_exists(self, cache: LLMCache):
"""Test _get_table when table exists in database."""
# Create entry to create table
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# Reset _table to force lookup
cache._table = None
# _get_table should find it
table = cache._get_table()
assert table is not None
@pytest.mark.asyncio
async def test_get_exact_hash_exception_handling(self, cache: LLMCache):
"""Test that exact hash lookup exceptions are handled gracefully."""
# Create an entry first
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# Mock the table's search to raise an exception
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_search = MagicMock()
mock_search.where.return_value.limit.return_value.to_list.side_effect = RuntimeError(
"Database error"
)
mock_table.search.return_value = mock_search
mock_get_table.return_value = mock_table
# Should handle exception and continue to similarity search
result = await cache.get(prompt="test", temperature=0.1, model_name="m")
# Will miss because similarity search also uses the mocked table
assert result is None
@pytest.mark.asyncio
async def test_get_similarity_search_exception_handling(self, cache: LLMCache):
"""Test that similarity search exceptions are handled gracefully."""
# Create an entry first
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# Mock to simulate exact match miss then similarity search failure
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_search = MagicMock()
# Exact match returns empty
mock_search.where.return_value.limit.return_value.to_list.return_value = []
# Similarity search raises
mock_search.limit.return_value.to_list.side_effect = ValueError("Embedding error")
mock_table.search.return_value = mock_search
mock_get_table.return_value = mock_table
result = await cache.get(prompt="test", temperature=0.1, model_name="m")
assert result is None
assert cache.stats["misses"] >= 1
@pytest.mark.asyncio
async def test_similarity_search_with_model_matching(
self, cache_path: Path, config: LLMCacheConfig
):
"""Test similarity search only returns matches with same model."""
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
# Create entry with model A
await cache.set(
prompt="What is Python?",
response="Python is a programming language",
temperature=0.1,
model_name="model-a",
)
# Try to get with different model - should not match
result = await cache.get(
prompt="What is Python?", # Same prompt
temperature=0.1,
model_name="model-b", # Different model
)
# Should not hit cache (exact hash matches but not similarity with model check)
# Actually exact hash doesn't check model, so this will hit
# The similarity path checks model
@pytest.mark.asyncio
async def test_similarity_search_returns_valid_match(
self, cache_path: Path, config: LLMCacheConfig
):
"""Test similarity search returns valid match when model matches."""
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
# Create entry
await cache.set(
prompt="What is Python programming language?",
response="Python is a high-level language",
temperature=0.1,
model_name="model-a",
)
# Query with same model and similar prompt
# Exact hash won't match, so it will try similarity
result = await cache.get(
prompt="What is Python programming language?", # Same prompt for exact match
temperature=0.1,
model_name="model-a",
)
# Should hit cache via exact match
assert result == "Python is a high-level language"
@pytest.mark.asyncio
async def test_set_exception_handling(self, cache: LLMCache):
"""Test that set exceptions are handled gracefully."""
# Mock embedding provider to raise
with patch.object(cache.embedding_provider, "embed", side_effect=ValueError("Bad input")):
# Should not raise, just log warning
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# Cache should still be functional
assert cache.get_entry_count() == 0
@pytest.mark.asyncio
async def test_set_index_creation_exception(self, cache: LLMCache):
"""Test that index creation exceptions are handled."""
# Create first entry which creates table
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# Verify entry was created
assert cache.get_entry_count() == 1
@pytest.mark.asyncio
async def test_record_hit_when_table_none(self, cache: LLMCache):
"""Test _record_hit does nothing when table is None."""
# Never created table
cache._table = None
# Should not raise
await cache._record_hit("some-id")
@pytest.mark.asyncio
async def test_record_hit_exception_handling(self, cache: LLMCache):
"""Test _record_hit handles exceptions gracefully."""
# Create an entry
await cache.set(prompt="test", response="response", temperature=0.1, model_name="m")
# _record_hit currently does nothing (pass) so this just tests it doesn't raise
await cache._record_hit("non-existent-id")
@pytest.mark.asyncio
async def test_maybe_evict_when_table_none(self, cache: LLMCache):
"""Test _maybe_evict does nothing when table is None."""
cache._table = None
# Should not raise
await cache._maybe_evict()
@pytest.mark.asyncio
async def test_maybe_evict_under_max_entries(self, cache: LLMCache):
"""Test _maybe_evict does nothing when under max_entries."""
# Create a few entries (well under max of 1000)
await cache.set(prompt="p1", response="r1", temperature=0.1, model_name="m")
await cache.set(prompt="p2", response="r2", temperature=0.1, model_name="m")
count_before = cache.get_entry_count()
# Manually call _maybe_evict
await cache._maybe_evict()
# Count should be same (no eviction needed)
assert cache.get_entry_count() == count_before
@pytest.mark.asyncio
async def test_maybe_evict_removes_expired_entries(self, cache_path: Path):
"""Test _maybe_evict removes expired entries."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=60, # Minimum allowed
max_entries=100, # Minimum allowed
similarity_threshold=0.95,
max_cacheable_temperature=0.3,
)
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
# Create entries
await cache.set(prompt="p1", response="r1", temperature=0.1, model_name="m")
# Manually expire the entry by manipulating created_at
# This is tricky with LanceDB, so we test the path differently
# For now, verify eviction logic is called without error
await cache._maybe_evict()
@pytest.mark.asyncio
async def test_maybe_evict_exception_handling(self, cache: LLMCache):
"""Test _maybe_evict handles exceptions gracefully."""
# Create entries
await cache.set(prompt="p1", response="r1", temperature=0.1, model_name="m")
# Mock table to raise during eviction
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_table.count_rows.side_effect = RuntimeError("DB error")
mock_get_table.return_value = mock_table
# Should not raise
await cache._maybe_evict()
@pytest.mark.asyncio
async def test_clear_exception_handling(self, cache: LLMCache):
"""Test clear handles exceptions gracefully."""
# Create entries
await cache.set(prompt="p1", response="r1", temperature=0.1, model_name="m")
# Mock to raise during clear
with patch.object(cache, "_connect") as mock_connect:
mock_db = MagicMock()
mock_db.list_tables.return_value.tables = [cache.TABLE_NAME]
mock_db.open_table.side_effect = OSError("File error")
mock_connect.return_value = mock_db
result = await cache.clear()
# Should return 0 on error
assert result == 0
def test_get_entry_count_exception_handling(self, cache: LLMCache):
"""Test get_entry_count handles exceptions gracefully."""
# Mock _get_table to return a table that raises
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_table.count_rows.side_effect = RuntimeError("DB error")
mock_get_table.return_value = mock_table
result = cache.get_entry_count()
# Should return 0 on error
assert result == 0
def test_get_entry_count_when_table_none(self, cache: LLMCache):
"""Test get_entry_count returns 0 when table is None."""
# Never created table
result = cache.get_entry_count()
assert result == 0
@pytest.mark.asyncio
async def test_clear_when_table_not_exists(self, cache: LLMCache):
"""Test clear returns 0 when table doesn't exist."""
# Never created any entries
result = await cache.clear()
assert result == 0
@pytest.mark.asyncio
async def test_is_valid_entry_expired(self, cache: LLMCache):
"""Test _is_valid_entry returns False for expired entries."""
# Create an entry dict that is expired
expired_entry = {
"created_at": time.time() - 10000, # 10000 seconds ago
"ttl_seconds": 60, # Only valid for 60 seconds
}
result = cache._is_valid_entry(expired_entry)
assert result is False
@pytest.mark.asyncio
async def test_is_valid_entry_valid(self, cache: LLMCache):
"""Test _is_valid_entry returns True for valid entries."""
valid_entry = {
"created_at": time.time(), # Just created
"ttl_seconds": 3600, # Valid for 1 hour
}
result = cache._is_valid_entry(valid_entry)
assert result is True
@pytest.mark.asyncio
async def test_is_valid_entry_uses_config_ttl_as_default(self, cache: LLMCache):
"""Test _is_valid_entry uses config ttl when entry has no ttl_seconds."""
entry = {
"created_at": time.time(),
# No ttl_seconds field
}
result = cache._is_valid_entry(entry)
# Should be valid since just created and config ttl is 3600
assert result is True
@pytest.mark.asyncio
async def test_compute_hash_with_none_system_prompt(self, cache: LLMCache):
"""Test _compute_hash handles None system_prompt."""
hash1 = cache._compute_hash(None, "test prompt")
hash2 = cache._compute_hash("", "test prompt")
# Both should produce consistent results
assert len(hash1) == 64 # SHA256 hex length
assert len(hash2) == 64
@pytest.mark.asyncio
async def test_similarity_search_checks_validity(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider
):
"""Test similarity search only returns valid (non-expired) entries."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=60,
max_entries=1000,
similarity_threshold=0.0, # Accept any similarity for test
max_cacheable_temperature=0.5,
)
cache = LLMCache(cache_path, embedding_provider, config)
# Create entry
await cache.set(
prompt="test query",
response="test response",
temperature=0.1,
model_name="test-model",
)
# Mock the search result to appear expired
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_search = MagicMock()
# Exact match returns empty
mock_search.where.return_value.limit.return_value.to_list.return_value = []
# Similarity returns expired entry
mock_search.limit.return_value.to_list.return_value = [
{
"id": "test-id",
"_distance": 0.01, # Very close
"model_name": "test-model",
"response": "cached response",
"created_at": time.time() - 10000, # Expired
"ttl_seconds": 60,
}
]
mock_table.search.return_value = mock_search
mock_get_table.return_value = mock_table
result = await cache.get(
prompt="test query",
temperature=0.1,
model_name="test-model",
)
# Should not return expired entry
assert result is None
@pytest.mark.asyncio
async def test_similarity_search_model_mismatch(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider
):
"""Test similarity search rejects entries with different model."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.0, # Accept any similarity for test
max_cacheable_temperature=0.5,
)
cache = LLMCache(cache_path, embedding_provider, config)
# Create entry with one model
await cache.set(
prompt="test query",
response="test response",
temperature=0.1,
model_name="model-a",
)
# Mock the search result with different model
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_search = MagicMock()
# Exact match returns empty
mock_search.where.return_value.limit.return_value.to_list.return_value = []
# Similarity returns entry with different model
mock_search.limit.return_value.to_list.return_value = [
{
"id": "test-id",
"_distance": 0.01, # Very close
"model_name": "model-b", # Different model
"response": "cached response",
"created_at": time.time(),
"ttl_seconds": 3600,
}
]
mock_table.search.return_value = mock_search
mock_get_table.return_value = mock_table
result = await cache.get(
prompt="test query",
temperature=0.1,
model_name="model-a", # Looking for model-a
)
# Should not return entry with different model
assert result is None
@pytest.mark.asyncio
async def test_eviction_deletes_expired_entries(self, cache_path: Path):
"""Test that eviction actually deletes expired entries."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=60,
max_entries=100, # Low to trigger eviction check
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
# Mock the table with many entries including expired ones
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_table.count_rows.return_value = 150 # Over limit
# Return mix of expired and valid entries
mock_table.search.return_value.limit.return_value.to_list.return_value = [
{"id": "expired-1", "created_at": time.time() - 10000, "ttl_seconds": 60},
{"id": "valid-1", "created_at": time.time(), "ttl_seconds": 3600},
{"id": "expired-2", "created_at": time.time() - 5000, "ttl_seconds": 60},
]
mock_get_table.return_value = mock_table
await cache._maybe_evict()
# Verify delete was called for expired entries
assert mock_table.delete.call_count >= 1
@pytest.mark.asyncio
async def test_eviction_delete_individual_failure(self, cache_path: Path):
"""Test that eviction continues even if individual deletes fail."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=60,
max_entries=100,
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_table.count_rows.return_value = 150
mock_table.search.return_value.limit.return_value.to_list.return_value = [
{"id": "expired-1", "created_at": time.time() - 10000, "ttl_seconds": 60},
{"id": "expired-2", "created_at": time.time() - 5000, "ttl_seconds": 60},
]
# First delete succeeds, second fails
mock_table.delete.side_effect = [None, RuntimeError("Delete failed")]
mock_get_table.return_value = mock_table
# Should not raise even when some deletes fail
await cache._maybe_evict()
@pytest.mark.asyncio
async def test_similarity_search_hit_returns_response(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider
):
"""Test similarity search successfully returns cached response."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.5, # Low threshold for test
max_cacheable_temperature=0.5,
)
cache = LLMCache(cache_path, embedding_provider, config)
# Mock the get to simulate similarity search hit
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_search = MagicMock()
# Exact match returns empty (no exact hash match)
mock_search.where.return_value.limit.return_value.to_list.return_value = []
# Similarity search returns valid entry
mock_search.limit.return_value.to_list.return_value = [
{
"id": "test-id-12345678",
"_distance": 0.1, # Low distance = high similarity (1 - 0.1 = 0.9 > 0.5 threshold)
"model_name": "test-model",
"response": "This is the cached response",
"created_at": time.time(), # Not expired
"ttl_seconds": 3600,
}
]
mock_table.search.return_value = mock_search
mock_get_table.return_value = mock_table
# Perform the get
result = await cache.get(
prompt="test query",
temperature=0.1,
model_name="test-model",
)
# Should return the cached response from similarity search
assert result == "This is the cached response"
assert cache.stats["hits"] == 1
@pytest.mark.asyncio
async def test_set_index_creation_failure(self, cache_path: Path):
"""Test that index creation failure is handled gracefully."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
# Mock to simulate index creation failure
with patch.object(cache, "_connect") as mock_connect:
mock_db = MagicMock()
# Table doesn't exist yet
mock_db.list_tables.return_value.tables = []
# Create table succeeds
mock_table = MagicMock()
mock_db.create_table.return_value = mock_table
# Index creation fails
mock_table.create_scalar_index.side_effect = ValueError("Index creation failed")
# For _maybe_evict
mock_table.count_rows.return_value = 1
mock_connect.return_value = mock_db
# Should not raise, index failure is caught
await cache.set(
prompt="test",
response="response",
temperature=0.1,
model_name="m",
)
# Verify create_scalar_index was called (even though it failed)
mock_table.create_scalar_index.assert_called_once_with("exact_hash")
@pytest.mark.asyncio
async def test_record_hit_updates_entry(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider
):
"""Test that _record_hit properly updates hit_count and last_hit_at."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
cache = LLMCache(cache_path, embedding_provider, config)
# Create an entry
await cache.set(
prompt="test prompt",
response="test response",
temperature=0.1,
model_name="test-model",
)
# Get the entry to find its ID
table = cache._get_table()
assert table is not None
entries = table.search().limit(1).to_list()
assert len(entries) == 1
entry_id = entries[0]["id"]
original_hit_count = entries[0].get("hit_count", 0)
# Record a hit
await cache._record_hit(entry_id)
# Verify the entry was updated
updated_entries = table.search().where(f"id = '{entry_id}'").limit(1).to_list()
assert len(updated_entries) == 1
assert updated_entries[0]["hit_count"] == original_hit_count + 1
assert updated_entries[0]["last_hit_at"] > entries[0]["last_hit_at"]
@pytest.mark.asyncio
async def test_record_hit_nonexistent_entry(
self, cache_path: Path, embedding_provider: MockEmbeddingProvider
):
"""Test that _record_hit handles nonexistent entries gracefully."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=1000,
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
cache = LLMCache(cache_path, embedding_provider, config)
# Create the table with an entry
await cache.set(
prompt="test",
response="response",
temperature=0.1,
model_name="m",
)
# Try to record hit for nonexistent entry - should not raise
await cache._record_hit("nonexistent-id-12345")
@pytest.mark.asyncio
async def test_lru_eviction_when_over_limit(self, cache_path: Path):
"""Test that LRU eviction removes oldest entries when over max_entries."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=3600,
max_entries=100, # Low limit for testing
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
# Mock the table with entries over limit (all valid, none expired)
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_table.count_rows.return_value = 150 # Over max_entries of 100
# Create entries with varying last_hit_at times
now = time.time()
entries = [
{
"id": f"entry-{i}",
"created_at": now - 1000,
"last_hit_at": now - (100 - i), # entry-0 is oldest, entry-99 is newest
"ttl_seconds": 3600,
}
for i in range(150)
]
mock_search = MagicMock()
mock_search.limit.return_value.to_list.return_value = entries
mock_table.search.return_value = mock_search
# Track deleted entries
deleted_ids = []
def track_delete(filter_expr):
# Extract ID from filter like "id = 'entry-5'"
import re
match = re.search(r"id = '([^']+)'", filter_expr)
if match:
deleted_ids.append(match.group(1))
mock_table.delete.side_effect = track_delete
mock_get_table.return_value = mock_table
# Trigger eviction
await cache._maybe_evict()
# Should have deleted oldest entries (LRU)
# Target is 80% of max_entries = 80, so need to delete 150 - 80 = 70
assert len(deleted_ids) > 0
# The oldest entries (lowest last_hit_at) should be deleted first
# entry-0 through entry-69 should be deleted (they have oldest last_hit_at)
for deleted_id in deleted_ids:
entry_num = int(deleted_id.split("-")[1])
# Should be from the oldest entries
assert entry_num < 70, f"Entry {deleted_id} should not have been evicted (not old enough)"
@pytest.mark.asyncio
async def test_eviction_prefers_expired_over_lru(self, cache_path: Path):
"""Test that eviction removes expired entries before LRU entries."""
config = LLMCacheConfig(
enabled=True,
ttl_seconds=60,
max_entries=100,
similarity_threshold=0.95,
max_cacheable_temperature=0.5,
)
embedding_provider = MockEmbeddingProvider()
cache = LLMCache(cache_path, embedding_provider, config)
with patch.object(cache, "_get_table") as mock_get_table:
mock_table = MagicMock()
mock_table.count_rows.return_value = 120 # Over limit
now = time.time()
# Mix of expired and valid entries
entries = [
# 20 expired entries (created long ago)
*[
{
"id": f"expired-{i}",
"created_at": now - 10000, # Long ago
"last_hit_at": now - 1, # Recent hit (would be kept by LRU)
"ttl_seconds": 60, # But expired
}
for i in range(20)
],
# 100 valid entries
*[
{
"id": f"valid-{i}",
"created_at": now - 30, # Recent
"last_hit_at": now - i, # Varying recency
"ttl_seconds": 3600,
}
for i in range(100)
],
]
mock_search = MagicMock()
mock_search.limit.return_value.to_list.return_value = entries
mock_table.search.return_value = mock_search
deleted_ids = []
mock_table.delete.side_effect = lambda f: deleted_ids.append(f)
mock_get_table.return_value = mock_table
await cache._maybe_evict()
# All 20 expired entries should be deleted
expired_deleted = sum(1 for d in deleted_ids if "expired-" in d)
assert expired_deleted == 20, "All expired entries should be deleted first"