embedding_provider_tests.rs•11.1 kB
use codegraph_core::{CodeNode, Language, Location, NodeType};
use codegraph_vector::{BatchConfig, EmbeddingProvider, FallbackStrategy, HybridEmbeddingPipeline};
use std::time::Duration;
/// Create a test CodeNode for embedding tests
fn create_test_node(name: &str, content: Option<String>) -> CodeNode {
let mut node = CodeNode::new(
name.to_string(),
Some(NodeType::Function),
Some(Language::Rust),
Location {
file_path: "test.rs".to_string(),
line: 1,
column: 1,
end_line: Some(1),
end_column: Some(10),
},
);
if let Some(content) = content {
node = node.with_content(content);
}
node
}
/// Mock embedding provider for testing
struct MockEmbeddingProvider {
name: String,
dimension: usize,
fail_on_generate: bool,
latency: Duration,
}
impl MockEmbeddingProvider {
fn new(name: &str, dimension: usize) -> Self {
Self {
name: name.to_string(),
dimension,
fail_on_generate: false,
latency: Duration::from_millis(10),
}
}
fn with_failure(mut self) -> Self {
self.fail_on_generate = true;
self
}
fn with_latency(mut self, latency: Duration) -> Self {
self.latency = latency;
self
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for MockEmbeddingProvider {
async fn generate_embedding(&self, _node: &CodeNode) -> codegraph_core::Result<Vec<f32>> {
tokio::time::sleep(self.latency).await;
if self.fail_on_generate {
return Err(codegraph_core::CodeGraphError::External(
"Mock provider failure".to_string(),
));
}
// Generate a deterministic but varied embedding based on node name hash
let hash = simple_hash(&_node.name);
let mut embedding = vec![0.0f32; self.dimension];
let mut rng_state = hash;
for i in 0..self.dimension {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
embedding[i] = ((rng_state as f32 / u32::MAX as f32) - 0.5) * 2.0;
}
// L2 normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
Ok(embedding)
}
async fn generate_embeddings(
&self,
nodes: &[CodeNode],
) -> codegraph_core::Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(nodes.len());
for node in nodes {
embeddings.push(self.generate_embedding(node).await?);
}
Ok(embeddings)
}
async fn generate_embeddings_with_config(
&self,
nodes: &[CodeNode],
_config: &BatchConfig,
) -> codegraph_core::Result<(Vec<Vec<f32>>, codegraph_vector::EmbeddingMetrics)> {
let start = std::time::Instant::now();
let embeddings = self.generate_embeddings(nodes).await?;
let duration = start.elapsed();
let metrics =
codegraph_vector::EmbeddingMetrics::new(self.name.clone(), nodes.len(), duration);
Ok((embeddings, metrics))
}
fn embedding_dimension(&self) -> usize {
self.dimension
}
fn provider_name(&self) -> &str {
&self.name
}
async fn is_available(&self) -> bool {
!self.fail_on_generate
}
fn performance_characteristics(&self) -> codegraph_vector::ProviderCharacteristics {
codegraph_vector::ProviderCharacteristics {
expected_throughput: 100.0,
typical_latency: self.latency,
max_batch_size: 32,
supports_streaming: false,
requires_network: false,
memory_usage: codegraph_vector::MemoryUsage::Low,
}
}
}
fn simple_hash(text: &str) -> u32 {
let mut hash = 5381u32;
for byte in text.bytes() {
hash = hash.wrapping_mul(33).wrapping_add(byte as u32);
}
hash
}
#[tokio::test]
async fn test_mock_provider_single_embedding() {
let provider = MockEmbeddingProvider::new("test", 384);
let node = create_test_node("test_function", Some("fn test() {}".to_string()));
let embedding = provider.generate_embedding(&node).await.unwrap();
assert_eq!(embedding.len(), 384);
// Check L2 normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"Embedding should be L2 normalized"
);
}
#[tokio::test]
async fn test_mock_provider_batch_embeddings() {
let provider = MockEmbeddingProvider::new("test", 384);
let nodes = vec![
create_test_node("function1", Some("fn one() {}".to_string())),
create_test_node("function2", Some("fn two() {}".to_string())),
create_test_node("function3", Some("fn three() {}".to_string())),
];
let embeddings = provider.generate_embeddings(&nodes).await.unwrap();
assert_eq!(embeddings.len(), 3);
for embedding in &embeddings {
assert_eq!(embedding.len(), 384);
// Check L2 normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"Embedding should be L2 normalized"
);
}
// Embeddings should be different for different inputs
assert_ne!(embeddings[0], embeddings[1]);
assert_ne!(embeddings[1], embeddings[2]);
}
#[tokio::test]
async fn test_deterministic_embeddings() {
let provider = MockEmbeddingProvider::new("test", 384);
let node = create_test_node("test_function", Some("fn test() {}".to_string()));
let embedding1 = provider.generate_embedding(&node).await.unwrap();
let embedding2 = provider.generate_embedding(&node).await.unwrap();
// Same input should produce same embedding
assert_eq!(embedding1, embedding2);
}
#[tokio::test]
async fn test_batch_config_metrics() {
let provider = MockEmbeddingProvider::new("test", 384);
let nodes = vec![
create_test_node("function1", None),
create_test_node("function2", None),
];
let config = BatchConfig {
batch_size: 10,
max_concurrent: 2,
timeout: Duration::from_secs(5),
retry_attempts: 3,
};
let (embeddings, metrics) = provider
.generate_embeddings_with_config(&nodes, &config)
.await
.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(metrics.texts_processed, 2);
assert_eq!(metrics.provider_name, "test");
assert!(metrics.throughput > 0.0);
}
#[tokio::test]
async fn test_provider_failure_handling() {
let provider = MockEmbeddingProvider::new("failing", 384).with_failure();
let node = create_test_node("test_function", None);
let result = provider.generate_embedding(&node).await;
assert!(result.is_err());
assert!(!provider.is_available().await);
}
#[tokio::test]
async fn test_hybrid_pipeline_primary_success() {
let primary = MockEmbeddingProvider::new("primary", 384);
let fallback = MockEmbeddingProvider::new("fallback", 384);
let pipeline = HybridEmbeddingPipeline::new(Box::new(primary), FallbackStrategy::Sequential)
.add_fallback(Box::new(fallback));
let node = create_test_node("test_function", None);
let embedding = pipeline.generate_embedding(&node).await.unwrap();
assert_eq!(embedding.len(), 384);
assert_eq!(pipeline.provider_name(), "HybridPipeline");
}
#[tokio::test]
async fn test_hybrid_pipeline_fallback_on_failure() {
let primary = MockEmbeddingProvider::new("primary", 384).with_failure();
let fallback = MockEmbeddingProvider::new("fallback", 384);
let pipeline = HybridEmbeddingPipeline::new(Box::new(primary), FallbackStrategy::Sequential)
.add_fallback(Box::new(fallback));
let node = create_test_node("test_function", None);
let embedding = pipeline.generate_embedding(&node).await.unwrap();
assert_eq!(embedding.len(), 384);
}
#[tokio::test]
async fn test_hybrid_pipeline_fastest_first_strategy() {
let slow_primary =
MockEmbeddingProvider::new("slow", 384).with_latency(Duration::from_millis(100));
let fast_fallback =
MockEmbeddingProvider::new("fast", 384).with_latency(Duration::from_millis(10));
let pipeline =
HybridEmbeddingPipeline::new(Box::new(slow_primary), FallbackStrategy::FastestFirst)
.add_fallback(Box::new(fast_fallback));
let node = create_test_node("test_function", None);
let start = std::time::Instant::now();
let embedding = pipeline.generate_embedding(&node).await.unwrap();
let duration = start.elapsed();
assert_eq!(embedding.len(), 384);
// Should use the faster provider
assert!(duration < Duration::from_millis(50));
}
#[tokio::test]
async fn test_empty_input_handling() {
let provider = MockEmbeddingProvider::new("test", 384);
let empty_nodes: Vec<CodeNode> = vec![];
let embeddings = provider.generate_embeddings(&empty_nodes).await.unwrap();
assert!(embeddings.is_empty());
}
#[tokio::test]
async fn test_performance_target_throughput() {
let provider = MockEmbeddingProvider::new("test", 384).with_latency(Duration::from_millis(5)); // 5ms per embedding
// Create 100 test nodes
let nodes: Vec<CodeNode> = (0..100)
.map(|i| create_test_node(&format!("function_{}", i), None))
.collect();
let start = std::time::Instant::now();
let embeddings = provider.generate_embeddings(&nodes).await.unwrap();
let duration = start.elapsed();
assert_eq!(embeddings.len(), 100);
let throughput = nodes.len() as f64 / duration.as_secs_f64();
println!("Achieved throughput: {:.2} texts/s", throughput);
// Target: ≥100 texts/s
// With 5ms latency, sequential processing gives ~200 texts/s
assert!(
throughput >= 50.0,
"Throughput too low: {:.2} texts/s",
throughput
);
}
#[tokio::test]
async fn test_large_batch_processing() {
let provider = MockEmbeddingProvider::new("test", 384);
// Create 1000 test nodes for batch processing test
let nodes: Vec<CodeNode> = (0..1000)
.map(|i| create_test_node(&format!("node_{}", i), None))
.collect();
let config = BatchConfig {
batch_size: 32,
max_concurrent: 4,
timeout: Duration::from_secs(30),
retry_attempts: 3,
};
let start = std::time::Instant::now();
let (embeddings, metrics) = provider
.generate_embeddings_with_config(&nodes, &config)
.await
.unwrap();
let duration = start.elapsed();
assert_eq!(embeddings.len(), 1000);
assert_eq!(metrics.texts_processed, 1000);
// Target: 1k texts ≤30s
assert!(
duration <= Duration::from_secs(30),
"Batch processing too slow: {:?}",
duration
);
println!(
"1000 text batch processed in {:?} ({:.2} texts/s)",
duration, metrics.throughput
);
}