"""
Embedding generation using Sentence Transformers.
Supports local models for privacy and offline use.
"""
from sentence_transformers import SentenceTransformer
from typing import List, Optional
import numpy as np
import logging
from functools import lru_cache
from app.config import settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EmbeddingService:
"""
Generates embeddings for text using Sentence Transformers.
Uses a lightweight model optimized for semantic search.
"""
def __init__(self, model_name: str = settings.EMBEDDING_MODEL):
"""
Initialize the embedding service.
Args:
model_name: Name of the Sentence Transformer model to use
"""
self.model_name = model_name
self._model: Optional[SentenceTransformer] = None
logger.info(f"Embedding service initialized with model: {model_name}")
@property
def model(self) -> SentenceTransformer:
"""Lazy load the model."""
if self._model is None:
logger.info(f"Loading embedding model: {self.model_name}")
self._model = SentenceTransformer(self.model_name)
logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}")
return self._model
def embed_text(self, text: str) -> List[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
Returns:
List of floats representing the embedding vector
"""
if not text.strip():
raise ValueError("Cannot embed empty text")
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
"""
Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
batch_size: Batch size for processing
Returns:
List of embedding vectors
"""
if not texts:
return []
# Filter out empty texts
valid_texts = [t for t in texts if t.strip()]
if len(valid_texts) != len(texts):
logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts")
logger.info(f"Generating embeddings for {len(valid_texts)} texts")
embeddings = self.model.encode(
valid_texts,
batch_size=batch_size,
show_progress_bar=len(valid_texts) > 100,
convert_to_numpy=True
)
return embeddings.tolist()
def embed_query(self, query: str) -> List[float]:
"""
Generate embedding for a search query.
Some models have different embeddings for queries vs documents.
Args:
query: Search query to embed
Returns:
Embedding vector for the query
"""
# For most models, query embedding is the same as document embedding
# But we keep this separate for models that differentiate
return self.embed_text(query)
def get_dimension(self) -> int:
"""Get the embedding dimension."""
return self.model.get_sentence_embedding_dimension()
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
"""
Compute cosine similarity between two embeddings.
Args:
embedding1: First embedding vector
embedding2: Second embedding vector
Returns:
Cosine similarity score (0-1)
"""
vec1 = np.array(embedding1)
vec2 = np.array(embedding2)
# Cosine similarity
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return float(dot_product / (norm1 * norm2))
# Global embedding service instance (lazy loaded)
_embedding_service: Optional[EmbeddingService] = None
def get_embedding_service() -> EmbeddingService:
"""Get the global embedding service instance."""
global _embedding_service
if _embedding_service is None:
_embedding_service = EmbeddingService()
return _embedding_service