"""AI model router for managing multiple AI providers."""
import asyncio
import logging
import time
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
import openai
import anthropic
import google.generativeai as genai
import ollama
from .config import AIProviderConfig, Settings
logger = logging.getLogger(__name__)
@dataclass
class AIResponse:
"""Standardized AI response."""
content: str
provider: str
model: str
tokens_used: int = 0
cost: float = 0.0
response_time: float = 0.0
error: Optional[str] = None
class AIProviderError(Exception):
"""AI provider error."""
pass
class AIProvider:
"""Base AI provider interface."""
def __init__(self, config: AIProviderConfig):
self.config = config
self.request_count = 0
self.last_request_time = 0.0
async def generate(self, prompt: str, **kwargs: Any) -> AIResponse:
"""Generate response from AI provider."""
raise NotImplementedError
def can_handle_request(self) -> bool:
"""Check if provider can handle new request (rate limiting)."""
current_time = time.time()
if current_time - self.last_request_time < 60:
return self.request_count < self.config.rate_limit
else:
self.request_count = 0
return True
def record_request(self):
"""Record a new request for rate limiting."""
current_time = time.time()
if current_time - self.last_request_time >= 60:
self.request_count = 0
self.request_count += 1
self.last_request_time = current_time
class OpenAIProvider(AIProvider):
"""OpenAI provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self.client = openai.AsyncOpenAI(
api_key=config.api_key,
base_url=config.base_url
)
async def generate(self, prompt: str, **kwargs: Any) -> AIResponse:
"""Generate response using OpenAI."""
start_time = time.time()
self.record_request()
try:
response = await self.client.chat.completions.create(
model=self.config.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
temperature=kwargs.get("temperature", self.config.temperature),
timeout=self.config.timeout
)
response_time = time.time() - start_time
tokens_used = response.usage.total_tokens if response.usage else 0
cost = tokens_used * self.config.cost_per_token
return AIResponse(
content=response.choices[0].message.content or "",
provider=self.config.name,
model=self.config.model,
tokens_used=tokens_used,
cost=cost,
response_time=response_time
)
except Exception as e:
return AIResponse(
content="",
provider=self.config.name,
model=self.config.model,
response_time=time.time() - start_time,
error=str(e)
)
class AnthropicProvider(AIProvider):
"""Anthropic (Claude) provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self.client = anthropic.AsyncAnthropic(
api_key=config.api_key,
base_url=config.base_url
)
async def generate(self, prompt: str, **kwargs: Any) -> AIResponse:
"""Generate response using Anthropic Claude."""
start_time = time.time()
self.record_request()
try:
response = await self.client.messages.create(
model=self.config.model,
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
temperature=kwargs.get("temperature", self.config.temperature),
messages=[{"role": "user", "content": prompt}]
)
response_time = time.time() - start_time
tokens_used = response.usage.input_tokens + response.usage.output_tokens
cost = tokens_used * self.config.cost_per_token
# Handle text block extraction safely
content = ""
if response.content and len(response.content) > 0:
try:
# Try to get text attribute, fallback to string conversion
text_block = response.content[0]
content = getattr(text_block, 'text', str(text_block))
except (AttributeError, IndexError):
content = str(response.content)
return AIResponse(
content=content,
provider=self.config.name,
model=self.config.model,
tokens_used=tokens_used,
cost=cost,
response_time=response_time
)
except Exception as e:
return AIResponse(
content="",
provider=self.config.name,
model=self.config.model,
response_time=time.time() - start_time,
error=str(e)
)
class GoogleProvider(AIProvider):
"""Google Gemini provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
if config.api_key:
genai.configure(api_key=config.api_key) # type: ignore
self.model = genai.GenerativeModel(config.model) # type: ignore
async def generate(self, prompt: str, **kwargs: Any) -> AIResponse:
"""Generate response using Google Gemini."""
start_time = time.time()
self.record_request()
try:
response = await self.model.generate_content_async( # type: ignore
prompt,
generation_config=genai.GenerationConfig( # type: ignore
max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens),
temperature=kwargs.get("temperature", self.config.temperature)
)
)
response_time = time.time() - start_time
tokens_used = response.usage_metadata.total_token_count if hasattr(response, 'usage_metadata') else 0
cost = tokens_used * self.config.cost_per_token
return AIResponse(
content=response.text,
provider=self.config.name,
model=self.config.model,
tokens_used=tokens_used,
cost=cost,
response_time=response_time
)
except Exception as e:
return AIResponse(
content="",
provider=self.config.name,
model=self.config.model,
response_time=time.time() - start_time,
error=str(e)
)
class OllamaProvider(AIProvider):
"""Ollama local provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self.client = ollama.AsyncClient(host=config.base_url)
async def generate(self, prompt: str, **kwargs: Any) -> AIResponse:
"""Generate response using Ollama."""
start_time = time.time()
self.record_request()
try:
response = await self.client.generate(
model=self.config.model,
prompt=prompt,
options={
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
"temperature": kwargs.get("temperature", self.config.temperature)
}
)
response_time = time.time() - start_time
return AIResponse(
content=response['response'],
provider=self.config.name,
model=self.config.model,
tokens_used=0, # Ollama doesn't provide token counts
cost=0.0, # Local model, no cost
response_time=response_time
)
except Exception as e:
return AIResponse(
content="",
provider=self.config.name,
model=self.config.model,
response_time=time.time() - start_time,
error=str(e)
)
class AIRouter:
"""AI model router for managing multiple providers."""
def __init__(self, settings: Settings):
self.settings = settings
self.providers: Dict[str, AIProvider] = {}
self._initialize_providers()
def _initialize_providers(self):
"""Initialize all configured providers."""
for config in self.settings.ai_providers:
if not config.enabled:
continue
try:
if config.type == "openai":
provider = OpenAIProvider(config)
elif config.type == "anthropic":
provider = AnthropicProvider(config)
elif config.type == "google":
provider = GoogleProvider(config)
elif config.type == "ollama":
provider = OllamaProvider(config)
else:
logger.warning(f"Unknown provider type: {config.type}")
continue
self.providers[config.name] = provider
logger.info(f"Initialized provider: {config.name} ({config.type})")
except Exception as e:
logger.error(f"Failed to initialize provider {config.name}: {e}")
def get_available_providers(self) -> List[str]:
"""Get list of available provider names."""
return [
name for name, provider in self.providers.items()
if provider.can_handle_request()
]
def get_provider_by_priority(self) -> Optional[AIProvider]:
"""Get the highest priority available provider."""
provider_configs = sorted(
[config for config in self.settings.ai_providers if config.enabled],
key=lambda x: x.priority
)
for config in provider_configs:
if config.name in self.providers:
provider = self.providers[config.name]
if provider.can_handle_request():
return provider
return None
async def generate(
self,
prompt: str,
provider_name: Optional[str] = None,
**kwargs: Any
) -> AIResponse:
"""Generate response using specified provider or best available."""
# Use specific provider if requested
if provider_name:
if provider_name not in self.providers:
return AIResponse(
content="",
provider=provider_name,
model="",
error=f"Provider '{provider_name}' not available"
)
provider = self.providers[provider_name]
else:
# Use highest priority available provider
provider = self.get_provider_by_priority()
if not provider:
return AIResponse(
content="",
provider="none",
model="",
error="No available providers"
)
# Generate response with retry logic
for attempt in range(provider.config.retry_attempts):
try:
response = await provider.generate(prompt, **kwargs)
if not response.error:
return response
logger.warning(
f"Attempt {attempt + 1} failed for {provider.config.name}: {response.error}"
)
if attempt < provider.config.retry_attempts - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
except Exception as e:
logger.error(f"Provider {provider.config.name} error: {e}")
if attempt < provider.config.retry_attempts - 1:
await asyncio.sleep(2 ** attempt)
# If all attempts failed, try next available provider
if not provider_name: # Only fallback if no specific provider was requested
return await self._fallback_generate(prompt, provider.config.name, **kwargs)
return AIResponse(
content="",
provider=provider.config.name,
model=provider.config.model,
error="All retry attempts failed"
)
async def _fallback_generate(
self,
prompt: str,
failed_provider: str,
**kwargs: Any
) -> AIResponse:
"""Try next available provider as fallback."""
provider_configs = sorted(
[config for config in self.settings.ai_providers
if config.enabled and config.name != failed_provider],
key=lambda x: x.priority
)
for config in provider_configs:
if config.name in self.providers:
provider = self.providers[config.name]
if provider.can_handle_request():
try:
response = await provider.generate(prompt, **kwargs)
if not response.error:
logger.info(f"Fallback successful with {config.name}")
return response
except Exception as e:
logger.error(f"Fallback provider {config.name} failed: {e}")
continue
return AIResponse(
content="",
provider="none",
model="",
error="All providers failed"
)
async def health_check(self) -> Dict[str, Dict[str, Any]]:
"""Check health status of all providers."""
health_status: Dict[str, Dict[str, Any]] = {}
for name, provider in self.providers.items():
try:
# Simple test prompt
response = await provider.generate("Hello", max_tokens=10)
health_status[name] = {
"status": "healthy" if not response.error else "error",
"error": response.error,
"response_time": response.response_time,
"can_handle_request": provider.can_handle_request()
}
except Exception as e:
health_status[name] = {
"status": "error",
"error": str(e),
"response_time": 0.0,
"can_handle_request": False
}
return health_status
# Global router instance
_router: Optional[AIRouter] = None
def get_ai_router(settings: Optional[Settings] = None) -> AIRouter:
"""Get global AI router instance."""
global _router
if _router is None:
if settings is None:
from .config import get_default_settings
settings = get_default_settings()
_router = AIRouter(settings)
return _router
def reset_ai_router():
"""Reset global AI router instance."""
global _router
_router = None