"""
Embedding service providers for semantic search.
This module provides embedding generation capabilities using OpenAI and Ollama APIs,
replacing the heavy sentence-transformers dependency with lightweight API clients.
"""
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import hashlib
import requests
import openai
import ollama
logger = logging.getLogger(__name__)
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a list of texts."""
pass
@abstractmethod
def get_embedding_dimension(self) -> int:
"""Return the embedding vector dimension."""
pass
@abstractmethod
def get_provider_name(self) -> str:
"""Return provider identifier."""
pass
@abstractmethod
def get_model_name(self) -> str:
"""Return model identifier."""
pass
async def embed_single_text(self, text: str) -> List[float]:
"""Convenience method for single text embedding."""
results = await self.embed_texts([text])
if not results:
logger.error("No embeddings returned for single text")
# Return zero vector with default dimension
dimension = self.get_embedding_dimension()
return [0.0] * dimension
return results[0]
def get_model_hash(self) -> str:
"""Generate a hash to detect model changes."""
model_info = f"{self.get_provider_name()}:{self.get_model_name()}"
return hashlib.md5(model_info.encode()).hexdigest()
class RateLimiter:
"""Simple rate limiter for API calls."""
def __init__(self, requests_per_minute: int = 3000):
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request = 0.0
async def wait_if_needed(self):
"""Wait if necessary to respect rate limits."""
current_time = time.time()
time_since_last = current_time - self.last_request
if time_since_last < self.min_interval:
wait_time = self.min_interval - time_since_last
await asyncio.sleep(wait_time)
self.last_request = time.time()
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider using their API."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small",
batch_size: int = 100, rate_limit_rpm: int = 3000,
base_url: Optional[str] = None):
"""
Initialize OpenAI embedding provider.
Args:
api_key: OpenAI API key
model: Embedding model name
batch_size: Number of texts to process per API call
rate_limit_rpm: Requests per minute rate limit
base_url: Optional custom base URL for OpenAI-compatible gateways
"""
# Keep credentials for optional sync detection
self._api_key = api_key
self._base_url = base_url
# Support custom base URL for OpenAI-compatible gateways
self.client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) if base_url else openai.AsyncOpenAI(api_key=api_key)
self.model = model
self.batch_size = batch_size
self.rate_limiter = RateLimiter(rate_limit_rpm)
# Model dimension mapping
self.model_dimensions = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536
}
# Cache for detected dimensions (useful for OpenAI-compatible custom models)
self._dimension_cache: Optional[int] = self.model_dimensions.get(self.model)
async def ensure_model_dimension_cached(self) -> int:
"""Ensure embedding dimension is cached by probing the provider."""
if self._dimension_cache is not None:
return self._dimension_cache
try:
await self.rate_limiter.wait_if_needed()
response = await self.client.embeddings.create(
model=self.model,
input="dimension probe"
)
if response and getattr(response, "data", None):
embedding = getattr(response.data[0], "embedding", None)
if embedding:
self._dimension_cache = len(embedding)
self.model_dimensions[self.model] = self._dimension_cache
logger.info(
f"Detected OpenAI embedding dimension: {self._dimension_cache}"
)
return self._dimension_cache
except Exception as e:
logger.debug(f"Could not auto-detect OpenAI embedding dimension: {e}")
fallback_dimension = self.model_dimensions.get(self.model, 1536)
if self._dimension_cache is None:
self._dimension_cache = fallback_dimension
return fallback_dimension
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using OpenAI API with rate limiting and batching."""
if not texts:
return []
all_embeddings = []
# Process in batches to respect API limits
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
# Wait for rate limiting
await self.rate_limiter.wait_if_needed()
try:
response = await self.client.embeddings.create(
model=self.model,
input=batch
)
batch_embeddings = [data.embedding for data in response.data]
# Cache dimension if unknown and we got a response
if batch_embeddings and self._dimension_cache is None:
self._dimension_cache = len(batch_embeddings[0])
all_embeddings.extend(batch_embeddings)
logger.debug(f"Generated {len(batch_embeddings)} embeddings "
f"(batch {i//self.batch_size + 1})")
except Exception as e:
logger.error(f"OpenAI embedding error for batch {i//self.batch_size + 1}: {e}")
# Return zero vectors for failed batch
dimension = self.get_embedding_dimension()
failed_embeddings = [[0.0] * dimension] * len(batch)
all_embeddings.extend(failed_embeddings)
return all_embeddings
def get_embedding_dimension(self) -> int:
"""Return dimension based on cached or detected value."""
if self._dimension_cache is not None:
return self._dimension_cache
try:
sync_client = openai.OpenAI(
api_key=self._api_key,
base_url=self._base_url
) if self._base_url else openai.OpenAI(api_key=self._api_key)
resp = sync_client.embeddings.create(model=self.model, input="dimension probe")
if resp and resp.data and hasattr(resp.data[0], 'embedding'):
embedding = resp.data[0].embedding
if embedding:
self._dimension_cache = len(embedding)
self.model_dimensions[self.model] = self._dimension_cache
return self._dimension_cache
except Exception as e:
logger.debug(f"Could not detect OpenAI embedding dimension synchronously: {e}")
fallback_dimension = self.model_dimensions.get(self.model, 1536)
self._dimension_cache = fallback_dimension
return fallback_dimension
def get_provider_name(self) -> str:
"""Return provider identifier."""
return "openai"
def get_model_name(self) -> str:
"""Return model identifier."""
return self.model
class OllamaEmbeddingProvider(EmbeddingProvider):
"""Ollama embedding provider for local and remote models."""
def __init__(self, host: str = "http://localhost:11434",
model: str = "nomic-embed-text", timeout: int = 60):
"""
Initialize Ollama embedding provider.
Args:
host: Ollama server URL (e.g., "192.168.1.189:8182")
model: Embedding model name
timeout: Request timeout in seconds
"""
# Handle host formatting
if not host.startswith(('http://', 'https://')):
host = f"http://{host}"
self.host = host
self.model = model
self.timeout = timeout
self.client = ollama.AsyncClient(host=host)
self._dimension_cache = None
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using Ollama API."""
if not texts:
return []
embeddings = []
for i, text in enumerate(texts):
try:
# Create a fresh client for each request to avoid event loop issues
client = ollama.AsyncClient(host=self.host)
response = await client.embeddings(
model=self.model,
prompt=text
)
if 'embedding' in response:
embeddings.append(response['embedding'])
else:
logger.error(f"No embedding in Ollama response for text {i}")
# Return zero vector on error
dimension = await self._get_model_dimension()
embeddings.append([0.0] * dimension)
logger.debug(f"Generated embedding {i+1}/{len(texts)}")
except Exception as e:
logger.error(f"Ollama embedding error for text {i}: {e}")
# Return zero vector on error
dimension = await self._get_model_dimension()
embeddings.append([0.0] * dimension)
return embeddings
async def _get_model_dimension(self) -> int:
"""Query Ollama for model dimension and cache result."""
if self._dimension_cache is not None:
return self._dimension_cache
try:
# Create a fresh client to avoid event loop issues
client = ollama.AsyncClient(host=self.host)
response = await client.embeddings(
model=self.model,
prompt="test"
)
if 'embedding' in response:
self._dimension_cache = len(response['embedding'])
logger.info(f"Detected Ollama model dimension: {self._dimension_cache}")
return self._dimension_cache
else:
logger.warning("Could not determine Ollama model dimension, using 768")
self._dimension_cache = 768
return self._dimension_cache
except Exception as e:
logger.error(f"Error detecting Ollama model dimension: {e}")
self._dimension_cache = 768 # Common embedding dimension
return self._dimension_cache
def get_embedding_dimension(self) -> int:
"""Return embedding dimension (may require async call for detection)."""
if self._dimension_cache is not None:
return self._dimension_cache
# For sync calls, return common dimension and log warning
logger.warning("Embedding dimension not cached, returning default 768")
return 768
def get_provider_name(self) -> str:
"""Return provider identifier."""
return "ollama"
def get_model_name(self) -> str:
"""Return model identifier."""
return self.model
async def test_connection(self) -> bool:
"""Test connection to Ollama server."""
try:
# Create a fresh client to test connection
client = ollama.AsyncClient(host=self.host)
response = await client.list()
logger.info(f"Ollama connection successful: {len(response.get('models', []))} models available")
return True
except Exception as e:
logger.error(f"Ollama connection failed: {e}")
return False
class EmbeddingServiceError(Exception):
"""Custom exception for embedding service errors."""
pass
async def create_embedding_provider(config: Dict[str, Any]) -> EmbeddingProvider:
"""
Create embedding provider from configuration.
Args:
config: Embedding configuration dictionary
Returns:
Configured embedding provider instance
Raises:
EmbeddingServiceError: If provider creation fails
"""
provider_type = config.get("provider", "openai").lower()
try:
if provider_type == "openai":
openai_config = config.get("openai", {})
api_key = openai_config.get("api_key")
if not api_key:
raise EmbeddingServiceError(
"OpenAI API key is required but not provided"
)
provider = OpenAIEmbeddingProvider(
api_key=api_key,
model=openai_config.get("model", "text-embedding-3-small"),
batch_size=openai_config.get("batch_size", 100),
rate_limit_rpm=openai_config.get("rate_limit_rpm", 3000),
base_url=openai_config.get("base_url")
)
await provider.ensure_model_dimension_cached()
logger.info(
f"Created OpenAI embedding provider with model {provider.get_model_name()} "
f"({provider.get_embedding_dimension()}D)"
)
return provider
elif provider_type == "ollama":
ollama_config = config.get("ollama", {})
provider = OllamaEmbeddingProvider(
host=ollama_config.get("host", "http://localhost:11434"),
model=ollama_config.get("model", "nomic-embed-text"),
timeout=ollama_config.get("timeout", 60)
)
# Test connection
if await provider.test_connection():
logger.info(f"Created Ollama embedding provider with model {provider.get_model_name()}")
# Cache dimension
await provider._get_model_dimension()
return provider
else:
raise EmbeddingServiceError(
f"Could not connect to Ollama server at {provider.host}"
)
else:
raise EmbeddingServiceError(
f"Unknown embedding provider: {provider_type}. "
"Supported providers: openai, ollama"
)
except Exception as e:
if isinstance(e, EmbeddingServiceError):
raise
else:
raise EmbeddingServiceError(f"Error creating {provider_type} provider: {e}")
def detect_model_change(current_provider: Optional[EmbeddingProvider],
new_config: Dict[str, Any]) -> bool:
"""
Detect if embedding model configuration has changed.
Args:
current_provider: Current embedding provider (None if first setup)
new_config: New embedding configuration
Returns:
True if model has changed and database rebuild is needed
"""
if current_provider is None:
return False # First setup, not a change
# Get new provider info
new_provider_type = new_config.get("provider", "openai").lower()
if new_provider_type == "openai":
new_model = new_config.get("openai", {}).get("model", "text-embedding-3-small")
elif new_provider_type == "ollama":
new_model = new_config.get("ollama", {}).get("model", "nomic-embed-text")
else:
return True # Unknown provider, assume change
# Compare with current provider
current_provider_type = current_provider.get_provider_name()
current_model = current_provider.get_model_name()
# Treat provider switch with identical model as no change to avoid unnecessary rebuilds
if new_model == current_model:
changed = False
else:
changed = True
if changed:
logger.info(f"Model change detected: {current_provider_type}:{current_model} "
f"-> {new_provider_type}:{new_model}")
return changed
# Synchronous wrapper for backward compatibility
def create_embedding_provider_sync(config: Dict[str, Any]) -> EmbeddingProvider:
"""
Synchronous wrapper for create_embedding_provider.
Note: This should be avoided in async contexts. Use create_embedding_provider directly.
"""
return asyncio.run(create_embedding_provider(config))