"""TTL cache with optional disk persistence."""
import json
import pickle
from collections import OrderedDict
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional
class TTLCache:
"""Time-to-live cache with LRU eviction and disk persistence."""
def __init__(self, ttl_seconds: int = 3600, max_size: int = 100, cache_dir: Optional[Path] = None):
self.ttl = timedelta(seconds=ttl_seconds)
self.max_size = max_size
self.cache_dir = cache_dir
self._cache: OrderedDict[str, Any] = OrderedDict()
self._timestamps: dict[str, datetime] = {}
self.hits = 0
self.misses = 0
if self.cache_dir:
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._load_from_disk()
def _load_from_disk(self):
"""Load cache from disk if available."""
cache_file = self.cache_dir / "cache.pkl"
meta_file = self.cache_dir / "cache_meta.json"
if cache_file.exists() and meta_file.exists():
try:
with open(cache_file, "rb") as f:
self._cache = pickle.load(f)
with open(meta_file, "r") as f:
meta = json.load(f)
self._timestamps = {
k: datetime.fromisoformat(v)
for k, v in meta.get("timestamps", {}).items()
}
self.hits = meta.get("hits", 0)
self.misses = meta.get("misses", 0)
# Clean expired entries
self._cleanup_expired()
except Exception:
# If loading fails, start fresh
self._cache.clear()
self._timestamps.clear()
def _save_to_disk(self):
"""Persist cache to disk."""
if not self.cache_dir:
return
cache_file = self.cache_dir / "cache.pkl"
meta_file = self.cache_dir / "cache_meta.json"
try:
with open(cache_file, "wb") as f:
pickle.dump(self._cache, f)
meta = {
"timestamps": {k: v.isoformat() for k, v in self._timestamps.items()},
"hits": self.hits,
"misses": self.misses
}
with open(meta_file, "w") as f:
json.dump(meta, f)
except Exception:
pass
def _cleanup_expired(self):
"""Remove expired entries."""
now = datetime.now()
expired = [
key for key, ts in self._timestamps.items()
if now - ts > self.ttl
]
for key in expired:
self._cache.pop(key, None)
self._timestamps.pop(key, None)
def get(self, key: str) -> Optional[Any]:
"""Get value from cache if not expired."""
self._cleanup_expired()
if key in self._cache and key in self._timestamps:
self.hits += 1
# Move to end (LRU)
self._cache.move_to_end(key)
return self._cache[key]
self.misses += 1
return None
def set(self, key: str, value: Any):
"""Set value in cache with current timestamp."""
# Evict oldest if at capacity
if len(self._cache) >= self.max_size and key not in self._cache:
oldest_key = next(iter(self._cache))
self._cache.pop(oldest_key)
self._timestamps.pop(oldest_key)
self._cache[key] = value
self._timestamps[key] = datetime.now()
self._cache.move_to_end(key)
# Periodically save to disk
if len(self._cache) % 10 == 0:
self._save_to_disk()
def clear(self):
"""Clear all cache entries."""
self._cache.clear()
self._timestamps.clear()
self.hits = 0
self.misses = 0
self._save_to_disk()
def stats(self) -> dict:
"""Get cache statistics."""
self._cleanup_expired()
total_requests = self.hits + self.misses
hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0
return {
"size": len(self._cache),
"max_size": self.max_size,
"hits": self.hits,
"misses": self.misses,
"hit_rate": round(hit_rate, 2),
"ttl_seconds": int(self.ttl.total_seconds())
}
def __del__(self):
"""Save cache on cleanup."""
self._save_to_disk()