"""OpenAI LLM provider."""
import os
from typing import AsyncIterator
from openai import AsyncOpenAI
from local_deepwiki.providers.base import LLMProvider
class OpenAILLMProvider(LLMProvider):
"""LLM provider using OpenAI API."""
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
"""Initialize the OpenAI provider.
Args:
model: OpenAI model name.
api_key: Optional API key. Uses OPENAI_API_KEY env var if not provided.
"""
self._model = model
self._client = AsyncOpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
async def generate(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> str:
"""Generate text from a prompt.
Args:
prompt: The user prompt.
system_prompt: Optional system prompt.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
Returns:
Generated text.
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = await self._client.chat.completions.create(
model=self._model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].message.content or ""
async def generate_stream(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> AsyncIterator[str]:
"""Generate text from a prompt with streaming.
Args:
prompt: The user prompt.
system_prompt: Optional system prompt.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
Yields:
Generated text chunks.
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
stream = await self._client.chat.completions.create(
model=self._model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
@property
def name(self) -> str:
"""Get the provider name."""
return f"openai:{self._model}"