Skip to main content
Glama

Gemini MCP Server

registry.pyโ€ข19.4 kB
"""Model provider registry for managing available providers.""" import logging from typing import TYPE_CHECKING, Optional from utils.env import get_env from .base import ModelProvider from .shared import ProviderType if TYPE_CHECKING: from tools.models import ToolModelCategory class ModelProviderRegistry: """Central catalogue of provider implementations used by the MCP server. Role Holds the mapping between :class:`ProviderType` values and concrete :class:`ModelProvider` subclasses/factories. At runtime the registry is responsible for instantiating providers, caching them for reuse, and mediating lookup of providers and model names in provider priority order. Core responsibilities * Resolve API keys and other runtime configuration for each provider * Lazily create provider instances so unused backends incur no cost * Expose convenience methods for enumerating available models and locating which provider can service a requested model name or alias * Honour the project-wide provider priority policy so namespaces (or alias collisions) are resolved deterministically. """ _instance = None # Provider priority order for model selection # Native APIs first, then custom endpoints, then catch-all providers PROVIDER_PRIORITY_ORDER = [ ProviderType.GOOGLE, # Direct Gemini access ProviderType.OPENAI, # Direct OpenAI access ProviderType.AZURE, # Azure-hosted OpenAI deployments ProviderType.XAI, # Direct X.AI GROK access ProviderType.DIAL, # DIAL unified API access ProviderType.CUSTOM, # Local/self-hosted models ProviderType.OPENROUTER, # Catch-all for cloud models ] def __new__(cls): """Singleton pattern for registry.""" if cls._instance is None: logging.debug("REGISTRY: Creating new registry instance") cls._instance = super().__new__(cls) # Initialize instance dictionaries on first creation cls._instance._providers = {} cls._instance._initialized_providers = {} logging.debug(f"REGISTRY: Created instance {cls._instance}") return cls._instance @classmethod def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None: """Register a new provider class. Args: provider_type: Type of the provider (e.g., ProviderType.GOOGLE) provider_class: Class that implements ModelProvider interface """ instance = cls() instance._providers[provider_type] = provider_class # Invalidate any cached instance so subsequent lookups use the new registration instance._initialized_providers.pop(provider_type, None) @classmethod def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]: """Get an initialized provider instance. Args: provider_type: Type of provider to get force_new: Force creation of new instance instead of using cached Returns: Initialized ModelProvider instance or None if not available """ instance = cls() # Return cached instance if available and not forcing new if not force_new and provider_type in instance._initialized_providers: return instance._initialized_providers[provider_type] # Check if provider class is registered if provider_type not in instance._providers: return None # Get API key from environment api_key = cls._get_api_key_for_provider(provider_type) # Get provider class or factory function provider_class = instance._providers[provider_type] # For custom providers, handle special initialization requirements if provider_type == ProviderType.CUSTOM: # Check if it's a factory function (callable but not a class) if callable(provider_class) and not isinstance(provider_class, type): # Factory function - call it with api_key parameter provider = provider_class(api_key=api_key) else: # Regular class - need to handle URL requirement custom_url = get_env("CUSTOM_API_URL", "") or "" if not custom_url: if api_key: # Key is set but URL is missing logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing โ€“ skipping Custom provider") return None # Use empty string as API key for custom providers that don't need auth (e.g., Ollama) # This allows the provider to be created even without CUSTOM_API_KEY being set api_key = api_key or "" # Initialize custom provider with both API key and base URL provider = provider_class(api_key=api_key, base_url=custom_url) elif provider_type == ProviderType.GOOGLE: # For Gemini, check if custom base URL is configured if not api_key: return None gemini_base_url = get_env("GEMINI_BASE_URL") provider_kwargs = {"api_key": api_key} if gemini_base_url: provider_kwargs["base_url"] = gemini_base_url logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}") provider = provider_class(**provider_kwargs) elif provider_type == ProviderType.AZURE: if not api_key: return None azure_endpoint = get_env("AZURE_OPENAI_ENDPOINT") if not azure_endpoint: logging.warning("AZURE_OPENAI_ENDPOINT missing โ€“ skipping Azure OpenAI provider") return None azure_version = get_env("AZURE_OPENAI_API_VERSION") provider = provider_class( api_key=api_key, azure_endpoint=azure_endpoint, api_version=azure_version, ) else: if not api_key: return None # Initialize non-custom provider with just API key provider = provider_class(api_key=api_key) # Cache the instance instance._initialized_providers[provider_type] = provider return provider @classmethod def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]: """Get provider instance for a specific model name. Provider priority order: 1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient 2. CUSTOM - For local/private models with specific endpoints 3. OPENROUTER - Catch-all for cloud models via unified API Args: model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5") Returns: ModelProvider instance that supports this model """ logging.debug(f"get_provider_for_model called with model_name='{model_name}'") # Check providers in priority order instance = cls() logging.debug(f"Registry instance: {instance}") logging.debug(f"Available providers in registry: {list(instance._providers.keys())}") for provider_type in cls.PROVIDER_PRIORITY_ORDER: if provider_type in instance._providers: logging.debug(f"Found {provider_type} in registry") # Get or create provider instance provider = cls.get_provider(provider_type) if provider and provider.validate_model_name(model_name): logging.debug(f"{provider_type} validates model {model_name}") return provider else: logging.debug(f"{provider_type} does not validate model {model_name}") else: logging.debug(f"{provider_type} not found in registry") logging.debug(f"No provider found for model {model_name}") return None @classmethod def get_available_providers(cls) -> list[ProviderType]: """Get list of registered provider types.""" instance = cls() return list(instance._providers.keys()) @classmethod def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]: """Get mapping of all available models to their providers. Args: respect_restrictions: If True, filter out models not allowed by restrictions Returns: Dict mapping model names to provider types """ # Import here to avoid circular imports from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() if respect_restrictions else None models: dict[str, ProviderType] = {} instance = cls() for provider_type in instance._providers: provider = cls.get_provider(provider_type) if not provider: continue try: available = provider.list_models(respect_restrictions=respect_restrictions) except NotImplementedError: logging.warning("Provider %s does not implement list_models", provider_type) continue if restriction_service and restriction_service.has_restrictions(provider_type): restricted_display = cls._collect_restricted_display_names( provider, provider_type, available, restriction_service, ) if restricted_display: for model_name in restricted_display: models[model_name] = provider_type continue for model_name in available: # ===================================================================================== # CRITICAL: Prevent double restriction filtering (Fixed Issue #98) # ===================================================================================== # Previously, both the provider AND registry applied restrictions, causing # double-filtering that resulted in "no models available" errors. # # Logic: If respect_restrictions=True, provider already filtered models, # so registry should NOT filter them again. # TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions # ===================================================================================== if ( restriction_service and not respect_restrictions # Only filter if provider didn't already filter and not restriction_service.is_allowed(provider_type, model_name) ): logging.debug("Model %s filtered by restrictions", model_name) continue models[model_name] = provider_type return models @classmethod def _collect_restricted_display_names( cls, provider: ModelProvider, provider_type: ProviderType, available: list[str], restriction_service, ) -> list[str] | None: """Derive the human-facing model list when restrictions are active.""" allowed_models = restriction_service.get_allowed_models(provider_type) if not allowed_models: return None allowed_details: list[tuple[str, int]] = [] for model_name in sorted(allowed_models): try: capabilities = provider.get_capabilities(model_name) except (AttributeError, ValueError): continue try: rank = capabilities.get_effective_capability_rank() rank_value = float(rank) except (AttributeError, TypeError, ValueError): rank_value = 0.0 allowed_details.append((model_name, rank_value)) if allowed_details: allowed_details.sort(key=lambda item: (-item[1], item[0])) return [name for name, _ in allowed_details] # Fallback: intersect the allowlist with the provider-advertised names. available_lookup = {name.lower(): name for name in available} display_names: list[str] = [] for model_name in sorted(allowed_models): lowered = model_name.lower() if lowered in available_lookup: display_names.append(available_lookup[lowered]) return display_names @classmethod def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]: """Get list of available model names, optionally filtered by provider. This respects model restrictions automatically. Args: provider_type: Optional provider to filter by Returns: List of available model names """ available_models = cls.get_available_models(respect_restrictions=True) if provider_type: # Filter by specific provider return [name for name, ptype in available_models.items() if ptype == provider_type] else: # Return all available models return list(available_models.keys()) @classmethod def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]: """Get API key for a provider from environment variables. Args: provider_type: Provider type to get API key for Returns: API key string or None if not found """ key_mapping = { ProviderType.GOOGLE: "GEMINI_API_KEY", ProviderType.OPENAI: "OPENAI_API_KEY", ProviderType.AZURE: "AZURE_OPENAI_API_KEY", ProviderType.XAI: "XAI_API_KEY", ProviderType.OPENROUTER: "OPENROUTER_API_KEY", ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth ProviderType.DIAL: "DIAL_API_KEY", } env_var = key_mapping.get(provider_type) if not env_var: return None return get_env(env_var) @classmethod def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]: """Get a list of allowed canonical model names for a given provider. Args: provider: The provider instance to get models for provider_type: The provider type for restriction checking Returns: List of model names that are both supported and allowed """ from utils.model_restrictions import get_restriction_service restriction_service = get_restriction_service() allowed_models = [] # Get the provider's supported models try: # Use list_models to get all supported models (handles both regular and custom providers) supported_models = provider.list_models(respect_restrictions=False) except (NotImplementedError, AttributeError): # Fallback to provider-declared capability maps if list_models not implemented model_map = getattr(provider, "MODEL_CAPABILITIES", None) supported_models = list(model_map.keys()) if isinstance(model_map, dict) else [] # Filter by restrictions for model_name in supported_models: if restriction_service.is_allowed(provider_type, model_name): allowed_models.append(model_name) return allowed_models @classmethod def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str: """Get the preferred fallback model based on provider priority and tool category. This method orchestrates model selection by: 1. Getting allowed models for each provider (respecting restrictions) 2. Asking providers for their preference from the allowed list 3. Falling back to first available model if no preference given Args: tool_category: Optional category to influence model selection Returns: Model name string for fallback use """ from tools.models import ToolModelCategory effective_category = tool_category or ToolModelCategory.BALANCED first_available_model = None # Ask each provider for their preference in priority order for provider_type in cls.PROVIDER_PRIORITY_ORDER: provider = cls.get_provider(provider_type) if provider: # 1. Registry filters the models first allowed_models = cls._get_allowed_models_for_provider(provider, provider_type) if not allowed_models: continue # 2. Keep track of the first available model as fallback if not first_available_model: first_available_model = sorted(allowed_models)[0] # 3. Ask provider to pick from allowed list preferred_model = provider.get_preferred_model(effective_category, allowed_models) if preferred_model: logging.debug( f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'" ) return preferred_model # If no provider returned a preference, use first available model if first_available_model: logging.debug(f"No provider preference, using first available: {first_available_model}") return first_available_model # Ultimate fallback if no providers have models logging.warning("No models available from any provider, using default fallback") return "gemini-2.5-flash" @classmethod def get_available_providers_with_keys(cls) -> list[ProviderType]: """Get list of provider types that have valid API keys. Returns: List of ProviderType values for providers with valid API keys """ available = [] instance = cls() for provider_type in instance._providers: if cls.get_provider(provider_type) is not None: available.append(provider_type) return available @classmethod def clear_cache(cls) -> None: """Clear cached provider instances.""" instance = cls() instance._initialized_providers.clear() @classmethod def reset_for_testing(cls) -> None: """Reset the registry to a clean state for testing. This provides a safe, public API for tests to clean up registry state without directly manipulating private attributes. """ cls._instance = None if hasattr(cls, "_providers"): cls._providers = {} @classmethod def unregister_provider(cls, provider_type: ProviderType) -> None: """Unregister a provider (mainly for testing).""" instance = cls() instance._providers.pop(provider_type, None) instance._initialized_providers.pop(provider_type, None)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BeehiveInnovations/gemini-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server