"""缓存管理器 - 用于优化查询性能."""
import time
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from .logger import default_logger as logger
class CacheManager:
"""简单的内存缓存管理器."""
def __init__(self, default_ttl: int = 300):
"""
初始化缓存管理器.
Args:
default_ttl: 默认缓存过期时间(秒),默认 5 分钟
"""
self._cache: Dict[str, Dict[str, Any]] = {}
self.default_ttl = default_ttl
# 在 MCP 模式下禁用启动日志,避免干扰 JSON-RPC 通信
# logger.info(f"缓存管理器已初始化,默认 TTL: {default_ttl} 秒")
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值.
Args:
key: 缓存键
Returns:
缓存值,如果不存在或已过期则返回 None
"""
if key not in self._cache:
return None
entry = self._cache[key]
expire_time = entry.get("expire_time")
# 检查是否过期
if expire_time and datetime.now() > expire_time:
del self._cache[key]
logger.debug(f"缓存键 '{key}' 已过期,已删除")
return None
logger.debug(f"缓存命中: {key}")
return entry.get("value")
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""
设置缓存值.
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒),如果为 None 则使用默认值
"""
if ttl is None:
ttl = self.default_ttl
expire_time = datetime.now() + timedelta(seconds=ttl)
self._cache[key] = {
"value": value,
"expire_time": expire_time,
"created_at": datetime.now()
}
logger.debug(f"缓存已设置: {key}, TTL: {ttl} 秒")
def delete(self, key: str) -> None:
"""删除缓存键."""
if key in self._cache:
del self._cache[key]
logger.debug(f"缓存已删除: {key}")
def clear(self) -> None:
"""清空所有缓存."""
count = len(self._cache)
self._cache.clear()
logger.info(f"已清空所有缓存(共 {count} 个条目)")
def invalidate_pattern(self, pattern: str) -> None:
"""
根据模式删除匹配的缓存键.
Args:
pattern: 模式字符串(简单的前缀匹配)
"""
keys_to_delete = [key for key in self._cache.keys() if key.startswith(pattern)]
for key in keys_to_delete:
del self._cache[key]
logger.info(f"已删除匹配模式 '{pattern}' 的缓存(共 {len(keys_to_delete)} 个)")
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息."""
now = datetime.now()
valid_entries = 0
expired_entries = 0
for entry in self._cache.values():
expire_time = entry.get("expire_time")
if expire_time and now > expire_time:
expired_entries += 1
else:
valid_entries += 1
return {
"total_entries": len(self._cache),
"valid_entries": valid_entries,
"expired_entries": expired_entries,
"default_ttl": self.default_ttl
}
def cleanup_expired(self) -> int:
"""
清理过期的缓存条目.
Returns:
清理的条目数量
"""
now = datetime.now()
keys_to_delete = []
for key, entry in self._cache.items():
expire_time = entry.get("expire_time")
if expire_time and now > expire_time:
keys_to_delete.append(key)
for key in keys_to_delete:
del self._cache[key]
if keys_to_delete:
logger.info(f"已清理 {len(keys_to_delete)} 个过期缓存条目")
return len(keys_to_delete)
# 全局缓存实例
_cache_manager: Optional[CacheManager] = None
def get_cache_manager() -> CacheManager:
"""获取全局缓存管理器实例."""
global _cache_manager
if _cache_manager is None:
_cache_manager = CacheManager()
return _cache_manager