"""
Caching utilities for GeoSight MCP Server.
Provides Redis-based caching with fallback to in-memory LRU cache.
"""
import asyncio
import hashlib
import json
import pickle
from abc import ABC, abstractmethod
from collections import OrderedDict
from datetime import timedelta
from typing import Any, Optional, TypeVar, Generic
import structlog
logger = structlog.get_logger(__name__)
T = TypeVar("T")
class CacheBackend(ABC, Generic[T]):
"""Abstract base class for cache backends."""
@abstractmethod
async def get(self, key: str) -> Optional[T]:
"""Get value from cache."""
pass
@abstractmethod
async def set(self, key: str, value: T, ttl: Optional[int] = None) -> None:
"""Set value in cache."""
pass
@abstractmethod
async def delete(self, key: str) -> None:
"""Delete value from cache."""
pass
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if key exists in cache."""
pass
@abstractmethod
async def clear(self) -> None:
"""Clear all cache entries."""
pass
class InMemoryCache(CacheBackend[T]):
"""
Simple in-memory LRU cache with TTL support.
Suitable for development and single-instance deployments.
"""
def __init__(self, max_size: int = 1000, default_ttl: int = 3600):
self._cache: OrderedDict[str, dict] = OrderedDict()
self._max_size = max_size
self._default_ttl = default_ttl
self._lock = asyncio.Lock()
async def get(self, key: str) -> Optional[T]:
async with self._lock:
if key not in self._cache:
return None
entry = self._cache[key]
# Check TTL
if entry["expires_at"] is not None:
import time
if time.time() > entry["expires_at"]:
del self._cache[key]
return None
# Move to end (most recently used)
self._cache.move_to_end(key)
return entry["value"]
async def set(self, key: str, value: T, ttl: Optional[int] = None) -> None:
async with self._lock:
import time
ttl = ttl or self._default_ttl
expires_at = time.time() + ttl if ttl > 0 else None
self._cache[key] = {
"value": value,
"expires_at": expires_at,
}
# Move to end
self._cache.move_to_end(key)
# Evict oldest if over capacity
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
async def delete(self, key: str) -> None:
async with self._lock:
self._cache.pop(key, None)
async def exists(self, key: str) -> bool:
return await self.get(key) is not None
async def clear(self) -> None:
async with self._lock:
self._cache.clear()
@property
def size(self) -> int:
return len(self._cache)
class RedisCache(CacheBackend[T]):
"""
Redis-based cache backend.
Suitable for production and multi-instance deployments.
"""
def __init__(
self,
url: str = "redis://localhost:6379/0",
prefix: str = "geosight:",
default_ttl: int = 3600,
):
self._url = url
self._prefix = prefix
self._default_ttl = default_ttl
self._redis = None
async def _get_client(self):
"""Lazy initialization of Redis client."""
if self._redis is None:
try:
import redis.asyncio as redis
self._redis = redis.from_url(self._url, decode_responses=False)
except ImportError:
logger.warning("redis package not installed, falling back to in-memory cache")
raise
return self._redis
def _make_key(self, key: str) -> str:
return f"{self._prefix}{key}"
async def get(self, key: str) -> Optional[T]:
try:
client = await self._get_client()
data = await client.get(self._make_key(key))
if data is None:
return None
return pickle.loads(data)
except Exception as e:
logger.error("cache_get_error", key=key, error=str(e))
return None
async def set(self, key: str, value: T, ttl: Optional[int] = None) -> None:
try:
client = await self._get_client()
ttl = ttl or self._default_ttl
data = pickle.dumps(value)
await client.setex(self._make_key(key), ttl, data)
except Exception as e:
logger.error("cache_set_error", key=key, error=str(e))
async def delete(self, key: str) -> None:
try:
client = await self._get_client()
await client.delete(self._make_key(key))
except Exception as e:
logger.error("cache_delete_error", key=key, error=str(e))
async def exists(self, key: str) -> bool:
try:
client = await self._get_client()
return await client.exists(self._make_key(key)) > 0
except Exception as e:
logger.error("cache_exists_error", key=key, error=str(e))
return False
async def clear(self) -> None:
try:
client = await self._get_client()
# Use SCAN to find all keys with prefix
cursor = 0
while True:
cursor, keys = await client.scan(cursor, match=f"{self._prefix}*")
if keys:
await client.delete(*keys)
if cursor == 0:
break
except Exception as e:
logger.error("cache_clear_error", error=str(e))
class Cache:
"""
High-level cache interface with automatic backend selection.
Tries Redis first, falls back to in-memory if unavailable.
"""
def __init__(
self,
redis_url: Optional[str] = None,
max_memory_items: int = 1000,
default_ttl: int = 3600,
):
self._default_ttl = default_ttl
self._backend: Optional[CacheBackend] = None
self._redis_url = redis_url
self._max_memory_items = max_memory_items
async def _get_backend(self) -> CacheBackend:
"""Get or initialize cache backend."""
if self._backend is not None:
return self._backend
# Try Redis first
if self._redis_url:
try:
backend = RedisCache(
url=self._redis_url,
default_ttl=self._default_ttl,
)
# Test connection
await backend._get_client()
self._backend = backend
logger.info("cache_initialized", backend="redis")
return self._backend
except Exception as e:
logger.warning("redis_unavailable", error=str(e))
# Fall back to in-memory
self._backend = InMemoryCache(
max_size=self._max_memory_items,
default_ttl=self._default_ttl,
)
logger.info("cache_initialized", backend="memory")
return self._backend
async def get(self, key: str) -> Optional[Any]:
backend = await self._get_backend()
return await backend.get(key)
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
backend = await self._get_backend()
await backend.set(key, value, ttl)
async def delete(self, key: str) -> None:
backend = await self._get_backend()
await backend.delete(key)
async def get_or_set(
self,
key: str,
factory,
ttl: Optional[int] = None,
) -> Any:
"""
Get value from cache or compute and store it.
Args:
key: Cache key
factory: Async callable to compute value if not cached
ttl: Optional TTL in seconds
Returns:
Cached or computed value
"""
value = await self.get(key)
if value is not None:
logger.debug("cache_hit", key=key)
return value
logger.debug("cache_miss", key=key)
value = await factory()
await self.set(key, value, ttl)
return value
def cache_key(*args, **kwargs) -> str:
"""
Generate a cache key from arguments.
Uses MD5 hash for consistent, short keys.
"""
data = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True, default=str)
return hashlib.md5(data.encode()).hexdigest()
def cached(ttl: int = 3600, key_prefix: str = ""):
"""
Decorator for caching async function results.
Args:
ttl: Cache TTL in seconds
key_prefix: Optional prefix for cache keys
Usage:
@cached(ttl=300, key_prefix="ndvi")
async def calculate_ndvi(lat, lon):
...
"""
def decorator(func):
import functools
# Create cache instance for this function
_cache = InMemoryCache(default_ttl=ttl)
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# Generate cache key
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
# Try cache
cached_value = await _cache.get(key)
if cached_value is not None:
return cached_value
# Compute and cache
result = await func(*args, **kwargs)
await _cache.set(key, result, ttl)
return result
# Expose cache for testing/management
wrapper._cache = _cache
return wrapper
return decorator