import os
import logging
import asyncio
from typing import AsyncGenerator, Dict, Any
from openai import AsyncOpenAI
from app.protocol.models import CompletionRequest, CompletionChunk
from app.core.config import config
logger = logging.getLogger(__name__)
class LLMService:
"""Service for LLM completions using OpenAI or Azure OpenAI"""
def __init__(self):
if config.USE_AZURE_OPENAI:
# Azure OpenAI configuration - create custom HTTP client with API version
import httpx
http_client = httpx.AsyncClient(
params={"api-version": config.AZURE_OPENAI_API_VERSION}
)
base_url = f"{config.AZURE_OPENAI_ENDPOINT.rstrip('/')}/openai/deployments/{config.AZURE_OPENAI_DEPLOYMENT}"
self.client = AsyncOpenAI(
api_key=config.OPENAI_API_KEY,
base_url=base_url,
default_headers={"api-key": config.OPENAI_API_KEY},
http_client=http_client
)
self.model = config.AZURE_OPENAI_DEPLOYMENT
self.api_version = config.AZURE_OPENAI_API_VERSION
self.is_azure = True
else:
self.client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
self.model = config.OPENAI_MODEL
self.is_azure = False
self.temperature = config.OPENAI_TEMPERATURE
async def complete(self, request: CompletionRequest) -> Dict[str, Any]:
"""Get a non-streaming completion"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": request.prompt}],
max_tokens=request.max_tokens,
temperature=request.temperature or self.temperature,
stream=False
)
content = response.choices[0].message.content
return {
"content": content,
"model": self.model,
"usage": {
"prompt_tokens": getattr(response.usage, 'prompt_tokens', 0) if response.usage else 0,
"completion_tokens": getattr(response.usage, 'completion_tokens', 0) if response.usage else 0,
"total_tokens": getattr(response.usage, 'total_tokens', 0) if response.usage else 0
}
}
except Exception as e:
logger.error(f"LLM completion error: {str(e)}")
raise Exception(f"LLM completion failed: {str(e)}")
async def stream_complete(self, request: CompletionRequest) -> AsyncGenerator[CompletionChunk, None]:
"""Get a streaming completion"""
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": request.prompt}],
max_tokens=request.max_tokens,
temperature=request.temperature or self.temperature,
stream=True
)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield CompletionChunk(content=content, finished=False)
# Check if this is the last chunk
if chunk.choices and chunk.choices[0].finish_reason:
yield CompletionChunk(content="", finished=True)
break
except Exception as e:
logger.error(f"LLM streaming error: {str(e)}")
yield CompletionChunk(
content=f"Error: {str(e)}",
finished=True
)
# Global service instance
llm_service = LLMService()