cache.py•15.2 kB
"""Enterprise-grade caching and rate limiting for GCP MCP server."""
import asyncio
import hashlib
import json
import time
from typing import Any, Dict, List, Optional, Callable
from collections import defaultdict, deque
from datetime import datetime, timedelta
import structlog
from pydantic import BaseModel
logger = structlog.get_logger(__name__)
class CacheEntry(BaseModel):
"""Cache entry with metadata."""
data: Any
timestamp: float
ttl: int
access_count: int = 0
last_access: float = 0
size_bytes: int = 0
class RateLimitEntry(BaseModel):
"""Rate limit tracking entry."""
count: int = 0
window_start: float = 0
blocked_until: float = 0
class EnterpriseCache:
"""Enterprise-grade caching system with LRU eviction and size limits."""
def __init__(self, max_size_mb: int = 100, default_ttl: int = 300):
"""Initialize cache.
Args:
max_size_mb: Maximum cache size in megabytes
default_ttl: Default TTL in seconds
"""
self.max_size_bytes = max_size_mb * 1024 * 1024
self.default_ttl = default_ttl
self.cache: Dict[str, CacheEntry] = {}
self.access_order: deque = deque()
self.current_size_bytes = 0
self.stats = {
"hits": 0,
"misses": 0,
"evictions": 0,
"size_evictions": 0
}
self._lock = asyncio.Lock()
def _generate_key(self, prefix: str, *args, **kwargs) -> str:
"""Generate cache key from arguments."""
key_data = {
"prefix": prefix,
"args": args,
"kwargs": sorted(kwargs.items()) if kwargs else {}
}
key_str = json.dumps(key_data, sort_keys=True, default=str)
return hashlib.md5(key_str.encode()).hexdigest()
def _estimate_size(self, data: Any) -> int:
"""Estimate the size of data in bytes."""
try:
if isinstance(data, (str, bytes)):
return len(data.encode() if isinstance(data, str) else data)
elif isinstance(data, (list, dict)):
return len(json.dumps(data, default=str).encode())
else:
return len(str(data).encode())
except Exception:
return 1024 # Default estimate
async def get(self, key: str) -> Optional[Any]:
"""Get item from cache."""
async with self._lock:
current_time = time.time()
if key not in self.cache:
self.stats["misses"] += 1
return None
entry = self.cache[key]
# Check if expired
if current_time - entry.timestamp > entry.ttl:
del self.cache[key]
if key in self.access_order:
self.access_order.remove(key)
self.current_size_bytes -= entry.size_bytes
self.stats["misses"] += 1
return None
# Update access statistics
entry.access_count += 1
entry.last_access = current_time
# Move to end of access order (most recent)
if key in self.access_order:
self.access_order.remove(key)
self.access_order.append(key)
self.stats["hits"] += 1
return entry.data
async def set(self, key: str, data: Any, ttl: Optional[int] = None) -> None:
"""Set item in cache."""
async with self._lock:
current_time = time.time()
ttl = ttl or self.default_ttl
size_bytes = self._estimate_size(data)
# Check if we need to make space
while (self.current_size_bytes + size_bytes > self.max_size_bytes
and self.access_order):
await self._evict_lru()
# If item is too large for cache, don't cache it
if size_bytes > self.max_size_bytes:
logger.warning("Item too large for cache", key=key, size_mb=size_bytes/1024/1024)
return
# Remove existing entry if present
if key in self.cache:
old_entry = self.cache[key]
self.current_size_bytes -= old_entry.size_bytes
if key in self.access_order:
self.access_order.remove(key)
# Add new entry
entry = CacheEntry(
data=data,
timestamp=current_time,
ttl=ttl,
access_count=1,
last_access=current_time,
size_bytes=size_bytes
)
self.cache[key] = entry
self.access_order.append(key)
self.current_size_bytes += size_bytes
async def _evict_lru(self) -> None:
"""Evict least recently used item."""
if not self.access_order:
return
lru_key = self.access_order.popleft()
if lru_key in self.cache:
entry = self.cache[lru_key]
self.current_size_bytes -= entry.size_bytes
del self.cache[lru_key]
self.stats["evictions"] += 1
self.stats["size_evictions"] += 1
async def clear(self) -> None:
"""Clear all cache entries."""
async with self._lock:
self.cache.clear()
self.access_order.clear()
self.current_size_bytes = 0
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
async with self._lock:
total_requests = self.stats["hits"] + self.stats["misses"]
hit_rate = (self.stats["hits"] / total_requests * 100) if total_requests > 0 else 0
return {
"hits": self.stats["hits"],
"misses": self.stats["misses"],
"hit_rate_percent": round(hit_rate, 2),
"evictions": self.stats["evictions"],
"size_evictions": self.stats["size_evictions"],
"entries_count": len(self.cache),
"size_mb": round(self.current_size_bytes / 1024 / 1024, 2),
"max_size_mb": round(self.max_size_bytes / 1024 / 1024, 2)
}
class RateLimiter:
"""Enterprise-grade rate limiter with multiple strategies."""
def __init__(self):
"""Initialize rate limiter."""
self.limits: Dict[str, RateLimitEntry] = defaultdict(RateLimitEntry)
self.global_limits = {
"queries_per_minute": 100,
"queries_per_hour": 1000,
"projects_per_hour": 50,
"concurrent_requests": 10
}
self.user_limits = {}
self.current_requests = 0
self._lock = asyncio.Lock()
async def check_rate_limit(
self,
identifier: str,
limit_type: str = "default",
custom_limit: Optional[int] = None
) -> bool:
"""Check if request is within rate limits.
Args:
identifier: Unique identifier (user, IP, etc.)
limit_type: Type of limit to check
custom_limit: Override default limit
Returns:
True if request is allowed, False if rate limited
"""
async with self._lock:
current_time = time.time()
key = f"{identifier}:{limit_type}"
entry = self.limits[key]
# Check if currently blocked
if entry.blocked_until > current_time:
logger.warning(
"Request blocked by rate limiter",
identifier=identifier,
limit_type=limit_type,
blocked_until=entry.blocked_until
)
return False
# Determine limits based on type
if limit_type == "queries_per_minute":
window_size = 60
max_requests = custom_limit or self.global_limits["queries_per_minute"]
elif limit_type == "queries_per_hour":
window_size = 3600
max_requests = custom_limit or self.global_limits["queries_per_hour"]
elif limit_type == "projects_per_hour":
window_size = 3600
max_requests = custom_limit or self.global_limits["projects_per_hour"]
else: # default
window_size = 60
max_requests = custom_limit or 60
# Reset window if needed
if current_time - entry.window_start >= window_size:
entry.count = 0
entry.window_start = current_time
# Check if limit exceeded
if entry.count >= max_requests:
# Block for remaining window time
entry.blocked_until = entry.window_start + window_size
logger.warning(
"Rate limit exceeded",
identifier=identifier,
limit_type=limit_type,
count=entry.count,
limit=max_requests
)
return False
# Increment counter
entry.count += 1
return True
async def check_concurrent_limit(self) -> bool:
"""Check concurrent request limit."""
async with self._lock:
if self.current_requests >= self.global_limits["concurrent_requests"]:
logger.warning(
"Concurrent request limit exceeded",
current=self.current_requests,
limit=self.global_limits["concurrent_requests"]
)
return False
return True
async def acquire_request_slot(self) -> None:
"""Acquire a concurrent request slot."""
async with self._lock:
self.current_requests += 1
async def release_request_slot(self) -> None:
"""Release a concurrent request slot."""
async with self._lock:
self.current_requests = max(0, self.current_requests - 1)
async def get_stats(self) -> Dict[str, Any]:
"""Get rate limiter statistics."""
async with self._lock:
active_limits = {}
blocked_count = 0
current_time = time.time()
for key, entry in self.limits.items():
if entry.count > 0 or entry.blocked_until > current_time:
active_limits[key] = {
"count": entry.count,
"window_start": entry.window_start,
"blocked_until": entry.blocked_until,
"is_blocked": entry.blocked_until > current_time
}
if entry.blocked_until > current_time:
blocked_count += 1
return {
"global_limits": self.global_limits,
"current_requests": self.current_requests,
"active_limits": active_limits,
"blocked_count": blocked_count
}
def cached(ttl: int = 300, cache_key_prefix: str = "default"):
"""Decorator for caching function results.
Args:
ttl: Time to live in seconds
cache_key_prefix: Prefix for cache keys
"""
def decorator(func: Callable):
async def wrapper(self, *args, **kwargs):
# Generate cache key
cache_key = self.cache._generate_key(
f"{cache_key_prefix}:{func.__name__}",
*args,
**kwargs
)
# Try to get from cache
cached_result = await self.cache.get(cache_key)
if cached_result is not None:
logger.debug("Cache hit", function=func.__name__, key=cache_key[:8])
return cached_result
# Execute function
result = await func(self, *args, **kwargs)
# Cache result
await self.cache.set(cache_key, result, ttl)
logger.debug("Cache set", function=func.__name__, key=cache_key[:8])
return result
return wrapper
return decorator
def rate_limited(
limit_type: str = "default",
custom_limit: Optional[int] = None,
identifier_func: Optional[Callable] = None
):
"""Decorator for rate limiting function calls.
Args:
limit_type: Type of rate limit
custom_limit: Override default limit
identifier_func: Function to generate identifier
"""
def decorator(func: Callable):
async def wrapper(self, *args, **kwargs):
# Generate identifier
if identifier_func:
identifier = identifier_func(*args, **kwargs)
else:
identifier = "default"
# Check rate limit
if not await self.rate_limiter.check_rate_limit(
identifier, limit_type, custom_limit
):
from .exceptions import GCPServiceError
raise GCPServiceError(
f"Rate limit exceeded for {limit_type}. Please try again later."
)
# Check concurrent limit
if not await self.rate_limiter.check_concurrent_limit():
from .exceptions import GCPServiceError
raise GCPServiceError(
"Too many concurrent requests. Please try again later."
)
# Acquire slot and execute
await self.rate_limiter.acquire_request_slot()
try:
return await func(self, *args, **kwargs)
finally:
await self.rate_limiter.release_request_slot()
return wrapper
return decorator
class CacheManager:
"""Manager for multiple cache instances."""
def __init__(self):
"""Initialize cache manager."""
self.caches = {}
self.rate_limiter = RateLimiter()
def get_cache(self, name: str, **kwargs) -> EnterpriseCache:
"""Get or create a cache instance."""
if name not in self.caches:
self.caches[name] = EnterpriseCache(**kwargs)
return self.caches[name]
async def get_all_stats(self) -> Dict[str, Any]:
"""Get statistics from all caches and rate limiter."""
stats = {
"rate_limiter": await self.rate_limiter.get_stats(),
"caches": {}
}
for name, cache in self.caches.items():
stats["caches"][name] = await cache.get_stats()
return stats
async def clear_all(self) -> None:
"""Clear all caches."""
for cache in self.caches.values():
await cache.clear()
# Global cache manager instance
cache_manager = CacheManager()