"""Simple TTL cache for stock data."""
import os
import time
from typing import Any, Optional, TypeVar, Generic
T = TypeVar('T')
def _get_env_int(key: str, default: int) -> int:
"""Get environment variable as integer with default."""
value = os.getenv(key)
if value is None:
return default
try:
return int(value)
except ValueError:
return default
class CacheEntry(Generic[T]):
"""A single cache entry with TTL."""
def __init__(self, value: T, ttl: int):
self.value = value
self.expires_at = time.time() + ttl
def is_expired(self) -> bool:
"""Check if this entry has expired."""
return time.time() > self.expires_at
class TTLCache:
"""Time-to-live cache with separate TTLs per data type."""
def __init__(
self,
quote_ttl: int = 5,
fundamentals_ttl: int = 3600,
statements_ttl: int = 86400,
history_ttl: int = 300,
search_ttl: int = 600
):
self.quote_ttl = quote_ttl
self.fundamentals_ttl = fundamentals_ttl
self.statements_ttl = statements_ttl
self.history_ttl = history_ttl
self.search_ttl = search_ttl
self._cache: dict[str, CacheEntry] = {}
def _make_key(self, prefix: str, *args) -> str:
"""Create a cache key from prefix and arguments."""
return f"{prefix}:" + ":".join(str(arg) for arg in args)
def get(self, key: str) -> Optional[Any]:
"""Get a value from cache if not expired."""
entry = self._cache.get(key)
if entry is None:
return None
if entry.is_expired():
del self._cache[key]
return None
return entry.value
def set(self, key: str, value: Any, ttl: int):
"""Set a value in cache with TTL."""
self._cache[key] = CacheEntry(value, ttl)
def get_quote(self, symbol: str) -> Optional[Any]:
"""Get cached quote."""
key = self._make_key("quote", symbol)
return self.get(key)
def set_quote(self, symbol: str, value: Any):
"""Cache a quote."""
key = self._make_key("quote", symbol)
self.set(key, value, self.quote_ttl)
def get_fundamentals(self, symbol: str, period: str) -> Optional[Any]:
"""Get cached fundamentals."""
key = self._make_key("fundamentals", symbol, period)
return self.get(key)
def set_fundamentals(self, symbol: str, period: str, value: Any):
"""Cache fundamentals."""
key = self._make_key("fundamentals", symbol, period)
self.set(key, value, self.fundamentals_ttl)
def get_statements(self, symbol: str, statement: str, period: str) -> Optional[Any]:
"""Get cached financial statements."""
key = self._make_key("statements", symbol, statement, period)
return self.get(key)
def set_statements(self, symbol: str, statement: str, period: str, value: Any):
"""Cache financial statements."""
key = self._make_key("statements", symbol, statement, period)
self.set(key, value, self.statements_ttl)
def get_history(self, symbol: str, start: str, end: str, interval: str) -> Optional[Any]:
"""Get cached history."""
key = self._make_key("history", symbol, start, end, interval)
return self.get(key)
def set_history(self, symbol: str, start: str, end: str, interval: str, value: Any):
"""Cache history."""
key = self._make_key("history", symbol, start, end, interval)
self.set(key, value, self.history_ttl)
def get_search(self, query: str, limit: int) -> Optional[Any]:
"""Get cached search results."""
key = self._make_key("search", query, limit)
return self.get(key)
def set_search(self, query: str, limit: int, value: Any):
"""Cache search results."""
key = self._make_key("search", query, limit)
self.set(key, value, self.search_ttl)
def clear(self):
"""Clear all cache entries."""
self._cache.clear()
def cleanup(self):
"""Remove all expired entries."""
expired_keys = [
key for key, entry in self._cache.items()
if entry.is_expired()
]
for key in expired_keys:
del self._cache[key]
# Global cache instance
_global_cache: Optional[TTLCache] = None
def get_cache() -> TTLCache:
"""Get the global cache instance with environment-configured TTLs."""
global _global_cache
if _global_cache is None:
_global_cache = TTLCache(
quote_ttl=_get_env_int("STOCK_MCP_CACHE_TTL_QUOTE", 5),
fundamentals_ttl=_get_env_int("STOCK_MCP_CACHE_TTL_FUNDAMENTALS", 3600),
statements_ttl=_get_env_int("STOCK_MCP_CACHE_TTL_STATEMENTS", 86400),
history_ttl=_get_env_int("STOCK_MCP_CACHE_TTL_HISTORY", 300),
search_ttl=_get_env_int("STOCK_MCP_CACHE_TTL_SEARCH", 600),
)
return _global_cache