Skip to main content
Glama
base.py20.4 kB
"""Base model provider interface and data classes.""" import base64 import binascii import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from tools.models import ToolModelCategory from utils.file_types import IMAGES, get_image_mime_type logger = logging.getLogger(__name__) class ProviderType(Enum): """Supported model provider types.""" GOOGLE = "google" OPENAI = "openai" XAI = "xai" OPENROUTER = "openrouter" CUSTOM = "custom" DIAL = "dial" KIMI = "kimi" GLM = "glm" class TemperatureConstraint(ABC): """Abstract base class for temperature constraints.""" @abstractmethod def validate(self, temperature: float) -> bool: """Check if temperature is valid.""" pass @abstractmethod def get_corrected_value(self, temperature: float) -> float: """Get nearest valid temperature.""" pass @abstractmethod def get_description(self) -> str: """Get human-readable description of constraint.""" pass @abstractmethod def get_default(self) -> float: """Get model's default temperature.""" pass class FixedTemperatureConstraint(TemperatureConstraint): """For models that only support one temperature value (e.g., O3).""" def __init__(self, value: float): self.value = value def validate(self, temperature: float) -> bool: return abs(temperature - self.value) < 1e-6 # Handle floating point precision def get_corrected_value(self, temperature: float) -> float: return self.value def get_description(self) -> str: return f"Only supports temperature={self.value}" def get_default(self) -> float: return self.value class RangeTemperatureConstraint(TemperatureConstraint): """For models supporting continuous temperature ranges.""" def __init__(self, min_temp: float, max_temp: float, default: float = None): self.min_temp = min_temp self.max_temp = max_temp self.default_temp = default or (min_temp + max_temp) / 2 def validate(self, temperature: float) -> bool: return self.min_temp <= temperature <= self.max_temp def get_corrected_value(self, temperature: float) -> float: return max(self.min_temp, min(self.max_temp, temperature)) def get_description(self) -> str: return f"Supports temperature range [{self.min_temp}, {self.max_temp}]" def get_default(self) -> float: return self.default_temp class DiscreteTemperatureConstraint(TemperatureConstraint): """For models supporting only specific temperature values.""" def __init__(self, allowed_values: list[float], default: float = None): self.allowed_values = sorted(allowed_values) self.default_temp = default or allowed_values[len(allowed_values) // 2] def validate(self, temperature: float) -> bool: return any(abs(temperature - val) < 1e-6 for val in self.allowed_values) def get_corrected_value(self, temperature: float) -> float: return min(self.allowed_values, key=lambda x: abs(x - temperature)) def get_description(self) -> str: return f"Supports temperatures: {self.allowed_values}" def get_default(self) -> float: return self.default_temp def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint: """Create temperature constraint object from configuration string. Args: constraint_type: Type of constraint ("fixed", "range", "discrete") Returns: TemperatureConstraint object based on configuration """ if constraint_type == "fixed": # Fixed temperature models (O3/O4) only support temperature=1.0 return FixedTemperatureConstraint(1.0) elif constraint_type == "discrete": # For models with specific allowed values - using common OpenAI values as default return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3) else: # Default range constraint (for "range" or None) return RangeTemperatureConstraint(0.0, 2.0, 0.3) @dataclass class ModelCapabilities: """Capabilities and constraints for a specific model.""" provider: ProviderType model_name: str friendly_name: str # Human-friendly name like "Gemini" or "OpenAI" context_window: int # Total context window size in tokens max_output_tokens: int # Maximum output tokens per request supports_extended_thinking: bool = False supports_system_prompts: bool = True supports_streaming: bool = True supports_function_calling: bool = False supports_images: bool = False # Whether model can process images max_image_size_mb: float = 0.0 # Maximum total size for all images in MB supports_temperature: bool = True # Whether model accepts temperature parameter in API calls # Additional fields for comprehensive model information description: str = "" # Human-readable description of the model aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model # JSON mode support (for providers that support structured output) supports_json_mode: bool = False # Thinking mode support (for models with thinking capabilities) max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models # Custom model flag (for models that only work with custom endpoints) is_custom: bool = False # Whether this model requires custom API endpoints # Temperature constraint object - defines temperature limits and behavior temperature_constraint: TemperatureConstraint = field( default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) ) @dataclass class ModelResponse: """Response from a model provider.""" content: str usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens model_name: str = "" friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI" provider: ProviderType = ProviderType.GOOGLE metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata @property def total_tokens(self) -> int: """Get total tokens used.""" return self.usage.get("total_tokens", 0) class ModelProvider(ABC): """Abstract base class for model providers.""" # All concrete providers must define their supported models SUPPORTED_MODELS: dict[str, Any] = {} # Default maximum image size in MB DEFAULT_MAX_IMAGE_SIZE_MB = 20.0 def __init__(self, api_key: str, **kwargs): """Initialize the provider with API key and optional configuration.""" self.api_key = api_key self.config = kwargs @abstractmethod def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a specific model.""" pass @abstractmethod def generate_content( self, prompt: str, model_name: str, system_prompt: Optional[str] = None, temperature: float = 0.3, max_output_tokens: Optional[int] = None, **kwargs, ) -> ModelResponse: """Generate content using the model. Args: prompt: User prompt to send to the model model_name: Name of the model to use system_prompt: Optional system prompt for model behavior temperature: Sampling temperature (0-2) max_output_tokens: Maximum tokens to generate **kwargs: Provider-specific parameters Returns: ModelResponse with generated content and metadata """ pass @abstractmethod def count_tokens(self, text: str, model_name: str) -> int: """Count tokens for the given text using the specified model's tokenizer.""" pass @abstractmethod def get_provider_type(self) -> ProviderType: """Get the provider type.""" pass @abstractmethod def validate_model_name(self, model_name: str) -> bool: """Validate if the model name is supported by this provider.""" pass def get_effective_temperature(self, model_name: str, requested_temperature: float) -> Optional[float]: """Get the effective temperature to use for a model given a requested temperature. This method handles: - Models that don't support temperature (returns None) - Fixed temperature models (returns the fixed value) - Clamping to min/max range for models with constraints Args: model_name: The model to get temperature for requested_temperature: The temperature requested by the user/tool Returns: The effective temperature to use, or None if temperature shouldn't be passed """ try: capabilities = self.get_capabilities(model_name) # Check if model supports temperature at all if not capabilities.supports_temperature: return None # Use temperature constraint to get corrected value corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature) if corrected_temp != requested_temperature: logger.debug( f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}" ) return corrected_temp except Exception as e: logger.debug(f"Could not determine effective temperature for {model_name}: {e}") # If we can't get capabilities, return the requested temperature return requested_temperature def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: """Validate model parameters against capabilities. Raises: ValueError: If parameters are invalid """ capabilities = self.get_capabilities(model_name) # Validate temperature using constraint if not capabilities.temperature_constraint.validate(temperature): constraint_desc = capabilities.temperature_constraint.get_description() raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}") @abstractmethod def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode.""" pass def get_model_configurations(self) -> dict[str, ModelCapabilities]: """Get model configurations for this provider. This is a hook method that subclasses can override to provide their model configurations from different sources. Returns: Dictionary mapping model names to their ModelCapabilities objects """ # Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects) if hasattr(self, "SUPPORTED_MODELS"): return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)} return {} def get_all_model_aliases(self) -> dict[str, list[str]]: """Get all model aliases for this provider. This is a hook method that subclasses can override to provide aliases from different sources. Returns: Dictionary mapping model names to their list of aliases """ # Default implementation extracts from ModelCapabilities objects aliases = {} for model_name, capabilities in self.get_model_configurations().items(): if capabilities.aliases: aliases[model_name] = capabilities.aliases return aliases def _resolve_model_name(self, model_name: str) -> str: """Resolve model shorthand to full name. This implementation uses the hook methods to support different model configuration sources. Args: model_name: Model name that may be an alias Returns: Resolved model name """ # Get model configurations from the hook method model_configs = self.get_model_configurations() # First check if it's already a base model name (case-sensitive exact match) if model_name in model_configs: return model_name # Check case-insensitively for both base models and aliases model_name_lower = model_name.lower() # Check base model names case-insensitively for base_model in model_configs: if base_model.lower() == model_name_lower: return base_model # Check aliases from the hook method all_aliases = self.get_all_model_aliases() for base_model, aliases in all_aliases.items(): if any(alias.lower() == model_name_lower for alias in aliases): return base_model # If not found, return as-is return model_name def list_models(self, respect_restrictions: bool = True) -> list[str]: """Return a list of model names supported by this provider. This implementation uses the get_model_configurations() hook to support different model configuration sources. Args: respect_restrictions: Whether to apply provider-specific restriction logic. Returns: List of model names available from this provider """ from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() if respect_restrictions else None models = [] # Get model configurations from the hook method model_configs = self.get_model_configurations() for model_name in model_configs: # Check restrictions if enabled if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): continue # Add the base model models.append(model_name) # Get aliases from the hook method all_aliases = self.get_all_model_aliases() for model_name, aliases in all_aliases.items(): # Only add aliases for models that passed restriction check if model_name in models: models.extend(aliases) return models def list_all_known_models(self) -> list[str]: """Return all model names known by this provider, including alias targets. This is used for validation purposes to ensure restriction policies can validate against both aliases and their target model names. Returns: List of all model names and alias targets known by this provider """ all_models = set() # Get model configurations from the hook method model_configs = self.get_model_configurations() # Add all base model names for model_name in model_configs: all_models.add(model_name.lower()) # Get aliases from the hook method and add them all_aliases = self.get_all_model_aliases() for _model_name, aliases in all_aliases.items(): for alias in aliases: all_models.add(alias.lower()) return list(all_models) def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: """Provider-independent image validation. Args: image_path: Path to image file or data URL max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB) Returns: Tuple of (image_bytes, mime_type) Raises: ValueError: If image is invalid Examples: # Validate a file path image_bytes, mime_type = provider.validate_image("/path/to/image.png") # Validate a data URL image_bytes, mime_type = provider.validate_image("data:image/png;base64,...") # Validate with custom size limit image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0) """ # Use default if not specified if max_size_mb is None: max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB if image_path.startswith("data:"): # Parse data URL: data:image/png;base64,iVBORw0... try: header, data = image_path.split(",", 1) mime_type = header.split(";")[0].split(":")[1] except (ValueError, IndexError) as e: raise ValueError(f"Invalid data URL format: {e}") # Validate MIME type using IMAGES constant valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES] if mime_type not in valid_mime_types: raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}") # Decode base64 data try: image_bytes = base64.b64decode(data) except binascii.Error as e: raise ValueError(f"Invalid base64 data: {e}") else: # Handle file path # Read file first to check if it exists try: with open(image_path, "rb") as f: image_bytes = f.read() except FileNotFoundError: raise ValueError(f"Image file not found: {image_path}") except Exception as e: raise ValueError(f"Failed to read image file: {e}") # Validate extension ext = os.path.splitext(image_path)[1].lower() if ext not in IMAGES: raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}") # Get MIME type mime_type = get_image_mime_type(ext) # Validate size size_mb = len(image_bytes) / (1024 * 1024) if size_mb > max_size_mb: raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)") return image_bytes, mime_type def close(self): """Clean up any resources held by the provider. Default implementation does nothing. Subclasses should override if they hold resources that need cleanup. """ # Base implementation: no resources to clean up return def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: """Get the preferred model from this provider for a given category. Env-driven preference (future-proof, cost-aware): - Reads <PROVIDER>_PREFERRED_MODELS (ordered, comma-separated) where PROVIDER is self.get_provider_type().value uppercased (e.g., KIMI, GLM) - Returns the first preferred model present in allowed_models - Falls back to None if no match Args: category: The tool category requiring a model (currently not used here) allowed_models: Pre-filtered list of model names that are allowed by restrictions Returns: Preferred model name, or None if no preference found """ try: provider_key = self.get_provider_type().value.upper() env_var = f"{provider_key}_PREFERRED_MODELS" prefs = os.getenv(env_var, "").strip() if not prefs: return None preferred_list = [m.strip() for m in prefs.split(",") if m.strip()] # Preserve order; return first that is allowed (case-insensitive match) allowed_lower = {m.lower(): m for m in allowed_models} for pref in preferred_list: if pref.lower() in allowed_lower: return allowed_lower[pref.lower()] return None except Exception: return None def get_model_registry(self) -> Optional[dict[str, Any]]: """Get the model registry for providers that maintain one. This is a hook method for providers like CustomProvider that maintain a dynamic model registry. Returns: Model registry dict or None if not applicable """ # Default implementation - most providers don't have a registry return None

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Zazzles2908/EX_AI-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server