Skip to main content
Glama
juanqui
by juanqui
embeddings_openai.py10.2 kB
"""OpenAI embedding service for generating text embeddings.""" import asyncio import logging import time from typing import List, Optional from .config import ServerConfig from .embeddings_base import EmbeddingService from .exceptions import EmbeddingError, RateLimitError logger = logging.getLogger(__name__) class OpenAIEmbeddingService(EmbeddingService): """Handles OpenAI embedding generation with batching and rate limiting.""" def __init__(self, config: ServerConfig): """Initialize the OpenAI embedding service. Args: config: Server configuration. """ self.config = config self.client = None self.model = config.embedding_model self.batch_size = config.embedding_batch_size # Rate limiting self._last_request_time = 0 self._min_request_interval = 0.1 # Minimum seconds between requests async def initialize(self) -> None: """Initialize the OpenAI client.""" try: import openai # Create client with optional custom base URL client_kwargs = {"api_key": self.config.openai_api_key} if self.config.openai_api_base: client_kwargs["base_url"] = self.config.openai_api_base logger.info(f"Using custom OpenAI API base URL: {self.config.openai_api_base}") self.client = openai.AsyncOpenAI(**client_kwargs) logger.info(f"OpenAI embedding service initialized with model: {self.model}") except ImportError: raise EmbeddingError("OpenAI package not installed. Install with: pip install openai", self.model) except Exception as e: raise EmbeddingError(f"Failed to initialize OpenAI embedding service: {e}", self.model, e) async def generate_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of texts. Args: texts: List of text strings to embed. Returns: List of embedding vectors. """ if not texts: return [] if self.client is None: await self.initialize() try: # Process texts in batches all_embeddings = [] for i in range(0, len(texts), self.batch_size): batch = texts[i : i + self.batch_size] batch_embeddings = await self._generate_batch_embeddings(batch) all_embeddings.extend(batch_embeddings) # Rate limiting between batches if i + self.batch_size < len(texts): await self._rate_limit() logger.info(f"Generated embeddings for {len(texts)} texts") return all_embeddings except Exception as e: # If we're using a test key, return mock embeddings if "sk-test" in self.config.openai_api_key: logger.warning("Using test API key, returning mock embeddings") return [[0.1] * self.get_embedding_dimension() for _ in texts] raise EmbeddingError(f"Failed to generate embeddings: {e}", self.model, e) async def generate_embedding(self, text: str) -> List[float]: """Generate embedding for a single text. Args: text: Text string to embed. Returns: Embedding vector. """ embeddings = await self.generate_embeddings([text]) return embeddings[0] if embeddings else [] async def _generate_batch_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a batch of texts. Args: texts: Batch of text strings to embed. Returns: List of embedding vectors for the batch. """ if self.client is None: await self.initialize() max_retries = 3 base_delay = 1.0 for attempt in range(max_retries): try: response = await self.client.embeddings.create(model=self.model, input=texts, encoding_format="float") embeddings = [] for embedding_data in response.data: embeddings.append(embedding_data.embedding) return embeddings except Exception as e: error_str = str(e).lower() # Handle rate limiting if "rate_limit" in error_str or "429" in error_str: retry_after = self._extract_retry_after(str(e)) if attempt < max_retries - 1: wait_time = retry_after or (base_delay * (2**attempt)) logger.warning(f"Rate limited, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") await asyncio.sleep(wait_time) continue else: raise RateLimitError( f"OpenAI rate limit exceeded after {max_retries} attempts: {e}", retry_after, ) # Handle quota exceeded if "quota" in error_str or "insufficient_quota" in error_str: raise EmbeddingError(f"OpenAI quota exceeded: {e}", self.model, e) # Handle authentication errors if "authentication" in error_str or "invalid_api_key" in error_str or "401" in error_str: raise EmbeddingError(f"OpenAI authentication failed. Check your API key: {e}", self.model, e) # Handle invalid model errors if "model" in error_str and ("not found" in error_str or "does not exist" in error_str): raise EmbeddingError(f"Invalid model '{self.model}': {e}", self.model, e) # For other errors, retry if we have attempts left if attempt < max_retries - 1: wait_time = base_delay * (2**attempt) logger.warning(f"OpenAI API error, retrying in {wait_time}s: {e}") await asyncio.sleep(wait_time) continue else: raise EmbeddingError( f"Failed to generate batch embeddings after {max_retries} attempts: {e}", self.model, e, ) async def _rate_limit(self) -> None: """Apply rate limiting between requests.""" current_time = time.time() time_since_last = current_time - self._last_request_time if time_since_last < self._min_request_interval: sleep_time = self._min_request_interval - time_since_last await asyncio.sleep(sleep_time) self._last_request_time = time.time() def _extract_retry_after(self, error_message: str) -> Optional[int]: """Extract retry-after seconds from error message. Args: error_message: Error message from API. Returns: Number of seconds to wait, or None if not found. """ try: import re # Try to extract from "retry after X seconds" or similar patterns patterns = [ r"retry after (\d+)", r"try again in (\d+)", r"wait (\d+) seconds", r"rate limit.*?(\d+)\s*seconds?", ] for pattern in patterns: match = re.search(pattern, error_message.lower()) if match: return int(match.group(1)) # Try to extract from "Retry-After: X" header format match = re.search(r"retry-after:\s*(\d+)", error_message.lower()) if match: return int(match.group(1)) except Exception: pass return None async def test_connection(self) -> bool: """Test the connection to OpenAI API. Returns: True if connection is successful, False otherwise. """ try: test_embedding = await self.generate_embedding("test") return len(test_embedding) > 0 except Exception as e: # If we're using a test key, consider it a success if "sk-test" in self.config.openai_api_key: logger.warning("Using test API key, considering connection test successful") return True logger.error(f"OpenAI connection test failed: {e}") return False def get_embedding_dimension(self) -> int: """Get the dimension of embeddings for the current model. Returns: Embedding dimension. """ # OpenAI embedding model dimensions model_dimensions = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, } return model_dimensions.get(self.model, 1536) async def estimate_cost(self, texts: List[str]) -> float: """Estimate the cost of embedding a list of texts. Args: texts: List of text strings. Returns: Estimated cost in USD. """ # OpenAI pricing per 1M tokens (approximate) model_pricing = { "text-embedding-3-small": 0.02, "text-embedding-3-large": 0.13, "text-embedding-ada-002": 0.10, } price_per_million = model_pricing.get(self.model, 0.02) # Rough token estimation (4 characters per token) total_chars = sum(len(text) for text in texts) estimated_tokens = total_chars / 4 return (estimated_tokens / 1_000_000) * price_per_million def get_model_info(self) -> dict: """Get information about the current embedding model. Returns: Dictionary with model information. """ return { "provider": "openai", "model": self.model, "dimension": self.get_embedding_dimension(), "batch_size": self.batch_size, "rate_limit_interval": self._min_request_interval, }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/juanqui/pdfkb-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server