Skip to main content
Glama

CodeGraph CLI MCP Server

by Jakedismo
onnx_provider.rsโ€ข16.7 kB
#[cfg(feature = "onnx")] use crate::providers::{ BatchConfig, EmbeddingMetrics, EmbeddingProvider, MemoryUsage, ProviderCharacteristics, }; #[cfg(feature = "onnx")] use async_trait::async_trait; #[cfg(feature = "onnx")] use codegraph_core::{CodeGraphError, CodeNode, Result}; #[cfg(feature = "onnx")] use hf_hub::api::tokio::Api; #[cfg(feature = "onnx")] use ndarray::{s, Array, Array2, Axis}; #[cfg(feature = "onnx")] use ort::execution_providers::CoreMLExecutionProvider; #[cfg(feature = "onnx")] use ort::session::builder::GraphOptimizationLevel; #[cfg(feature = "onnx")] use ort::session::Session; #[cfg(feature = "onnx")] use ort::value::Value; #[cfg(feature = "onnx")] use parking_lot::Mutex; #[cfg(feature = "onnx")] use std::sync::Arc; #[cfg(feature = "onnx")] use std::time::Instant; #[cfg(feature = "onnx")] use tokenizers::Tokenizer; #[cfg(feature = "onnx")] #[derive(Debug, Clone)] pub struct OnnxConfig { pub model_repo: String, // HF repo id or path pub model_file: Option<String>, // specific ONNX filename if not default pub max_sequence_length: usize, pub pooling: OnnxPooling, } #[cfg(feature = "onnx")] #[derive(Debug, Clone)] pub enum OnnxPooling { Cls, Mean, Max, } #[cfg(feature = "onnx")] pub struct OnnxEmbeddingProvider { session: Arc<Mutex<Session>>, tokenizer: Arc<Tokenizer>, hidden_size: usize, config: OnnxConfig, } #[cfg(feature = "onnx")] impl OnnxEmbeddingProvider { pub async fn new(config: OnnxConfig) -> Result<Self> { use std::path::Path; // Resolve files via HF Hub if repo id or local path let (tokenizer_path, model_path) = if Path::new(&config.model_repo).exists() { let base = Path::new(&config.model_repo); let tok = base.join("tokenizer.json"); let candidates: Vec<String> = if let Some(mf) = &config.model_file { vec![ mf.clone(), "onnx/model.onnx".into(), "model.onnx".into(), "model_fp16.onnx".into(), ] } else { vec![ "model.onnx".into(), "onnx/model.onnx".into(), "model_fp16.onnx".into(), ] }; let mut found = None; for cand in candidates { let p = base.join(&cand); if p.exists() { found = Some(p); break; } } let model = found.ok_or_else(|| CodeGraphError::External("No ONNX model file found in local repo path (tried model.onnx, onnx/model.onnx, model_fp16.onnx)".into()))?; (tok, model) } else { let api = Api::new().map_err(|e| CodeGraphError::External(e.to_string()))?; let repo = api.model(config.model_repo.clone()); let tok = repo .get("tokenizer.json") .await .map_err(|e| CodeGraphError::External(e.to_string()))?; let candidates: Vec<String> = if let Some(mf) = &config.model_file { vec![ mf.clone(), "onnx/model.onnx".into(), "model.onnx".into(), "model_fp16.onnx".into(), ] } else { vec![ "model.onnx".into(), "onnx/model.onnx".into(), "model_fp16.onnx".into(), ] }; let mut model_opt = None; for cand in candidates { match repo.get(&cand).await { Ok(p) => { model_opt = Some(p); break; } Err(e) => { tracing::debug!("ONNX model candidate '{}' not found: {}", cand, e); } } } let model = model_opt.ok_or_else(|| CodeGraphError::External("No ONNX model file found in HF repo (tried model.onnx, onnx/model.onnx, model_fp16.onnx)".into()))?; (tok, model) }; let tokenizer = Tokenizer::from_file(tokenizer_path) .map_err(|e| CodeGraphError::External(e.to_string()))?; // Decide execution providers based on env let ep = std::env::var("CODEGRAPH_ONNX_EP") .unwrap_or_else(|_| "cpu".into()) .to_lowercase(); let mut session_builder = Session::builder() .map_err(|e| CodeGraphError::External(e.to_string()))? .with_optimization_level(GraphOptimizationLevel::Level3) .map_err(|e| CodeGraphError::External(e.to_string()))?; // Register CoreML EP when requested, fall back to CPU if unavailable if ep == "coreml" { #[cfg(target_os = "macos")] { session_builder = session_builder .with_execution_providers([CoreMLExecutionProvider::default().build()]) .map_err(|e| CodeGraphError::External(e.to_string()))?; tracing::info!("Using ONNX Runtime CoreML execution provider"); } #[cfg(not(target_os = "macos"))] { tracing::warn!( "CODEGRAPH_ONNX_EP=coreml set, but CoreML registration not supported on this platform; using CPU." ); } } let session = session_builder .commit_from_file(&model_path) .map_err(|e| CodeGraphError::External(e.to_string()))?; // Hidden size discovery (fallback to 768) let hidden_size = 768; Ok(Self { session: Arc::new(Mutex::new(session)), tokenizer: Arc::new(tokenizer), hidden_size, config, }) } fn prepare_inputs( &self, texts: &[String], ) -> Result<(Array2<i64>, Array2<i64>, Array2<i64>, usize)> { let mut ids: Vec<Vec<i64>> = Vec::with_capacity(texts.len()); let mut mask: Vec<Vec<i64>> = Vec::with_capacity(texts.len()); let mut type_ids: Vec<Vec<i64>> = Vec::with_capacity(texts.len()); let mut max_len = 0usize; for t in texts { let enc = self .tokenizer .encode(t.as_str(), true) .map_err(|e| CodeGraphError::External(e.to_string()))?; let mut tid: Vec<i64> = enc.get_ids().iter().map(|&x| x as i64).collect(); let mut m: Vec<i64> = enc .get_attention_mask() .iter() .map(|&x| if x > 0 { 1 } else { 0 }) .collect(); if tid.len() > self.config.max_sequence_length { tid.truncate(self.config.max_sequence_length); m.truncate(self.config.max_sequence_length); } max_len = max_len.max(tid.len()); // token_type_ids default to zeros for single-sequence inputs let tt: Vec<i64> = std::iter::repeat(0).take(tid.len()).collect(); ids.push(tid); mask.push(m); type_ids.push(tt); } for ((tid, m), tt) in ids.iter_mut().zip(mask.iter_mut()).zip(type_ids.iter_mut()) { if tid.len() < max_len { let pad = max_len - tid.len(); tid.extend(std::iter::repeat(0).take(pad)); m.extend(std::iter::repeat(0).take(pad)); tt.extend(std::iter::repeat(0).take(pad)); } } // Build ndarray [B, L] let b = texts.len(); let l = max_len; let mut arr_ids = Array2::<i64>::zeros((b, l)); let mut arr_mask = Array2::<i64>::zeros((b, l)); let mut arr_type_ids = Array2::<i64>::zeros((b, l)); for i in 0..b { let tid = &ids[i]; let m = &mask[i]; let tt = &type_ids[i]; arr_ids .slice_mut(s![i, ..]) .assign(&ndarray::ArrayView1::from(tid.as_slice())); arr_mask .slice_mut(s![i, ..]) .assign(&ndarray::ArrayView1::from(m.as_slice())); arr_type_ids .slice_mut(s![i, ..]) .assign(&ndarray::ArrayView1::from(tt.as_slice())); } Ok((arr_ids, arr_mask, arr_type_ids, l)) } fn pool_and_normalize( &self, last_hidden: Array2<f32>, _mask: &Array2<i64>, ) -> Result<Vec<Vec<f32>>> { // last_hidden expected [B*L, H] after reshape, weโ€™ll assume provider returns [B, H] already in a simplified path for now. // As a minimal implementation, we treat last_hidden as [B, H]. let b = last_hidden.len_of(Axis(0)); let mut out = Vec::with_capacity(b); for i in 0..b { let mut v = last_hidden.slice(s![i, ..]).to_owned().to_vec(); // L2 normalize let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12); for x in &mut v { *x /= norm; } out.push(v); } Ok(out) } } #[cfg(feature = "onnx")] #[async_trait] impl EmbeddingProvider for OnnxEmbeddingProvider { async fn generate_embedding(&self, node: &CodeNode) -> Result<Vec<f32>> { let (embs, _) = self .generate_embeddings_with_config(&[node.clone()], &BatchConfig::default()) .await?; Ok(embs.into_iter().next().unwrap_or_default()) } async fn generate_embeddings(&self, nodes: &[CodeNode]) -> Result<Vec<Vec<f32>>> { let (embs, _) = self .generate_embeddings_with_config(nodes, &BatchConfig::default()) .await?; Ok(embs) } async fn generate_embeddings_with_config( &self, nodes: &[CodeNode], config: &BatchConfig, ) -> Result<(Vec<Vec<f32>>, EmbeddingMetrics)> { if nodes.is_empty() { return Ok(( Vec::new(), EmbeddingMetrics::new("ONNX".into(), 0, std::time::Duration::ZERO), )); } let start = Instant::now(); // Prepare texts let texts: Vec<String> = nodes .iter() .map(|n| { let mut s = String::new(); if let Some(lang) = &n.language { s.push_str(&format!("{:?} ", lang)); } if let Some(nt) = &n.node_type { s.push_str(&format!("{:?} ", nt)); } s.push_str(n.name.as_str()); if let Some(c) = &n.content { s.push(' '); s.push_str(c); } s }) .collect(); let mut all = Vec::with_capacity(nodes.len()); for chunk in texts.chunks(config.batch_size.max(1)) { let (ids, mask, type_ids, _l) = self.prepare_inputs(chunk)?; // Run session and immediately convert output to an owned Array2<f32> to avoid borrowing issues. let arr: Array2<f32> = { let mut sess = self.session.lock(); // Prepare values once let input_ids_v = Value::from_array(ids.clone().into_dyn()) .map_err(|e| CodeGraphError::External(e.to_string()))?; let attention_mask_v = Value::from_array(mask.clone().into_dyn()) .map_err(|e| CodeGraphError::External(e.to_string()))?; let token_type_ids_v = Value::from_array(type_ids.clone().into_dyn()) .map_err(|e| CodeGraphError::External(e.to_string()))?; // Build named inputs based on session's expected names let mut named: Vec<(String, ort::session::SessionInputValue<'_>)> = Vec::new(); for inp in &sess.inputs { let n = inp.name.to_lowercase(); if n.contains("input_ids") || n == "input" { // some models use generic name named.push((inp.name.clone(), input_ids_v.clone().into())); } else if n.contains("attention") || n.contains("mask") { named.push((inp.name.clone(), attention_mask_v.clone().into())); } else if n.contains("token_type") || n.contains("segment") { named.push((inp.name.clone(), token_type_ids_v.clone().into())); } } // Fallbacks if matching by names failed to fill all if named.is_empty() { // Use common defaults by arity match sess.inputs.len() { 3 => { named.push(("input_ids".into(), input_ids_v.clone().into())); named.push(("attention_mask".into(), attention_mask_v.clone().into())); named.push(("token_type_ids".into(), token_type_ids_v.clone().into())); } 2 => { named.push(("input_ids".into(), input_ids_v.clone().into())); named.push(("attention_mask".into(), attention_mask_v.clone().into())); } 1 => { named.push(("input".into(), input_ids_v.clone().into())); } _ => {} } } let outputs = sess .run(named) .map_err(|e| CodeGraphError::External(e.to_string()))?; let (shape, data) = outputs[0] .try_extract_tensor::<f32>() .map_err(|e| CodeGraphError::External(e.to_string()))?; let arr_dyn = ndarray::Array::from_shape_vec(shape.to_ixdyn(), data.to_vec()) .map_err(|e| CodeGraphError::External(e.to_string()))?; // Handle both [B, H] and [B, L, H] outputs if arr_dyn.ndim() == 2 { arr_dyn .into_dimensionality::<ndarray::Ix2>() .map_err(|e| CodeGraphError::External(e.to_string()))? } else if arr_dyn.ndim() == 3 { let arr3 = arr_dyn .into_dimensionality::<ndarray::Ix3>() .map_err(|e| CodeGraphError::External(e.to_string()))?; let b = arr3.len_of(Axis(0)); let _seq_len = arr3.len_of(Axis(1)); let h = arr3.len_of(Axis(2)); // Broadcast mask to [B, L, 1] let mask_f = mask.map(|&x| x as f32); let mask_exp = mask_f.clone().insert_axis(Axis(2)); // Weighted sum over L let masked = arr3 * &mask_exp; let sum_emb = masked.sum_axis(Axis(1)); // [B, H] let counts = mask_f .sum_axis(Axis(1)) .mapv(|x| if x <= 0.0 { 1.0 } else { x }); // [B] let mut pooled = Array2::<f32>::zeros((b, h)); for i in 0..b { let denom = counts[i]; pooled .slice_mut(s![i, ..]) .assign(&(&sum_emb.slice(s![i, ..]) / denom)); } pooled } else { return Err(CodeGraphError::External( "Unexpected ONNX output rank; expected 2D or 3D tensor".into(), )); } }; let pooled = self.pool_and_normalize(arr, &mask).unwrap_or_default(); all.extend(pooled); } let dur = start.elapsed(); let metrics = EmbeddingMetrics::new("ONNX".into(), nodes.len(), dur); Ok((all, metrics)) } fn embedding_dimension(&self) -> usize { self.hidden_size } fn provider_name(&self) -> &str { "ONNX" } async fn is_available(&self) -> bool { true } fn performance_characteristics(&self) -> ProviderCharacteristics { ProviderCharacteristics { expected_throughput: 100.0, typical_latency: std::time::Duration::from_millis(10), max_batch_size: 256, supports_streaming: false, requires_network: false, memory_usage: MemoryUsage::Medium, } } }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Jakedismo/codegraph-rust'

If you have feedback or need assistance with the MCP directory API, please join our Discord server