loaders.rs•13.9 kB
use async_graphql::dataloader::{DataLoader, Loader};
use async_trait::async_trait;
use codegraph_core::{CodeGraphError, CodeNode, NodeId};
use codegraph_vector::rag::{ContextRetriever, RetrievalResult};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{debug, instrument, warn};
use uuid::Uuid;
use crate::graphql::types::{GraphQLCodeNode, GraphQLEdge};
use crate::state::AppState;
/// DataLoader for efficient batch loading of code nodes by ID
pub struct NodeLoader {
pub state: Arc<AppState>,
}
#[async_trait]
impl Loader<NodeId> for NodeLoader {
type Value = GraphQLCodeNode;
type Error = CodeGraphError;
#[instrument(skip(self, keys), fields(batch_size = keys.len()))]
async fn load(&self, keys: &[NodeId]) -> std::result::Result<HashMap<NodeId, GraphQLCodeNode>, CodeGraphError> {
let start_time = Instant::now();
debug!("Loading batch of {} nodes", keys.len());
// Deduplicate keys to avoid redundant database queries
let unique_keys: HashSet<NodeId> = keys.iter().cloned().collect();
let mut result_map = HashMap::new();
// In a real implementation, this would be a batch query to RocksDB
// For now, we'll simulate batch loading
let graph = self.state.graph.read().await;
for &node_id in &unique_keys {
// Simulate database batch query - in real implementation this would be optimized
// to fetch multiple nodes in a single RocksDB batch operation
if let Some(node) = self.simulate_node_fetch(node_id).await? {
result_map.insert(node_id, node.into());
}
}
let elapsed = start_time.elapsed();
debug!(
"Loaded {} nodes in {}ms",
result_map.len(),
elapsed.as_millis()
);
// Log performance warning if batch takes too long
if elapsed.as_millis() > 50 {
warn!(
"Node batch loading took {}ms (>50ms threshold)",
elapsed.as_millis()
);
}
Ok(result_map)
}
}
impl NodeLoader {
pub fn new(state: Arc<AppState>) -> Self {
Self { state }
}
async fn simulate_node_fetch(
&self,
node_id: NodeId,
) -> std::result::Result<Option<CodeNode>, CodeGraphError> {
// Simulate fetching from database
// In real implementation, this would be a RocksDB get operation
// For demonstration, we create mock nodes
let now = chrono::Utc::now();
let node = CodeNode {
id: node_id,
name: format!("node_{}", node_id.simple().to_string()[..8].to_string()),
node_type: Some(codegraph_core::NodeType::Function),
language: Some(codegraph_core::Language::Rust),
content: Some("fn example() {}".to_string()),
embedding: None,
location: codegraph_core::Location {
file_path: "test.rs".to_string(),
line: 1,
column: 1,
end_line: None,
end_column: None,
},
metadata: codegraph_core::Metadata {
attributes: HashMap::new(),
created_at: now,
updated_at: now,
},
complexity: Some(1.0),
};
Ok(Some(node))
}
}
/// DataLoader for batch loading edges by source node ID
pub struct EdgesBySourceLoader {
pub state: Arc<AppState>,
}
#[async_trait]
impl Loader<NodeId> for EdgesBySourceLoader {
type Value = Vec<GraphQLEdge>;
type Error = CodeGraphError;
#[instrument(skip(self, keys), fields(batch_size = keys.len()))]
async fn load(&self, keys: &[NodeId]) -> std::result::Result<HashMap<NodeId, Vec<GraphQLEdge>>, CodeGraphError> {
let start_time = Instant::now();
debug!("Loading edges for {} source nodes", keys.len());
let unique_keys: HashSet<NodeId> = keys.iter().cloned().collect();
let mut result_map = HashMap::new();
// Batch query for edges - in real implementation this would be optimized
for &source_id in &unique_keys {
let edges = self.simulate_edge_fetch(source_id).await?;
result_map.insert(source_id, edges);
}
let elapsed = start_time.elapsed();
debug!(
"Loaded edges for {} sources in {}ms",
result_map.len(),
elapsed.as_millis()
);
if elapsed.as_millis() > 50 {
warn!(
"Edge batch loading took {}ms (>50ms threshold)",
elapsed.as_millis()
);
}
Ok(result_map)
}
}
impl EdgesBySourceLoader {
pub fn new(state: Arc<AppState>) -> Self {
Self { state }
}
async fn simulate_edge_fetch(
&self,
source_id: NodeId,
) -> std::result::Result<Vec<GraphQLEdge>, CodeGraphError> {
// Simulate fetching edges from database
let now = chrono::Utc::now();
let edges = vec![GraphQLEdge {
id: async_graphql::ID(Uuid::new_v4().to_string()),
source_id: async_graphql::ID(source_id.to_string()),
target_id: async_graphql::ID(Uuid::new_v4().to_string()),
edge_type: crate::graphql::types::GraphQLEdgeType::Calls,
weight: Some(1.0),
attributes: HashMap::new(),
created_at: now,
}];
Ok(edges)
}
}
/// DataLoader for semantic search results
pub struct SemanticSearchLoader {
pub state: Arc<AppState>,
}
#[async_trait]
impl Loader<String> for SemanticSearchLoader {
type Value = Vec<RetrievalResult>;
type Error = CodeGraphError;
#[instrument(skip(self, keys), fields(batch_size = keys.len()))]
async fn load(&self, keys: &[String]) -> std::result::Result<HashMap<String, Vec<RetrievalResult>>, CodeGraphError> {
let start_time = Instant::now();
debug!("Semantic search batch for {} queries", keys.len());
let mut result_map = HashMap::new();
// Batch semantic search - leverage the RAG system's context retriever
for query in keys {
let results = self
.state
.semantic_search
.search_by_text(query, 10)
.await
.map_err(|_| CodeGraphError::Vector("Semantic search failed".to_string()))?;
// Convert semantic search results to retrieval results
let retrieval_results: Vec<RetrievalResult> = results
.into_iter()
.map(|res| {
let snippet = res
.node
.as_ref()
.and_then(|n| n.content.as_ref())
.map(|s| {
let s = s.as_str();
s[..100.min(s.len())].to_string()
})
.unwrap_or_default();
RetrievalResult {
node_id: res.node_id,
node: res.node,
relevance_score: res.score,
retrieval_method: codegraph_vector::rag::RetrievalMethod::SemanticSimilarity,
context_snippet: snippet,
}
})
.collect();
result_map.insert(query.clone(), retrieval_results);
}
let elapsed = start_time.elapsed();
debug!(
"Semantic search batch completed in {}ms",
elapsed.as_millis()
);
if elapsed.as_millis() > 100 {
warn!(
"Semantic search batch took {}ms (>100ms threshold)",
elapsed.as_millis()
);
}
Ok(result_map)
}
}
impl SemanticSearchLoader {
pub fn new(state: Arc<AppState>) -> Self {
Self { state }
}
}
/// DataLoader for graph traversal caching
pub struct GraphTraversalLoader {
pub state: Arc<AppState>,
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct TraversalKey {
pub start_node: NodeId,
pub max_depth: i32,
pub edge_types: Vec<String>, // Serialized edge types for hashing
pub direction: String,
}
#[async_trait]
impl Loader<TraversalKey> for GraphTraversalLoader {
type Value = Vec<GraphQLCodeNode>;
type Error = CodeGraphError;
#[instrument(skip(self, keys), fields(batch_size = keys.len()))]
async fn load(
&self,
keys: &[TraversalKey],
) -> std::result::Result<HashMap<TraversalKey, Vec<GraphQLCodeNode>>, CodeGraphError> {
let start_time = Instant::now();
debug!("Graph traversal batch for {} queries", keys.len());
let mut result_map = HashMap::new();
for key in keys {
// Perform graph traversal - in real implementation this would use
// the actual graph data structure with optimized traversal algorithms
let traversal_nodes = self.simulate_traversal(key).await?;
result_map.insert(key.clone(), traversal_nodes);
}
let elapsed = start_time.elapsed();
debug!(
"Graph traversal batch completed in {}ms",
elapsed.as_millis()
);
if elapsed.as_millis() > 200 {
warn!(
"Graph traversal batch took {}ms (>200ms threshold)",
elapsed.as_millis()
);
}
Ok(result_map)
}
}
impl GraphTraversalLoader {
pub fn new(state: Arc<AppState>) -> Self {
Self { state }
}
async fn simulate_traversal(
&self,
key: &TraversalKey,
) -> std::result::Result<Vec<GraphQLCodeNode>, CodeGraphError> {
// Simulate graph traversal with mock data
let now = chrono::Utc::now();
let mut nodes = vec![];
for i in 0..key.max_depth.min(5) as usize {
let node = GraphQLCodeNode {
id: async_graphql::ID(Uuid::new_v4().to_string()),
name: format!("traversed_node_{}", i),
node_type: Some(crate::graphql::types::GraphQLNodeType::Function),
language: Some(crate::graphql::types::GraphQLLanguage::Rust),
location: crate::graphql::types::GraphQLLocation {
file_path: "traversal.rs".to_string(),
line: i as u32 + 1,
column: 1,
end_line: None,
end_column: None,
},
content: Some(format!("fn traversed_function_{}() {{}}", i)),
complexity: Some(1.0 + i as f32 * 0.1),
created_at: now,
updated_at: now,
attributes: HashMap::new(),
};
nodes.push(node);
}
Ok(nodes)
}
}
/// Centralized DataLoader factory for creating all loaders with shared state
pub struct LoaderFactory {
state: Arc<AppState>,
}
impl LoaderFactory {
pub fn new(state: Arc<AppState>) -> Self {
Self { state }
}
pub fn create_node_loader(&self) -> DataLoader<NodeLoader> {
DataLoader::new(NodeLoader::new(self.state.clone()), tokio::spawn)
.max_batch_size(100) // Batch up to 100 nodes at once
.delay(std::time::Duration::from_millis(1)) // 1ms delay for batching
}
pub fn create_edges_loader(&self) -> DataLoader<EdgesBySourceLoader> {
DataLoader::new(EdgesBySourceLoader::new(self.state.clone()), tokio::spawn)
.max_batch_size(50)
.delay(std::time::Duration::from_millis(1))
}
pub fn create_semantic_search_loader(&self) -> DataLoader<SemanticSearchLoader> {
DataLoader::new(SemanticSearchLoader::new(self.state.clone()), tokio::spawn)
.max_batch_size(20) // Semantic search is more expensive
.delay(std::time::Duration::from_millis(5)) // Slightly longer delay for batching
}
pub fn create_traversal_loader(&self) -> DataLoader<GraphTraversalLoader> {
DataLoader::new(GraphTraversalLoader::new(self.state.clone()), tokio::spawn)
.max_batch_size(10) // Traversals are expensive
.delay(std::time::Duration::from_millis(2))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::AppState;
#[tokio::test]
async fn test_node_loader_batch() {
let state = Arc::new(AppState::new().await.unwrap());
let loader = NodeLoader::new(state);
let node_ids = vec![Uuid::new_v4(), Uuid::new_v4(), Uuid::new_v4()];
let result = loader.load(&node_ids).await;
assert!(result.is_ok());
let nodes = result.unwrap();
assert!(nodes.len() <= node_ids.len()); // May be less due to filtering
}
#[tokio::test]
async fn test_semantic_search_loader_batch() {
let state = Arc::new(AppState::new().await.unwrap());
let loader = SemanticSearchLoader::new(state);
let queries = vec![
"function implementation".to_string(),
"error handling".to_string(),
];
let result = loader.load(&queries).await;
assert!(result.is_ok());
let search_results = result.unwrap();
assert_eq!(search_results.len(), queries.len());
}
#[tokio::test]
async fn test_loader_factory() {
let state = Arc::new(AppState::new().await.unwrap());
let factory = LoaderFactory::new(state);
let _node_loader = factory.create_node_loader();
let _edges_loader = factory.create_edges_loader();
let _semantic_loader = factory.create_semantic_search_loader();
let _traversal_loader = factory.create_traversal_loader();
// Test that loaders are created successfully
assert!(true);
}
}