cache_utils.py•8.2 kB
#!/usr/bin/env python3
"""
轻量级缓存工具
基于Python内置functools.lru_cache和自定义缓存实现
"""
import time
import hashlib
from functools import lru_cache, wraps
from typing import Dict, Any, List, Optional, Callable, Tuple
from collections import OrderedDict
class MemorySearchCache:
"""记忆搜索缓存 - 专门用于缓存搜索结果"""
def __init__(self, max_size: int = 1000, ttl_seconds: int = 300):
self.max_size = max_size
self.ttl_seconds = ttl_seconds
self.cache: OrderedDict = OrderedDict()
self.timestamps: Dict[str, float] = {}
self.hit_count = 0
self.miss_count = 0
def _generate_key(self, query: str, top_k: int, use_reranker: bool,
use_feedback_boost: bool, use_time_decay: bool) -> str:
"""生成缓存键"""
# 创建查询的标准化表示
key_data = f"{query.lower().strip()}|{top_k}|{use_reranker}|{use_feedback_boost}|{use_time_decay}"
return hashlib.md5(key_data.encode()).hexdigest()
def get(self, query: str, top_k: int = 5, use_reranker: bool = True,
use_feedback_boost: bool = True, use_time_decay: bool = True) -> Optional[List[Dict[str, Any]]]:
"""获取缓存的搜索结果"""
key = self._generate_key(query, top_k, use_reranker, use_feedback_boost, use_time_decay)
# 检查是否存在且未过期
if key in self.cache:
timestamp = self.timestamps.get(key, 0)
if time.time() - timestamp < self.ttl_seconds:
# 移到最后(LRU更新)
self.cache.move_to_end(key)
self.hit_count += 1
return self.cache[key]
else:
# 过期,删除
del self.cache[key]
del self.timestamps[key]
self.miss_count += 1
return None
def put(self, query: str, results: List[Dict[str, Any]], top_k: int = 5,
use_reranker: bool = True, use_feedback_boost: bool = True,
use_time_decay: bool = True):
"""存储搜索结果到缓存"""
key = self._generate_key(query, top_k, use_reranker, use_feedback_boost, use_time_decay)
# 如果缓存已满,删除最旧的项
if len(self.cache) >= self.max_size:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
del self.timestamps[oldest_key]
# 存储结果
self.cache[key] = results.copy() # 深拷贝避免外部修改
self.timestamps[key] = time.time()
def clear(self):
"""清空缓存"""
self.cache.clear()
self.timestamps.clear()
self.hit_count = 0
self.miss_count = 0
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_requests = self.hit_count + self.miss_count
hit_rate = self.hit_count / total_requests if total_requests > 0 else 0
return {
"size": len(self.cache),
"max_size": self.max_size,
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"hit_rate": hit_rate,
"ttl_seconds": self.ttl_seconds
}
def print_stats(self):
"""打印缓存统计信息"""
stats = self.get_stats()
print(f"📊 搜索缓存统计:")
print(f" 缓存大小: {stats['size']}/{stats['max_size']}")
print(f" 命中次数: {stats['hit_count']}")
print(f" 未命中次数: {stats['miss_count']}")
print(f" 命中率: {stats['hit_rate']:.2%}")
print(f" TTL: {stats['ttl_seconds']}秒")
class EmbeddingCache:
"""嵌入向量缓存 - 缓存文本的嵌入向量"""
def __init__(self, max_size: int = 5000):
self.max_size = max_size
self.cache: OrderedDict = OrderedDict()
self.hit_count = 0
self.miss_count = 0
def _generate_key(self, text: str) -> str:
"""生成缓存键"""
return hashlib.md5(text.encode()).hexdigest()
def get(self, text: str) -> Optional[List[float]]:
"""获取缓存的嵌入向量"""
key = self._generate_key(text)
if key in self.cache:
# 移到最后(LRU更新)
self.cache.move_to_end(key)
self.hit_count += 1
return self.cache[key]
self.miss_count += 1
return None
def put(self, text: str, embedding: List[float]):
"""存储嵌入向量到缓存"""
key = self._generate_key(text)
# 如果缓存已满,删除最旧的项
if len(self.cache) >= self.max_size:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
# 存储嵌入向量
self.cache[key] = embedding.copy()
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_requests = self.hit_count + self.miss_count
hit_rate = self.hit_count / total_requests if total_requests > 0 else 0
return {
"size": len(self.cache),
"max_size": self.max_size,
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"hit_rate": hit_rate
}
def cached_search(cache_instance: MemorySearchCache):
"""搜索结果缓存装饰器"""
def decorator(func: Callable):
@wraps(func)
def wrapper(self, query: str, top_k: int = 5, use_reranker: bool = True,
use_feedback_boost: bool = True, use_time_decay: bool = True, **kwargs):
# 尝试从缓存获取
cached_result = cache_instance.get(query, top_k, use_reranker,
use_feedback_boost, use_time_decay)
if cached_result is not None:
print(f"🎯 缓存命中,直接返回结果")
return cached_result
# 缓存未命中,执行原函数
result = func(self, query, top_k, use_reranker, use_feedback_boost,
use_time_decay, **kwargs)
# 存储到缓存
if result:
cache_instance.put(query, result, top_k, use_reranker,
use_feedback_boost, use_time_decay)
return result
return wrapper
return decorator
def cached_embedding(cache_instance: EmbeddingCache):
"""嵌入向量缓存装饰器"""
def decorator(func: Callable):
@wraps(func)
def wrapper(self, text: str, **kwargs):
# 尝试从缓存获取
cached_embedding = cache_instance.get(text)
if cached_embedding is not None:
return cached_embedding
# 缓存未命中,执行原函数
embedding = func(self, text, **kwargs)
# 存储到缓存
if embedding:
cache_instance.put(text, embedding)
return embedding
return wrapper
return decorator
# 全局缓存实例
search_cache = MemorySearchCache(max_size=1000, ttl_seconds=300) # 5分钟TTL
embedding_cache = EmbeddingCache(max_size=5000)
def print_all_cache_stats():
"""打印所有缓存的统计信息"""
print("📊 缓存系统统计:")
print("=" * 40)
print("🔍 搜索结果缓存:")
search_cache.print_stats()
print("\n🧮 嵌入向量缓存:")
embedding_stats = embedding_cache.get_stats()
print(f" 缓存大小: {embedding_stats['size']}/{embedding_stats['max_size']}")
print(f" 命中次数: {embedding_stats['hit_count']}")
print(f" 未命中次数: {embedding_stats['miss_count']}")
print(f" 命中率: {embedding_stats['hit_rate']:.2%}")
def clear_all_caches():
"""清空所有缓存"""
search_cache.clear()
embedding_cache.cache.clear()
embedding_cache.hit_count = 0
embedding_cache.miss_count = 0
print("🧹 所有缓存已清空")