"""
Enterprise-grade caching system for MCP
Supports multiple backends with intelligent cache warming and invalidation
"""
import asyncio
import json
import hashlib
import logging
import time
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List, Union, Callable
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from enum import Enum
import pickle
# Optional dependencies for different cache backends
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
try:
import memcache
MEMCACHE_AVAILABLE = True
except ImportError:
MEMCACHE_AVAILABLE = False
logger = logging.getLogger(__name__)
class CacheBackend(Enum):
"""Available cache backend types"""
MEMORY = "memory"
REDIS = "redis"
MEMCACHE = "memcache"
HYBRID = "hybrid"
class CacheStrategy(Enum):
"""Cache eviction strategies"""
LRU = "lru"
LFU = "lfu"
TTL = "ttl"
ADAPTIVE = "adaptive"
@dataclass
class CacheEntry:
"""Cache entry with metadata"""
key: str
value: Any
created_at: datetime
last_accessed: datetime
access_count: int
ttl_seconds: Optional[int] = None
tags: List[str] = None
size_bytes: int = 0
def __post_init__(self):
if self.tags is None:
self.tags = []
@property
def is_expired(self) -> bool:
"""Check if cache entry is expired"""
if self.ttl_seconds is None:
return False
return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds
@property
def age_seconds(self) -> float:
"""Get age of cache entry in seconds"""
return (datetime.now() - self.created_at).total_seconds()
class CacheBackendInterface(ABC):
"""Abstract interface for cache backends"""
@abstractmethod
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
pass
@abstractmethod
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache"""
pass
@abstractmethod
async def delete(self, key: str) -> bool:
"""Delete value from cache"""
pass
@abstractmethod
async def clear(self) -> bool:
"""Clear all cache entries"""
pass
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
pass
@abstractmethod
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
pass
class MemoryCacheBackend(CacheBackendInterface):
"""In-memory cache backend with LRU eviction"""
def __init__(self, max_size: int = 1000, default_ttl: int = 3600):
self.max_size = max_size
self.default_ttl = default_ttl
self.cache: Dict[str, CacheEntry] = {}
self.access_order: List[str] = [] # For LRU
self.stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'deletes': 0,
'evictions': 0
}
async def get(self, key: str) -> Optional[Any]:
"""Get value from memory cache"""
if key not in self.cache:
self.stats['misses'] += 1
return None
entry = self.cache[key]
# Check expiration
if entry.is_expired:
await self.delete(key)
self.stats['misses'] += 1
return None
# Update access metadata
entry.last_accessed = datetime.now()
entry.access_count += 1
# Update LRU order
if key in self.access_order:
self.access_order.remove(key)
self.access_order.append(key)
self.stats['hits'] += 1
return entry.value
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in memory cache"""
try:
# Calculate size
size_bytes = len(pickle.dumps(value))
# Create cache entry
entry = CacheEntry(
key=key,
value=value,
created_at=datetime.now(),
last_accessed=datetime.now(),
access_count=0,
ttl_seconds=ttl or self.default_ttl,
size_bytes=size_bytes
)
# Evict if necessary
await self._evict_if_needed()
# Store entry
self.cache[key] = entry
if key not in self.access_order:
self.access_order.append(key)
self.stats['sets'] += 1
return True
except Exception as e:
logger.error(f"Failed to set cache entry {key}: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete value from memory cache"""
if key in self.cache:
del self.cache[key]
if key in self.access_order:
self.access_order.remove(key)
self.stats['deletes'] += 1
return True
return False
async def clear(self) -> bool:
"""Clear all cache entries"""
self.cache.clear()
self.access_order.clear()
return True
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
return key in self.cache and not self.cache[key].is_expired
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total_requests = self.stats['hits'] + self.stats['misses']
hit_rate = (self.stats['hits'] / total_requests * 100) if total_requests > 0 else 0
return {
**self.stats,
'hit_rate_percent': hit_rate,
'total_entries': len(self.cache),
'max_size': self.max_size,
'memory_usage_bytes': sum(entry.size_bytes for entry in self.cache.values())
}
async def _evict_if_needed(self):
"""Evict entries if cache is full"""
while len(self.cache) >= self.max_size and self.access_order:
# Remove LRU entry
lru_key = self.access_order[0]
await self.delete(lru_key)
self.stats['evictions'] += 1
class RedisCacheBackend(CacheBackendInterface):
"""Redis cache backend"""
def __init__(self, redis_url: str = "redis://localhost:6379", key_prefix: str = "mcp:"):
if not REDIS_AVAILABLE:
raise ImportError("Redis not available. Install with: pip install redis")
self.redis_url = redis_url
self.key_prefix = key_prefix
self.redis_client: Optional[redis.Redis] = None
async def _ensure_connection(self):
"""Ensure Redis connection is established"""
if self.redis_client is None:
self.redis_client = redis.from_url(self.redis_url, decode_responses=False)
def _make_key(self, key: str) -> str:
"""Add prefix to key"""
return f"{self.key_prefix}{key}"
async def get(self, key: str) -> Optional[Any]:
"""Get value from Redis cache"""
try:
await self._ensure_connection()
data = self.redis_client.get(self._make_key(key))
if data is None:
return None
return pickle.loads(data)
except Exception as e:
logger.error(f"Redis get error for key {key}: {e}")
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in Redis cache"""
try:
await self._ensure_connection()
data = pickle.dumps(value)
redis_key = self._make_key(key)
if ttl:
return self.redis_client.setex(redis_key, ttl, data)
else:
return self.redis_client.set(redis_key, data)
except Exception as e:
logger.error(f"Redis set error for key {key}: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete value from Redis cache"""
try:
await self._ensure_connection()
result = self.redis_client.delete(self._make_key(key))
return result > 0
except Exception as e:
logger.error(f"Redis delete error for key {key}: {e}")
return False
async def clear(self) -> bool:
"""Clear all cache entries with prefix"""
try:
await self._ensure_connection()
keys = self.redis_client.keys(f"{self.key_prefix}*")
if keys:
return self.redis_client.delete(*keys) > 0
return True
except Exception as e:
logger.error(f"Redis clear error: {e}")
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in Redis cache"""
try:
await self._ensure_connection()
return self.redis_client.exists(self._make_key(key)) > 0
except Exception as e:
logger.error(f"Redis exists error for key {key}: {e}")
return False
async def get_stats(self) -> Dict[str, Any]:
"""Get Redis cache statistics"""
try:
await self._ensure_connection()
info = self.redis_client.info()
keys_count = len(self.redis_client.keys(f"{self.key_prefix}*"))
return {
'backend': 'redis',
'total_entries': keys_count,
'memory_usage_bytes': info.get('used_memory', 0),
'connected_clients': info.get('connected_clients', 0),
'uptime_seconds': info.get('uptime_in_seconds', 0)
}
except Exception as e:
logger.error(f"Redis stats error: {e}")
return {'backend': 'redis', 'error': str(e)}
class CacheManager:
"""
Enterprise-grade cache manager with multiple backends and advanced features
"""
def __init__(
self,
backend: CacheBackend = CacheBackend.MEMORY,
config: Dict[str, Any] = None
):
self.backend_type = backend
self.config = config or {}
self.backend: Optional[CacheBackendInterface] = None
self.cache_warming_tasks: Dict[str, asyncio.Task] = {}
self.invalidation_patterns: Dict[str, List[str]] = {}
# Initialize backend
self._initialize_backend()
# Cache warming and invalidation
self.warming_enabled = self.config.get('cache_warming', True)
self.invalidation_enabled = self.config.get('cache_invalidation', True)
def _initialize_backend(self):
"""Initialize the cache backend"""
if self.backend_type == CacheBackend.MEMORY:
self.backend = MemoryCacheBackend(
max_size=self.config.get('max_size', 1000),
default_ttl=self.config.get('default_ttl', 3600)
)
elif self.backend_type == CacheBackend.REDIS:
self.backend = RedisCacheBackend(
redis_url=self.config.get('redis_url', 'redis://localhost:6379'),
key_prefix=self.config.get('key_prefix', 'mcp:')
)
else:
raise ValueError(f"Unsupported cache backend: {self.backend_type}")
def cache_key(self, prefix: str, *args, **kwargs) -> str:
"""Generate cache key from arguments"""
# Create deterministic key from arguments
key_parts = [prefix]
# Add positional arguments
for arg in args:
if isinstance(arg, (str, int, float, bool)):
key_parts.append(str(arg))
else:
# Hash complex objects
key_parts.append(hashlib.md5(str(arg).encode()).hexdigest()[:8])
# Add keyword arguments (sorted for consistency)
for k, v in sorted(kwargs.items()):
if isinstance(v, (str, int, float, bool)):
key_parts.append(f"{k}:{v}")
else:
key_parts.append(f"{k}:{hashlib.md5(str(v).encode()).hexdigest()[:8]}")
return ":".join(key_parts)
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
return await self.backend.get(key)
async def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None,
tags: List[str] = None
) -> bool:
"""Set value in cache with optional tags"""
success = await self.backend.set(key, value, ttl)
# Store tag associations for invalidation
if success and tags and self.invalidation_enabled:
for tag in tags:
if tag not in self.invalidation_patterns:
self.invalidation_patterns[tag] = []
if key not in self.invalidation_patterns[tag]:
self.invalidation_patterns[tag].append(key)
return success
async def delete(self, key: str) -> bool:
"""Delete value from cache"""
return await self.backend.delete(key)
async def clear(self) -> bool:
"""Clear all cache entries"""
self.invalidation_patterns.clear()
return await self.backend.clear()
async def invalidate_by_tag(self, tag: str) -> int:
"""Invalidate all cache entries with a specific tag"""
if not self.invalidation_enabled or tag not in self.invalidation_patterns:
return 0
keys_to_delete = self.invalidation_patterns[tag].copy()
deleted_count = 0
for key in keys_to_delete:
if await self.backend.delete(key):
deleted_count += 1
# Clean up invalidation patterns
del self.invalidation_patterns[tag]
logger.info(f"Invalidated {deleted_count} cache entries with tag '{tag}'")
return deleted_count
async def warm_cache(
self,
key: str,
value_func: Callable,
ttl: Optional[int] = None,
force: bool = False
):
"""Warm cache with a computed value"""
if not self.warming_enabled:
return
# Check if already cached and not forced
if not force and await self.backend.exists(key):
return
try:
# Compute value
if asyncio.iscoroutinefunction(value_func):
value = await value_func()
else:
value = value_func()
# Cache the value
await self.set(key, value, ttl)
logger.debug(f"Cache warmed for key: {key}")
except Exception as e:
logger.error(f"Cache warming failed for key {key}: {e}")
async def get_or_set(
self,
key: str,
value_func: Callable,
ttl: Optional[int] = None,
tags: List[str] = None
) -> Any:
"""Get from cache or compute and set if not exists"""
# Try to get from cache first
cached_value = await self.get(key)
if cached_value is not None:
return cached_value
try:
# Compute value
if asyncio.iscoroutinefunction(value_func):
value = await value_func()
else:
value = value_func()
# Cache the computed value
await self.set(key, value, ttl, tags)
return value
except Exception as e:
logger.error(f"Failed to compute and cache value for key {key}: {e}")
raise
async def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive cache statistics"""
backend_stats = await self.backend.get_stats()
return {
**backend_stats,
'backend_type': self.backend_type.value,
'warming_enabled': self.warming_enabled,
'invalidation_enabled': self.invalidation_enabled,
'invalidation_patterns_count': len(self.invalidation_patterns),
'active_warming_tasks': len(self.cache_warming_tasks)
}
def cache_decorator(
self,
key_prefix: str,
ttl: Optional[int] = None,
tags: List[str] = None,
skip_args: List[int] = None
):
"""Decorator for automatic function result caching"""
def decorator(func):
async def async_wrapper(*args, **kwargs):
# Generate cache key
cache_args = args
if skip_args:
cache_args = [arg for i, arg in enumerate(args) if i not in skip_args]
cache_key = self.cache_key(key_prefix, *cache_args, **kwargs)
# Try to get cached result
cached_result = await self.get(cache_key)
if cached_result is not None:
return cached_result
# Compute result
result = await func(*args, **kwargs)
# Cache result
await self.set(cache_key, result, ttl, tags)
return result
def sync_wrapper(*args, **kwargs):
# For sync functions, use asyncio.run for cache operations
cache_args = args
if skip_args:
cache_args = [arg for i, arg in enumerate(args) if i not in skip_args]
cache_key = self.cache_key(key_prefix, *cache_args, **kwargs)
# Try to get cached result
loop = None
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
cached_result = loop.run_until_complete(self.get(cache_key))
if cached_result is not None:
return cached_result
# Compute result
result = func(*args, **kwargs)
# Cache result
loop.run_until_complete(self.set(cache_key, result, ttl, tags))
return result
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator
# Global cache manager instance
_cache_manager: Optional[CacheManager] = None
def get_cache_manager() -> CacheManager:
"""Get global cache manager instance"""
global _cache_manager
if _cache_manager is None:
_cache_manager = CacheManager()
return _cache_manager
def setup_caching(
backend: CacheBackend = CacheBackend.MEMORY,
config: Dict[str, Any] = None
) -> CacheManager:
"""Setup and configure global cache manager"""
global _cache_manager
_cache_manager = CacheManager(backend, config)
return _cache_manager
# Convenience decorators using global cache manager
def cache_result(
key_prefix: str,
ttl: Optional[int] = None,
tags: List[str] = None,
skip_args: List[int] = None
):
"""Decorator for caching function results using global cache manager"""
cache_manager = get_cache_manager()
return cache_manager.cache_decorator(key_prefix, ttl, tags, skip_args)
# Example usage decorators for MCP components
def cache_sql_query(ttl: int = 300):
"""Cache SQL query results for 5 minutes"""
return cache_result("sql_query", ttl=ttl, tags=["sql", "database"])
def cache_schema_info(ttl: int = 3600):
"""Cache schema information for 1 hour"""
return cache_result("schema", ttl=ttl, tags=["schema", "database"])
def cache_semantic_search(ttl: int = 1800):
"""Cache semantic search results for 30 minutes"""
return cache_result("semantic", ttl=ttl, tags=["semantic", "search"])
def cache_synthesis(ttl: int = 600):
"""Cache synthesis results for 10 minutes"""
return cache_result("synthesis", ttl=ttl, tags=["synthesis", "llm"])