Skip to main content
Glama

CodeGraph CLI MCP Server

by Jakedismo
knn.rsโ€ข31.1 kB
use crate::{SearchCacheManager, QueryHash, ContextHash, CacheConfig}; use codegraph_core::{CodeGraphError, CodeNode, Language, NodeId, NodeType, Result}; use dashmap::DashMap; use faiss::{Index, IndexImpl, MetricType}; use parking_lot::{RwLock, Mutex}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet, BinaryHeap}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Semaphore; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchConfig { pub k: usize, pub precision_recall_tradeoff: f32, // 0.0 = max recall, 1.0 = max precision pub enable_clustering: bool, pub cluster_threshold: f32, pub max_parallel_queries: usize, pub context_weight: f32, pub language_boost: f32, pub type_boost: f32, } impl Default for SearchConfig { fn default() -> Self { Self { k: 10, precision_recall_tradeoff: 0.5, enable_clustering: true, cluster_threshold: 0.8, max_parallel_queries: 8, context_weight: 0.3, language_boost: 0.2, type_boost: 0.1, } } } #[derive(Debug, Clone)] pub struct ContextualSearchResult { pub node_id: NodeId, pub similarity_score: f32, pub context_score: f32, pub final_score: f32, pub node: Option<CodeNode>, pub cluster_id: Option<usize>, pub explanation: String, } #[derive(Debug, Clone)] struct CodeCluster { id: usize, centroid: Vec<f32>, nodes: HashSet<NodeId>, language: Option<Language>, node_type: Option<NodeType>, avg_similarity: f32, } pub struct OptimizedKnnEngine { // Core FAISS indices with different configurations exact_index: Arc<RwLock<Option<IndexImpl>>>, approximate_index: Arc<RwLock<Option<IndexImpl>>>, // Node mappings and metadata id_mapping: Arc<DashMap<i64, NodeId>>, reverse_mapping: Arc<DashMap<NodeId, i64>>, node_metadata: Arc<DashMap<NodeId, CodeNode>>, embeddings: Arc<DashMap<NodeId, Vec<f32>>>, // Clustering system clusters: Arc<RwLock<Vec<CodeCluster>>>, node_to_cluster: Arc<DashMap<NodeId, usize>>, // Performance optimization dimension: usize, next_id: Arc<Mutex<i64>>, query_semaphore: Arc<Semaphore>, cache_manager: Arc<SearchCacheManager>, // Configuration config: SearchConfig, } impl OptimizedKnnEngine { pub fn new(dimension: usize, config: SearchConfig) -> Result<Self> { let query_semaphore = Arc::new(Semaphore::new(config.max_parallel_queries)); // Initialize cache manager with optimized configurations let cache_manager = Arc::new(SearchCacheManager::new( CacheConfig { max_entries: 5000, ttl: Duration::from_secs(1800), // 30 minutes cleanup_interval: Duration::from_secs(60), enable_stats: true, }, CacheConfig { max_entries: 10000, ttl: Duration::from_secs(3600), // 1 hour cleanup_interval: Duration::from_secs(120), enable_stats: true, }, CacheConfig { max_entries: 2000, ttl: Duration::from_secs(900), // 15 minutes cleanup_interval: Duration::from_secs(30), enable_stats: true, }, )); Ok(Self { exact_index: Arc::new(RwLock::new(None)), approximate_index: Arc::new(RwLock::new(None)), id_mapping: Arc::new(DashMap::new()), reverse_mapping: Arc::new(DashMap::new()), node_metadata: Arc::new(DashMap::new()), embeddings: Arc::new(DashMap::new()), clusters: Arc::new(RwLock::new(Vec::new())), node_to_cluster: Arc::new(DashMap::new()), dimension, next_id: Arc::new(Mutex::new(0)), query_semaphore, cache_manager, config, }) } pub async fn build_indices(&self, nodes: &[CodeNode]) -> Result<()> { let start_time = Instant::now(); // Filter nodes with embeddings let embedded_nodes: Vec<_> = nodes .iter() .filter(|node| node.embedding.is_some()) .collect(); if embedded_nodes.is_empty() { return Ok(()); } // Build exact index for high precision queries self.build_exact_index(&embedded_nodes).await?; // Build approximate index for fast queries self.build_approximate_index(&embedded_nodes).await?; // Build clustering if enabled if self.config.enable_clustering { self.build_clusters(&embedded_nodes).await?; } tracing::info!( "Built KNN indices for {} nodes in {:?}", embedded_nodes.len(), start_time.elapsed() ); Ok(()) } async fn build_exact_index(&self, nodes: &[&CodeNode]) -> Result<()> { let mut exact_index_guard = self.exact_index.write(); let exact_index = faiss::index_factory( self.dimension, "Flat", MetricType::InnerProduct, ) .map_err(|e| CodeGraphError::Vector(e.to_string()))?; let (vectors, mappings) = self.prepare_vectors_and_mappings(nodes)?; exact_index .add(&vectors) .map_err(|e| CodeGraphError::Vector(e.to_string()))?; self.store_mappings(mappings); *exact_index_guard = Some(exact_index); Ok(()) } async fn build_approximate_index(&self, nodes: &[&CodeNode]) -> Result<()> { let mut approx_index_guard = self.approximate_index.write(); // Choose index type based on dataset size and precision/recall tradeoff let index_description = self.get_optimal_index_description(nodes.len()); let approx_index = faiss::index_factory( self.dimension, &index_description, MetricType::InnerProduct, ) .map_err(|e| CodeGraphError::Vector(e.to_string()))?; let (vectors, _) = self.prepare_vectors_and_mappings(nodes)?; approx_index .add(&vectors) .map_err(|e| CodeGraphError::Vector(e.to_string()))?; *approx_index_guard = Some(approx_index); Ok(()) } fn get_optimal_index_description(&self, node_count: usize) -> String { let precision_factor = self.config.precision_recall_tradeoff; match node_count { 0..=1000 => "Flat".to_string(), 1001..=10000 => { if precision_factor > 0.7 { "IVF100,Flat".to_string() } else { "IVF50,Flat".to_string() } } 10001..=100000 => { if precision_factor > 0.7 { "IVF1000,PQ8".to_string() } else { "IVF500,PQ8".to_string() } } _ => { if precision_factor > 0.7 { "IVF4000,PQ16".to_string() } else { "IVF2000,PQ16".to_string() } } } } fn prepare_vectors_and_mappings(&self, nodes: &[&CodeNode]) -> Result<(Vec<f32>, Vec<(NodeId, Vec<f32>)>)> { let estimated = nodes.len(); let mut vectors = Vec::with_capacity(estimated.saturating_mul(self.dimension)); let mut mappings = Vec::with_capacity(estimated); for node in nodes { if let Some(embedding) = &node.embedding { vectors.extend_from_slice(embedding); mappings.push((node.id, embedding.clone())); // Store node metadata for context scoring self.node_metadata.insert(node.id, (*node).clone()); } } Ok((vectors, mappings)) } fn store_mappings(&self, mappings: Vec<(NodeId, Vec<f32>)>) { let mut next_id_guard = self.next_id.lock(); for (node_id, embedding) in mappings { let faiss_id = *next_id_guard; *next_id_guard += 1; self.id_mapping.insert(faiss_id, node_id); self.reverse_mapping.insert(node_id, faiss_id); self.embeddings.insert(node_id, embedding); } } pub async fn parallel_similarity_search( &self, queries: Vec<Vec<f32>>, config: Option<SearchConfig>, ) -> Result<Vec<Vec<ContextualSearchResult>>> { let search_config = config.unwrap_or_else(|| self.config.clone()); let start_time = Instant::now(); // Process queries in parallel with semaphore-controlled concurrency let results: Result<Vec<_>> = futures::future::try_join_all( queries.into_iter().map(|query| { let engine = self; let config = search_config.clone(); async move { let _permit = engine.query_semaphore.acquire().await .map_err(|e| CodeGraphError::Vector(e.to_string()))?; engine.single_similarity_search(query, config).await } }) ).await; let results = results?; tracing::debug!( "Parallel search completed for {} queries in {:?}", results.len(), start_time.elapsed() ); Ok(results) } pub async fn single_similarity_search( &self, query_embedding: Vec<f32>, config: SearchConfig, ) -> Result<Vec<ContextualSearchResult>> { if query_embedding.len() != self.dimension { return Err(CodeGraphError::Vector(format!( "Query embedding dimension {} doesn't match index dimension {}", query_embedding.len(), self.dimension ))); } // Create cache key for the query let config_str = format!("{:?}", config); let query_hash = QueryHash::new(&query_embedding, config.k, &config_str); // Check cache first for performance if let Some(cached_results) = self.cache_manager.get_query_results(&query_hash) { tracing::debug!("Cache hit for query"); return self.convert_to_contextual_results(cached_results, &query_embedding, &config).await; } let start_time = Instant::now(); // Choose index based on precision/recall tradeoff let use_exact = config.precision_recall_tradeoff > 0.8; let raw_results = if use_exact { self.search_exact_index(&query_embedding, config.k * 2).await? } else { self.search_approximate_index(&query_embedding, config.k * 2).await? }; // Cache the raw results for future use self.cache_manager.cache_query_results(query_hash, raw_results.clone()); // Apply contextual ranking and filtering let mut contextual_results = self.apply_contextual_ranking( &query_embedding, raw_results, &config, ).await?; // Apply clustering-based filtering if enabled if config.enable_clustering { contextual_results = self.apply_clustering_filter( contextual_results, &config, ).await?; } // Sort by final score and limit results contextual_results.sort_by(|a, b| { b.final_score.partial_cmp(&a.final_score) .unwrap_or(std::cmp::Ordering::Equal) }); contextual_results.truncate(config.k); tracing::debug!( "Single similarity search completed in {:?}", start_time.elapsed() ); Ok(contextual_results) } async fn convert_to_contextual_results( &self, raw_results: Vec<(NodeId, f32)>, query_embedding: &[f32], config: &SearchConfig, ) -> Result<Vec<ContextualSearchResult>> { let contextual_results = self.apply_contextual_ranking( query_embedding, raw_results, config, ).await?; Ok(contextual_results) } async fn search_exact_index(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> { let index_guard = self.exact_index.read(); let index = index_guard .as_ref() .ok_or_else(|| CodeGraphError::Vector("Exact index not initialized".to_string()))?; self.perform_faiss_search(index, query, k).await } async fn search_approximate_index(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> { let index_guard = self.approximate_index.read(); let index = index_guard .as_ref() .ok_or_else(|| CodeGraphError::Vector("Approximate index not initialized".to_string()))?; self.perform_faiss_search(index, query, k).await } async fn perform_faiss_search( &self, index: &IndexImpl, query: &[f32], k: usize, ) -> Result<Vec<(NodeId, f32)>> { let results = tokio::task::spawn_blocking({ let query = query.to_vec(); let index_ptr = index as *const IndexImpl; move || unsafe { let index_ref = &*index_ptr; index_ref.search(&query, k) } }) .await .map_err(|e| CodeGraphError::Vector(e.to_string()))? .map_err(|e| CodeGraphError::Vector(e.to_string()))?; let mut node_results = Vec::with_capacity(k); for (i, &faiss_id) in results.labels.iter().enumerate() { if let Some(node_id) = self.id_mapping.get(&faiss_id) { let score = results.distances[i]; node_results.push((*node_id, score)); } } Ok(node_results) } async fn apply_contextual_ranking( &self, query_embedding: &[f32], raw_results: Vec<(NodeId, f32)>, config: &SearchConfig, ) -> Result<Vec<ContextualSearchResult>> { let query_context = self.extract_query_context(query_embedding).await?; let contextual_results: Vec<_> = if raw_results.len() <= 64 { raw_results .iter() .filter_map(|&(node_id, similarity_score)| { let node = self.node_metadata.get(&node_id)?; let context_score = self.calculate_context_score(&node, &query_context, config); let final_score = similarity_score * (1.0 - config.context_weight) + context_score * config.context_weight; let explanation = format!( "Similarity: {:.3}, Context: {:.3}, Language boost: {}, Type boost: {}", similarity_score, context_score, node.language.as_ref().map_or("none", |_| "applied"), node.node_type.as_ref().map_or("none", |_| "applied") ); Some(ContextualSearchResult { node_id, similarity_score, context_score, final_score, node: Some(node.clone()), cluster_id: self.node_to_cluster.get(&node_id).map(|v| *v), explanation, }) }) .collect() } else { raw_results .par_iter() .filter_map(|&(node_id, similarity_score)| { let node = self.node_metadata.get(&node_id)?; let context_score = self.calculate_context_score(&node, &query_context, config); let final_score = similarity_score * (1.0 - config.context_weight) + context_score * config.context_weight; let explanation = format!( "Similarity: {:.3}, Context: {:.3}, Language boost: {}, Type boost: {}", similarity_score, context_score, node.language.as_ref().map_or("none", |_| "applied"), node.node_type.as_ref().map_or("none", |_| "applied") ); Some(ContextualSearchResult { node_id, similarity_score, context_score, final_score, node: Some(node.clone()), cluster_id: self.node_to_cluster.get(&node_id).map(|v| *v), explanation, }) }) .collect() }; Ok(contextual_results) } async fn extract_query_context(&self, _query_embedding: &[f32]) -> Result<QueryContext> { // In a real implementation, this would analyze the query embedding // to infer likely language, node type, etc. Ok(QueryContext { inferred_language: None, inferred_node_type: None, complexity_score: 0.5, }) } #[inline] fn calculate_context_score( &self, node: &CodeNode, query_context: &QueryContext, config: &SearchConfig, ) -> f32 { let mut score = 1.0; // Language boost if let (Some(node_lang), Some(query_lang)) = (&node.language, &query_context.inferred_language) { if node_lang == query_lang { score += config.language_boost; } } // Node type boost if let (Some(node_type), Some(query_type)) = (&node.node_type, &query_context.inferred_node_type) { if node_type == query_type { score += config.type_boost; } } // Complexity matching let complexity_diff = (node.complexity.unwrap_or(0.5) - query_context.complexity_score).abs(); score *= 1.0 - complexity_diff * 0.2; score.min(2.0).max(0.0) } async fn apply_clustering_filter( &self, mut results: Vec<ContextualSearchResult>, config: &SearchConfig, ) -> Result<Vec<ContextualSearchResult>> { if results.is_empty() { return Ok(results); } // Group results by clusters let mut cluster_groups: HashMap<Option<usize>, Vec<ContextualSearchResult>> = HashMap::new(); for result in results { let cluster_id = result.cluster_id; cluster_groups.entry(cluster_id).or_default().push(result); } // Select best representatives from each cluster let mut filtered_results = Vec::new(); for (cluster_id, mut cluster_results) in cluster_groups { cluster_results.sort_by(|a, b| { b.final_score.partial_cmp(&a.final_score) .unwrap_or(std::cmp::Ordering::Equal) }); // Take top results from each cluster let take_count = if cluster_id.is_some() { (config.k / self.clusters.read().len().max(1)).max(1) } else { cluster_results.len() }; filtered_results.extend(cluster_results.into_iter().take(take_count)); } Ok(filtered_results) } async fn build_clusters(&self, nodes: &[&CodeNode]) -> Result<()> { let embeddings: Vec<_> = nodes .iter() .filter_map(|node| { node.embedding.as_ref().map(|emb| (node.id, emb.clone(), node)) }) .collect(); if embeddings.len() < 2 { return Ok(()); } let clusters = self.perform_clustering(&embeddings).await?; // Update cluster mappings let mut clusters_guard = self.clusters.write(); *clusters_guard = clusters; for (i, cluster) in clusters_guard.iter().enumerate() { for &node_id in &cluster.nodes { self.node_to_cluster.insert(node_id, i); } } tracing::info!("Built {} clusters", clusters_guard.len()); Ok(()) } async fn perform_clustering( &self, embeddings: &[(NodeId, Vec<f32>, &CodeNode)], ) -> Result<Vec<CodeCluster>> { let k = (embeddings.len() as f32).sqrt() as usize; let k = k.max(2).min(embeddings.len() / 2); // Simple k-means clustering implementation let mut clusters = Vec::new(); let mut assignments = vec![0usize; embeddings.len()]; // Initialize centroids randomly for i in 0..k { let idx = i * embeddings.len() / k; clusters.push(CodeCluster { id: i, centroid: embeddings[idx].1.clone(), nodes: HashSet::new(), language: None, node_type: None, avg_similarity: 0.0, }); } // Iterate until convergence for _iteration in 0..20 { // Assign points to nearest centroids for (point_idx, (node_id, embedding, _)) in embeddings.iter().enumerate() { let mut best_cluster = 0; let mut best_distance = f32::MAX; for (cluster_idx, cluster) in clusters.iter().enumerate() { let distance = self.squared_distance(embedding, &cluster.centroid); if distance < best_distance { best_distance = distance; best_cluster = cluster_idx; } } assignments[point_idx] = best_cluster; } // Update centroids for cluster in &mut clusters { cluster.nodes.clear(); } let mut cluster_sums: Vec<Vec<f32>> = clusters.iter() .map(|c| vec![0.0; self.dimension]) .collect(); let mut cluster_counts = vec![0; clusters.len()]; for (point_idx, (node_id, embedding, node)) in embeddings.iter().enumerate() { let cluster_idx = assignments[point_idx]; clusters[cluster_idx].nodes.insert(*node_id); for (i, &val) in embedding.iter().enumerate() { cluster_sums[cluster_idx][i] += val; } cluster_counts[cluster_idx] += 1; // Update cluster metadata if clusters[cluster_idx].language.is_none() { clusters[cluster_idx].language = node.language.clone(); } if clusters[cluster_idx].node_type.is_none() { clusters[cluster_idx].node_type = node.node_type.clone(); } } // Update centroids for (cluster_idx, cluster) in clusters.iter_mut().enumerate() { if cluster_counts[cluster_idx] > 0 { for (i, sum) in cluster_sums[cluster_idx].iter().enumerate() { cluster.centroid[i] = sum / cluster_counts[cluster_idx] as f32; } } } } Ok(clusters) } #[inline(always)] fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 { let len = a.len().min(b.len()); let mut acc = 0.0f32; let mut i = 0; while i + 4 <= len { let d0 = a[i] - b[i]; let d1 = a[i + 1] - b[i + 1]; let d2 = a[i + 2] - b[i + 2]; let d3 = a[i + 3] - b[i + 3]; acc = d0.mul_add(d0, acc); acc = d1.mul_add(d1, acc); acc = d2.mul_add(d2, acc); acc = d3.mul_add(d3, acc); i += 4; } while i < len { let d = a[i] - b[i]; acc = d.mul_add(d, acc); i += 1; } acc.sqrt() } #[inline(always)] fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 { let len = a.len().min(b.len()); let mut acc = 0.0f32; let mut i = 0; while i + 4 <= len { let d0 = a[i] - b[i]; let d1 = a[i + 1] - b[i + 1]; let d2 = a[i + 2] - b[i + 2]; let d3 = a[i + 3] - b[i + 3]; acc = d0.mul_add(d0, acc); acc = d1.mul_add(d1, acc); acc = d2.mul_add(d2, acc); acc = d3.mul_add(d3, acc); i += 4; } while i < len { let d = a[i] - b[i]; acc = d.mul_add(d, acc); i += 1; } acc } pub async fn get_cluster_info(&self) -> Vec<ClusterInfo> { let clusters_guard = self.clusters.read(); clusters_guard .iter() .map(|cluster| ClusterInfo { id: cluster.id, size: cluster.nodes.len(), language: cluster.language.clone(), node_type: cluster.node_type.clone(), avg_similarity: cluster.avg_similarity, }) .collect() } pub fn get_performance_stats(&self) -> PerformanceStats { let cache_stats = self.cache_manager.get_cache_stats(); PerformanceStats { total_nodes: self.embeddings.len(), total_clusters: self.clusters.read().len(), index_type: if self.config.precision_recall_tradeoff > 0.8 { "Exact".to_string() } else { "Approximate".to_string() }, max_parallel_queries: self.config.max_parallel_queries, cache_hit_rate: cache_stats.get("query_cache") .map(|stats| stats.hit_ratio) .unwrap_or(0.0), avg_search_latency_ms: 0.0, // Would be updated by metrics collection } } pub fn clear_caches(&self) { self.cache_manager.clear_all(); } pub async fn warmup_cache(&self, sample_nodes: &[NodeId]) -> Result<()> { tracing::info!("Starting cache warmup with {} sample nodes", sample_nodes.len()); let warmup_config = SearchConfig { k: 5, precision_recall_tradeoff: 0.5, enable_clustering: false, ..self.config.clone() }; for &node_id in sample_nodes { if let Some(embedding) = self.embeddings.get(&node_id) { let _ = self.single_similarity_search(embedding.clone(), warmup_config.clone()).await; } } tracing::info!("Cache warmup completed"); Ok(()) } pub async fn optimize_index_configuration(&mut self) -> Result<()> { let node_count = self.embeddings.len(); let avg_query_latency = self.measure_average_query_latency().await?; // Adjust precision/recall tradeoff based on performance if avg_query_latency > Duration::from_millis(100) && node_count > 10000 { tracing::info!("High latency detected, switching to more approximate index"); self.config.precision_recall_tradeoff *= 0.8; } else if avg_query_latency < Duration::from_millis(20) && node_count < 1000 { tracing::info!("Low latency detected, switching to more exact index"); self.config.precision_recall_tradeoff = (self.config.precision_recall_tradeoff * 1.2).min(1.0); } Ok(()) } async fn measure_average_query_latency(&self) -> Result<Duration> { let sample_size = 10.min(self.embeddings.len()); let mut total_latency = Duration::default(); let sample_embeddings: Vec<_> = self.embeddings .iter() .take(sample_size) .map(|entry| entry.value().clone()) .collect(); for embedding in sample_embeddings { let start = Instant::now(); let _ = self.single_similarity_search(embedding, self.config.clone()).await?; total_latency += start.elapsed(); } Ok(total_latency / sample_size as u32) } } #[derive(Debug, Clone)] struct QueryContext { inferred_language: Option<Language>, inferred_node_type: Option<NodeType>, complexity_score: f32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClusterInfo { pub id: usize, pub size: usize, pub language: Option<Language>, pub node_type: Option<NodeType>, pub avg_similarity: f32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PerformanceStats { pub total_nodes: usize, pub total_clusters: usize, pub index_type: String, pub max_parallel_queries: usize, pub cache_hit_rate: f64, pub avg_search_latency_ms: f64, } // Batch search operations for high-throughput scenarios impl OptimizedKnnEngine { pub async fn batch_search_similar_functions( &self, function_nodes: &[CodeNode], config: Option<SearchConfig>, ) -> Result<Vec<Vec<ContextualSearchResult>>> { let search_config = config.unwrap_or_else(|| self.config.clone()); let queries: Result<Vec<_>> = function_nodes .iter() .filter(|node| matches!(node.node_type, Some(NodeType::Function))) .map(|node| { node.embedding .clone() .ok_or_else(|| CodeGraphError::Vector("Node missing embedding".to_string())) }) .collect(); let queries = queries?; self.parallel_similarity_search(queries, Some(search_config)).await } pub async fn discover_related_code_clusters( &self, seed_nodes: &[NodeId], expansion_factor: usize, ) -> Result<HashMap<usize, Vec<ContextualSearchResult>>> { let mut cluster_results: HashMap<usize, Vec<ContextualSearchResult>> = HashMap::new(); for &seed_node in seed_nodes { if let Some(embedding) = self.embeddings.get(&seed_node) { let results = self.single_similarity_search( embedding.clone(), SearchConfig { k: expansion_factor, enable_clustering: true, ..self.config.clone() } ).await?; for result in results { if let Some(cluster_id) = result.cluster_id { cluster_results.entry(cluster_id).or_default().push(result); } } } } Ok(cluster_results) } }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Jakedismo/codegraph-rust'

If you have feedback or need assistance with the MCP directory API, please join our Discord server