Skip to main content
Glama

PyTorch Documentation Search Tool

embedding.py12.6 kB
""" Embedding generation module for PyTorch Documentation Search Tool. Handles generating embeddings with OpenAI API and basic caching. """ import os import json import hashlib import time from typing import List, Dict, Any, Optional from openai import OpenAI from ptsearch.utils import logger from ptsearch.utils.error import APIError, ConfigError from ptsearch.config import settings class EmbeddingGenerator: """Generates embeddings using OpenAI API with caching support.""" def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None, use_cache: bool = True, cache_dir: Optional[str] = None): """Initialize embedding generator with OpenAI API and basic caching.""" self.model = model or settings.embedding_model self.api_key = api_key or settings.openai_api_key self.use_cache = use_cache self.cache_dir = cache_dir or settings.cache_dir self.stats = {"hits": 0, "misses": 0} # Validate API key early if not self.api_key: error_msg = "OPENAI_API_KEY not found. Please set this key in your .env file or environment." logger.error(error_msg) raise ConfigError(error_msg) # Initialize OpenAI client with compatibility handling self._initialize_client() # Initialize cache if enabled if use_cache: os.makedirs(self.cache_dir, exist_ok=True) logger.info(f"Embedding cache initialized", path=self.cache_dir) def _initialize_client(self): """Initialize OpenAI client with error handling for compatibility.""" try: # Standard initialization self.client = OpenAI(api_key=self.api_key) logger.info("OpenAI client initialized successfully") except TypeError as e: # Handle proxies parameter error if "unexpected keyword argument 'proxies'" in str(e): import httpx logger.info("Creating custom HTTP client for OpenAI compatibility") http_client = httpx.Client(timeout=60.0) self.client = OpenAI(api_key=self.api_key, http_client=http_client) else: error_msg = f"Unexpected error initializing OpenAI client: {e}" logger.error(error_msg) raise APIError(error_msg) def generate_embedding(self, text: str) -> List[float]: """Generate embedding for a single text with caching.""" if not text: logger.warning("Empty text provided for embedding generation") return [0.0] * settings.embedding_dimensions if self.use_cache: # Check cache first cached_embedding = self._get_from_cache(text) if cached_embedding: self.stats["hits"] += 1 return cached_embedding self.stats["misses"] += 1 # Generate embedding via API try: response = self.client.embeddings.create( input=text, model=self.model ) embedding = response.data[0].embedding # Cache the result if self.use_cache: self._save_to_cache(text, embedding) return embedding except Exception as e: error_msg = f"Error generating embedding: {e}" logger.error(error_msg) # Return zeros as fallback rather than failing completely return [0.0] * settings.embedding_dimensions def generate_embeddings(self, texts: List[str], batch_size: int = 20) -> List[List[float]]: """Generate embeddings for multiple texts with batching.""" if not texts: logger.warning("Empty text list provided for batch embedding generation") return [] all_embeddings = [] # Process in batches for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] batch_embeddings = [] # Check cache first uncached_texts = [] uncached_indices = [] if self.use_cache: for j, text in enumerate(batch_texts): cached_embedding = self._get_from_cache(text) if cached_embedding: self.stats["hits"] += 1 batch_embeddings.append(cached_embedding) else: self.stats["misses"] += 1 uncached_texts.append(text) uncached_indices.append(j) else: uncached_texts = batch_texts uncached_indices = list(range(len(batch_texts))) self.stats["misses"] += len(batch_texts) # Process uncached texts if uncached_texts: try: response = self.client.embeddings.create( input=uncached_texts, model=self.model ) api_embeddings = [item.embedding for item in response.data] # Cache results if self.use_cache: for text, embedding in zip(uncached_texts, api_embeddings): self._save_to_cache(text, embedding) # Place embeddings in correct order for idx, embedding in zip(uncached_indices, api_embeddings): while len(batch_embeddings) <= idx: batch_embeddings.append(None) batch_embeddings[idx] = embedding except Exception as e: error_msg = f"Error generating batch embeddings: {e}" logger.error(error_msg, batch=i//batch_size) # Use zeros as fallback for idx in uncached_indices: while len(batch_embeddings) <= idx: batch_embeddings.append(None) batch_embeddings[idx] = [0.0] * settings.embedding_dimensions # Ensure all positions have embeddings for j in range(len(batch_texts)): if j >= len(batch_embeddings) or batch_embeddings[j] is None: batch_embeddings.append([0.0] * settings.embedding_dimensions) all_embeddings.extend(batch_embeddings[:len(batch_texts)]) # Respect API rate limits if i + batch_size < len(texts): time.sleep(0.5) # Log cache stats once at the end total_processed = self.stats["hits"] + self.stats["misses"] if self.use_cache and total_processed > 0: hit_rate = self.stats["hits"] / total_processed logger.info(f"Embedding cache statistics", hits=self.stats["hits"], misses=self.stats["misses"], hit_rate=f"{hit_rate:.2%}") return all_embeddings def embed_chunks(self, chunks: List[Dict[str, Any]], batch_size: int = 20) -> List[Dict[str, Any]]: """Generate embeddings for a list of chunks.""" # Extract texts from chunks texts = [chunk["text"] for chunk in chunks] logger.info(f"Generating embeddings for chunks", count=len(texts), model=self.model, batch_size=batch_size) # Generate embeddings embeddings = self.generate_embeddings(texts, batch_size) # Add embeddings to chunks for i, embedding in enumerate(embeddings): chunks[i]["embedding"] = embedding return chunks def process_file(self, input_file: str, output_file: Optional[str] = None) -> List[Dict[str, Any]]: """Process a file containing chunks and add embeddings.""" logger.info(f"Loading chunks from file", path=input_file) # Load chunks try: with open(input_file, 'r', encoding='utf-8') as f: chunks = json.load(f) logger.info(f"Loaded chunks from file", count=len(chunks)) # Generate embeddings chunks_with_embeddings = self.embed_chunks(chunks) # Save to file if output_file is provided if output_file: with open(output_file, 'w', encoding='utf-8') as f: json.dump(chunks_with_embeddings, f) logger.info(f"Saved chunks with embeddings to file", count=len(chunks_with_embeddings), path=output_file) return chunks_with_embeddings except Exception as e: error_msg = f"Error processing file: {e}" logger.error(error_msg) raise APIError(error_msg, details={"input_file": input_file}) def _get_cache_path(self, text: str) -> str: """Generate cache file path for a text.""" text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() return os.path.join(self.cache_dir, f"{text_hash}.json") def _get_from_cache(self, text: str) -> Optional[List[float]]: """Get embedding from cache.""" cache_path = self._get_cache_path(text) if os.path.exists(cache_path): try: with open(cache_path, 'r') as f: data = json.load(f) return data.get("embedding") except Exception as e: logger.error(f"Error reading from cache", path=cache_path, error=str(e)) return None def _save_to_cache(self, text: str, embedding: List[float]) -> None: """Save embedding to cache.""" cache_path = self._get_cache_path(text) try: with open(cache_path, 'w') as f: json.dump({ "text_preview": text[:100] + "..." if len(text) > 100 else text, "model": self.model, "embedding": embedding, "timestamp": time.time() }, f) # Manage cache size (simple LRU) self._manage_cache_size() except Exception as e: logger.error(f"Error writing to cache", path=cache_path, error=str(e)) def _manage_cache_size(self) -> None: """Manage cache size using LRU strategy.""" max_size_bytes = int(settings.max_cache_size_gb * 1024 * 1024 * 1024) # Get all cache files with their info cache_files = [] for filename in os.listdir(self.cache_dir): if filename.endswith('.json'): filepath = os.path.join(self.cache_dir, filename) try: stats = os.stat(filepath) cache_files.append({ 'path': filepath, 'size': stats.st_size, 'last_access': stats.st_atime }) except Exception: pass # Calculate total size total_size = sum(f['size'] for f in cache_files) # If over limit, remove oldest files if total_size > max_size_bytes: # Sort by last access time (oldest first) cache_files.sort(key=lambda x: x['last_access']) # Remove files until under limit bytes_to_remove = total_size - max_size_bytes bytes_removed = 0 removed_count = 0 for file_info in cache_files: if bytes_removed >= bytes_to_remove: break try: os.remove(file_info['path']) bytes_removed += file_info['size'] removed_count += 1 except Exception: pass mb_removed = bytes_removed / 1024 / 1024 logger.info(f"Cache cleanup completed", files_removed=removed_count, mb_removed=f"{mb_removed:.2f}", total_files=len(cache_files))

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/seanmichaelmcgee/pytorch-docs-refactored'

If you have feedback or need assistance with the MCP directory API, please join our Discord server