Skip to main content
Glama
proxy.pyβ€’26.2 kB
"""Proxy infrastructure for managing downstream MCP server connections.""" import asyncio import logging from typing import Any from fastmcp.client import Client from fastmcp.client.transports import StreamableHttpTransport logger = logging.getLogger(__name__) class ProxyManager: """Manages connections to downstream MCP servers. This class initializes and maintains Client instances for each configured downstream MCP server. It supports both stdio (npx/uvx) and HTTP transports, implements lazy connection strategy, and provides graceful error handling for unreachable servers. """ def __init__(self): """Initialize ProxyManager with empty client registry.""" self._clients: dict[str, Client] = {} self._connection_status: dict[str, bool] = {} self._connection_errors: dict[str, str] = {} self._current_config: dict = {} # Store current config for reload comparison def initialize_connections(self, mcp_config: dict) -> dict[str, Client]: """Initialize Client instances from MCP configuration. Creates disconnected Client instances for lazy connection strategy. Connections are established on first use via async context manager. Args: mcp_config: MCP servers configuration dictionary with structure: { "mcpServers": { "server-name": { "command": "npx", "args": [...], "env": {...} } } } Returns: Dictionary mapping server names to Client instances Raises: ValueError: If mcp_config is invalid or malformed """ if not isinstance(mcp_config, dict): raise ValueError( f"MCP configuration must be a dict, got {type(mcp_config).__name__}" ) mcp_servers = mcp_config.get("mcpServers", {}) if not isinstance(mcp_servers, dict): raise ValueError( f'"mcpServers" must be a dict, got {type(mcp_servers).__name__}' ) # Clear existing clients self._clients.clear() self._connection_status.clear() self._connection_errors.clear() # Create ProxyClient for each server for server_name, server_config in mcp_servers.items(): try: client = self._create_client(server_name, server_config) self._clients[server_name] = client self._connection_status[server_name] = False # Not yet connected self._connection_errors[server_name] = "" logger.info(f"Initialized ProxyClient for server: {server_name}") except Exception as e: logger.error(f"Failed to initialize client for {server_name}: {e}") self._connection_errors[server_name] = str(e) # Store current config for reload comparison self._current_config = mcp_config return self._clients def _create_client(self, server_name: str, server_config: dict) -> Client: """Create Client instance from server configuration. Args: server_name: Name of the server server_config: Server configuration dictionary Returns: Client instance (disconnected) Raises: ValueError: If server configuration is invalid """ # Determine transport type has_command = "command" in server_config has_url = "url" in server_config if not has_command and not has_url: raise ValueError( f'Server "{server_name}" must specify either "command" (stdio) ' f'or "url" (HTTP) transport' ) if has_command and has_url: raise ValueError( f'Server "{server_name}" cannot have both "command" and "url"' ) # Create stdio client if has_command: command = server_config["command"] args = server_config.get("args", []) env = server_config.get("env", {}) if not isinstance(command, str): raise ValueError( f'Server "{server_name}": "command" must be a string' ) if not isinstance(args, list): raise ValueError( f'Server "{server_name}": "args" must be a list' ) if not isinstance(env, dict): raise ValueError( f'Server "{server_name}": "env" must be a dict' ) logger.debug( f"Creating stdio Client for {server_name}: " f"command={command}, args={args}" ) # FastMCP Client expects MCPConfig with mcpServers key # We need to wrap the single server config client_config = { "mcpServers": { server_name: server_config } } return Client(transport=client_config) # Create HTTP client if has_url: url = server_config["url"] headers = server_config.get("headers", {}) if not isinstance(url, str): raise ValueError( f'Server "{server_name}": "url" must be a string' ) if not isinstance(headers, dict): raise ValueError( f'Server "{server_name}": "headers" must be a dict' ) # Check if Authorization header is provided (PAT or other auth) has_auth_header = headers and any( k.lower() == "authorization" for k in headers.keys() ) if has_auth_header: # User provided explicit auth - respect it, don't enable OAuth # Use StreamableHttpTransport to pass custom headers logger.info( f"Creating HTTP Client with custom authentication for {server_name}: " f"url={url}" ) transport = StreamableHttpTransport(url, headers=headers) return Client(transport) else: # No auth provided - enable OAuth auto-detection (for Notion, etc.) logger.info( f"Creating HTTP Client with OAuth support for {server_name}: " f"url={url}" ) return Client(url, auth="oauth") # Should never reach here due to earlier validation raise ValueError(f'Server "{server_name}" has invalid configuration') def get_client(self, server_name: str) -> Client: """Get Client instance for a server. Returns the disconnected Client for the specified server. The caller should use 'async with client:' to establish a connection and perform operations. Args: server_name: Name of the server Returns: Client instance for the server Raises: KeyError: If server_name is not found in initialized clients RuntimeError: If server had initialization errors """ if server_name not in self._clients: # Check if it had initialization error if server_name in self._connection_errors and self._connection_errors[server_name]: raise RuntimeError( f'Server "{server_name}" is unavailable: ' f'{self._connection_errors[server_name]}' ) raise KeyError( f'Server "{server_name}" not found in configured servers' ) return self._clients[server_name] async def test_connection( self, server_name: str, timeout_ms: int = 5000, max_retries: int = 3 ) -> bool: """Test connection to a server with retry logic. Attempts to connect to the server and retrieve its tool list. Uses exponential backoff for retries. Args: server_name: Name of the server to test timeout_ms: Connection timeout in milliseconds (default: 5000) max_retries: Maximum number of retry attempts (default: 3) Returns: True if connection successful, False otherwise Raises: KeyError: If server_name is not found in initialized clients """ client = self.get_client(server_name) timeout_sec = timeout_ms / 1000.0 base_delay = 0.5 # Start with 500ms delay for attempt in range(max_retries): try: logger.debug( f"Testing connection to {server_name} " f"(attempt {attempt + 1}/{max_retries})" ) async with asyncio.timeout(timeout_sec): async with client: # Try to list tools as connection test await client.list_tools() self._connection_status[server_name] = True self._connection_errors[server_name] = "" logger.info(f"Successfully connected to server: {server_name}") return True except asyncio.TimeoutError: delay = base_delay * (2 ** attempt) logger.warning( f"Connection timeout for {server_name} on attempt {attempt + 1}. " f"Retrying in {delay}s..." ) if attempt < max_retries - 1: await asyncio.sleep(delay) except Exception as e: delay = base_delay * (2 ** attempt) logger.warning( f"Connection error for {server_name} on attempt {attempt + 1}: {e}. " f"Retrying in {delay}s..." ) if attempt < max_retries - 1: await asyncio.sleep(delay) # All retries failed error_msg = f"Failed to connect after {max_retries} attempts" self._connection_status[server_name] = False self._connection_errors[server_name] = error_msg logger.error(f"Failed to connect to server {server_name}: {error_msg}") return False async def call_tool( self, server_name: str, tool_name: str, arguments: dict[str, Any], timeout_ms: int | None = None ) -> Any: """Call a tool on a downstream server. Establishes a fresh session for each call (automatic session isolation). Args: server_name: Name of the server tool_name: Name of the tool to call arguments: Tool arguments timeout_ms: Optional timeout in milliseconds Returns: Tool execution result Raises: KeyError: If server_name is not found RuntimeError: If server is unavailable or tool call fails asyncio.TimeoutError: If operation times out """ client = self.get_client(server_name) try: if timeout_ms: timeout_sec = timeout_ms / 1000.0 async with asyncio.timeout(timeout_sec): async with client: result = await client.call_tool(tool_name, arguments) return result else: async with client: result = await client.call_tool(tool_name, arguments) return result except asyncio.TimeoutError: logger.error( f"Timeout calling tool {tool_name} on server {server_name} " f"(timeout: {timeout_ms}ms)" ) raise except Exception as e: logger.error( f"Error calling tool {tool_name} on server {server_name}: {e}" ) raise RuntimeError( f"Failed to call tool {tool_name} on server {server_name}: {e}" ) async def list_tools(self, server_name: str) -> list[Any]: """List all tools available from a server. Args: server_name: Name of the server Returns: List of tool definitions Raises: KeyError: If server_name is not found RuntimeError: If server is unavailable """ client = self.get_client(server_name) try: async with client: tools = await client.list_tools() return tools except Exception as e: logger.error(f"Error listing tools from server {server_name}: {e}") raise RuntimeError( f"Failed to list tools from server {server_name}: {e}" ) def get_server_status(self, server_name: str) -> dict[str, Any]: """Get connection status for a server. Args: server_name: Name of the server Returns: Dictionary with status information: { "connected": bool, "error": str, "initialized": bool } """ return { "connected": self._connection_status.get(server_name, False), "error": self._connection_errors.get(server_name, ""), "initialized": server_name in self._clients } def get_all_servers(self) -> list[str]: """Get list of all initialized server names. Returns: List of server names """ return list(self._clients.keys()) def get_servers_config(self) -> dict: """Get current MCP servers configuration. Returns the mcpServers dictionary from the current configuration, reflecting any hot-reload changes. Returns: Dictionary mapping server names to server configurations """ return self._current_config.get("mcpServers", {}) def _config_changed(self, server_name: str, new_mcp_config: dict) -> bool: """Check if a server's configuration has changed. Compares the server configuration in the current config with the new config to determine if the server needs to be reloaded. Args: server_name: Name of the server to check new_mcp_config: New MCP configuration to compare against Returns: True if configuration changed, False otherwise """ old_servers = self._current_config.get("mcpServers", {}) new_servers = new_mcp_config.get("mcpServers", {}) old_config = old_servers.get(server_name, {}) new_config = new_servers.get(server_name, {}) # Compare the configurations (deep comparison) return old_config != new_config async def reload(self, new_mcp_config: dict) -> tuple[bool, str | None]: """Reload proxy client connections with new MCP server configuration. This method performs an atomic configuration update by: 1. Validating the new configuration 2. Determining which servers need to be added, removed, or updated 3. Closing connections for removed/updated servers 4. Creating new clients for added/updated servers 5. Preserving unchanged servers (no disruption) The reload follows the lazy connection strategy - new clients are created in disconnected state and will connect on first use. Args: new_mcp_config: New MCP server configuration dictionary with structure: { "mcpServers": { "server-name": { "command": "npx", "args": [...], "env": {...} } } } Returns: Tuple of (success: bool, error_message: str | None) - (True, None) if reload successful - (False, error_message) if validation failed or reload encountered errors Thread Safety: This method is NOT thread-safe. The caller must ensure that reload operations are serialized and that no concurrent operations are accessing the ProxyManager during reload. Examples: >>> manager = ProxyManager() >>> manager.initialize_connections(initial_config) >>> success, error = await manager.reload(new_config) >>> if success: ... logger.info("Reload successful") >>> else: ... logger.error(f"Reload failed: {error}") """ logger.info("ProxyManager reload initiated") # Validate new configuration structure try: if not isinstance(new_mcp_config, dict): error_msg = f"MCP configuration must be a dict, got {type(new_mcp_config).__name__}" logger.error(f"Reload validation failed: {error_msg}") return False, error_msg new_mcp_servers = new_mcp_config.get("mcpServers", {}) if not isinstance(new_mcp_servers, dict): error_msg = f'"mcpServers" must be a dict, got {type(new_mcp_servers).__name__}' logger.error(f"Reload validation failed: {error_msg}") return False, error_msg # Validate each server config before proceeding for server_name, server_config in new_mcp_servers.items(): try: # Validate by attempting to parse the config (without creating client) has_command = "command" in server_config has_url = "url" in server_config if not has_command and not has_url: raise ValueError( f'Server "{server_name}" must specify either "command" (stdio) ' f'or "url" (HTTP) transport' ) if has_command and has_url: raise ValueError( f'Server "{server_name}" cannot have both "command" and "url"' ) # Validate stdio config if has_command: command = server_config["command"] args = server_config.get("args", []) env = server_config.get("env", {}) if not isinstance(command, str): raise ValueError(f'Server "{server_name}": "command" must be a string') if not isinstance(args, list): raise ValueError(f'Server "{server_name}": "args" must be a list') if not isinstance(env, dict): raise ValueError(f'Server "{server_name}": "env" must be a dict') # Validate HTTP config if has_url: url = server_config["url"] headers = server_config.get("headers", {}) if not isinstance(url, str): raise ValueError(f'Server "{server_name}": "url" must be a string') if not isinstance(headers, dict): raise ValueError(f'Server "{server_name}": "headers" must be a dict') except Exception as e: error_msg = f"Invalid configuration for server '{server_name}': {e}" logger.error(f"Reload validation failed: {error_msg}") return False, error_msg except Exception as e: error_msg = f"Configuration validation error: {e}" logger.error(f"Reload validation failed: {error_msg}") return False, error_msg logger.info("Configuration validation passed") # Determine server changes old_servers = set(self._clients.keys()) new_servers = set(new_mcp_config.get("mcpServers", {}).keys()) servers_to_add = new_servers - old_servers servers_to_remove = old_servers - new_servers servers_to_check = old_servers & new_servers servers_to_update = [ s for s in servers_to_check if self._config_changed(s, new_mcp_config) ] servers_unchanged = [ s for s in servers_to_check if not self._config_changed(s, new_mcp_config) ] logger.info( f"Server changes: " f"+{len(servers_to_add)} (add), " f"-{len(servers_to_remove)} (remove), " f"~{len(servers_to_update)} (update), " f"={len(servers_unchanged)} (unchanged)" ) if servers_to_add: logger.info(f"Servers to add: {sorted(servers_to_add)}") if servers_to_remove: logger.info(f"Servers to remove: {sorted(servers_to_remove)}") if servers_to_update: logger.info(f"Servers to update: {sorted(servers_to_update)}") if servers_unchanged: logger.debug(f"Servers unchanged: {sorted(servers_unchanged)}") # Close connections for removed and updated servers servers_to_close = list(servers_to_remove) + servers_to_update if servers_to_close: logger.info(f"Closing connections for {len(servers_to_close)} servers") for server_name in servers_to_close: try: client = self._clients.get(server_name) if client is not None: # Since we use lazy connection strategy with context managers, # clients don't maintain persistent connections. However, we # should still clean up any resources. logger.debug(f"Cleaning up client for server: {server_name}") # Note: Client instances don't have an explicit close() method # as they use async context managers. Resources are released # when we remove the reference. # Remove from all tracking dictionaries self._clients.pop(server_name, None) self._connection_status.pop(server_name, None) self._connection_errors.pop(server_name, None) logger.debug(f"Removed client for server: {server_name}") except Exception as e: logger.warning( f"Error during cleanup for server '{server_name}': {e}", exc_info=True ) # Continue with reload despite cleanup errors # Create clients for new and updated servers servers_to_create = list(servers_to_add) + servers_to_update new_mcp_servers = new_mcp_config.get("mcpServers", {}) if servers_to_create: logger.info(f"Creating clients for {len(servers_to_create)} servers") for server_name in servers_to_create: try: server_config = new_mcp_servers[server_name] client = self._create_client(server_name, server_config) self._clients[server_name] = client self._connection_status[server_name] = False # Not yet connected self._connection_errors[server_name] = "" logger.info(f"Created client for server: {server_name}") except Exception as e: # Log error but don't fail the entire reload # This follows the lazy connection strategy - servers with # initialization errors are recorded and will error on use logger.error( f"Failed to create client for server '{server_name}': {e}", exc_info=True ) self._connection_errors[server_name] = str(e) # Update stored config self._current_config = new_mcp_config logger.info( f"ProxyManager reload completed successfully. " f"Active servers: {len(self._clients)}" ) return True, None async def close_all_connections(self): """Gracefully shutdown all downstream MCP server connections. Follows MCP spec for stdio transport: 1. Close input stream to child process 2. Wait for server to exit (2s timeout) 3. Send SIGTERM/SIGKILL if not exited This method iterates through all clients and calls close() on each, which triggers the proper transport cleanup sequence. """ logger.info("Initiating graceful shutdown of all downstream servers") server_count = len(self._clients) closed_count = 0 error_count = 0 for server_name, client in list(self._clients.items()): try: logger.debug(f"Closing client for server: {server_name}") # Client.close() triggers transport.close() -> disconnect() # For stdio transports, this terminates the subprocess await client.close() closed_count += 1 logger.info(f"Closed connection to: {server_name}") except Exception as e: error_count += 1 logger.warning(f"Error closing client for {server_name}: {e}") # Continue with other clients - one failure shouldn't prevent others # Clear internal state self._clients.clear() self._connection_status.clear() self._connection_errors.clear() logger.info( f"Graceful shutdown complete: {closed_count}/{server_count} servers closed" + (f", {error_count} errors" if error_count else "") )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/roddutra/agent-mcp-gateway'

If you have feedback or need assistance with the MCP directory API, please join our Discord server