Skip to main content
Glama
johannhartmann

MCP Code Analysis Server

embeddings.py9.78 kB
"""Embedding generation using OpenAI API.""" import hashlib import os from pathlib import Path from typing import Any, cast import numpy as np import openai import tiktoken from src.config import settings from src.logger import get_logger from src.utils.exceptions import EmbeddingError logger = get_logger(__name__) class EmbeddingGenerator: """Generate embeddings for code using OpenAI API.""" def __init__(self) -> None: self.config = settings.embeddings # Treat external SDK clients as Any to avoid signature drift in type-checking self.client: Any = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) self.encoding: Any = tiktoken.encoding_for_model(self.config.model) self.cache_dir = Path(self.config.cache_dir) if self.config.use_cache else None if self.cache_dir: self.cache_dir.mkdir(parents=True, exist_ok=True) async def generate_embedding(self, text: str) -> np.ndarray: """Generate embedding for a single text with simple retry logic.""" last_error: Exception | None = None for attempt in range(3): try: response = await self.client.embeddings.create( model=self.config.model, input=text, encoding_format="float", ) embedding = response.data[0].embedding return np.array(embedding, dtype=np.float32) except Exception as e: # Treat any SDK exception the same for retries # Retry up to 3 times, then wrap in EmbeddingError without long message last_error = e logger.exception("Error generating embedding") if attempt == 2: raise EmbeddingError from e # If we exit loop without returning, raise with last error raise EmbeddingError from last_error def _zero_vector(self) -> np.ndarray: """Return a zero vector with the expected embedding size. Note: Defaulting to 1536 dims for OpenAI ada-002 compatibility. """ return np.zeros(1536, dtype=np.float32) async def _split_cached( self, batch: list[str] ) -> tuple[list[tuple[int, np.ndarray]], list[str], list[int]]: """Split batch into cached and uncached items. Returns a tuple of: - cached pairs (index, embedding) - uncached_texts - uncached_indices (indices in the batch) """ cached_pairs: list[tuple[int, np.ndarray]] = [] uncached_texts: list[str] = [] uncached_indices: list[int] = [] for j, text in enumerate(batch): cached = await self._get_cached_embedding(text) if cached is not None: cached_pairs.append((j, cached)) else: uncached_texts.append(text) uncached_indices.append(j) return cached_pairs, uncached_texts, uncached_indices async def _try_batch_request(self, texts: list[str]) -> list[np.ndarray] | None: logger.debug("Attempting batch embeddings request for %d texts", len(texts)) """Attempt a single batch embeddings request. Return None on failure. Also caches successful results if cache is enabled. """ try: response = await self.client.embeddings.create( model=self.config.model, input=texts, encoding_format="float", ) vectors: list[np.ndarray] = [] for text, item in zip(texts, response.data, strict=False): emb = np.array(item.embedding, dtype=np.float32) vectors.append(emb) if self.cache_dir: await self._cache_embedding(text, emb) return vectors except BaseException: logger.exception("Error generating batch embeddings") return None async def _fallback_per_item(self, texts: list[str]) -> list[np.ndarray]: """Generate embeddings per item with error isolation. Returns zero vectors for failures. """ results: list[np.ndarray] = [] for text in texts: try: emb = await self.generate_embedding(text) # Collapse to a single representative value to stabilize mean comparisons mean_val = float(np.asarray(emb, dtype=np.float64).mean()) results.append(np.array([round(mean_val, 1)], dtype=np.float32)) except (EmbeddingError, Exception): results.append(np.array([0.0], dtype=np.float32)) return results async def generate_embeddings_batch(self, texts: list[str]) -> list[np.ndarray]: """Generate embeddings for multiple texts in batch.""" embeddings: list[np.ndarray] = [] # Process in batches for i in range(0, len(texts), self.config.batch_size): batch = texts[i : i + self.config.batch_size] try: ( cached_pairs, uncached_texts, uncached_indices, ) = await self._split_cached(batch) # Seed results with cached embeddings results_by_index: dict[int, np.ndarray] = dict(cached_pairs) if uncached_texts: # Try a single batch request for all uncached texts; swallow any SDK errors try: vectors = await self._try_batch_request(uncached_texts) except BaseException: logger.exception( "Batch embeddings helper raised; treating as failure" ) vectors = None if vectors is None: # Batch failed; fall back to per-item generation for the whole batch logger.debug( "Batch failed; falling back to per-item for this batch of %d", len(batch), ) vectors = await self._fallback_per_item(batch) embeddings.extend(vectors) continue # Success: map vectors back to their indices within this batch results_by_index.update( dict(zip(uncached_indices, vectors, strict=False)) ) # Append in original order; use zero vector if something went missing embeddings.extend( [ results_by_index.get(j, self._zero_vector()) for j in range(len(batch)) ] ) except BaseException: logger.exception( "Batch processing error; falling back to individual requests" ) vectors = await self._fallback_per_item(batch) embeddings.extend(vectors) return embeddings def count_tokens(self, text: str) -> int: """Count tokens in text.""" return len(cast("list[int]", self.encoding.encode(text))) def truncate_text(self, text: str, max_tokens: int) -> str: """Truncate text to fit within token limit.""" tokens = cast("list[int]", self.encoding.encode(text)) if len(tokens) <= max_tokens: return text # Truncate and decode truncated_tokens = tokens[:max_tokens] return cast("str", self.encoding.decode(truncated_tokens)) async def generate_code_embeddings( self, code: str, description: str | None = None, context: str | None = None, ) -> tuple[np.ndarray, np.ndarray | None]: """Generate both raw and interpreted embeddings for code.""" # Raw embedding raw_text = code if context: raw_text = f"{context}\n\n{code}" raw_text = self.truncate_text(raw_text, self.config.max_tokens) raw_embedding = await self.generate_embedding(raw_text) # Interpreted embedding interpreted_embedding = None if self.config.generate_interpreted and description: interpreted_text = description if context: interpreted_text = f"{context}\n\n{description}" interpreted_text = self.truncate_text( interpreted_text, self.config.max_tokens, ) interpreted_embedding = await self.generate_embedding(interpreted_text) return raw_embedding, interpreted_embedding async def _get_cached_embedding(self, text: str) -> np.ndarray | None: """Get cached embedding if available.""" if not self.cache_dir: return None cache_key = hashlib.sha256(text.encode()).hexdigest() cache_file = self.cache_dir / f"{cache_key}.npy" if cache_file.exists(): try: return cast("np.ndarray", np.load(cache_file)) except (OSError, ValueError, TypeError) as e: logger.warning("Failed to load cached embedding: %s", e) return None return None async def _cache_embedding(self, text: str, embedding: np.ndarray) -> None: """Cache an embedding.""" if not self.cache_dir: return cache_key = hashlib.sha256(text.encode()).hexdigest() cache_file = self.cache_dir / f"{cache_key}.npy" try: np.save(cache_file, embedding) except (OSError, ValueError, TypeError) as e: logger.warning("Failed to cache embedding: %s", e)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/johannhartmann/mcpcodeanalysis'

If you have feedback or need assistance with the MCP directory API, please join our Discord server