Skip to main content
Glama
embedding.py6.38 kB
"""Embedding model management for MCP-RAG.""" import logging from abc import ABC, abstractmethod from typing import List, Optional from .config import settings logger = logging.getLogger(__name__) try: import httpx HTTPX_AVAILABLE = True except ImportError: HTTPX_AVAILABLE = False logger.warning("httpx not available, Doubao embedding will not work") try: from sentence_transformers import SentenceTransformer SENTENCE_TRANSFORMERS_AVAILABLE = True except ImportError: SENTENCE_TRANSFORMERS_AVAILABLE = False SentenceTransformer = None logger.warning("sentence-transformers not available, local embedding models will not work. Install with: pip install mcp-rag[local-embeddings]") DOUBAO_AVAILABLE = HTTPX_AVAILABLE class EmbeddingModel(ABC): """Abstract base class for embedding models.""" @abstractmethod async def encode(self, texts: List[str]) -> List[List[float]]: """Encode texts to embeddings.""" pass @abstractmethod async def encode_single(self, text: str) -> List[float]: """Encode single text to embedding.""" pass class SentenceTransformerModel(EmbeddingModel): """SentenceTransformer-based embedding model.""" def __init__(self, model_name: str = "m3e-small", device: str = "cpu", cache_dir: Optional[str] = None): if not SENTENCE_TRANSFORMERS_AVAILABLE: raise RuntimeError("sentence-transformers not available. Install with: pip install mcp-rag[local-embeddings]") self.model_name = model_name self.device = device self.cache_dir = cache_dir self.model = None async def initialize(self) -> None: """Initialize the embedding model.""" try: # Map model names to actual model identifiers model_mapping = { "m3e-small": "moka-ai/m3e-small", "e5-small": "intfloat/e5-small-v2" } actual_model_name = model_mapping.get(self.model_name, self.model_name) self.model = SentenceTransformer( actual_model_name, device=self.device, cache_folder=self.cache_dir ) logger.info(f"Initialized embedding model: {actual_model_name} on {self.device}") except Exception as e: logger.error(f"Failed to initialize embedding model {self.model_name}: {e}") raise async def encode(self, texts: List[str]) -> List[List[float]]: """Encode multiple texts to embeddings.""" if not self.model: raise RuntimeError("Model not initialized") try: embeddings = self.model.encode(texts, convert_to_numpy=True).tolist() return embeddings except Exception as e: logger.error(f"Failed to encode texts: {e}") raise async def encode_single(self, text: str) -> List[float]: """Encode single text to embedding.""" embeddings = await self.encode([text]) return embeddings[0] class DoubaoEmbeddingModel(EmbeddingModel): """Doubao (豆包) embedding model.""" def __init__(self, api_key: Optional[str] = None, base_url: str = "https://ark.cn-beijing.volces.com/api/v3", model: str = "doubao-embedding-text-240715"): if not DOUBAO_AVAILABLE: raise RuntimeError("httpx not available. Please install it with: pip install httpx") self.api_key = api_key self.base_url = base_url.rstrip('/') self.model = model self.client: Optional[httpx.AsyncClient] = None async def initialize(self) -> None: """Initialize the Doubao client.""" try: if not self.api_key: raise ValueError("API key is required for Doubao embedding") self.client = httpx.AsyncClient( base_url=self.base_url, headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" }, timeout=30.0 ) logger.info(f"Initialized Doubao embedding model: {self.model}") except Exception as e: logger.error(f"Failed to initialize Doubao embedding model: {e}") raise async def encode(self, texts: List[str]) -> List[List[float]]: """Encode multiple texts to embeddings.""" if not self.client: raise RuntimeError("Client not initialized") try: response = await self.client.post( "/embeddings", json={ "model": self.model, "input": texts, "encoding_format": "float" } ) if response.status_code != 200: raise RuntimeError(f"Doubao API error: {response.status_code} - {response.text}") data = response.json() embeddings = [item["embedding"] for item in data["data"]] return embeddings except Exception as e: logger.error(f"Failed to encode texts with Doubao: {e}") raise async def encode_single(self, text: str) -> List[float]: """Encode single text to embedding.""" embeddings = await self.encode([text]) return embeddings[0] async def close(self) -> None: """Close the HTTP client.""" if self.client: await self.client.aclose() # Global embedding model instance embedding_model: Optional[EmbeddingModel] = None async def get_embedding_model() -> EmbeddingModel: """Get the global embedding model instance.""" global embedding_model if embedding_model is None: if settings.embedding_provider == "doubao": embedding_model = DoubaoEmbeddingModel( api_key=settings.embedding_api_key, base_url=settings.embedding_base_url, model=settings.embedding_model ) else: # Local SentenceTransformer models embedding_model = SentenceTransformerModel( model_name=settings.embedding_model, device=settings.embedding_device, cache_dir=settings.embedding_cache_dir ) await embedding_model.initialize() return embedding_model

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/kalicyh/mcp-rag'

If you have feedback or need assistance with the MCP directory API, please join our Discord server