"""
Cache Service
Provides caching for slot resolution results to improve performance.
Implements request-level memoization and LRU caching.
"""
import logging
import time
from typing import Optional, Dict, Any
from functools import lru_cache
from slot_resolution.core.models import SlotResolutionResponse
logger = logging.getLogger(__name__)
class CacheService:
"""
Service for caching slot resolution results.
Implements two levels of caching:
1. Request-level memoization (for duplicate lookups within same request)
2. LRU cache with TTL (for cross-request caching)
"""
def __init__(self, ttl_seconds: int = 900, max_size: int = 1000):
"""
Initialize the cache service.
Args:
ttl_seconds: Time-to-live for cache entries (default: 15 minutes)
max_size: Maximum number of cache entries
"""
self.ttl_seconds = ttl_seconds
self.max_size = max_size
self._cache: Dict[str, Dict[str, Any]] = {}
logger.info(
f"CacheService initialized with TTL={ttl_seconds}s, max_size={max_size}"
)
def get(
self,
tenant_id: str,
entity_type: str,
normalized_query: str
) -> Optional[SlotResolutionResponse]:
"""
Get cached resolution result.
Args:
tenant_id: Tenant identifier
entity_type: Type of entity
normalized_query: Normalized search query
Returns:
Cached SlotResolutionResponse if found and not expired, None otherwise
"""
cache_key = self._make_cache_key(tenant_id, entity_type, normalized_query)
if cache_key not in self._cache:
return None
entry = self._cache[cache_key]
# Check if expired
if time.time() - entry["timestamp"] > self.ttl_seconds:
logger.debug(f"Cache entry expired: {cache_key}")
del self._cache[cache_key]
return None
logger.debug(f"Cache hit: {cache_key}")
return entry["response"]
def put(
self,
tenant_id: str,
entity_type: str,
normalized_query: str,
response: SlotResolutionResponse
):
"""
Store resolution result in cache.
Args:
tenant_id: Tenant identifier
entity_type: Type of entity
normalized_query: Normalized search query
response: Resolution response to cache
"""
cache_key = self._make_cache_key(tenant_id, entity_type, normalized_query)
# Evict oldest entry if cache is full
if len(self._cache) >= self.max_size:
self._evict_oldest()
self._cache[cache_key] = {
"response": response,
"timestamp": time.time()
}
logger.debug(f"Cache put: {cache_key}")
def invalidate(
self,
tenant_id: Optional[str] = None,
entity_type: Optional[str] = None
):
"""
Invalidate cache entries.
Args:
tenant_id: If specified, invalidate only for this tenant
entity_type: If specified, invalidate only for this entity type
"""
if tenant_id is None and entity_type is None:
# Clear all
self._cache.clear()
logger.info("Cleared all cache")
return
# Selective invalidation
keys_to_remove = []
for key in self._cache.keys():
parts = key.split(":")
if len(parts) >= 3:
key_tenant = parts[0]
key_entity = parts[1]
if tenant_id and key_tenant != tenant_id:
continue
if entity_type and key_entity != entity_type:
continue
keys_to_remove.append(key)
for key in keys_to_remove:
del self._cache[key]
logger.info(
f"Invalidated {len(keys_to_remove)} cache entries "
f"(tenant={tenant_id}, entity_type={entity_type})"
)
def _make_cache_key(
self,
tenant_id: str,
entity_type: str,
normalized_query: str
) -> str:
"""Create cache key from parameters."""
return f"{tenant_id}:{entity_type}:{normalized_query}"
def _evict_oldest(self):
"""Evict the oldest cache entry."""
if not self._cache:
return
oldest_key = min(
self._cache.keys(),
key=lambda k: self._cache[k]["timestamp"]
)
del self._cache[oldest_key]
logger.debug(f"Evicted oldest cache entry: {oldest_key}")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache stats
"""
return {
"size": len(self._cache),
"max_size": self.max_size,
"ttl_seconds": self.ttl_seconds
}