"""AI model management system for Keyboard Maestro MCP tools.
This module provides comprehensive AI model management with support for
multiple providers, automatic model selection, and usage optimization.
Implements enterprise-grade AI infrastructure with security and cost control.
Security: All API keys encrypted and model access controlled.
Performance: Intelligent model selection and response caching.
Type Safety: Complete integration with AI type system.
"""
import hashlib
import json
import os
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from src.core.ai_integration import (
DEFAULT_AI_MODELS,
AIModel,
AIModelId,
AIModelType,
AIOperation,
AIRequest,
AIResponse,
CostAmount,
ProcessingMode,
TokenCount,
)
from src.core.either import Either
from src.core.logging import get_logger
logger = get_logger(__name__)
class AIError(Exception):
"""AI processing error with detailed context."""
def __init__(self, error_type: str, message: str, details: dict | None = None):
self.error_type = error_type
self.message = message
self.details = details or {}
super().__init__(f"{error_type}: {message}")
@classmethod
def initialization_failed(cls, details: str) -> "AIError":
return cls(
"initialization_failed",
f"AI system initialization failed: {details}",
)
@classmethod
def no_suitable_model(cls, operation: AIOperation) -> "AIError":
return cls(
"no_suitable_model",
f"No suitable model found for operation: {operation.value}",
)
@classmethod
def model_selection_failed(cls, details: str) -> "AIError":
return cls("model_selection_failed", f"Model selection failed: {details}")
@classmethod
def api_call_failed(cls, model: str, details: str) -> "AIError":
return cls("api_call_failed", f"API call to {model} failed: {details}")
@classmethod
def rate_limit_exceeded(cls, model: str) -> "AIError":
return cls("rate_limit_exceeded", f"Rate limit exceeded for model: {model}")
@classmethod
def invalid_response(cls, details: str) -> "AIError":
return cls("invalid_response", f"Invalid AI response: {details}")
@classmethod
def cost_limit_exceeded(cls, estimated: float, limit: float) -> "AIError":
return cls(
"cost_limit_exceeded",
f"Cost estimate {estimated} exceeds limit {limit}",
)
@dataclass
class ModelUsageTracker:
"""Track usage statistics for individual models."""
model_id: AIModelId
requests_this_minute: int = 0
requests_this_hour: int = 0
total_requests: int = 0
total_tokens: TokenCount = TokenCount(0)
total_cost: CostAmount = CostAmount(0.0)
last_request_time: datetime | None = None
last_minute_reset: datetime = field(default_factory=lambda: datetime.now(UTC))
last_hour_reset: datetime = field(default_factory=lambda: datetime.now(UTC))
def record_request(self, tokens: TokenCount, cost: CostAmount) -> None:
"""Record a new request and update statistics."""
now = datetime.now(UTC)
# Reset counters if needed
if (now - self.last_minute_reset).total_seconds() >= 60:
self.requests_this_minute = 0
self.last_minute_reset = now
if (now - self.last_hour_reset).total_seconds() >= 3600:
self.requests_this_hour = 0
self.last_hour_reset = now
# Update counters
self.requests_this_minute += 1
self.requests_this_hour += 1
self.total_requests += 1
self.total_tokens = TokenCount(self.total_tokens + tokens)
self.total_cost = CostAmount(self.total_cost + cost)
self.last_request_time = now
def can_make_request(self, model: AIModel) -> bool:
"""Check if we can make a request within rate limits."""
return self.requests_this_minute < model.rate_limit_per_minute
@dataclass
class ModelCacheEntry:
"""Cache entry for AI model responses."""
request_hash: str
response: AIResponse
timestamp: datetime
access_count: int = 1
def is_expired(self, ttl_seconds: int = 3600) -> bool:
"""Check if cache entry is expired."""
return (datetime.now(UTC) - self.timestamp).total_seconds() > ttl_seconds
def record_access(self) -> None:
"""Record cache access."""
self.access_count += 1
class AIModelManager:
"""Comprehensive AI model management and selection system."""
def __init__(self, cache_ttl: int = 3600, max_cache_size: int = 1000):
self.available_models: dict[AIModelId, AIModel] = {}
self.usage_trackers: dict[AIModelId, ModelUsageTracker] = {}
self.model_cache: dict[str, ModelCacheEntry] = {}
self.cache_ttl = cache_ttl
self.max_cache_size = max_cache_size
self.initialized = False
self.api_keys: dict[AIModelType, str] = {}
self.provider_registry = None # Will be initialized in _initialize_providers
# Load default models
for _model_name, model in DEFAULT_AI_MODELS.items():
self.available_models[model.model_id] = model
self.usage_trackers[model.model_id] = ModelUsageTracker(model.model_id)
async def initialize(
self,
api_keys: dict[str, str] | None = None,
) -> Either[AIError, None]:
"""Initialize AI model manager with API keys and validation."""
try:
logger.info("Initializing AI model manager")
# Load API keys from environment or provided dict
self._load_api_keys(api_keys or {})
# Validate available models
validation_results = await self._validate_models()
if validation_results.is_left():
return validation_results
# Initialize provider connections
await self._initialize_providers()
self.initialized = True
logger.info(
f"AI model manager initialized with {len(self.available_models)} models",
)
return Either.right(None)
except Exception as e:
logger.error(f"AI model manager initialization failed: {e}")
return Either.left(AIError.initialization_failed(str(e)))
def _load_api_keys(self, provided_keys: dict[str, str]) -> None:
"""Load API keys from environment variables and provided dictionary."""
# Environment variables take precedence
env_keys = {
AIModelType.OPENAI: os.getenv("OPENAI_API_KEY"),
AIModelType.AZURE_OPENAI: os.getenv("AZURE_OPENAI_API_KEY"),
AIModelType.GOOGLE_AI: os.getenv("GOOGLE_AI_API_KEY"),
AIModelType.ANTHROPIC: os.getenv("ANTHROPIC_API_KEY"),
}
# Add non-None environment keys
for model_type, key in env_keys.items():
if key:
self.api_keys[model_type] = key
# Add provided keys (don't override env keys)
for key_name, key_value in provided_keys.items():
try:
model_type = AIModelType(key_name.lower())
if model_type not in self.api_keys:
self.api_keys[model_type] = key_value
except ValueError:
logger.warning(f"Unknown API key type: {key_name}")
async def _validate_models(self) -> Either[AIError, None]:
"""Validate that models have required API keys."""
try:
missing_keys = []
for _model_id, model in self.available_models.items():
if model.api_key_required and model.model_type not in self.api_keys:
missing_keys.append(
f"{model.model_type.value} (for {model.display_name})",
)
if missing_keys:
logger.warning(f"Missing API keys for: {', '.join(missing_keys)}")
# Don't fail initialization - some models might still work
return Either.right(None)
except Exception as e:
return Either.left(
AIError.initialization_failed(f"Model validation failed: {e}"),
)
async def _initialize_providers(self) -> None:
"""Initialize connections to AI providers."""
try:
# Initialize provider factory and clients
from .providers import get_provider_registry, initialize_providers_from_env
# Initialize providers from environment variables
initialized_providers = initialize_providers_from_env()
if initialized_providers:
logger.info(
f"Successfully initialized AI providers: {initialized_providers}",
)
# Get provider registry for request processing
self.provider_registry = get_provider_registry()
# Test provider health
healthy_provider = await self.provider_registry.get_healthy_provider()
if healthy_provider:
logger.info(
f"Healthy provider available: {healthy_provider.provider_name}",
)
else:
logger.warning("No healthy providers available")
else:
logger.warning(
"No AI providers initialized - check API key configuration",
)
self.provider_registry = None
except Exception as e:
logger.error(f"Provider initialization failed: {e}")
self.provider_registry = None
def select_best_model(
self,
operation: AIOperation,
mode: ProcessingMode,
input_size: int = 0,
cost_limit: CostAmount | None = None,
preferred_providers: set[AIModelType] | None = None,
) -> Either[AIError, AIModel]:
"""Select the best model for given operation and constraints."""
try:
if not self.initialized:
return Either.left(
AIError.initialization_failed("Manager not initialized"),
)
# Filter suitable models
suitable_models = []
for model in self.available_models.values():
# Check if model can handle operation
if not model.can_handle_operation(operation, input_size):
continue
# Check if we have API key
if model.api_key_required and model.model_type not in self.api_keys:
continue
# Check provider preference
if preferred_providers and model.model_type not in preferred_providers:
continue
# Check rate limits
tracker = self.usage_trackers[model.model_id]
if not tracker.can_make_request(model):
continue
suitable_models.append(model)
if not suitable_models:
return Either.left(AIError.no_suitable_model(operation))
# Apply cost filter
if cost_limit:
estimated_tokens = TokenCount(
max(input_size // 4, 100),
) # Rough estimate
suitable_models = [
model
for model in suitable_models
if model.estimate_cost(estimated_tokens, estimated_tokens)
<= cost_limit
]
if not suitable_models:
return Either.left(
AIError.cost_limit_exceeded(0, float(cost_limit)),
)
# Select based on processing mode
selected_model = self._select_by_mode(suitable_models, mode, operation)
logger.debug(
f"Selected model {selected_model.display_name} for {operation.value} in {mode.value} mode",
)
return Either.right(selected_model)
except Exception as e:
return Either.left(AIError.model_selection_failed(str(e)))
def _select_by_mode(
self,
models: list[AIModel],
mode: ProcessingMode,
_operation: AIOperation,
) -> AIModel:
"""Select model based on processing mode preferences."""
if mode == ProcessingMode.FAST:
# Prefer faster, smaller models
return min(models, key=lambda m: (m.max_tokens, m.cost_per_input_token))
if mode == ProcessingMode.ACCURATE:
# Prefer larger, more capable models
return max(models, key=lambda m: (m.context_window, m.max_tokens))
if mode == ProcessingMode.COST_EFFECTIVE:
# Prefer lowest cost models
return min(
models,
key=lambda m: m.cost_per_input_token + m.cost_per_output_token,
)
if mode == ProcessingMode.CREATIVE:
# Prefer models known for creativity (bias toward GPT-4)
gpt4_models = [m for m in models if "gpt-4" in m.model_name.lower()]
if gpt4_models:
return max(gpt4_models, key=lambda m: m.max_tokens)
return max(models, key=lambda m: m.max_tokens)
# BALANCED
# Balance between capability and cost
scored_models = []
for model in models:
# Score based on context window, cost efficiency, and features
capability_score = model.context_window / 1000 # Normalize
cost_efficiency = 1.0 / (
model.cost_per_input_token + 0.000001
) # Avoid division by zero
feature_score = (
(2 if model.supports_vision else 0)
+ (1 if model.supports_function_calling else 0)
+ (1 if model.supports_streaming else 0)
)
total_score = capability_score + cost_efficiency + feature_score
scored_models.append((total_score, model))
return max(scored_models, key=lambda x: x[0])[1]
def get_model_by_id(self, model_id: AIModelId) -> Either[AIError, AIModel]:
"""Get model by ID with validation."""
if model_id not in self.available_models:
return Either.left(
AIError.model_selection_failed(f"Model not found: {model_id}"),
)
model = self.available_models[model_id]
if model.api_key_required and model.model_type not in self.api_keys:
return Either.left(
AIError.model_selection_failed(f"No API key for model: {model_id}"),
)
return Either.right(model)
def list_available_models(
self,
operation: AIOperation | None = None,
provider: AIModelType | None = None,
) -> list[dict[str, Any]]:
"""List available models with metadata."""
models = []
for model in self.available_models.values():
# Filter by operation capability
if operation and not model.can_handle_operation(operation):
continue
# Filter by provider
if provider and model.model_type != provider:
continue
# Get usage statistics
tracker = self.usage_trackers[model.model_id]
models.append(
{
"model_id": model.model_id,
"model_name": model.model_name,
"display_name": model.display_name,
"provider": model.model_type.value,
"supports_vision": model.supports_vision,
"supports_functions": model.supports_function_calling,
"context_window": model.context_window,
"cost_per_1k_input": model.cost_per_input_token * 1000,
"cost_per_1k_output": model.cost_per_output_token * 1000,
"rate_limit": model.rate_limit_per_minute,
"available": (
not model.api_key_required or model.model_type in self.api_keys
),
"usage_stats": {
"total_requests": tracker.total_requests,
"total_tokens": tracker.total_tokens,
"total_cost": tracker.total_cost,
"requests_this_minute": tracker.requests_this_minute,
},
},
)
return sorted(models, key=lambda x: x["display_name"])
def record_usage(self, model_id: AIModelId, response: AIResponse) -> None:
"""Record usage statistics for a model."""
if model_id in self.usage_trackers:
tracker = self.usage_trackers[model_id]
tracker.record_request(response.tokens_used, response.cost_estimate)
def get_usage_statistics(self, model_id: AIModelId | None = None) -> dict[str, Any]:
"""Get usage statistics for model(s)."""
if model_id:
if model_id not in self.usage_trackers:
return {}
tracker = self.usage_trackers[model_id]
model = self.available_models[model_id]
return {
"model_id": model_id,
"model_name": model.display_name,
"total_requests": tracker.total_requests,
"total_tokens": tracker.total_tokens,
"total_cost": tracker.total_cost,
"requests_this_minute": tracker.requests_this_minute,
"requests_this_hour": tracker.requests_this_hour,
"last_request": tracker.last_request_time.isoformat()
if tracker.last_request_time
else None,
"average_cost_per_request": (
tracker.total_cost / tracker.total_requests
if tracker.total_requests > 0
else 0.0
),
"can_make_request": tracker.can_make_request(model),
}
# Aggregate statistics
total_requests = sum(t.total_requests for t in self.usage_trackers.values())
total_tokens = sum(t.total_tokens for t in self.usage_trackers.values())
total_cost = sum(t.total_cost for t in self.usage_trackers.values())
return {
"total_models": len(self.available_models),
"available_models": len(
[
m
for m in self.available_models.values()
if not m.api_key_required or m.model_type in self.api_keys
],
),
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
"cache_size": len(self.model_cache),
"cache_hit_ratio": self._calculate_cache_hit_ratio(),
}
def _calculate_cache_hit_ratio(self) -> float:
"""Calculate cache hit ratio."""
if not self.model_cache:
return 0.0
total_accesses = sum(entry.access_count for entry in self.model_cache.values())
cache_hits = sum(entry.access_count - 1 for entry in self.model_cache.values())
return cache_hits / total_accesses if total_accesses > 0 else 0.0
def get_cached_response(self, request: AIRequest) -> AIResponse | None:
"""Get cached response for request if available."""
request_hash = self._hash_request(request)
if request_hash not in self.model_cache:
return None
entry = self.model_cache[request_hash]
# Check if expired
if entry.is_expired(self.cache_ttl):
del self.model_cache[request_hash]
return None
# Record access and return response
entry.record_access()
logger.debug(f"Cache hit for request {request.request_id}")
return entry.response
def cache_response(self, request: AIRequest, response: AIResponse) -> None:
"""Cache response for future use."""
request_hash = self._hash_request(request)
# Ensure cache size limit
if len(self.model_cache) >= self.max_cache_size:
self._evict_cache_entries()
# Cache the response
self.model_cache[request_hash] = ModelCacheEntry(
request_hash=request_hash,
response=response,
timestamp=datetime.now(UTC),
)
logger.debug(f"Cached response for request {request.request_id}")
def _hash_request(self, request: AIRequest) -> str:
"""Generate hash for request caching."""
cache_key = {
"model": request.model.model_id,
"operation": request.operation.value,
"input": request.prepare_input_for_model(),
"temperature": request.temperature,
"max_tokens": request.get_effective_max_tokens(),
"system_prompt": request.system_prompt,
}
key_string = json.dumps(cache_key, sort_keys=True)
return hashlib.sha256(key_string.encode()).hexdigest()
def _evict_cache_entries(self) -> None:
"""Evict least recently used cache entries."""
# Remove expired entries first
expired_keys = [
key
for key, entry in self.model_cache.items()
if entry.is_expired(self.cache_ttl)
]
for key in expired_keys:
del self.model_cache[key]
# If still over limit, remove least recently accessed
if len(self.model_cache) >= self.max_cache_size:
sorted_entries = sorted(
self.model_cache.items(),
key=lambda x: (x[1].timestamp, x[1].access_count),
)
# Remove 25% of entries
remove_count = max(1, len(sorted_entries) // 4)
for key, _ in sorted_entries[:remove_count]:
del self.model_cache[key]
def clear_cache(self) -> None:
"""Clear the response cache."""
self.model_cache.clear()
logger.info("Model response cache cleared")