model_info_manager.py•17.8 kB
"""
Dynamic Model Information Manager
Queries AI provider APIs to get actual model capabilities
"""
import os
import json
import logging
from typing import Dict, Optional, Any, Tuple, List
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncio
from functools import lru_cache
logger = logging.getLogger(__name__)
@dataclass
class ModelInfo:
"""Model capability information"""
name: str
provider: str
context_window: int
max_output_tokens: int
input_cost_per_1k: float = 0.0
output_cost_per_1k: float = 0.0
supports_vision: bool = False
supports_functions: bool = False
supports_streaming: bool = True
supports_json_mode: bool = False
rate_limit_rpm: int = 0 # Requests per minute
rate_limit_tpm: int = 0 # Tokens per minute
last_updated: Optional[datetime] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def usable_context(self) -> int:
"""Usable context after reserving for output"""
return self.context_window - self.max_output_tokens
def to_dict(self) -> Dict:
"""Convert to dictionary for serialization"""
return {
'name': self.name,
'provider': self.provider,
'context_window': self.context_window,
'max_output_tokens': self.max_output_tokens,
'input_cost_per_1k': self.input_cost_per_1k,
'output_cost_per_1k': self.output_cost_per_1k,
'supports_vision': self.supports_vision,
'supports_functions': self.supports_functions,
'supports_streaming': self.supports_streaming,
'supports_json_mode': self.supports_json_mode,
'rate_limit_rpm': self.rate_limit_rpm,
'rate_limit_tpm': self.rate_limit_tpm,
'last_updated': self.last_updated.isoformat() if self.last_updated else None,
'metadata': self.metadata
}
class ModelInfoManager:
"""Manage model information from various providers"""
# Fallback defaults if API calls fail (conservative values)
FALLBACK_LIMITS = {
'anthropic': {
'claude-3-opus': {'context': 200000, 'output': 4096, 'vision': True, 'json': True},
'claude-3-sonnet': {'context': 200000, 'output': 4096, 'vision': True, 'json': True},
'claude-3-haiku': {'context': 200000, 'output': 4096, 'vision': True, 'json': True},
'claude-2.1': {'context': 200000, 'output': 4096, 'vision': False, 'json': False},
'claude-2.0': {'context': 100000, 'output': 4096, 'vision': False, 'json': False},
'claude-instant': {'context': 100000, 'output': 4096, 'vision': False, 'json': False}
},
'openai': {
'gpt-4-turbo': {'context': 128000, 'output': 4096, 'vision': False, 'json': True},
'gpt-4-vision': {'context': 128000, 'output': 4096, 'vision': True, 'json': True},
'gpt-4-1106': {'context': 128000, 'output': 4096, 'vision': False, 'json': True},
'gpt-4-32k': {'context': 32768, 'output': 4096, 'vision': False, 'json': False},
'gpt-4': {'context': 8192, 'output': 4096, 'vision': False, 'json': False},
'gpt-3.5-turbo': {'context': 16384, 'output': 4096, 'vision': False, 'json': True},
'gpt-3.5-turbo-16k': {'context': 16384, 'output': 4096, 'vision': False, 'json': True}
},
'google': {
'gemini-1.5-pro': {'context': 1000000, 'output': 8192, 'vision': True, 'json': True},
'gemini-1.5-flash': {'context': 1000000, 'output': 8192, 'vision': True, 'json': True},
'gemini-pro': {'context': 32768, 'output': 2048, 'vision': False, 'json': True},
'gemini-pro-vision': {'context': 32768, 'output': 2048, 'vision': True, 'json': True}
},
'cohere': {
'command-r': {'context': 128000, 'output': 4000, 'vision': False, 'json': True},
'command-r-plus': {'context': 128000, 'output': 4000, 'vision': False, 'json': True},
'command': {'context': 4000, 'output': 4000, 'vision': False, 'json': False}
}
}
# Known pricing (as of 2024, in USD per 1K tokens)
PRICING = {
'claude-3-opus': {'input': 0.015, 'output': 0.075},
'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
'gpt-4': {'input': 0.03, 'output': 0.06},
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
'gemini-1.5-pro': {'input': 0.007, 'output': 0.021},
'gemini-1.5-flash': {'input': 0.00035, 'output': 0.00105}
}
def __init__(self, cache_duration: timedelta = timedelta(hours=1)):
"""
Initialize model info manager
Args:
cache_duration: How long to cache model info
"""
self.cache_duration = cache_duration
self.model_cache: Dict[str, ModelInfo] = {}
# Get API keys from environment
self.api_keys = {
'anthropic': os.getenv('ANTHROPIC_API_KEY'),
'openai': os.getenv('OPENAI_API_KEY'),
'google': os.getenv('GOOGLE_API_KEY', os.getenv('GEMINI_API_KEY')),
'cohere': os.getenv('COHERE_API_KEY')
}
# Track whether we have aiohttp available
self.has_aiohttp = False
try:
import aiohttp
self.has_aiohttp = True
except ImportError:
logger.warning("aiohttp not available, using fallback model info only")
def detect_provider(self, model_name: str) -> str:
"""
Detect provider from model name
Args:
model_name: Name of the model
Returns:
Provider name
"""
model_lower = model_name.lower()
if 'claude' in model_lower:
return 'anthropic'
elif 'gpt' in model_lower:
return 'openai'
elif 'gemini' in model_lower or 'palm' in model_lower:
return 'google'
elif 'command' in model_lower:
return 'cohere'
elif 'llama' in model_lower or 'mistral' in model_lower:
return 'local'
else:
return 'unknown'
async def get_model_info(self, model_name: str, provider: str = None) -> ModelInfo:
"""
Get model information, trying API first, then cache, then fallback
Args:
model_name: Name of the model
provider: Provider name (auto-detected if not specified)
Returns:
ModelInfo object
"""
# Determine provider from model name if not specified
if not provider:
provider = self.detect_provider(model_name)
# Check cache first
cache_key = f"{provider}:{model_name}"
if cache_key in self.model_cache:
cached = self.model_cache[cache_key]
if cached.last_updated and \
datetime.now() - cached.last_updated < self.cache_duration:
logger.debug(f"Using cached model info for {model_name}")
return cached
# Try to fetch from API if aiohttp is available
if self.has_aiohttp:
try:
if provider == 'anthropic':
info = await self._fetch_anthropic_info(model_name)
elif provider == 'openai':
info = await self._fetch_openai_info(model_name)
elif provider == 'google':
info = await self._fetch_google_info(model_name)
elif provider == 'cohere':
info = await self._fetch_cohere_info(model_name)
else:
info = self._get_fallback_info(model_name, provider)
# Cache the result
self.model_cache[cache_key] = info
return info
except Exception as e:
logger.warning(f"Failed to fetch model info from API: {e}")
# Use fallback
return self._get_fallback_info(model_name, provider)
async def _fetch_anthropic_info(self, model_name: str) -> ModelInfo:
"""
Fetch model info from Anthropic API
Note: Anthropic doesn't have a direct model info endpoint,
so we use known values and test with a minimal request
"""
if not self.api_keys['anthropic']:
return self._get_fallback_info(model_name, 'anthropic')
# For now, use known values for Anthropic models
# In future, could make a test request to verify access
return self._get_fallback_info(model_name, 'anthropic')
async def _fetch_openai_info(self, model_name: str) -> ModelInfo:
"""
Fetch model info from OpenAI API
OpenAI has a models endpoint but it doesn't return context info,
so we use known values
"""
if not self.api_keys['openai']:
return self._get_fallback_info(model_name, 'openai')
# For now, use known values
return self._get_fallback_info(model_name, 'openai')
async def _fetch_google_info(self, model_name: str) -> ModelInfo:
"""
Fetch model info from Google/Vertex AI
Google has model info in their API responses
"""
if not self.api_keys['google']:
return self._get_fallback_info(model_name, 'google')
# For now, use known values
return self._get_fallback_info(model_name, 'google')
async def _fetch_cohere_info(self, model_name: str) -> ModelInfo:
"""Fetch model info from Cohere API"""
if not self.api_keys['cohere']:
return self._get_fallback_info(model_name, 'cohere')
# For now, use known values
return self._get_fallback_info(model_name, 'cohere')
def _get_fallback_info(self, model_name: str, provider: str) -> ModelInfo:
"""
Get fallback information when API is unavailable
Args:
model_name: Name of the model
provider: Provider name
Returns:
ModelInfo with conservative defaults
"""
# Check hardcoded fallbacks
if provider in self.FALLBACK_LIMITS:
provider_limits = self.FALLBACK_LIMITS[provider]
# Find matching model
for model_key, limits in provider_limits.items():
if model_key in model_name.lower():
# Get pricing if available
pricing = self.PRICING.get(model_key, {'input': 0, 'output': 0})
return ModelInfo(
name=model_name,
provider=provider,
context_window=limits.get('context', 8192),
max_output_tokens=limits.get('output', 2048),
input_cost_per_1k=pricing['input'],
output_cost_per_1k=pricing['output'],
supports_vision=limits.get('vision', False),
supports_functions=True, # Most modern models support functions
supports_streaming=True,
supports_json_mode=limits.get('json', False),
last_updated=datetime.now()
)
# Ultimate fallback - very conservative
logger.warning(f"Using ultimate fallback for {model_name}")
return ModelInfo(
name=model_name,
provider=provider,
context_window=8192, # Conservative default
max_output_tokens=2048,
input_cost_per_1k=0.01, # Assume moderate cost
output_cost_per_1k=0.03,
supports_functions=False,
supports_streaming=True,
last_updated=datetime.now()
)
def get_model_info_sync(self, model_name: str, provider: str = None) -> ModelInfo:
"""
Synchronous version of get_model_info
Args:
model_name: Name of the model
provider: Provider name
Returns:
ModelInfo object
"""
# Try to get from cache first
if not provider:
provider = self.detect_provider(model_name)
cache_key = f"{provider}:{model_name}"
if cache_key in self.model_cache:
cached = self.model_cache[cache_key]
if cached.last_updated and \
datetime.now() - cached.last_updated < self.cache_duration:
return cached
# Use fallback for sync version
info = self._get_fallback_info(model_name, provider)
self.model_cache[cache_key] = info
return info
def estimate_cost(self, model_name: str, input_tokens: int,
output_tokens: int) -> float:
"""
Estimate cost for a model operation
Args:
model_name: Name of the model
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
info = self.get_model_info_sync(model_name)
input_cost = (input_tokens / 1000) * info.input_cost_per_1k
output_cost = (output_tokens / 1000) * info.output_cost_per_1k
return input_cost + output_cost
def can_fit_in_context(self, model_name: str, tokens: int) -> bool:
"""
Check if tokens fit in model's context window
Args:
model_name: Name of the model
tokens: Number of tokens to check
Returns:
True if tokens fit
"""
info = self.get_model_info_sync(model_name)
return tokens <= info.usable_context
def suggest_model_for_size(self, tokens: int,
prefer_provider: str = None) -> Optional[str]:
"""
Suggest a model that can handle the token count
Args:
tokens: Number of tokens needed
prefer_provider: Preferred provider
Returns:
Model name or None if no suitable model
"""
candidates = []
# Check all known models
for provider, models in self.FALLBACK_LIMITS.items():
if prefer_provider and provider != prefer_provider:
continue
for model_key, limits in models.items():
context = limits.get('context', 0)
if context > tokens:
# Get cost for prioritization
pricing = self.PRICING.get(model_key, {'input': 1, 'output': 1})
cost_score = pricing['input'] + pricing['output']
candidates.append({
'model': model_key,
'provider': provider,
'context': context,
'cost': cost_score
})
if not candidates:
return None
# Sort by cost (cheapest first) then by context (smallest sufficient)
candidates.sort(key=lambda x: (x['cost'], x['context']))
return candidates[0]['model']
def get_available_models(self) -> Dict[str, List[str]]:
"""
Get list of available models by provider
Returns:
Dictionary of provider -> list of model names
"""
available = {}
for provider, models in self.FALLBACK_LIMITS.items():
# Check if we have API key for this provider
if self.api_keys.get(provider):
available[provider] = list(models.keys())
return available
def compare_models(self, model1: str, model2: str) -> Dict[str, Any]:
"""
Compare two models side by side
Args:
model1: First model name
model2: Second model name
Returns:
Comparison dictionary
"""
info1 = self.get_model_info_sync(model1)
info2 = self.get_model_info_sync(model2)
return {
'models': {
model1: info1.to_dict(),
model2: info2.to_dict()
},
'comparison': {
'context_ratio': info1.context_window / info2.context_window,
'cost_ratio': {
'input': info1.input_cost_per_1k / info2.input_cost_per_1k if info2.input_cost_per_1k else 0,
'output': info1.output_cost_per_1k / info2.output_cost_per_1k if info2.output_cost_per_1k else 0
},
'larger_context': model1 if info1.context_window > info2.context_window else model2,
'cheaper': model1 if info1.input_cost_per_1k < info2.input_cost_per_1k else model2,
'vision_support': {
model1: info1.supports_vision,
model2: info2.supports_vision
}
}
}
# Singleton instance for easy access
_manager_instance = None
def get_model_manager() -> ModelInfoManager:
"""Get or create singleton model manager instance"""
global _manager_instance
if _manager_instance is None:
_manager_instance = ModelInfoManager()
return _manager_instance