Skip to main content
Glama
cache.py20.4 kB
""" Session cache management for Shannon MCP Server. This module provides session caching with: - LRU cache implementation - Session expiration - Persistence support - Memory management - Cache statistics """ import asyncio from typing import Dict, Optional, List, Any, Tuple from datetime import datetime, timedelta from collections import OrderedDict from dataclasses import dataclass, field import json import pickle import structlog from pathlib import Path from ..utils.logging import get_logger from ..utils.errors import CacheError, handle_errors from ..utils.notifications import emit, EventCategory logger = get_logger("shannon-mcp.cache") @dataclass class CacheEntry: """Single cache entry.""" key: str value: Any created_at: datetime = field(default_factory=datetime.utcnow) last_accessed: datetime = field(default_factory=datetime.utcnow) access_count: int = 0 size_bytes: int = 0 ttl_seconds: Optional[int] = None @property def is_expired(self) -> bool: """Check if entry has expired.""" if self.ttl_seconds is None: return False age = (datetime.utcnow() - self.created_at).total_seconds() return age > self.ttl_seconds def touch(self): """Update last accessed time.""" self.last_accessed = datetime.utcnow() self.access_count += 1 @dataclass class CacheStats: """Cache statistics.""" total_entries: int = 0 total_size_bytes: int = 0 hits: int = 0 misses: int = 0 evictions: int = 0 expirations: int = 0 @property def hit_rate(self) -> float: """Calculate cache hit rate.""" total = self.hits + self.misses return self.hits / total if total > 0 else 0.0 class LRUCache: """Least Recently Used cache implementation.""" def __init__( self, max_entries: int = 1000, max_size_bytes: int = 100 * 1024 * 1024, # 100MB default_ttl: Optional[int] = 3600, # 1 hour persistence_path: Optional[Path] = None ): """ Initialize LRU cache. Args: max_entries: Maximum number of entries max_size_bytes: Maximum cache size in bytes default_ttl: Default TTL in seconds persistence_path: Optional path for persistence """ self.max_entries = max_entries self.max_size_bytes = max_size_bytes self.default_ttl = default_ttl self.persistence_path = persistence_path # OrderedDict maintains insertion order for LRU self._cache: OrderedDict[str, CacheEntry] = OrderedDict() self._stats = CacheStats() self._lock = asyncio.Lock() # Background tasks self._cleanup_task: Optional[asyncio.Task] = None self._persist_task: Optional[asyncio.Task] = None async def initialize(self) -> None: """Initialize cache and start background tasks.""" # Load persisted cache if available if self.persistence_path: await self._load_persisted() # Start background tasks self._cleanup_task = asyncio.create_task(self._cleanup_expired()) if self.persistence_path: self._persist_task = asyncio.create_task(self._periodic_persist()) logger.info( "cache_initialized", max_entries=self.max_entries, max_size_mb=self.max_size_bytes / 1024 / 1024, persistence_enabled=bool(self.persistence_path) ) async def shutdown(self) -> None: """Shutdown cache and save state.""" # Cancel background tasks if self._cleanup_task: self._cleanup_task.cancel() if self._persist_task: self._persist_task.cancel() # Final persistence if self.persistence_path: await self._save_persisted() logger.info("cache_shutdown", stats=self.get_stats()) async def get(self, key: str) -> Optional[Any]: """ Get value from cache. Args: key: Cache key Returns: Cached value or None """ async with self._lock: entry = self._cache.get(key) if entry is None: self._stats.misses += 1 return None # Check expiration if entry.is_expired: self._stats.expirations += 1 del self._cache[key] self._stats.total_entries -= 1 self._stats.total_size_bytes -= entry.size_bytes return None # Update LRU order self._cache.move_to_end(key) entry.touch() self._stats.hits += 1 return entry.value async def set( self, key: str, value: Any, ttl: Optional[int] = None ) -> None: """ Set value in cache. Args: key: Cache key value: Value to cache ttl: Optional TTL override """ # Calculate size try: size_bytes = len(pickle.dumps(value)) except: # Fallback for non-pickleable objects size_bytes = len(str(value).encode()) async with self._lock: # Check if key exists if key in self._cache: old_entry = self._cache[key] self._stats.total_size_bytes -= old_entry.size_bytes # Create new entry entry = CacheEntry( key=key, value=value, size_bytes=size_bytes, ttl_seconds=ttl or self.default_ttl ) # Add to cache self._cache[key] = entry self._cache.move_to_end(key) # Update stats self._stats.total_entries = len(self._cache) self._stats.total_size_bytes += size_bytes # Evict if necessary await self._evict_if_needed() async def delete(self, key: str) -> bool: """ Delete entry from cache. Args: key: Cache key Returns: True if deleted, False if not found """ async with self._lock: if key not in self._cache: return False entry = self._cache[key] del self._cache[key] self._stats.total_entries -= 1 self._stats.total_size_bytes -= entry.size_bytes return True async def clear(self) -> None: """Clear all cache entries.""" async with self._lock: self._cache.clear() self._stats.total_entries = 0 self._stats.total_size_bytes = 0 self._stats.evictions += len(self._cache) async def _evict_if_needed(self) -> None: """Evict entries if cache limits exceeded.""" evicted = 0 # Evict by entry count while len(self._cache) > self.max_entries: # Remove least recently used key, entry = self._cache.popitem(last=False) self._stats.total_size_bytes -= entry.size_bytes evicted += 1 # Evict by size while self._stats.total_size_bytes > self.max_size_bytes: if not self._cache: break key, entry = self._cache.popitem(last=False) self._stats.total_size_bytes -= entry.size_bytes evicted += 1 if evicted > 0: self._stats.evictions += evicted self._stats.total_entries = len(self._cache) logger.info( "cache_eviction", evicted_count=evicted, remaining_entries=self._stats.total_entries, size_mb=self._stats.total_size_bytes / 1024 / 1024 ) async def _cleanup_expired(self) -> None: """Background task to clean up expired entries.""" while True: try: await asyncio.sleep(60) # Check every minute async with self._lock: expired_keys = [] for key, entry in self._cache.items(): if entry.is_expired: expired_keys.append(key) for key in expired_keys: entry = self._cache[key] del self._cache[key] self._stats.total_entries -= 1 self._stats.total_size_bytes -= entry.size_bytes self._stats.expirations += 1 if expired_keys: logger.info( "expired_entries_cleaned", count=len(expired_keys) ) except asyncio.CancelledError: break except Exception as e: logger.error("cleanup_error", error=str(e)) async def _periodic_persist(self) -> None: """Background task to persist cache periodically.""" while True: try: await asyncio.sleep(300) # Save every 5 minutes await self._save_persisted() except asyncio.CancelledError: break except Exception as e: logger.error("persist_error", error=str(e)) async def _save_persisted(self) -> None: """Save cache to disk.""" if not self.persistence_path: return try: # Create directory if needed self.persistence_path.parent.mkdir(parents=True, exist_ok=True) # Prepare data data = { "version": 1, "timestamp": datetime.utcnow().isoformat(), "stats": { "hits": self._stats.hits, "misses": self._stats.misses, "evictions": self._stats.evictions, "expirations": self._stats.expirations }, "entries": {} } # Add non-expired entries async with self._lock: for key, entry in self._cache.items(): if not entry.is_expired: try: # Try to serialize value data["entries"][key] = { "value": pickle.dumps(entry.value).hex(), "created_at": entry.created_at.isoformat(), "last_accessed": entry.last_accessed.isoformat(), "access_count": entry.access_count, "size_bytes": entry.size_bytes, "ttl_seconds": entry.ttl_seconds } except: # Skip non-serializable entries logger.warning( "skip_non_serializable", key=key ) # Write atomically temp_path = self.persistence_path.with_suffix('.tmp') with open(temp_path, 'w') as f: json.dump(data, f) temp_path.replace(self.persistence_path) logger.info( "cache_persisted", entries=len(data["entries"]), path=str(self.persistence_path) ) except Exception as e: logger.error( "persist_save_error", error=str(e), path=str(self.persistence_path) ) async def _load_persisted(self) -> None: """Load cache from disk.""" if not self.persistence_path or not self.persistence_path.exists(): return try: with open(self.persistence_path, 'r') as f: data = json.load(f) # Check version if data.get("version") != 1: logger.warning("unsupported_cache_version", version=data.get("version")) return # Restore stats stats = data.get("stats", {}) self._stats.hits = stats.get("hits", 0) self._stats.misses = stats.get("misses", 0) self._stats.evictions = stats.get("evictions", 0) self._stats.expirations = stats.get("expirations", 0) # Restore entries loaded = 0 for key, entry_data in data.get("entries", {}).items(): try: # Deserialize value value = pickle.loads(bytes.fromhex(entry_data["value"])) # Create entry entry = CacheEntry( key=key, value=value, created_at=datetime.fromisoformat(entry_data["created_at"]), last_accessed=datetime.fromisoformat(entry_data["last_accessed"]), access_count=entry_data["access_count"], size_bytes=entry_data["size_bytes"], ttl_seconds=entry_data.get("ttl_seconds") ) # Skip if expired if not entry.is_expired: self._cache[key] = entry self._stats.total_entries += 1 self._stats.total_size_bytes += entry.size_bytes loaded += 1 except Exception as e: logger.warning( "skip_corrupted_entry", key=key, error=str(e) ) logger.info( "cache_loaded", loaded=loaded, path=str(self.persistence_path) ) except Exception as e: logger.error( "persist_load_error", error=str(e), path=str(self.persistence_path) ) def get_stats(self) -> Dict[str, Any]: """Get cache statistics.""" return { "total_entries": self._stats.total_entries, "total_size_mb": self._stats.total_size_bytes / 1024 / 1024, "hits": self._stats.hits, "misses": self._stats.misses, "hit_rate": self._stats.hit_rate, "evictions": self._stats.evictions, "expirations": self._stats.expirations, "max_entries": self.max_entries, "max_size_mb": self.max_size_bytes / 1024 / 1024 } def get_entry_stats(self) -> List[Dict[str, Any]]: """Get statistics for all entries.""" stats = [] for key, entry in self._cache.items(): stats.append({ "key": key, "size_bytes": entry.size_bytes, "created_at": entry.created_at.isoformat(), "last_accessed": entry.last_accessed.isoformat(), "access_count": entry.access_count, "age_seconds": (datetime.utcnow() - entry.created_at).total_seconds(), "is_expired": entry.is_expired }) return sorted(stats, key=lambda x: x["last_accessed"], reverse=True) class SessionCache(LRUCache): """Specialized cache for sessions with additional features.""" def __init__( self, max_sessions: int = 100, max_size_mb: int = 500, session_ttl: int = 3600, persistence_dir: Optional[Path] = None ): """ Initialize session cache. Args: max_sessions: Maximum number of cached sessions max_size_mb: Maximum cache size in MB session_ttl: Session TTL in seconds persistence_dir: Directory for persistence """ persistence_path = None if persistence_dir: persistence_path = persistence_dir / "session_cache.json" super().__init__( max_entries=max_sessions, max_size_bytes=max_size_mb * 1024 * 1024, default_ttl=session_ttl, persistence_path=persistence_path ) self.persistence_dir = persistence_dir async def get_session(self, session_id: str): """Get session from cache.""" session = await self.get(session_id) if session: logger.debug( "session_cache_hit", session_id=session_id ) # Emit cache hit event await emit( "session_cache_hit", EventCategory.SESSION, {"session_id": session_id} ) else: logger.debug( "session_cache_miss", session_id=session_id ) return session async def cache_session(self, session, ttl: Optional[int] = None) -> None: """Cache a session.""" await self.set(session.id, session, ttl) logger.info( "session_cached", session_id=session.id, state=session.state.value, ttl=ttl or self.default_ttl ) # Emit cache event await emit( "session_cached", EventCategory.SESSION, { "session_id": session.id, "ttl": ttl or self.default_ttl } ) async def invalidate_session(self, session_id: str) -> bool: """Remove session from cache.""" removed = await self.delete(session_id) if removed: logger.info( "session_invalidated", session_id=session_id ) # Emit invalidation event await emit( "session_cache_invalidated", EventCategory.SESSION, {"session_id": session_id} ) return removed async def get_active_sessions(self) -> List[str]: """Get list of active session IDs.""" active = [] async with self._lock: for key, entry in self._cache.items(): if not entry.is_expired: session = entry.value if hasattr(session, 'state') and session.state.value in ['ready', 'running']: active.append(key) return active async def cleanup_stale_sessions(self) -> int: """Clean up stale sessions.""" cleaned = 0 async with self._lock: stale_keys = [] for key, entry in self._cache.items(): session = entry.value if hasattr(session, 'state'): # Remove failed or timed out sessions if session.state.value in ['failed', 'timeout']: age = (datetime.utcnow() - entry.created_at).total_seconds() if age > 300: # 5 minutes stale_keys.append(key) for key in stale_keys: del self._cache[key] cleaned += 1 self._stats.total_entries = len(self._cache) if cleaned > 0: logger.info( "stale_sessions_cleaned", count=cleaned ) return cleaned # Export public API __all__ = ['LRUCache', 'SessionCache', 'CacheEntry', 'CacheStats']

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/krzemienski/shannon-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server