redis_cache.pyβ’14.5 kB
"""
Redis cache implementation
TDD Green Phase: Implement to pass tests
"""
import asyncio
import json
import logging
from typing import Any, Dict, Optional, Set, Union
import pickle
import base64
from ..exceptions import CacheError
class RedisCache:
"""Redis-based cache implementation with async operations"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
default_ttl: int = 3600,
key_prefix: str = "mcp_stock:",
mock_mode: bool = False
):
"""Initialize Redis cache
Args:
host: Redis host
port: Redis port
db: Redis database number
password: Redis password (optional)
default_ttl: Default time-to-live in seconds
key_prefix: Prefix for all cache keys
mock_mode: Use mock Redis for testing
"""
self.host = host
self.port = port
self.db = db
self.password = password
self.default_ttl = default_ttl
self.key_prefix = key_prefix
self.mock_mode = mock_mode
self.logger = logging.getLogger("mcp_stock_details.redis_cache")
# Stats tracking
self._hits = 0
self._misses = 0
self._sets = 0
self._deletes = 0
# Mock storage for testing (always initialize)
self._mock_store: Dict[str, Any] = {}
self._mock_expiries: Dict[str, float] = {}
self._redis = None
self._connected = False
async def connect(self) -> bool:
"""Connect to Redis server
Returns:
True if connected successfully
"""
if self.mock_mode:
self._connected = True
self.logger.info("Using mock Redis for testing")
return True
try:
# Try to import redis-py
import aioredis
self._redis = await aioredis.from_url(
f"redis://{self.host}:{self.port}/{self.db}",
password=self.password,
decode_responses=False # We handle encoding ourselves
)
# Test connection
await self._redis.ping()
self._connected = True
self.logger.info(f"Connected to Redis at {self.host}:{self.port}")
return True
except ImportError:
self.logger.warning("aioredis not installed, using mock mode")
self.mock_mode = True
self._connected = True
return True
except Exception as e:
self.logger.error(f"Failed to connect to Redis: {e}")
# Fallback to mock mode
self.mock_mode = True
self._connected = True
self.logger.info("Falling back to mock Redis")
return True
async def get(self, key: str) -> Optional[Any]:
"""Get value from Redis cache
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
if not self._connected:
await self.connect()
full_key = f"{self.key_prefix}{key}"
try:
if self.mock_mode:
return await self._mock_get(full_key)
raw_data = await self._redis.get(full_key)
if raw_data is None:
self._misses += 1
return None
# Deserialize data
value = self._deserialize(raw_data)
self._hits += 1
self.logger.debug(f"Cache hit for key: {key}")
return value
except Exception as e:
self.logger.error(f"Error getting cache key {key}: {e}")
self._misses += 1
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in Redis cache
Args:
key: Cache key
value: Value to cache
ttl: Time-to-live in seconds (optional)
Returns:
True if set successfully
"""
if not self._connected:
await self.connect()
full_key = f"{self.key_prefix}{key}"
ttl = ttl or self.default_ttl
try:
if self.mock_mode:
return await self._mock_set(full_key, value, ttl)
# Serialize data
serialized_data = self._serialize(value)
result = await self._redis.setex(full_key, ttl, serialized_data)
if result:
self._sets += 1
self.logger.debug(f"Cached key: {key}, TTL: {ttl}s")
return True
return False
except Exception as e:
self.logger.error(f"Error setting cache key {key}: {e}")
raise CacheError(f"Failed to set cache key: {key}")
async def delete(self, key: str) -> bool:
"""Delete key from Redis cache
Args:
key: Cache key to delete
Returns:
True if deleted, False if not found
"""
if not self._connected:
await self.connect()
full_key = f"{self.key_prefix}{key}"
try:
if self.mock_mode:
return await self._mock_delete(full_key)
result = await self._redis.delete(full_key)
if result > 0:
self._deletes += 1
return True
return False
except Exception as e:
self.logger.error(f"Error deleting cache key {key}: {e}")
return False
async def clear(self, pattern: Optional[str] = None) -> int:
"""Clear cache entries
Args:
pattern: Optional pattern to match keys (uses key prefix if None)
Returns:
Number of keys deleted
"""
if not self._connected:
await self.connect()
pattern = pattern or f"{self.key_prefix}*"
try:
if self.mock_mode:
return await self._mock_clear(pattern)
keys = await self._redis.keys(pattern)
if keys:
deleted = await self._redis.delete(*keys)
self._deletes += deleted
self.logger.info(f"Cleared {deleted} cache entries")
return deleted
return 0
except Exception as e:
self.logger.error(f"Error clearing cache: {e}")
return 0
async def exists(self, key: str) -> bool:
"""Check if key exists in Redis
Args:
key: Cache key
Returns:
True if exists
"""
if not self._connected:
await self.connect()
full_key = f"{self.key_prefix}{key}"
try:
if self.mock_mode:
return await self._mock_exists(full_key)
result = await self._redis.exists(full_key)
return result > 0
except Exception as e:
self.logger.error(f"Error checking key existence {key}: {e}")
return False
async def keys(self, pattern: Optional[str] = None) -> Set[str]:
"""Get all cache keys matching pattern
Args:
pattern: Pattern to match (default: all keys with prefix)
Returns:
Set of matching keys (without prefix)
"""
if not self._connected:
await self.connect()
search_pattern = pattern or f"{self.key_prefix}*"
try:
if self.mock_mode:
return await self._mock_keys(search_pattern)
keys = await self._redis.keys(search_pattern)
# Remove prefix from keys
clean_keys = set()
for key in keys:
if isinstance(key, bytes):
key = key.decode('utf-8')
if key.startswith(self.key_prefix):
clean_keys.add(key[len(self.key_prefix):])
return clean_keys
except Exception as e:
self.logger.error(f"Error getting keys: {e}")
return set()
def stats(self) -> Dict[str, Any]:
"""Get cache statistics
Returns:
Dictionary with cache statistics
"""
hit_rate = self._hits / (self._hits + self._misses) if (self._hits + self._misses) > 0 else 0
return {
"type": "redis",
"hits": self._hits,
"misses": self._misses,
"sets": self._sets,
"deletes": self._deletes,
"hit_rate": round(hit_rate, 3),
"connected": self._connected,
"mock_mode": self.mock_mode,
"host": self.host,
"port": self.port,
"default_ttl": self.default_ttl
}
def _serialize(self, value: Any) -> bytes:
"""Serialize value for Redis storage"""
try:
# Try JSON first for simple types
if isinstance(value, (str, int, float, bool, list, dict, type(None))):
json_str = json.dumps(value, ensure_ascii=False)
return f"json:{json_str}".encode('utf-8')
else:
# Use pickle for complex objects
pickled = pickle.dumps(value)
encoded = base64.b64encode(pickled)
return b"pickle:" + encoded
except Exception:
# Fallback to pickle
pickled = pickle.dumps(value)
encoded = base64.b64encode(pickled)
return b"pickle:" + encoded
def _deserialize(self, data: bytes) -> Any:
"""Deserialize value from Redis storage"""
try:
data_str = data.decode('utf-8')
if data_str.startswith("json:"):
json_str = data_str[5:] # Remove "json:" prefix
return json.loads(json_str)
elif data.startswith(b"pickle:"):
encoded = data[7:] # Remove "pickle:" prefix
pickled = base64.b64decode(encoded)
return pickle.loads(pickled)
else:
# Legacy: try JSON first, then pickle
try:
return json.loads(data_str)
except:
return pickle.loads(data)
except Exception as e:
self.logger.error(f"Error deserializing data: {e}")
return None
# Mock Redis methods for testing
async def _mock_get(self, key: str) -> Optional[Any]:
"""Mock get for testing"""
if key in self._mock_store:
# Check expiry
if key in self._mock_expiries:
import time
if time.time() > self._mock_expiries[key]:
del self._mock_store[key]
del self._mock_expiries[key]
self._misses += 1
return None
self._hits += 1
return self._mock_store[key]
self._misses += 1
return None
async def _mock_set(self, key: str, value: Any, ttl: int) -> bool:
"""Mock set for testing"""
import time
self._mock_store[key] = value
self._mock_expiries[key] = time.time() + ttl
self._sets += 1
return True
async def _mock_delete(self, key: str) -> bool:
"""Mock delete for testing"""
if key in self._mock_store:
del self._mock_store[key]
self._mock_expiries.pop(key, None)
self._deletes += 1
return True
return False
async def _mock_clear(self, pattern: str) -> int:
"""Mock clear for testing"""
# Simple pattern matching
if pattern.endswith("*"):
prefix = pattern[:-1]
keys_to_delete = [k for k in self._mock_store.keys() if k.startswith(prefix)]
else:
keys_to_delete = [k for k in self._mock_store.keys() if k == pattern]
for key in keys_to_delete:
del self._mock_store[key]
self._mock_expiries.pop(key, None)
self._deletes += len(keys_to_delete)
return len(keys_to_delete)
async def _mock_exists(self, key: str) -> bool:
"""Mock exists for testing"""
if key in self._mock_store:
# Check expiry
if key in self._mock_expiries:
import time
if time.time() > self._mock_expiries[key]:
del self._mock_store[key]
del self._mock_expiries[key]
return False
return True
return False
async def _mock_keys(self, pattern: str) -> Set[str]:
"""Mock keys for testing"""
# Clean expired keys first
import time
current_time = time.time()
expired_keys = [k for k, exp in self._mock_expiries.items() if current_time > exp]
for key in expired_keys:
self._mock_store.pop(key, None)
self._mock_expiries.pop(key, None)
# Match pattern
if pattern.endswith("*"):
prefix = pattern[:-1]
matching_keys = [k for k in self._mock_store.keys() if k.startswith(prefix)]
else:
matching_keys = [k for k in self._mock_store.keys() if k == pattern]
# Remove prefix
clean_keys = set()
for key in matching_keys:
if key.startswith(self.key_prefix):
clean_keys.add(key[len(self.key_prefix):])
return clean_keys
async def close(self) -> None:
"""Close Redis connection"""
if self._redis and not self.mock_mode:
await self._redis.close()
self._connected = False
self.logger.info("Redis cache connection closed")