embedder.py•15.9 kB
"""Ollama embedder service with batching and retry logic.
Generates embeddings using Ollama HTTP API with batching, retry logic, and
comprehensive error handling for production use.
Constitutional Compliance:
- Principle IV: Performance (batching, connection pooling, <30s for 1000 chunks)
- Principle V: Production quality (retry logic, timeout handling, validation)
- Principle VIII: Type safety (full mypy --strict compliance)
Key Features:
- Direct HTTP calls to Ollama using httpx async client
- Batch embedding generation (50-100 texts per batch)
- Exponential backoff retry logic (3 attempts)
- Timeout handling (30s per request)
- Model validation on startup
- Connection pooling for performance
"""
from __future__ import annotations
import asyncio
from typing import Final, Sequence
import httpx
from pydantic import BaseModel, Field, field_validator
from src.config.settings import get_settings
from src.mcp.mcp_logging import get_logger
# ==============================================================================
# Constants
# ==============================================================================
logger = get_logger(__name__)
# Retry configuration
MAX_RETRIES: Final[int] = 3
INITIAL_RETRY_DELAY: Final[float] = 1.0 # seconds
MAX_RETRY_DELAY: Final[float] = 8.0 # seconds
# Timeout configuration
REQUEST_TIMEOUT: Final[float] = 30.0 # seconds
CONNECTION_TIMEOUT: Final[float] = 5.0 # seconds
# Expected embedding dimensions for nomic-embed-text
EXPECTED_EMBEDDING_DIM: Final[int] = 768
# ==============================================================================
# Pydantic Models
# ==============================================================================
class EmbeddingRequest(BaseModel):
"""Request model for Ollama embedding API.
Attributes:
model: Embedding model name (e.g., "nomic-embed-text")
prompt: Text to embed
"""
model: str = Field(..., min_length=1, description="Embedding model name")
prompt: str = Field(..., min_length=1, description="Text to embed")
model_config = {"frozen": True}
class EmbeddingResponse(BaseModel):
"""Response model for Ollama embedding API.
Attributes:
embedding: List of float values representing the embedding vector
"""
embedding: list[float] = Field(..., min_length=1, description="Embedding vector")
@field_validator("embedding")
@classmethod
def validate_embedding_dimension(cls, v: list[float]) -> list[float]:
"""Validate embedding dimension matches expected size."""
if len(v) != EXPECTED_EMBEDDING_DIM:
raise ValueError(
f"Expected embedding dimension {EXPECTED_EMBEDDING_DIM}, got {len(v)}"
)
return v
model_config = {"frozen": True}
class OllamaModelInfo(BaseModel):
"""Model information from Ollama API.
Attributes:
name: Model name
modified_at: Last modification timestamp
size: Model size in bytes
"""
name: str
modified_at: str
size: int
model_config = {"frozen": True}
class OllamaModelsResponse(BaseModel):
"""Response model for Ollama list models API.
Attributes:
models: List of available models
"""
models: list[OllamaModelInfo]
model_config = {"frozen": True}
# ==============================================================================
# Custom Exceptions
# ==============================================================================
class OllamaError(Exception):
"""Base exception for Ollama-related errors."""
pass
class OllamaConnectionError(OllamaError):
"""Raised when unable to connect to Ollama server."""
pass
class OllamaModelNotFoundError(OllamaError):
"""Raised when requested model is not available."""
pass
class OllamaTimeoutError(OllamaError):
"""Raised when request times out."""
pass
class OllamaValidationError(OllamaError):
"""Raised when response validation fails."""
pass
# ==============================================================================
# Embedder Client
# ==============================================================================
class OllamaEmbedder:
"""Client for Ollama embedding generation with retry logic.
Singleton pattern with connection pooling for performance.
"""
_instance: OllamaEmbedder | None = None
_client: httpx.AsyncClient | None = None
def __new__(cls) -> OllamaEmbedder:
"""Ensure singleton instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
"""Initialize embedder with settings."""
if self._client is not None:
return # Already initialized
settings = get_settings()
self.base_url = str(settings.ollama_base_url)
self.model = settings.ollama_embedding_model
self.batch_size = settings.embedding_batch_size
# Create async HTTP client with connection pooling
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(REQUEST_TIMEOUT, connect=CONNECTION_TIMEOUT),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
)
logger.info(
"Ollama embedder initialized",
extra={
"context": {
"base_url": self.base_url,
"model": self.model,
"batch_size": self.batch_size,
}
},
)
async def close(self) -> None:
"""Close HTTP client and cleanup resources."""
if self._client is not None:
await self._client.aclose()
self._client = None
logger.info("Ollama embedder closed")
async def _request_with_retry(
self, request: EmbeddingRequest, attempt: int = 1
) -> list[float]:
"""Make embedding request with exponential backoff retry.
Args:
request: Embedding request
attempt: Current attempt number (1-indexed)
Returns:
Embedding vector
Raises:
OllamaError: If all retries fail
"""
if self._client is None:
raise OllamaError("Client not initialized")
try:
response = await self._client.post(
"/api/embeddings", json=request.model_dump()
)
response.raise_for_status()
# Parse and validate response
embedding_response = EmbeddingResponse(**response.json())
return embedding_response.embedding
except httpx.TimeoutException as e:
logger.warning(
f"Request timeout (attempt {attempt}/{MAX_RETRIES})",
extra={"context": {"attempt": attempt, "error": str(e)}},
)
if attempt >= MAX_RETRIES:
raise OllamaTimeoutError(
f"Request timed out after {MAX_RETRIES} attempts"
) from e
except httpx.HTTPStatusError as e:
logger.warning(
f"HTTP error {e.response.status_code} (attempt {attempt}/{MAX_RETRIES})",
extra={
"context": {
"attempt": attempt,
"status_code": e.response.status_code,
"error": str(e),
}
},
)
if attempt >= MAX_RETRIES or e.response.status_code == 404:
raise OllamaError(f"HTTP error: {e}") from e
except httpx.ConnectError as e:
logger.warning(
f"Connection error (attempt {attempt}/{MAX_RETRIES})",
extra={"context": {"attempt": attempt, "error": str(e)}},
)
if attempt >= MAX_RETRIES:
raise OllamaConnectionError(
f"Unable to connect to Ollama at {self.base_url}"
) from e
except ValueError as e:
# Pydantic validation error
logger.error(
"Response validation failed",
extra={"context": {"error": str(e)}},
)
raise OllamaValidationError(f"Invalid response format: {e}") from e
except Exception as e:
logger.error(
f"Unexpected error (attempt {attempt}/{MAX_RETRIES})",
extra={"context": {"attempt": attempt, "error": str(e)}},
)
if attempt >= MAX_RETRIES:
raise OllamaError(f"Unexpected error: {e}") from e
# Calculate exponential backoff delay
delay = min(INITIAL_RETRY_DELAY * (2 ** (attempt - 1)), MAX_RETRY_DELAY)
logger.info(
f"Retrying after {delay}s delay",
extra={"context": {"attempt": attempt, "delay_seconds": delay}},
)
await asyncio.sleep(delay)
# Retry
return await self._request_with_retry(request, attempt + 1)
async def generate_embedding(self, text: str) -> list[float]:
"""Generate embedding for single text.
Args:
text: Text to embed
Returns:
Embedding vector (768 dimensions)
Raises:
ValueError: If text is empty
OllamaError: If embedding generation fails
"""
if not text:
raise ValueError("Text cannot be empty")
request = EmbeddingRequest(model=self.model, prompt=text)
return await self._request_with_retry(request)
async def generate_embeddings(self, texts: Sequence[str]) -> list[list[float]]:
"""Generate embeddings for batch of texts.
Args:
texts: Sequence of texts to embed
Returns:
List of embedding vectors
Raises:
ValueError: If texts is empty or contains empty strings
OllamaError: If embedding generation fails
Performance:
Uses asyncio.gather for parallel requests
Processes texts in batches according to batch_size setting
"""
if not texts:
raise ValueError("Texts cannot be empty")
if any(not text for text in texts):
raise ValueError("All texts must be non-empty")
logger.debug(
f"Generating embeddings for {len(texts)} texts",
extra={"context": {"text_count": len(texts)}},
)
start_time = asyncio.get_event_loop().time()
# Create requests for all texts
requests = [EmbeddingRequest(model=self.model, prompt=text) for text in texts]
# Process in parallel (Ollama can handle concurrent requests)
tasks = [self._request_with_retry(req) for req in requests]
embeddings = await asyncio.gather(*tasks)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
logger.info(
f"Generated {len(embeddings)} embeddings",
extra={
"context": {
"text_count": len(texts),
"embedding_count": len(embeddings),
"duration_ms": elapsed_ms,
"avg_ms_per_embedding": elapsed_ms / len(embeddings),
}
},
)
return embeddings
# ==============================================================================
# Public API
# ==============================================================================
async def validate_ollama_connection() -> bool:
"""Check if Ollama is running and model is available.
Returns:
True if connection successful and model available
Raises:
OllamaConnectionError: If unable to connect to Ollama
OllamaModelNotFoundError: If requested model not available
"""
settings = get_settings()
base_url = str(settings.ollama_base_url)
model_name = settings.ollama_embedding_model
logger.info(
"Validating Ollama connection",
extra={"context": {"base_url": base_url, "model": model_name}},
)
async with httpx.AsyncClient(base_url=base_url, timeout=10.0) as client:
try:
# Check if Ollama is running
response = await client.get("/api/tags")
response.raise_for_status()
# Parse models list
models_response = OllamaModelsResponse(**response.json())
# Check if requested model is available
available_models = {model.name for model in models_response.models}
if model_name not in available_models:
error_msg = (
f"Model '{model_name}' not found in Ollama. "
f"Available models: {available_models}\n"
f"Run: ollama pull {model_name}"
)
logger.error(
"Model not found",
extra={
"context": {
"model": model_name,
"available_models": list(available_models),
}
},
)
raise OllamaModelNotFoundError(error_msg)
logger.info(
"Ollama connection validated",
extra={
"context": {
"base_url": base_url,
"model": model_name,
"available_models": list(available_models),
}
},
)
return True
except httpx.ConnectError as e:
error_msg = (
f"Unable to connect to Ollama at {base_url}. "
"Ensure Ollama is running: ollama serve"
)
logger.error(
"Connection failed",
extra={"context": {"base_url": base_url, "error": str(e)}},
)
raise OllamaConnectionError(error_msg) from e
except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error {e.response.status_code}",
extra={
"context": {
"base_url": base_url,
"status_code": e.response.status_code,
}
},
)
raise OllamaError(f"HTTP error: {e}") from e
async def generate_embedding(text: str) -> list[float]:
"""Generate embedding for single text.
Convenience function using singleton embedder instance.
Args:
text: Text to embed
Returns:
Embedding vector (768 dimensions)
Raises:
ValueError: If text is empty
OllamaError: If embedding generation fails
"""
embedder = OllamaEmbedder()
return await embedder.generate_embedding(text)
async def generate_embeddings(texts: Sequence[str]) -> list[list[float]]:
"""Generate embeddings for batch of texts.
Convenience function using singleton embedder instance.
Args:
texts: Sequence of texts to embed
Returns:
List of embedding vectors
Raises:
ValueError: If texts is empty or contains empty strings
OllamaError: If embedding generation fails
"""
embedder = OllamaEmbedder()
return await embedder.generate_embeddings(texts)
# ==============================================================================
# Module Exports
# ==============================================================================
__all__ = [
"OllamaEmbedder",
"OllamaError",
"OllamaConnectionError",
"OllamaModelNotFoundError",
"OllamaTimeoutError",
"OllamaValidationError",
"validate_ollama_connection",
"generate_embedding",
"generate_embeddings",
]