"""OpenAI embedding provider."""
import os
from openai import AsyncOpenAI
from local_deepwiki.providers.base import EmbeddingProvider
# Embedding dimensions for OpenAI models
OPENAI_EMBEDDING_DIMENSIONS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""Embedding provider using OpenAI API."""
def __init__(self, model: str = "text-embedding-3-small", api_key: str | None = None):
"""Initialize the OpenAI embedding provider.
Args:
model: OpenAI embedding model name.
api_key: Optional API key. Uses OPENAI_API_KEY env var if not provided.
"""
self._model = model
self._client = AsyncOpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
self._dimension = OPENAI_EMBEDDING_DIMENSIONS.get(model, 1536)
async def embed(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for a list of texts.
Args:
texts: List of text strings to embed.
Returns:
List of embedding vectors.
"""
response = await self._client.embeddings.create(
model=self._model,
input=texts,
)
return [item.embedding for item in response.data]
def get_dimension(self) -> int:
"""Get the embedding dimension.
Returns:
The dimension of the embedding vectors.
"""
return self._dimension
@property
def name(self) -> str:
"""Get the provider name."""
return f"openai:{self._model}"