"""OpenAI embeddings generation for semantic search.
This module provides functionality to generate embeddings for message text using
OpenAI's embeddings API, with support for batching, error handling, and incremental
updates.
"""
import logging
import os
import sqlite3
import time
from typing import Any
import numpy as np
from openai import (
APIConnectionError,
APIError,
APIStatusError,
APITimeoutError,
OpenAI,
RateLimitError,
)
from ..db.search_index import get_sync_metadata, set_sync_metadata
from ..exceptions import EmbeddingError
logger = logging.getLogger(__name__)
# Default embedding model
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
# Retry configuration for API errors
MAX_RETRIES = 5
INITIAL_BACKOFF = 1.0 # seconds
MAX_BACKOFF = 60.0 # seconds
# Minimum text length for meaningful embeddings (skip very short texts)
MIN_TEXT_LENGTH = 3
def get_openai_client() -> OpenAI | None:
"""Initialize OpenAI client from OPENAI_API_KEY environment variable.
Returns:
OpenAI client instance if API key is set, None otherwise.
"""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return None
return OpenAI(api_key=api_key)
def generate_embeddings(
texts: list[str],
model: str = DEFAULT_EMBEDDING_MODEL,
client: OpenAI | None = None,
) -> list[np.ndarray]:
"""Generate embeddings for a batch of texts using OpenAI's API.
This function handles rate limiting with exponential backoff and retries
on server errors.
Args:
texts: List of text strings to generate embeddings for.
model: OpenAI embedding model to use (default: text-embedding-3-small).
client: Optional OpenAI client instance. If None, creates a new client.
Returns:
List of numpy arrays (float32) containing embeddings for each input text.
The order matches the input texts.
Raises:
EmbeddingError: If API call fails after all retries or other errors occur.
"""
if not texts:
return []
# Get or create OpenAI client
if client is None:
client = get_openai_client()
if client is None:
raise EmbeddingError(
"OPENAI_API_KEY environment variable not set. "
"Set it to use embedding generation.",
is_retryable=False,
)
# Retry loop with exponential backoff
backoff = INITIAL_BACKOFF
last_error: EmbeddingError | None = None
for attempt in range(MAX_RETRIES):
try:
logger.debug(
f"Generating embeddings for {len(texts)} texts (attempt {attempt + 1}/{MAX_RETRIES})"
)
# Call OpenAI embeddings API
response = client.embeddings.create(
input=texts,
model=model,
)
# Extract embeddings and convert to numpy arrays
embeddings = [
np.array(item.embedding, dtype=np.float32) for item in response.data
]
logger.info(f"Successfully generated {len(embeddings)} embeddings")
return embeddings
except RateLimitError as e:
# Rate limit hit - use exponential backoff
logger.warning(f"Rate limit error on attempt {attempt + 1}/{MAX_RETRIES}: {e}")
last_error = EmbeddingError(
f"Rate limit exceeded from OpenAI API: {e}",
is_retryable=True,
)
if attempt < MAX_RETRIES - 1:
sleep_time = min(backoff, MAX_BACKOFF)
logger.info(f"Waiting {sleep_time:.1f}s before retry...")
time.sleep(sleep_time)
backoff *= 2 # Exponential backoff
else:
# Final retry failed
logger.error(f"Rate limit error persisted after {MAX_RETRIES} attempts")
raise last_error
except APIStatusError as e:
# Handle specific API status errors (4xx, 5xx)
logger.warning(
f"API status error on attempt {attempt + 1}/{MAX_RETRIES}: "
f"status={e.status_code}, message={e.message}"
)
# Check if this is a retryable server error (5xx)
is_server_error = e.status_code >= 500
if is_server_error and attempt < MAX_RETRIES - 1:
sleep_time = min(backoff, MAX_BACKOFF)
logger.info(
f"Server error (HTTP {e.status_code}), retrying in {sleep_time:.1f}s..."
)
time.sleep(sleep_time)
backoff *= 2
last_error = EmbeddingError(
f"OpenAI API returned {e.status_code}: {e.message}",
is_retryable=True,
)
else:
# Client error (4xx) or final attempt
is_retryable = is_server_error
logger.error(
f"API error (HTTP {e.status_code}): {e.message} "
f"(retryable={is_retryable})"
)
raise EmbeddingError(
f"OpenAI API error (HTTP {e.status_code}): {e.message}",
is_retryable=is_retryable,
)
except (APIConnectionError, APITimeoutError) as e:
# Network or timeout errors - retryable
logger.warning(
f"Network/timeout error on attempt {attempt + 1}/{MAX_RETRIES}: {e}"
)
last_error = EmbeddingError(
f"Network or timeout error connecting to OpenAI API: {e}",
is_retryable=True,
)
if attempt < MAX_RETRIES - 1:
sleep_time = min(backoff, MAX_BACKOFF)
logger.info(f"Retrying connection in {sleep_time:.1f}s...")
time.sleep(sleep_time)
backoff *= 2
else:
logger.error(f"Network error persisted after {MAX_RETRIES} attempts")
raise last_error
except APIError as e:
# Other API errors
logger.error(f"Unexpected OpenAI API error on attempt {attempt + 1}/{MAX_RETRIES}: {e}")
raise EmbeddingError(
f"Unexpected OpenAI API error: {e}",
is_retryable=False,
)
except Exception as e:
# Unexpected errors
logger.error(f"Unexpected error on attempt {attempt + 1}/{MAX_RETRIES}: {e}")
raise EmbeddingError(
f"Unexpected error during embedding generation: {e}",
is_retryable=False,
)
# Should not reach here, but handle gracefully
if last_error:
raise last_error
return []
def _is_text_too_short(text: str | None) -> bool:
"""Check if text is too short or empty for meaningful embeddings.
Args:
text: Text to check.
Returns:
True if text is None, empty, or shorter than MIN_TEXT_LENGTH.
"""
if not text:
return True
if len(text.strip()) < MIN_TEXT_LENGTH:
return True
return False
def generate_pending_embeddings(
index_conn: sqlite3.Connection,
batch_size: int = 100,
model: str = DEFAULT_EMBEDDING_MODEL,
) -> dict[str, Any]:
"""Generate embeddings for messages that don't have them yet.
This function performs incremental embedding generation by:
1. Querying messages where embedding IS NULL and rowid > last_embedded_rowid
2. Skipping empty or very short texts (< MIN_TEXT_LENGTH characters)
3. Generating embeddings in batches
4. Storing embeddings as BLOBs in the database
5. Updating last_embedded_rowid after each successful batch
6. Storing embedding_model in sync_metadata on first use
Args:
index_conn: SQLite connection to the search index database.
batch_size: Number of messages to process per batch (default: 100).
model: OpenAI embedding model to use (default: text-embedding-3-small).
Returns:
Dictionary with:
- total_processed: Total number of messages processed
- total_embedded: Total number of embeddings generated
- total_skipped: Total number of messages skipped (too short)
- batches: Number of batches processed
- model: Embedding model used
- errors: List of error messages encountered (if any)
Raises:
EmbeddingError: If a critical error prevents processing.
"""
logger.info(f"Starting pending embeddings generation (batch_size={batch_size})")
# Check for OpenAI API key
client = get_openai_client()
if client is None:
raise EmbeddingError(
"OPENAI_API_KEY environment variable not set. "
"Set it to use embedding generation.",
is_retryable=False,
)
# Get last embedded rowid from metadata
try:
last_embedded_str = get_sync_metadata(index_conn, "last_embedded_rowid")
last_embedded_rowid = int(last_embedded_str) if last_embedded_str else 0
except Exception as e:
logger.error(f"Error reading embedding metadata: {e}")
raise EmbeddingError(
f"Failed to read embedding metadata: {e}",
is_retryable=True,
)
# Store embedding model in metadata on first use
try:
current_model = get_sync_metadata(index_conn, "embedding_model")
if not current_model:
set_sync_metadata(index_conn, "embedding_model", model)
except Exception as e:
logger.warning(f"Could not store embedding model in metadata: {e}")
# Track statistics
total_processed = 0
total_embedded = 0
total_skipped = 0
batches_processed = 0
errors: list[str] = []
try:
while True:
try:
# Query messages without embeddings after last_embedded_rowid
cursor = index_conn.execute(
"""
SELECT rowid, text
FROM message_index
WHERE embedding IS NULL AND rowid > ?
ORDER BY rowid ASC
LIMIT ?
""",
(last_embedded_rowid, batch_size),
)
rows = cursor.fetchall()
if not rows:
# No more messages to process
logger.info("No more messages to process")
break
logger.debug(f"Processing batch of {len(rows)} messages")
# Separate texts and rowids, filtering out empty/short texts
batch_texts: list[str] = []
batch_rowids: list[int] = []
skipped_rowids: list[int] = []
for row in rows:
rowid = row["rowid"]
text = row["text"]
if _is_text_too_short(text):
skipped_rowids.append(rowid)
else:
batch_texts.append(text)
batch_rowids.append(rowid)
# Generate embeddings for non-empty texts
if batch_texts:
try:
embeddings = generate_embeddings(batch_texts, model=model, client=client)
# Store embeddings in database
for rowid, embedding in zip(batch_rowids, embeddings, strict=True):
embedding_blob = embedding.tobytes()
index_conn.execute(
"UPDATE message_index SET embedding = ? WHERE rowid = ?",
(embedding_blob, rowid),
)
total_embedded += len(embeddings)
logger.debug(f"Generated embeddings for {len(embeddings)} messages")
except EmbeddingError as e:
error_msg = f"Embedding batch failed: {e}"
logger.error(error_msg)
errors.append(error_msg)
# Continue with next batch rather than crashing
if not e.is_retryable:
# Non-retryable error, stop processing
raise
# Mark skipped messages with empty blob to avoid reprocessing
# (we store an empty blob rather than NULL so they won't be picked up again)
if skipped_rowids:
for rowid in skipped_rowids:
try:
index_conn.execute(
"UPDATE message_index SET embedding = ? WHERE rowid = ?",
(b"", rowid), # Empty blob
)
except Exception as e:
error_msg = f"Failed to mark skipped message {rowid}: {e}"
logger.warning(error_msg)
errors.append(error_msg)
total_skipped += len(skipped_rowids)
total_processed += len(rows)
batches_processed += 1
# Update last_embedded_rowid to the highest rowid in this batch
# This ensures we make progress even if some embeddings fail
try:
max_rowid = max(row["rowid"] for row in rows)
set_sync_metadata(index_conn, "last_embedded_rowid", str(max_rowid))
last_embedded_rowid = max_rowid
except Exception as e:
error_msg = f"Failed to update embedding progress: {e}"
logger.warning(error_msg)
errors.append(error_msg)
# Commit after each batch for incremental progress
try:
index_conn.commit()
except sqlite3.DatabaseError as e:
error_msg = f"Database error while committing batch: {e}"
logger.error(error_msg)
errors.append(error_msg)
# Don't stop on database errors, they might be transient
except EmbeddingError as e:
if not e.is_retryable:
# Non-retryable error, propagate it
raise
error_msg = f"Retryable error during batch processing: {e}"
logger.warning(error_msg)
errors.append(error_msg)
# Continue with next batch
except Exception as e:
logger.error(f"Critical error during pending embeddings generation: {e}")
raise EmbeddingError(
f"Failed to generate pending embeddings: {e}",
is_retryable=True,
)
result = {
"total_processed": total_processed,
"total_embedded": total_embedded,
"total_skipped": total_skipped,
"batches": batches_processed,
"model": model,
}
if errors:
result["errors"] = errors
logger.warning(f"Completed with {len(errors)} errors")
logger.info(
f"Pending embeddings generation complete: "
f"processed={total_processed}, embedded={total_embedded}, skipped={total_skipped}"
)
return result