llmintegrationsystem.py•14.7 kB
"""
LLM Integration System
Handles integration with various Language Learning Models (LLMs) including
OpenAI, Anthropic, and other providers. Provides a unified interface for
chat completions and prompt management.
"""
import asyncio
import os
from typing import Any, Dict, List, Optional, Union
import httpx
import structlog
from pydantic import BaseModel, Field
logger = structlog.get_logger(__name__)
class LLMProvider(BaseModel):
name: str
api_key: Optional[str] = None
base_url: Optional[str] = None
default_model: str
models: List[str]
class ChatMessage(BaseModel):
role: str = Field(..., description="Role of the message sender")
content: str = Field(..., description="Content of the message")
timestamp: Optional[str] = None
class LLMResponse(BaseModel):
content: str
model: str
provider: str
usage: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None
class LLMIntegrationSystem:
def __init__(self):
self.providers = {
"openai": LLMProvider(
name="openai",
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://api.openai.com/v1",
default_model="gpt-3.5-turbo",
models=["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-4o"],
),
"anthropic": LLMProvider(
name="anthropic",
api_key=os.getenv("ANTHROPIC_API_KEY"),
base_url="https://api.anthropic.com/v1",
default_model="claude-3-sonnet-20240229",
models=[
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
],
),
"ollama": LLMProvider(
name="ollama",
api_key=None, # Ollama doesn't require API key
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
default_model="llama2",
models=["llama2", "llama3", "mistral", "codellama", "phi", "gemma"],
),
}
self.client = httpx.AsyncClient(timeout=120.0) # Longer timeout for local models
async def chat(
self,
message: str,
model: str = "gpt-3.5-turbo",
provider: Optional[str] = None,
conversation_history: Optional[List[ChatMessage]] = None,
**kwargs,
) -> str:
# Normalize model name - handle formats like "ollama/mistral:latest" or "mistral:latest"
normalized_model = self._normalize_model_name(model)
if not provider:
provider = self._get_provider_for_model(normalized_model)
# Use normalized model name for API calls
model = normalized_model
if provider not in self.providers:
raise ValueError(f"Unknown provider: {provider}")
provider_config = self.providers[provider]
# Ollama doesn't require API key
if provider != "ollama" and not provider_config.api_key:
raise ValueError(f"API key not found for provider: {provider}")
messages = []
if conversation_history:
messages.extend([{"role": msg.role, "content": msg.content} for msg in conversation_history])
messages.append({"role": "user", "content": message})
try:
if provider == "openai":
response = await self._call_openai(provider_config, model, messages, **kwargs)
elif provider == "anthropic":
response = await self._call_anthropic(provider_config, model, messages, **kwargs)
elif provider == "ollama":
response = await self._call_ollama(provider_config, model, messages, **kwargs)
else:
raise ValueError(f"Provider {provider} not implemented")
return response.content
except Exception as e:
logger.error("LLM chat failed", provider=provider, model=model, error=str(e))
raise
async def chat_stream(
self,
message: str,
model: str = "mistral:latest",
provider: Optional[str] = None,
conversation_history: Optional[List[ChatMessage]] = None,
**kwargs,
):
"""Stream chat responses from LLM providers"""
# Normalize model name - handle formats like "ollama/mistral:latest" or "mistral:latest"
normalized_model = self._normalize_model_name(model)
if not provider:
provider = self._get_provider_for_model(normalized_model)
# Use normalized model name for API calls
model = normalized_model
if provider not in self.providers:
raise ValueError(f"Unknown provider: {provider}")
provider_config = self.providers[provider]
# Ollama doesn't require API key
if provider != "ollama" and not provider_config.api_key:
raise ValueError(f"API key not found for provider: {provider}")
messages = []
if conversation_history:
messages.extend([{"role": msg.role, "content": msg.content} for msg in conversation_history])
messages.append({"role": "user", "content": message})
try:
if provider == "ollama":
async for chunk in self._stream_ollama(provider_config, model, messages, **kwargs):
yield chunk
else:
# For now, only Ollama supports streaming
# Other providers would need similar implementation
raise ValueError(f"Streaming not supported for provider: {provider}")
except Exception as e:
logger.error("LLM streaming chat failed", provider=provider, model=model, error=str(e))
raise
async def _stream_ollama(
self,
provider: LLMProvider,
model: str,
messages: List[Dict[str, str]],
**kwargs,
):
"""Stream responses from Ollama"""
headers = {
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": messages,
"stream": True, # Enable streaming
"options": {
"temperature": kwargs.get("temperature", 0.7),
"num_predict": kwargs.get("max_tokens", 1000),
}
}
try:
async with self.client.stream(
"POST",
f"{provider.base_url}/api/chat",
headers=headers,
json=payload,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip():
try:
data = line.strip()
if isinstance(data, bytes):
data = data.decode('utf-8')
import json
chunk_data = json.loads(data)
if chunk_data.get("message", {}).get("content"):
yield {
"content": chunk_data["message"]["content"],
"done": chunk_data.get("done", False),
"model": model,
"provider": provider.name
}
if chunk_data.get("done", False):
break
except json.JSONDecodeError:
# Skip invalid JSON lines
continue
except httpx.ConnectError:
raise ValueError("Cannot connect to Ollama. Make sure Ollama is running on localhost:11434")
except Exception as e:
logger.error("Ollama streaming API call failed", error=str(e))
raise
async def _call_openai(
self,
provider: LLMProvider,
model: str,
messages: List[Dict[str, str]],
**kwargs,
) -> LLMResponse:
headers = {
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 1000),
}
response = await self.client.post(
f"{provider.base_url}/chat/completions",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
return LLMResponse(
content=data["choices"][0]["message"]["content"],
model=model,
provider=provider.name,
usage=data.get("usage"),
metadata={"response_id": data.get("id")},
)
async def _call_anthropic(
self,
provider: LLMProvider,
model: str,
messages: List[Dict[str, str]],
**kwargs,
) -> LLMResponse:
headers = {
"x-api-key": provider.api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
system_message = ""
user_messages = []
for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
user_messages.append(msg)
payload = {
"model": model,
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
"messages": user_messages,
}
if system_message:
payload["system"] = system_message
response = await self.client.post(
f"{provider.base_url}/messages",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
return LLMResponse(
content=data["content"][0]["text"],
model=model,
provider=provider.name,
usage=data.get("usage"),
metadata={"response_id": data.get("id")},
)
async def _call_ollama(
self,
provider: LLMProvider,
model: str,
messages: List[Dict[str, str]],
**kwargs,
) -> LLMResponse:
headers = {
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": messages,
"stream": False,
"options": {
"temperature": kwargs.get("temperature", 0.7),
"num_predict": kwargs.get("max_tokens", 1000),
}
}
try:
response = await self.client.post(
f"{provider.base_url}/api/chat",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
return LLMResponse(
content=data["message"]["content"],
model=model,
provider=provider.name,
usage=data.get("usage"),
metadata={
"created_at": data.get("created_at"),
"done": data.get("done"),
"total_duration": data.get("total_duration"),
},
)
except httpx.ConnectError:
raise ValueError("Cannot connect to Ollama. Make sure Ollama is running on localhost:11434")
except Exception as e:
logger.error("Ollama API call failed", error=str(e))
raise
def _get_provider_for_model(self, model: str) -> str:
for provider_name, provider_config in self.providers.items():
if model in provider_config.models:
return provider_name
# Default to ollama for local testing if no specific provider matches
return "ollama"
async def get_ollama_models(self) -> List[str]:
"""Get list of available models from Ollama"""
try:
ollama_config = self.providers["ollama"]
response = await self.client.get(f"{ollama_config.base_url}/api/tags")
response.raise_for_status()
data = response.json()
models = [model["name"] for model in data.get("models", [])]
# Update the ollama provider with available models
if models:
self.providers["ollama"].models = models
if models and not self.providers["ollama"].default_model in models:
self.providers["ollama"].default_model = models[0]
return models
except Exception as e:
logger.warning("Could not fetch Ollama models", error=str(e))
return []
async def ensure_ollama_model(self, model: str) -> bool:
"""Ensure a specific model is available in Ollama"""
try:
available_models = await self.get_ollama_models()
if model in available_models:
return True
# If model not available, suggest pulling it
logger.info(f"Model {model} not found. You may need to run: ollama pull {model}")
return False
except Exception as e:
logger.error("Error checking Ollama model availability", error=str(e))
return False
def _normalize_model_name(self, model: str) -> str:
"""Normalize model name to remove provider prefix if present"""
# Handle formats like "ollama/mistral:latest" -> "mistral:latest"
if "/" in model:
parts = model.split("/", 1)
if len(parts) == 2:
provider_prefix, model_name = parts
# Only strip known provider prefixes
if provider_prefix.lower() in ["ollama", "openai", "anthropic"]:
return model_name
return model
def get_available_models(self, provider: Optional[str] = None) -> Dict[str, List[str]]:
if provider:
if provider in self.providers:
return {provider: self.providers[provider].models}
else:
return {}
else:
return {name: config.models for name, config in self.providers.items()}
async def close(self):
await self.client.aclose()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()