cache.rs•12.6 kB
use codegraph_core::NodeId;
use dashmap::DashMap;
use parking_lot::{Mutex, RwLock};
use std::collections::{HashMap, VecDeque};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub cleanup_interval: Duration,
pub enable_stats: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
ttl: Duration::from_secs(3600), // 1 hour
cleanup_interval: Duration::from_secs(300), // 5 minutes
enable_stats: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub entries: usize,
pub hit_ratio: f64,
}
impl CacheStats {
pub fn new() -> Self {
Self {
hits: 0,
misses: 0,
evictions: 0,
entries: 0,
hit_ratio: 0.0,
}
}
pub fn update_hit_ratio(&mut self) {
let total = self.hits + self.misses;
self.hit_ratio = if total > 0 {
self.hits as f64 / total as f64
} else {
0.0
};
}
}
#[derive(Debug, Clone)]
struct CacheEntry<T> {
value: T,
created_at: Instant,
last_accessed: Instant,
access_count: u64,
}
impl<T> CacheEntry<T> {
fn new(value: T) -> Self {
let now = Instant::now();
Self {
value,
created_at: now,
last_accessed: now,
access_count: 1,
}
}
fn access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
}
pub struct LfuCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
data: Arc<DashMap<K, CacheEntry<V>>>,
frequency: Arc<DashMap<K, u64>>,
access_order: Arc<Mutex<VecDeque<K>>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStats>>,
cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
impl<K, V> LfuCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(config: CacheConfig) -> Self {
let cache = Self {
data: Arc::new(DashMap::new()),
frequency: Arc::new(DashMap::new()),
access_order: Arc::new(Mutex::new(VecDeque::new())),
config: config.clone(),
stats: Arc::new(RwLock::new(CacheStats::new())),
cleanup_handle: None,
};
cache
}
pub fn start_cleanup_task(&mut self) {
let data = Arc::clone(&self.data);
let frequency = Arc::clone(&self.frequency);
let access_order = Arc::clone(&self.access_order);
let stats = Arc::clone(&self.stats);
let interval = self.config.cleanup_interval;
let ttl = self.config.ttl;
let max_entries = self.config.max_entries;
let handle = tokio::spawn(async move {
let mut cleanup_interval = time::interval(interval);
loop {
cleanup_interval.tick().await;
// Remove expired entries
let expired_keys: Vec<K> = data
.iter()
.filter_map(|entry| {
if entry.value().is_expired(ttl) {
Some(entry.key().clone())
} else {
None
}
})
.collect();
for key in expired_keys {
data.remove(&key);
frequency.remove(&key);
if let Some(mut order) = access_order.try_lock() {
order.retain(|k| k != &key);
}
}
// Evict least frequently used entries if over capacity
if data.len() > max_entries {
let evict_count = data.len() - max_entries;
let mut keys_to_evict = Vec::new();
// Find least frequently used keys
let mut freq_list: Vec<_> = frequency
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
freq_list.sort_by(|a, b| a.1.cmp(&b.1));
for (key, _) in freq_list.into_iter().take(evict_count) {
keys_to_evict.push(key);
}
for key in keys_to_evict {
data.remove(&key);
frequency.remove(&key);
if let Some(mut order) = access_order.try_lock() {
order.retain(|k| k != &key);
}
if let Some(mut stats) = stats.try_write() {
stats.evictions += 1;
}
}
}
// Update stats
if let Some(mut stats) = stats.try_write() {
stats.entries = data.len();
stats.update_hit_ratio();
}
}
});
self.cleanup_handle = Some(handle);
}
pub fn get(&self, key: &K) -> Option<V> {
if let Some(mut entry) = self.data.get_mut(key) {
entry.access();
// Update frequency
self.frequency
.entry(key.clone())
.and_modify(|freq| *freq += 1)
.or_insert(1);
// Update access order
if let Some(mut order) = self.access_order.try_lock() {
order.retain(|k| k != key);
order.push_back(key.clone());
}
if self.config.enable_stats {
if let Some(mut stats) = self.stats.try_write() {
stats.hits += 1;
stats.update_hit_ratio();
}
}
Some(entry.value.clone())
} else {
if self.config.enable_stats {
if let Some(mut stats) = self.stats.try_write() {
stats.misses += 1;
stats.update_hit_ratio();
}
}
None
}
}
pub fn put(&self, key: K, value: V) {
// Check if we need to evict before inserting
if self.data.len() >= self.config.max_entries {
self.evict_lfu();
}
let entry = CacheEntry::new(value);
self.data.insert(key.clone(), entry);
self.frequency.insert(key.clone(), 1);
if let Some(mut order) = self.access_order.try_lock() {
order.push_back(key);
}
if self.config.enable_stats {
if let Some(mut stats) = self.stats.try_write() {
stats.entries = self.data.len();
}
}
}
fn evict_lfu(&self) {
// Find the least frequently used key
if let Some(min_entry) = self
.frequency
.iter()
.min_by(|a, b| a.value().cmp(b.value()))
{
let key_to_evict = min_entry.key().clone();
self.data.remove(&key_to_evict);
self.frequency.remove(&key_to_evict);
if let Some(mut order) = self.access_order.try_lock() {
order.retain(|k| k != &key_to_evict);
}
if self.config.enable_stats {
if let Some(mut stats) = self.stats.try_write() {
stats.evictions += 1;
}
}
}
}
pub fn remove(&self, key: &K) -> Option<V> {
let value = self.data.remove(key).map(|(_, entry)| entry.value);
self.frequency.remove(key);
if let Some(mut order) = self.access_order.try_lock() {
order.retain(|k| k != key);
}
value
}
pub fn clear(&self) {
self.data.clear();
self.frequency.clear();
if let Some(mut order) = self.access_order.try_lock() {
order.clear();
}
if self.config.enable_stats {
if let Some(mut stats) = self.stats.try_write() {
*stats = CacheStats::new();
}
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn get_stats(&self) -> CacheStats {
if let Some(stats) = self.stats.try_read() {
stats.clone()
} else {
CacheStats::new()
}
}
pub fn contains_key(&self, key: &K) -> bool {
self.data.contains_key(key)
}
}
impl<K, V> Drop for LfuCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
fn drop(&mut self) {
if let Some(handle) = self.cleanup_handle.take() {
handle.abort();
}
}
}
// Specialized caches for search results
pub type QueryResultCache = LfuCache<QueryHash, Vec<(NodeId, f32)>>;
pub type EmbeddingCache = LfuCache<NodeId, Vec<f32>>;
pub type ContextCache = LfuCache<ContextHash, f32>;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct QueryHash {
embedding_hash: u64,
k: usize,
config_hash: u64,
}
impl QueryHash {
pub fn new(embedding: &[f32], k: usize, config: &str) -> Self {
let mut hasher = DefaultHasher::new();
// Hash embedding (sample every 10th element for performance)
for (i, &val) in embedding.iter().enumerate() {
if i % 10 == 0 {
(val as u32).hash(&mut hasher);
}
}
let embedding_hash = hasher.finish();
let mut config_hasher = DefaultHasher::new();
config.hash(&mut config_hasher);
let config_hash = config_hasher.finish();
Self {
embedding_hash,
k,
config_hash,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ContextHash {
nodes: Vec<NodeId>,
context_type: String,
}
impl ContextHash {
pub fn new(nodes: Vec<NodeId>, context_type: String) -> Self {
Self {
nodes,
context_type,
}
}
}
// High-performance cache manager for the KNN engine
pub struct SearchCacheManager {
query_cache: QueryResultCache,
embedding_cache: EmbeddingCache,
context_cache: ContextCache,
}
impl SearchCacheManager {
pub fn new(
query_config: CacheConfig,
embedding_config: CacheConfig,
context_config: CacheConfig,
) -> Self {
let mut query_cache = QueryResultCache::new(query_config);
let mut embedding_cache = EmbeddingCache::new(embedding_config);
let mut context_cache = ContextCache::new(context_config);
query_cache.start_cleanup_task();
embedding_cache.start_cleanup_task();
context_cache.start_cleanup_task();
Self {
query_cache,
embedding_cache,
context_cache,
}
}
pub fn get_query_results(&self, query_hash: &QueryHash) -> Option<Vec<(NodeId, f32)>> {
self.query_cache.get(query_hash)
}
pub fn cache_query_results(&self, query_hash: QueryHash, results: Vec<(NodeId, f32)>) {
self.query_cache.put(query_hash, results);
}
pub fn get_embedding(&self, node_id: &NodeId) -> Option<Vec<f32>> {
self.embedding_cache.get(node_id)
}
pub fn cache_embedding(&self, node_id: NodeId, embedding: Vec<f32>) {
self.embedding_cache.put(node_id, embedding);
}
pub fn get_context_score(&self, context_hash: &ContextHash) -> Option<f32> {
self.context_cache.get(context_hash)
}
pub fn cache_context_score(&self, context_hash: ContextHash, score: f32) {
self.context_cache.put(context_hash, score);
}
pub fn clear_all(&self) {
self.query_cache.clear();
self.embedding_cache.clear();
self.context_cache.clear();
}
pub fn get_cache_stats(&self) -> HashMap<String, CacheStats> {
let mut stats = HashMap::new();
stats.insert("query_cache".to_string(), self.query_cache.get_stats());
stats.insert(
"embedding_cache".to_string(),
self.embedding_cache.get_stats(),
);
stats.insert("context_cache".to_string(), self.context_cache.get_stats());
stats
}
}