// ABOUTME: LLM tool executor for SurrealDB graph analysis functions
// ABOUTME: Executes graph analysis tools by calling Rust SDK wrappers with validated parameters
use codegraph_graph::GraphFunctions;
use lru::LruCache;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value as JsonValue};
use std::num::NonZeroUsize;
use std::sync::Arc;
use tracing::{debug, info};
use crate::error::McpError;
use crate::graph_tool_schemas::GraphToolSchemas;
use crate::Result;
/// Statistics about LRU cache performance
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CacheStats {
/// Number of cache hits (successful lookups)
pub hits: u64,
/// Number of cache misses (lookups that required SurrealDB call)
pub misses: u64,
/// Number of entries evicted due to LRU policy
pub evictions: u64,
/// Current number of entries in cache
pub current_size: usize,
/// Maximum cache size (capacity)
pub max_size: usize,
}
impl CacheStats {
/// Calculate cache hit rate as percentage
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
(self.hits as f64 / total as f64) * 100.0
}
}
}
/// Executor for graph analysis tools
/// Receives tool calls from LLM and executes appropriate SurrealDB functions
pub struct GraphToolExecutor {
graph_functions: Arc<GraphFunctions>,
/// LRU cache for tool results (function_name + params → result)
cache: Arc<Mutex<LruCache<String, JsonValue>>>,
/// Cache statistics for observability
cache_stats: Arc<Mutex<CacheStats>>,
/// Whether caching is enabled
cache_enabled: bool,
}
impl GraphToolExecutor {
/// Create a new tool executor with GraphFunctions instance
pub fn new(graph_functions: Arc<GraphFunctions>) -> Self {
Self::with_cache(graph_functions, true, 100)
}
/// Create a new tool executor with custom cache configuration
pub fn with_cache(
graph_functions: Arc<GraphFunctions>,
cache_enabled: bool,
cache_size: usize,
) -> Self {
let capacity = NonZeroUsize::new(cache_size).unwrap_or(NonZeroUsize::new(100).unwrap());
let cache = Arc::new(Mutex::new(LruCache::new(capacity)));
let cache_stats = Arc::new(Mutex::new(CacheStats {
hits: 0,
misses: 0,
evictions: 0,
current_size: 0,
max_size: cache_size,
}));
Self {
graph_functions,
cache,
cache_stats,
cache_enabled,
}
}
/// Get current cache statistics
pub fn cache_stats(&self) -> CacheStats {
self.cache_stats.lock().clone()
}
/// Clear the cache and reset statistics
pub fn clear_cache(&self) {
let mut cache = self.cache.lock();
cache.clear();
let mut stats = self.cache_stats.lock();
stats.hits = 0;
stats.misses = 0;
stats.evictions = 0;
stats.current_size = 0;
}
/// Generate a cache key from tool name and parameters
fn cache_key(tool_name: &str, parameters: &JsonValue) -> String {
// Create deterministic key from function name + serialized params
format!("{}:{}", tool_name, parameters.to_string())
}
/// Execute a tool call from LLM
///
/// # Arguments
/// * `tool_name` - Name of the tool to execute
/// * `parameters` - JSON parameters for the tool
///
/// # Returns
/// JSON result from the tool execution
pub async fn execute(&self, tool_name: &str, parameters: JsonValue) -> Result<JsonValue> {
info!("Executing graph tool: {}", tool_name);
debug!("Tool parameters: {}", parameters);
// Validate tool exists
let _schema = GraphToolSchemas::get_by_name(tool_name)
.ok_or_else(|| McpError::Protocol(format!("Unknown tool: {}", tool_name)))?;
// Check cache if enabled
if self.cache_enabled {
let cache_key = Self::cache_key(tool_name, ¶meters);
// Try cache lookup
{
let mut cache = self.cache.lock();
if let Some(cached_result) = cache.get(&cache_key) {
// Cache hit
let mut stats = self.cache_stats.lock();
stats.hits += 1;
debug!("Cache hit for {}: {}", tool_name, cache_key);
return Ok(cached_result.clone());
}
}
// Cache miss - record it
{
let mut stats = self.cache_stats.lock();
stats.misses += 1;
}
debug!("Cache miss for {}: {}", tool_name, cache_key);
}
// Execute based on tool name
let result = match tool_name {
"get_transitive_dependencies" => {
self.execute_get_transitive_dependencies(parameters.clone())
.await?
}
"detect_circular_dependencies" => {
self.execute_detect_circular_dependencies(parameters.clone())
.await?
}
"trace_call_chain" => self.execute_trace_call_chain(parameters.clone()).await?,
"calculate_coupling_metrics" => {
self.execute_calculate_coupling_metrics(parameters.clone())
.await?
}
"get_hub_nodes" => self.execute_get_hub_nodes(parameters.clone()).await?,
"get_reverse_dependencies" => {
self.execute_get_reverse_dependencies(parameters.clone())
.await?
}
_ => {
return Err(
McpError::Protocol(format!("Tool not implemented: {}", tool_name)).into(),
)
}
};
// Cache the result if enabled
if self.cache_enabled {
let cache_key = Self::cache_key(tool_name, ¶meters);
let mut cache = self.cache.lock();
let was_evicted = cache.len() >= cache.cap().get();
cache.put(cache_key, result.clone());
// Update stats
let mut stats = self.cache_stats.lock();
if was_evicted {
stats.evictions += 1;
}
stats.current_size = cache.len();
}
info!("Tool execution complete: {}", tool_name);
Ok(result)
}
/// Execute get_transitive_dependencies
async fn execute_get_transitive_dependencies(&self, params: JsonValue) -> Result<JsonValue> {
let node_id = params["node_id"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing node_id".to_string()))?;
let edge_type = params["edge_type"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing edge_type".to_string()))?;
let depth = params["depth"].as_i64().unwrap_or(3) as i32;
let result = self
.graph_functions
.get_transitive_dependencies(node_id, edge_type, depth)
.await
.map_err(|e| {
McpError::Protocol(format!("get_transitive_dependencies failed: {}", e))
})?;
Ok(json!({
"tool": "get_transitive_dependencies",
"parameters": {
"node_id": node_id,
"edge_type": edge_type,
"depth": depth
},
"result": result
}))
}
/// Execute detect_circular_dependencies
async fn execute_detect_circular_dependencies(&self, params: JsonValue) -> Result<JsonValue> {
let edge_type = params["edge_type"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing edge_type".to_string()))?;
let result = self
.graph_functions
.detect_circular_dependencies(edge_type)
.await
.map_err(|e| {
McpError::Protocol(format!("detect_circular_dependencies failed: {}", e))
})?;
Ok(json!({
"tool": "detect_circular_dependencies",
"parameters": {
"edge_type": edge_type
},
"result": result
}))
}
/// Execute trace_call_chain
async fn execute_trace_call_chain(&self, params: JsonValue) -> Result<JsonValue> {
let from_node = params["from_node"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing from_node".to_string()))?;
let max_depth = params["max_depth"].as_i64().unwrap_or(5) as i32;
let result = self
.graph_functions
.trace_call_chain(from_node, max_depth)
.await
.map_err(|e| McpError::Protocol(format!("trace_call_chain failed: {}", e)))?;
Ok(json!({
"tool": "trace_call_chain",
"parameters": {
"from_node": from_node,
"max_depth": max_depth
},
"result": result
}))
}
/// Execute calculate_coupling_metrics
async fn execute_calculate_coupling_metrics(&self, params: JsonValue) -> Result<JsonValue> {
let node_id = params["node_id"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing node_id".to_string()))?;
let result = self
.graph_functions
.calculate_coupling_metrics(node_id)
.await
.map_err(|e| McpError::Protocol(format!("calculate_coupling_metrics failed: {}", e)))?;
Ok(json!({
"tool": "calculate_coupling_metrics",
"parameters": {
"node_id": node_id
},
"result": result
}))
}
/// Execute get_hub_nodes
async fn execute_get_hub_nodes(&self, params: JsonValue) -> Result<JsonValue> {
let min_degree = params["min_degree"].as_i64().unwrap_or(5) as i32;
let result = self
.graph_functions
.get_hub_nodes(min_degree)
.await
.map_err(|e| McpError::Protocol(format!("get_hub_nodes failed: {}", e)))?;
Ok(json!({
"tool": "get_hub_nodes",
"parameters": {
"min_degree": min_degree
},
"result": result
}))
}
/// Execute get_reverse_dependencies
async fn execute_get_reverse_dependencies(&self, params: JsonValue) -> Result<JsonValue> {
let node_id = params["node_id"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing node_id".to_string()))?;
let edge_type = params["edge_type"]
.as_str()
.ok_or_else(|| McpError::Protocol("Missing edge_type".to_string()))?;
let depth = params["depth"].as_i64().unwrap_or(3) as i32;
let result = self
.graph_functions
.get_reverse_dependencies(node_id, edge_type, depth)
.await
.map_err(|e| McpError::Protocol(format!("get_reverse_dependencies failed: {}", e)))?;
Ok(json!({
"tool": "get_reverse_dependencies",
"parameters": {
"node_id": node_id,
"edge_type": edge_type,
"depth": depth
},
"result": result
}))
}
/// Get all available tool schemas for registration
pub fn get_tool_schemas() -> Vec<crate::ToolSchema> {
GraphToolSchemas::all()
}
/// Get tool names for listing
pub fn get_tool_names() -> Vec<String> {
GraphToolSchemas::tool_names()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_schemas_available() {
let schemas = GraphToolExecutor::get_tool_schemas();
assert_eq!(schemas.len(), 6);
}
#[test]
fn test_tool_names() {
let names = GraphToolExecutor::get_tool_names();
assert_eq!(names.len(), 6);
assert!(names.contains(&"get_transitive_dependencies".to_string()));
}
#[test]
fn test_parameter_extraction() {
let params = json!({
"node_id": "nodes:123",
"edge_type": "Calls",
"depth": 5
});
assert_eq!(params["node_id"].as_str().unwrap(), "nodes:123");
assert_eq!(params["edge_type"].as_str().unwrap(), "Calls");
assert_eq!(params["depth"].as_i64().unwrap(), 5);
}
// === Cache Tests ===
#[test]
fn test_cache_key_generation() {
let params1 = json!({
"node_id": "nodes:123",
"edge_type": "Calls",
"depth": 3
});
let params2 = json!({
"node_id": "nodes:123",
"edge_type": "Calls",
"depth": 3
});
let params3 = json!({
"node_id": "nodes:456",
"edge_type": "Calls",
"depth": 3
});
let key1 = GraphToolExecutor::cache_key("get_transitive_dependencies", ¶ms1);
let key2 = GraphToolExecutor::cache_key("get_transitive_dependencies", ¶ms2);
let key3 = GraphToolExecutor::cache_key("get_transitive_dependencies", ¶ms3);
// Same params should generate same key
assert_eq!(key1, key2);
// Different params should generate different key
assert_ne!(key1, key3);
}
#[test]
fn test_cache_stats_initialization() {
let stats = CacheStats {
hits: 0,
misses: 0,
evictions: 0,
current_size: 0,
max_size: 100,
};
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_cache_stats_hit_rate_calculation() {
let stats = CacheStats {
hits: 75,
misses: 25,
evictions: 5,
current_size: 50,
max_size: 100,
};
assert_eq!(stats.hit_rate(), 75.0);
}
#[test]
fn test_cache_stats_hit_rate_no_requests() {
let stats = CacheStats {
hits: 0,
misses: 0,
evictions: 0,
current_size: 0,
max_size: 100,
};
assert_eq!(stats.hit_rate(), 0.0);
}
}