"""Base classes for providers."""
from abc import ABC, abstractmethod
from typing import AsyncIterator
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
async def embed(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for a list of texts.
Args:
texts: List of text strings to embed.
Returns:
List of embedding vectors.
"""
pass
@abstractmethod
def get_dimension(self) -> int:
"""Get the embedding dimension.
Returns:
The dimension of the embedding vectors.
"""
pass
@property
@abstractmethod
def name(self) -> str:
"""Get the provider name."""
pass
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
@abstractmethod
async def generate(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> str:
"""Generate text from a prompt.
Args:
prompt: The user prompt.
system_prompt: Optional system prompt.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
Returns:
Generated text.
"""
pass
@abstractmethod
async def generate_stream(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> AsyncIterator[str]:
"""Generate text from a prompt with streaming.
Args:
prompt: The user prompt.
system_prompt: Optional system prompt.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
Yields:
Generated text chunks.
"""
pass
@property
@abstractmethod
def name(self) -> str:
"""Get the provider name."""
pass