Skip to main content
Glama

CodeGraph CLI MCP Server

by Jakedismo
provider_router.rs9.28 kB
// ABOUTME: Routes LLM requests to appropriate providers based on LATS phase // ABOUTME: Supports different models for selection, expansion, evaluation, backpropagation use codegraph_ai::llm_provider::LLMProvider; use codegraph_mcp_core::config_manager::CodeGraphConfig; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum LATSPhase { Selection, Expansion, Evaluation, Backpropagation, } impl std::fmt::Display for LATSPhase { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Selection => write!(f, "selection"), Self::Expansion => write!(f, "expansion"), Self::Evaluation => write!(f, "evaluation"), Self::Backpropagation => write!(f, "backpropagation"), } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProviderStats { pub default_provider: String, pub phase_providers: HashMap<String, String>, } /// Routes LLM requests to appropriate providers based on LATS phase /// /// The ProviderRouter maintains a mapping from LATS phases to LLM providers, /// enabling multi-provider configurations where different models can be used /// for different phases of the LATS algorithm. /// /// For Phase 2 implementation, this currently defaults to using the primary /// LLM provider for all phases. Future implementations will support reading /// phase-specific provider configurations from the config file. pub struct ProviderRouter { providers: HashMap<LATSPhase, Arc<dyn LLMProvider>>, default_provider: Arc<dyn LLMProvider>, } impl ProviderRouter { /// Create a new ProviderRouter /// /// For Phase 2, this uses the default_provider for all phases. /// Future implementation will read from config.llm.lats to create /// phase-specific providers. /// /// # Arguments /// * `_config` - Configuration (reserved for future use) /// * `default_provider` - The default LLM provider to use for all phases /// /// # Returns /// A new ProviderRouter instance pub fn new(_config: &CodeGraphConfig, default_provider: Arc<dyn LLMProvider>) -> Self { // Phase 2 implementation: Use default provider for all phases // TODO: In Phase 3, read config.llm.lats to create phase-specific providers: // if let Some(ref lats_config) = config.llm.lats { // // Create phase-specific providers from config // if let (Some(provider), Some(model)) = // (&lats_config.selection_provider, &lats_config.selection_model) // { // let llm_config = Self::create_phase_config(provider, model, config); // providers.insert( // LATSPhase::Selection, // create_llm_provider(&llm_config)? // ); // } // // Similar for Expansion, Evaluation, Backpropagation // } Self { providers: HashMap::new(), default_provider, } } /// Get the LLM provider for a specific LATS phase /// /// If a phase-specific provider is configured, it will be returned. /// Otherwise, the default provider is used. /// /// # Arguments /// * `phase` - The LATS phase requiring an LLM provider /// /// # Returns /// An Arc to the appropriate LLM provider pub fn get_provider(&self, phase: LATSPhase) -> Arc<dyn LLMProvider> { self.providers .get(&phase) .cloned() .unwrap_or_else(|| self.default_provider.clone()) } /// Get statistics about provider allocation /// /// Returns information about which providers are being used for which phases, /// useful for debugging and monitoring. pub fn stats(&self) -> ProviderStats { ProviderStats { default_provider: self.default_provider.provider_name().to_string(), phase_providers: self .providers .iter() .map(|(phase, provider)| { (format!("{}", phase), provider.provider_name().to_string()) }) .collect(), } } /// Check if a specific phase has a dedicated provider pub fn has_phase_provider(&self, phase: LATSPhase) -> bool { self.providers.contains_key(&phase) } /// Get the number of unique providers configured pub fn unique_provider_count(&self) -> usize { let mut unique_providers = std::collections::HashSet::new(); unique_providers.insert(self.default_provider.provider_name()); for provider in self.providers.values() { unique_providers.insert(provider.provider_name()); } unique_providers.len() } } #[cfg(test)] mod tests { use super::*; use async_trait::async_trait; use codegraph_ai::llm_provider::LLMProvider; use codegraph_mcp_core::config_manager::CodeGraphConfig; // Mock LLM provider for testing struct MockProvider { name: &'static str, } #[async_trait] impl LLMProvider for MockProvider { fn provider_name(&self) -> &str { self.name } fn model_name(&self) -> &str { "mock-model" } async fn is_available(&self) -> bool { true } fn characteristics(&self) -> codegraph_ai::llm_provider::ProviderCharacteristics { codegraph_ai::llm_provider::ProviderCharacteristics { max_tokens: 100000, avg_latency_ms: 100, rpm_limit: None, tpm_limit: None, supports_streaming: false, supports_functions: false, } } async fn generate_chat( &self, _messages: &[codegraph_ai::llm_provider::Message], _config: &codegraph_ai::llm_provider::GenerationConfig, ) -> codegraph_ai::llm_provider::LLMResult<codegraph_ai::llm_provider::LLMResponse> { Ok(codegraph_ai::llm_provider::LLMResponse { content: "mock response".to_string(), answer: String::new(), total_tokens: Some(10), prompt_tokens: Some(5), completion_tokens: Some(5), finish_reason: Some("stop".to_string()), model: "mock-model".to_string(), tool_calls: None, }) } } #[test] fn test_provider_router_creation() { let config = CodeGraphConfig::default(); let mock_provider = Arc::new(MockProvider { name: "mock" }) as Arc<dyn LLMProvider>; let router = ProviderRouter::new(&config, mock_provider); // Should use default provider for all phases assert_eq!( router.get_provider(LATSPhase::Selection).provider_name(), "mock" ); assert_eq!( router.get_provider(LATSPhase::Expansion).provider_name(), "mock" ); assert_eq!( router.get_provider(LATSPhase::Evaluation).provider_name(), "mock" ); assert_eq!( router .get_provider(LATSPhase::Backpropagation) .provider_name(), "mock" ); } #[test] fn test_provider_stats() { let config = CodeGraphConfig::default(); let mock_provider = Arc::new(MockProvider { name: "test-provider", }) as Arc<dyn LLMProvider>; let router = ProviderRouter::new(&config, mock_provider); let stats = router.stats(); assert_eq!(stats.default_provider, "test-provider"); assert_eq!(stats.phase_providers.len(), 0); // No phase-specific providers in Phase 2 } #[test] fn test_has_phase_provider() { let config = CodeGraphConfig::default(); let mock_provider = Arc::new(MockProvider { name: "mock" }) as Arc<dyn LLMProvider>; let router = ProviderRouter::new(&config, mock_provider); // Phase 2: No phase-specific providers assert!(!router.has_phase_provider(LATSPhase::Selection)); assert!(!router.has_phase_provider(LATSPhase::Expansion)); assert!(!router.has_phase_provider(LATSPhase::Evaluation)); assert!(!router.has_phase_provider(LATSPhase::Backpropagation)); } #[test] fn test_unique_provider_count() { let config = CodeGraphConfig::default(); let mock_provider = Arc::new(MockProvider { name: "mock" }) as Arc<dyn LLMProvider>; let router = ProviderRouter::new(&config, mock_provider); // Phase 2: Only default provider assert_eq!(router.unique_provider_count(), 1); } #[test] fn test_lats_phase_display() { assert_eq!(format!("{}", LATSPhase::Selection), "selection"); assert_eq!(format!("{}", LATSPhase::Expansion), "expansion"); assert_eq!(format!("{}", LATSPhase::Evaluation), "evaluation"); assert_eq!(format!("{}", LATSPhase::Backpropagation), "backpropagation"); } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Jakedismo/codegraph-rust'

If you have feedback or need assistance with the MCP directory API, please join our Discord server