embedding_service.py•3.64 kB
"""
Embedding Service
Supports local models (sentence-transformers) and cloud (OpenAI)
"""
import logging
from typing import Optional, Union
logger = logging.getLogger(__name__)
class EmbeddingService:
"""
Service for generating text/code embeddings
Supports two implementations:
- Local: sentence-transformers (free)
- Cloud: OpenAI embeddings (low cost)
"""
def __init__(
self,
provider: str = "local",
model_name: Optional[str] = None,
api_key: Optional[str] = None,
):
"""
Args:
provider: 'local' or 'openai'
model_name: Model name (default: 'all-MiniLM-L6-v2' for local)
api_key: API key (OpenAI only)
"""
self.provider = provider
self.model_name = model_name
self.api_key = api_key
self._model = None
self._initialize_model()
def _initialize_model(self):
"""Initialize embedding model"""
if self.provider == "local":
try:
from sentence_transformers import SentenceTransformer
model_name = self.model_name or "all-MiniLM-L6-v2"
self._model = SentenceTransformer(model_name)
# Silent initialization
except ImportError:
raise ImportError(
"sentence-transformers not installed. " "Run: pip install sentence-transformers"
)
elif self.provider == "openai":
try:
import openai
if not self.api_key:
raise ValueError("OpenAI API key is required")
openai.api_key = self.api_key
self._model = openai
self.model_name = self.model_name or "text-embedding-3-small"
# Silent initialization
except ImportError:
raise ImportError("openai not installed. Run: pip install openai")
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(self, texts: Union[str, list[str]], normalize: bool = True) -> list[list[float]]:
"""
Generate embeddings for text(s)
Args:
texts: Single string or list of strings
normalize: Normalize vectors (recommended for similarity)
Returns:
List of embeddings (vectors)
"""
if isinstance(texts, str):
texts = [texts]
if self.provider == "local" and self._model is not None:
embeddings = self._model.encode(
texts, normalize_embeddings=normalize, show_progress_bar=False
)
result: list[list[float]] = embeddings.tolist()
return result
elif self.provider == "openai" and self._model is not None:
response = self._model.embeddings.create(
input=texts, model=self.model_name or "text-embedding-3-small"
)
embeddings_list: list[list[float]] = [item.embedding for item in response.data]
return embeddings_list
return []
def get_dimension(self) -> int:
"""Return embedding dimensions"""
if self.provider == "local" and self._model is not None:
dim: int = self._model.get_sentence_embedding_dimension()
return dim
elif self.provider == "openai":
model_name = self.model_name or ""
# text-embedding-3-small = 1536 dims
# text-embedding-3-large = 3072 dims
return 1536 if "small" in model_name else 3072
return 384 # fallback