memory_cache.pyβ’8.51 kB
"""
In-memory cache implementation
TDD Green Phase: Implement to pass tests
"""
import asyncio
import logging
import time
import threading
from typing import Any, Dict, Optional, Set
from datetime import datetime, timedelta
from collections import OrderedDict
from ..exceptions import CacheError
class MemoryCache:
"""Thread-safe in-memory cache with TTL and size limits"""
def __init__(self, max_size: int = 1000, default_ttl: int = 3600):
"""Initialize memory cache
Args:
max_size: Maximum number of items in cache
default_ttl: Default time-to-live in seconds
"""
self.max_size = max_size
self.default_ttl = default_ttl
self._cache: OrderedDict = OrderedDict()
self._expiry_times: Dict[str, float] = {}
self._lock = threading.RLock()
self._hits = 0
self._misses = 0
self._sets = 0
self._deletes = 0
self.logger = logging.getLogger("mcp_stock_details.memory_cache")
# Background cleanup task (will be started on first async operation)
self._cleanup_task = None
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
# Start cleanup task if not already running
await self._ensure_cleanup_task()
with self._lock:
if key not in self._cache:
self._misses += 1
return None
# Check if expired
if self._is_expired(key):
await self._delete_key(key)
self._misses += 1
return None
# Move to end (LRU)
value = self._cache.pop(key)
self._cache[key] = value
self._hits += 1
return value
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Time-to-live in seconds (optional)
Returns:
True if set successfully
"""
# Start cleanup task if not already running
await self._ensure_cleanup_task()
try:
with self._lock:
ttl = ttl or self.default_ttl
expiry_time = time.time() + ttl
# Remove oldest items if at capacity
while len(self._cache) >= self.max_size:
oldest_key = next(iter(self._cache))
await self._delete_key(oldest_key)
self._cache[key] = value
self._expiry_times[key] = expiry_time
self._sets += 1
self.logger.debug(f"Cached key: {key}, TTL: {ttl}s")
return True
except Exception as e:
self.logger.error(f"Error setting cache key {key}: {e}")
raise CacheError(f"Failed to set cache key: {key}")
async def delete(self, key: str) -> bool:
"""Delete key from cache
Args:
key: Cache key to delete
Returns:
True if deleted, False if not found
"""
with self._lock:
if key in self._cache:
await self._delete_key(key)
return True
return False
async def clear(self) -> None:
"""Clear all cache entries"""
with self._lock:
self._cache.clear()
self._expiry_times.clear()
self._hits = 0
self._misses = 0
self._sets = 0
self._deletes = 0
self.logger.info("Memory cache cleared")
async def exists(self, key: str) -> bool:
"""Check if key exists and is not expired
Args:
key: Cache key
Returns:
True if exists and not expired
"""
with self._lock:
if key not in self._cache:
return False
if self._is_expired(key):
await self._delete_key(key)
return False
return True
async def keys(self, pattern: Optional[str] = None) -> Set[str]:
"""Get all cache keys, optionally filtered by pattern
Args:
pattern: Optional pattern to match (simple string contains)
Returns:
Set of matching keys
"""
with self._lock:
# Clean expired keys first
expired_keys = [k for k in self._cache.keys() if self._is_expired(k)]
for key in expired_keys:
await self._delete_key(key)
keys = set(self._cache.keys())
if pattern:
keys = {k for k in keys if pattern in k}
return keys
@property
def size(self) -> int:
"""Get current cache size"""
with self._lock:
return len(self._cache)
def stats(self) -> Dict[str, Any]:
"""Get cache statistics
Returns:
Dictionary with cache statistics
"""
with self._lock:
hit_rate = self._hits / (self._hits + self._misses) if (self._hits + self._misses) > 0 else 0
return {
"type": "memory",
"hits": self._hits,
"misses": self._misses,
"sets": self._sets,
"deletes": self._deletes,
"hit_rate": round(hit_rate, 3),
"size": len(self._cache),
"max_size": self.max_size,
"default_ttl": self.default_ttl
}
def _is_expired(self, key: str) -> bool:
"""Check if key is expired"""
if key not in self._expiry_times:
return True
return time.time() > self._expiry_times[key]
async def _delete_key(self, key: str) -> None:
"""Internal method to delete a key"""
self._cache.pop(key, None)
self._expiry_times.pop(key, None)
self._deletes += 1
async def _ensure_cleanup_task(self) -> None:
"""Ensure cleanup task is running"""
if self._cleanup_task is None or self._cleanup_task.done():
try:
self._cleanup_task = asyncio.create_task(self._cleanup_expired())
except RuntimeError:
# No event loop running, cleanup will be done manually
pass
async def _cleanup_expired(self) -> None:
"""Background task to clean up expired entries"""
while True:
try:
with self._lock:
current_time = time.time()
expired_keys = [
key for key, expiry in self._expiry_times.items()
if current_time > expiry
]
for key in expired_keys:
await self._delete_key(key)
if expired_keys:
self.logger.debug(f"Cleaned up {len(expired_keys)} expired keys")
# Sleep for 60 seconds before next cleanup
await asyncio.sleep(60)
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error in cache cleanup: {e}")
await asyncio.sleep(60)
async def close(self) -> None:
"""Close cache and cleanup resources"""
if hasattr(self, '_cleanup_task'):
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
await self.clear()
self.logger.info("Memory cache closed")
def __len__(self) -> int:
"""Get cache size"""
return self.size
def __contains__(self, key: str) -> bool:
"""Check if key exists (sync version)"""
with self._lock:
return key in self._cache and not self._is_expired(key)