cache_manager.pyβ’16.4 kB
"""
Multi-level cache manager
TDD Green Phase: Implement to pass tests
"""
import asyncio
import hashlib
import logging
from typing import Any, Dict, List, Optional, Set, Union
from datetime import datetime, timedelta
from .memory_cache import MemoryCache
from .redis_cache import RedisCache
from ..exceptions import CacheError
class CacheManager:
"""Multi-level cache manager with L1 (memory) and L2 (Redis) caches"""
def __init__(
self,
l1_cache: Optional[MemoryCache] = None,
l2_cache: Optional[RedisCache] = None,
enable_l1: bool = True,
enable_l2: bool = True,
l1_to_l2_ratio: float = 0.8, # When to promote from L2 to L1
stats_interval: int = 300 # Stats logging interval in seconds
):
"""Initialize cache manager
Args:
l1_cache: Memory cache instance (L1)
l2_cache: Redis cache instance (L2)
enable_l1: Enable L1 cache
enable_l2: Enable L2 cache
l1_to_l2_ratio: Hit ratio threshold to promote L2 items to L1
stats_interval: Interval for stats logging
"""
self.l1_cache = l1_cache
self.l2_cache = l2_cache
self.enable_l1 = enable_l1 and l1_cache is not None
self.enable_l2 = enable_l2 and l2_cache is not None
self.l1_to_l2_ratio = l1_to_l2_ratio
self.stats_interval = stats_interval
self.logger = logging.getLogger("mcp_stock_details.cache_manager")
# Overall stats
self._total_gets = 0
self._l1_hits = 0
self._l2_hits = 0
self._misses = 0
self._sets = 0
self._deletes = 0
self._invalidations = 0
# Key access tracking for promotion decisions
self._key_access_count: Dict[str, int] = {}
# Background tasks (will be started on first async operation)
self._stats_task = None
async def get(self, key: str, use_l1: bool = True, use_l2: bool = True) -> Optional[Any]:
"""Get value from cache with multi-level lookup
Args:
key: Cache key
use_l1: Use L1 cache
use_l2: Use L2 cache
Returns:
Cached value or None if not found
"""
# Ensure background tasks are running
await self._ensure_background_tasks()
self._total_gets += 1
# Try L1 cache first
if self.enable_l1 and use_l1 and self.l1_cache:
try:
value = await self.l1_cache.get(key)
if value is not None:
self._l1_hits += 1
self._track_key_access(key)
self.logger.debug(f"L1 cache hit for key: {key}")
return value
except Exception as e:
self.logger.warning(f"L1 cache error for key {key}: {e}")
# Try L2 cache
if self.enable_l2 and use_l2 and self.l2_cache:
try:
value = await self.l2_cache.get(key)
if value is not None:
self._l2_hits += 1
self._track_key_access(key)
# Always promote to L1 for better performance on second access
if self.enable_l1 and self.l1_cache:
try:
await self.l1_cache.set(key, value)
self.logger.debug(f"Promoted key {key} from L2 to L1")
except Exception as e:
self.logger.warning(f"Failed to promote key {key} to L1: {e}")
self.logger.debug(f"L2 cache hit for key: {key}")
return value
except Exception as e:
self.logger.warning(f"L2 cache error for key {key}: {e}")
# Cache miss
self._misses += 1
self.logger.debug(f"Cache miss for key: {key}")
return None
async def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None,
l1_ttl: Optional[int] = None,
l2_ttl: Optional[int] = None,
l1_only: bool = False,
l2_only: bool = False
) -> bool:
"""Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Default TTL for both levels
l1_ttl: Specific TTL for L1 cache
l2_ttl: Specific TTL for L2 cache
l1_only: Store in L1 only
l2_only: Store in L2 only
Returns:
True if set in at least one cache level
"""
# Ensure background tasks are running
await self._ensure_background_tasks()
self._sets += 1
success = False
# Determine TTLs
l1_ttl = l1_ttl or ttl
l2_ttl = l2_ttl or ttl
# Set in L1 cache
if not l2_only and self.enable_l1 and self.l1_cache:
try:
if await self.l1_cache.set(key, value, l1_ttl):
success = True
self.logger.debug(f"Set key {key} in L1 cache")
else:
self.logger.warning(f"Failed to set key {key} in L1 cache")
except Exception as e:
self.logger.error(f"Error setting key {key} in L1: {e}")
# Set in L2 cache
if not l1_only and self.enable_l2 and self.l2_cache:
try:
if await self.l2_cache.set(key, value, l2_ttl):
success = True
self.logger.debug(f"Set key {key} in L2 cache")
else:
self.logger.warning(f"Failed to set key {key} in L2 cache")
except Exception as e:
self.logger.error(f"Error setting key {key} in L2: {e}")
if not success:
raise CacheError(f"Failed to set key {key} in any cache level")
return success
async def delete(self, key: str) -> bool:
"""Delete key from all cache levels
Args:
key: Cache key to delete
Returns:
True if deleted from at least one level
"""
self._deletes += 1
deleted = False
# Delete from L1
if self.enable_l1 and self.l1_cache:
try:
if await self.l1_cache.delete(key):
deleted = True
self.logger.debug(f"Deleted key {key} from L1 cache")
except Exception as e:
self.logger.warning(f"Error deleting key {key} from L1: {e}")
# Delete from L2
if self.enable_l2 and self.l2_cache:
try:
if await self.l2_cache.delete(key):
deleted = True
self.logger.debug(f"Deleted key {key} from L2 cache")
except Exception as e:
self.logger.warning(f"Error deleting key {key} from L2: {e}")
# Remove from access tracking
self._key_access_count.pop(key, None)
return deleted
async def clear(self, pattern: Optional[str] = None) -> Dict[str, int]:
"""Clear cache entries from all levels
Args:
pattern: Optional pattern to match keys
Returns:
Dictionary with counts of deleted keys per level
"""
result = {"l1": 0, "l2": 0}
# Clear L1
if self.enable_l1 and self.l1_cache:
try:
if pattern:
# Get matching keys and delete them
keys = await self.l1_cache.keys(pattern)
for key in keys:
await self.l1_cache.delete(key)
result["l1"] = len(keys)
else:
await self.l1_cache.clear()
result["l1"] = -1 # Indicate full clear
self.logger.info(f"Cleared L1 cache: {result['l1']} keys")
except Exception as e:
self.logger.error(f"Error clearing L1 cache: {e}")
# Clear L2
if self.enable_l2 and self.l2_cache:
try:
result["l2"] = await self.l2_cache.clear(pattern)
self.logger.info(f"Cleared L2 cache: {result['l2']} keys")
except Exception as e:
self.logger.error(f"Error clearing L2 cache: {e}")
# Clear access tracking
if pattern:
keys_to_remove = [k for k in self._key_access_count.keys() if pattern in k]
for key in keys_to_remove:
del self._key_access_count[key]
else:
self._key_access_count.clear()
return result
async def exists(self, key: str) -> bool:
"""Check if key exists in any cache level
Args:
key: Cache key
Returns:
True if exists in any level
"""
# Check L1 first
if self.enable_l1 and self.l1_cache:
try:
if await self.l1_cache.exists(key):
return True
except Exception as e:
self.logger.warning(f"Error checking L1 existence for key {key}: {e}")
# Check L2
if self.enable_l2 and self.l2_cache:
try:
if await self.l2_cache.exists(key):
return True
except Exception as e:
self.logger.warning(f"Error checking L2 existence for key {key}: {e}")
return False
async def keys(self, pattern: Optional[str] = None) -> Set[str]:
"""Get all cache keys from all levels
Args:
pattern: Optional pattern to match
Returns:
Set of all matching keys across all levels
"""
all_keys = set()
# Get keys from L1
if self.enable_l1 and self.l1_cache:
try:
l1_keys = await self.l1_cache.keys(pattern)
all_keys.update(l1_keys)
except Exception as e:
self.logger.warning(f"Error getting L1 keys: {e}")
# Get keys from L2
if self.enable_l2 and self.l2_cache:
try:
l2_keys = await self.l2_cache.keys(pattern)
all_keys.update(l2_keys)
except Exception as e:
self.logger.warning(f"Error getting L2 keys: {e}")
return all_keys
async def invalidate(self, tags: List[str]) -> int:
"""Invalidate cache entries by tags
Args:
tags: List of tags to invalidate
Returns:
Number of keys invalidated
"""
self._invalidations += 1
total_invalidated = 0
for tag in tags:
# Use tag as pattern to find matching keys
pattern = f"*{tag}*"
try:
matching_keys = await self.keys(pattern)
for key in matching_keys:
if await self.delete(key):
total_invalidated += 1
self.logger.info(f"Invalidated {len(matching_keys)} keys for tag: {tag}")
except Exception as e:
self.logger.error(f"Error invalidating tag {tag}: {e}")
return total_invalidated
def generate_cache_key(self, prefix: str, *args, **kwargs) -> str:
"""Generate a standardized cache key
Args:
prefix: Key prefix
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Generated cache key
"""
# Create a string representation of all arguments
key_parts = [prefix]
# Add positional arguments
for arg in args:
if isinstance(arg, (str, int, float, bool)):
key_parts.append(str(arg))
else:
key_parts.append(str(hash(str(arg))))
# Add keyword arguments (sorted for consistency)
for k, v in sorted(kwargs.items()):
if isinstance(v, (str, int, float, bool)):
key_parts.append(f"{k}:{v}")
else:
key_parts.append(f"{k}:{hash(str(v))}")
# Join and hash for consistent length
key_string = "|".join(key_parts)
key_hash = hashlib.md5(key_string.encode()).hexdigest()
return f"{prefix}:{key_hash}"
def stats(self) -> Dict[str, Any]:
"""Get comprehensive cache statistics
Returns:
Dictionary with cache statistics
"""
hit_rate = (self._l1_hits + self._l2_hits) / self._total_gets if self._total_gets > 0 else 0
l1_hit_rate = self._l1_hits / self._total_gets if self._total_gets > 0 else 0
l2_hit_rate = self._l2_hits / self._total_gets if self._total_gets > 0 else 0
stats = {
"total_gets": self._total_gets,
"l1_hits": self._l1_hits,
"l2_hits": self._l2_hits,
"misses": self._misses,
"sets": self._sets,
"deletes": self._deletes,
"invalidations": self._invalidations,
"hit_rate": round(hit_rate, 3),
"l1_hit_rate": round(l1_hit_rate, 3),
"l2_hit_rate": round(l2_hit_rate, 3),
"l1_enabled": self.enable_l1,
"l2_enabled": self.enable_l2
}
# Add individual cache stats
if self.enable_l1 and self.l1_cache:
stats["l1_cache"] = self.l1_cache.stats()
if self.enable_l2 and self.l2_cache:
stats["l2_cache"] = self.l2_cache.stats()
return stats
def _track_key_access(self, key: str) -> None:
"""Track key access for promotion decisions"""
self._key_access_count[key] = self._key_access_count.get(key, 0) + 1
def _should_promote_to_l1(self, key: str) -> bool:
"""Determine if a key should be promoted from L2 to L1"""
access_count = self._key_access_count.get(key, 0)
# Simple promotion logic: promote after 2+ accesses
return access_count >= 2
async def _ensure_background_tasks(self) -> None:
"""Ensure background tasks are running"""
if self.stats_interval > 0 and (self._stats_task is None or self._stats_task.done()):
try:
self._stats_task = asyncio.create_task(self._log_stats_periodically())
except RuntimeError:
# No event loop running, stats will be logged manually
pass
async def _log_stats_periodically(self) -> None:
"""Background task to log statistics periodically"""
while True:
try:
await asyncio.sleep(self.stats_interval)
stats = self.stats()
self.logger.info(f"Cache stats: {stats}")
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error logging cache stats: {e}")
await asyncio.sleep(self.stats_interval)
async def close(self) -> None:
"""Close cache manager and all cache instances"""
# Cancel stats task
if self._stats_task:
self._stats_task.cancel()
try:
await self._stats_task
except asyncio.CancelledError:
pass
# Close cache instances
if self.l1_cache:
await self.l1_cache.close()
if self.l2_cache:
await self.l2_cache.close()
self.logger.info("Cache manager closed")
# Context manager support
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.close()