Skip to main content
Glama
gateway.py23.6 kB
"""Gateway server for Agent MCP Gateway.""" import asyncio import fnmatch from typing import Annotated, Any, Optional from fastmcp import FastMCP from fastmcp.exceptions import ToolError from pydantic import BaseModel, Field from .policy import PolicyEngine from .proxy import ProxyManager # Output schemas for gateway tools class ServerInfo(BaseModel): """Server information returned by list_servers.""" name: Annotated[str, Field(description="Server name (use in get_server_tools and execute_tool)")] description: Annotated[Optional[str], Field(description="What this server provides (from config or null if not configured)")] = None transport: Annotated[Optional[str], Field(description="How server communicates: stdio or http (only if include_metadata=true)")] = None command: Annotated[Optional[str], Field(description="Command that runs this server (only if include_metadata=true and transport=stdio)")] = None url: Annotated[Optional[str], Field(description="Server endpoint (only if include_metadata=true and transport=http)")] = None class ToolDefinition(BaseModel): """Tool definition from downstream server.""" name: Annotated[str, Field(description="Tool name (use in execute_tool)")] description: Annotated[str, Field(description="What this tool does")] inputSchema: Annotated[dict, Field(description="JSON Schema defining required/optional parameters for execute_tool args")] class GetServerToolsResponse(BaseModel): """Response from get_server_tools.""" tools: Annotated[list[ToolDefinition], Field(description="Tool definitions you can access")] server: Annotated[str, Field(description="Queried server name")] total_available: Annotated[int, Field(description="Total tools on server (may exceed returned if filtered by permissions/criteria)")] returned: Annotated[int, Field(description="Count of tools returned (less than total_available is normal due to filtering)")] tokens_used: Annotated[Optional[int], Field(description="Tokens used in schemas (if max_schema_tokens was set)")] = None error: Annotated[Optional[str], Field(description="Error message if request failed")] = None class ToolExecutionResponse(BaseModel): """Response from execute_tool.""" content: Annotated[list[dict], Field(description="Result from the downstream tool (format varies by tool)")] isError: Annotated[bool, Field(description="True if the downstream tool returned an error")] class GatewayStatusResponse(BaseModel): """Response from get_gateway_status (debug tool).""" reload_status: Annotated[Optional[dict], Field(description="Hot reload history with timestamps and errors")] policy_state: Annotated[dict, Field(description="Policy engine configuration (agent count, defaults)")] available_servers: Annotated[list[str], Field(description="All configured server names")] config_paths: Annotated[dict, Field(description="File paths to gateway configuration")] message: Annotated[str, Field(description="Summary status message")] # Create FastMCP instance gateway = FastMCP(name="Agent MCP Gateway") # Module-level storage for configurations (set by main.py) _policy_engine: PolicyEngine | None = None _mcp_config: dict | None = None _proxy_manager: ProxyManager | None = None _check_config_changes_fn: Any | None = None # Fallback reload checker _get_reload_status_fn: Any | None = None # Reload status getter for diagnostics _default_agent_id: str | None = None # Default agent for fallback chain _debug_mode: bool = False # Debug mode flag def initialize_gateway( policy_engine: PolicyEngine, mcp_config: dict, proxy_manager: ProxyManager | None = None, check_config_changes_fn: Any = None, get_reload_status_fn: Any = None, default_agent_id: str | None = None, debug_mode: bool = False ): """Initialize gateway with policy engine, MCP config, and proxy manager. This must be called before the gateway starts accepting requests. Args: policy_engine: PolicyEngine instance for access control mcp_config: MCP servers configuration dictionary proxy_manager: Optional ProxyManager instance (required for get_server_tools) check_config_changes_fn: Optional function to check for config changes (fallback mechanism) get_reload_status_fn: Optional function to get reload status for diagnostics default_agent_id: Optional default agent ID from GATEWAY_DEFAULT_AGENT env var for fallback chain debug_mode: Whether debug mode is enabled (exposes get_gateway_status tool) """ global _policy_engine, _mcp_config, _proxy_manager, _check_config_changes_fn, _get_reload_status_fn, _default_agent_id, _debug_mode _policy_engine = policy_engine _mcp_config = mcp_config _proxy_manager = proxy_manager _check_config_changes_fn = check_config_changes_fn _get_reload_status_fn = get_reload_status_fn _default_agent_id = default_agent_id _debug_mode = debug_mode # Conditionally register debug tools based on debug mode if debug_mode: _register_debug_tools() def get_default_agent_id() -> str | None: """Get the default agent ID from gateway configuration. Returns: Default agent ID from GATEWAY_DEFAULT_AGENT env var, or None if not set """ return _default_agent_id def update_mcp_config(new_mcp_config: dict) -> None: """Update the gateway's MCP configuration after hot reload. NOTE: This function is deprecated and no longer needed since list_servers now queries ProxyManager directly. It's retained for backward compatibility with existing tests. Args: new_mcp_config: New MCP servers configuration dictionary """ import logging logger = logging.getLogger(__name__) # Defensive validation if not isinstance(new_mcp_config, dict): raise TypeError(f"Invalid config type: {type(new_mcp_config)}") if "mcpServers" not in new_mcp_config: raise KeyError("Config missing 'mcpServers' key") global _mcp_config old_servers = set(_mcp_config.get("mcpServers", {}).keys()) if _mcp_config else set() new_servers = set(new_mcp_config.get("mcpServers", {}).keys()) # Log changes added = new_servers - old_servers removed = old_servers - new_servers if added: logger.info(f"MCP config update: Added servers: {added}") if removed: logger.info(f"MCP config update: Removed servers: {removed}") # Update the module-level config _mcp_config = new_mcp_config def _register_debug_tools(): """Register debug-only tools when debug mode is enabled. This function is called by initialize_gateway() when debug_mode=True. It registers additional diagnostic tools that should only be available in debug/development environments. """ # Register get_gateway_status tool # Note: The function itself is always defined (for testing), but only # registered as a gateway tool when debug mode is enabled gateway.tool(get_gateway_status) @gateway.tool async def list_servers( agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None, include_metadata: Annotated[bool, "Include technical details (transport, command, url)"] = False ) -> list[dict]: """Discover downstream MCP servers available through this gateway. Your access is determined by gateway policy rules. Workflow: 1) Call list_servers to discover servers, 2) Call get_server_tools to see available tools, 3) Call execute_tool to use them.""" # Defensive check (middleware should have resolved agent_id) if agent_id is None: raise ToolError("Internal error: agent_id not resolved by middleware") # Get configurations from module-level storage policy_engine = _policy_engine proxy_manager = _proxy_manager if not policy_engine: raise RuntimeError("PolicyEngine not initialized in gateway state") if not proxy_manager: raise RuntimeError("ProxyManager not initialized in gateway state") # Get servers this agent can access allowed_servers = policy_engine.get_allowed_servers(agent_id) # Get current server list from ProxyManager (reflects hot-reload changes) all_servers = proxy_manager.get_servers_config() # Build response server_list = [] # Handle wildcard access if allowed_servers == ["*"]: # Agent has wildcard access - return all servers allowed_servers = list(all_servers.keys()) for server_name in allowed_servers: if server_name in all_servers: server_config = all_servers[server_name] # Determine transport type transport = "stdio" if "command" in server_config else "http" # Build ServerInfo object - always include name and description server_info_kwargs = { "name": server_name, "description": server_config.get("description") # Include description always (None if not in config) } # Add technical metadata if requested if include_metadata: server_info_kwargs["transport"] = transport # Add transport-specific metadata if transport == "stdio": server_info_kwargs["command"] = server_config.get("command") elif transport == "http": server_info_kwargs["url"] = server_config.get("url") server_list.append(ServerInfo(**server_info_kwargs)) return [server.model_dump() for server in server_list] def _matches_pattern(tool_name: str, pattern: str) -> bool: """Check if tool name matches wildcard pattern. Uses glob-style pattern matching: - * matches any sequence of characters - ? matches any single character - [seq] matches any character in seq - [!seq] matches any character not in seq Args: tool_name: Name of the tool to match pattern: Pattern with wildcards (e.g., "get_*", "*_user") Returns: True if tool_name matches pattern, False otherwise Example: >>> _matches_pattern("get_user", "get_*") True >>> _matches_pattern("delete_user", "get_*") False >>> _matches_pattern("list_items", "*_items") True """ return fnmatch.fnmatch(tool_name, pattern) def _estimate_tool_tokens(tool: Any) -> int: """Estimate token count for a tool definition. Estimates tokens based on name, description, and input schema JSON length. Uses rough approximation: characters / 4 = tokens (typical for English text). Args: tool: Tool object with name, description, and inputSchema attributes Returns: Estimated token count for the tool definition Example: >>> tool = Tool(name="get_user", description="Get user by ID", inputSchema={...}) >>> _estimate_tool_tokens(tool) 42 """ # Count name length name_len = len(tool.name) if hasattr(tool, 'name') and tool.name else 0 # Count description length desc_len = len(tool.description) if hasattr(tool, 'description') and tool.description else 0 # Count input schema length (convert to string for estimation) schema_len = 0 if hasattr(tool, 'inputSchema') and tool.inputSchema: # Convert schema dict to string for rough character count import json try: schema_len = len(json.dumps(tool.inputSchema)) except Exception: # If serialization fails, use a default estimate schema_len = 100 # Total characters / 4 = rough token estimate total_chars = name_len + desc_len + schema_len return max(1, total_chars // 4) @gateway.tool async def get_server_tools( agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None, server: Annotated[str, "Server name from list_servers"] = "", names: Annotated[Optional[str], "Filter: comma-separated tool names"] = None, pattern: Annotated[Optional[str], "Filter: wildcard pattern (e.g., 'get_*')"] = None, max_schema_tokens: Annotated[Optional[int], "Limit total tokens in returned schemas"] = None ) -> dict: """Discover tools available on a downstream MCP server accessed through this gateway. Returns only tools you have permission to use (filtered by policy rules). Use the returned tool definitions to call execute_tool.""" # Defensive check (middleware should have resolved agent_id) if agent_id is None: raise ToolError("Internal error: agent_id not resolved by middleware") # Check for config changes (fallback mechanism for when file watching doesn't work) if _check_config_changes_fn: try: _check_config_changes_fn() except Exception: pass # Don't let config check errors break tool execution # Parse comma-separated names string into list names_list: Optional[list[str]] = None if names is not None and names.strip(): # Split by comma and trim whitespace from each name names_list = [name.strip() for name in names.split(",") if name.strip()] # If we ended up with an empty list after filtering, treat as None if not names_list: names_list = None # Get configurations from module-level storage policy_engine = _policy_engine proxy_manager = _proxy_manager if not policy_engine: return GetServerToolsResponse( tools=[], server=server, total_available=0, returned=0, tokens_used=None, error="PolicyEngine not initialized in gateway state" ).model_dump() if not proxy_manager: return GetServerToolsResponse( tools=[], server=server, total_available=0, returned=0, tokens_used=None, error="ProxyManager not initialized in gateway state" ).model_dump() # Validate agent can access server if not policy_engine.can_access_server(agent_id, server): return GetServerToolsResponse( tools=[], server=server, total_available=0, returned=0, tokens_used=None, error=f"Access denied: Agent '{agent_id}' cannot access server '{server}'" ).model_dump() # Get tools from downstream server try: all_tools = await proxy_manager.list_tools(server) except KeyError: return GetServerToolsResponse( tools=[], server=server, total_available=0, returned=0, tokens_used=None, error=f"Server '{server}' not found in configured servers" ).model_dump() except RuntimeError as e: return GetServerToolsResponse( tools=[], server=server, total_available=0, returned=0, tokens_used=None, error=f"Server unavailable: {str(e)}" ).model_dump() except Exception as e: return GetServerToolsResponse( tools=[], server=server, total_available=0, returned=0, tokens_used=None, error=f"Failed to retrieve tools: {str(e)}" ).model_dump() total_available = len(all_tools) # Filter tools based on criteria filtered_tools = [] token_count = 0 for tool in all_tools: tool_name = tool.name if hasattr(tool, 'name') else str(tool) # Filter by explicit names list if names_list is not None and tool_name not in names_list: continue # Filter by wildcard pattern if pattern is not None and not _matches_pattern(tool_name, pattern): continue # Filter by policy permissions if not policy_engine.can_access_tool(agent_id, server, tool_name): continue # Check token budget limit if max_schema_tokens is not None: tool_tokens = _estimate_tool_tokens(tool) if token_count + tool_tokens > max_schema_tokens: # Stop adding tools - budget exceeded break token_count += tool_tokens # Convert tool to ToolDefinition tool_definition = ToolDefinition( name=tool_name, description=tool.description if hasattr(tool, 'description') and tool.description else "", inputSchema=tool.inputSchema if hasattr(tool, 'inputSchema') else {} ) filtered_tools.append(tool_definition) return GetServerToolsResponse( tools=filtered_tools, server=server, total_available=total_available, returned=len(filtered_tools), tokens_used=token_count if max_schema_tokens is not None else None ).model_dump() @gateway.tool async def execute_tool( agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None, server: Annotated[str, "Server name from list_servers"] = "", tool: Annotated[str, "Tool name from get_server_tools"] = "", args: Annotated[dict, "Arguments matching tool's inputSchema"] = {}, timeout_ms: Annotated[Optional[int], "Execution timeout in milliseconds"] = None ) -> dict: """Execute a tool on a downstream MCP server accessed through this gateway. Gateway validates permissions then forwards your request to the server. Returns the server's response directly.""" # Defensive check (middleware should have resolved agent_id) if agent_id is None: raise ToolError("Internal error: agent_id not resolved by middleware") # Get configurations from module-level storage policy_engine = _policy_engine proxy_manager = _proxy_manager if not policy_engine: raise ToolError("PolicyEngine not initialized in gateway state") if not proxy_manager: raise ToolError("ProxyManager not initialized in gateway state") # 1. Validate agent can access server if not policy_engine.can_access_server(agent_id, server): raise ToolError(f"Agent '{agent_id}' cannot access server '{server}'") # 2. Validate agent can access tool if not policy_engine.can_access_tool(agent_id, server, tool): raise ToolError(f"Agent '{agent_id}' not authorized to call tool '{tool}' on server '{server}'") # 3. Execute tool on downstream server try: result = await proxy_manager.call_tool(server, tool, args, timeout_ms) # 4. Return result transparently # Handle both ToolResult objects and dict responses if hasattr(result, 'content'): # ToolResult object - serialize content items if they're Pydantic models content = result.content # Convert Pydantic models to dicts if isinstance(content, list): serialized_content = [] for item in content: if hasattr(item, 'model_dump'): # Pydantic v2 serialized_content.append(item.model_dump()) elif hasattr(item, 'dict'): # Pydantic v1 serialized_content.append(item.dict()) elif isinstance(item, dict): # Already a dict serialized_content.append(item) else: # Fallback: convert to text content serialized_content.append({"type": "text", "text": str(item)}) content = serialized_content return ToolExecutionResponse( content=content, isError=getattr(result, "isError", False) ).model_dump() elif isinstance(result, dict): # Already a dict - ensure it has the expected structure return ToolExecutionResponse( content=result.get("content", [{"type": "text", "text": str(result)}]), isError=result.get("isError", False) ).model_dump() else: # Wrap other return types return ToolExecutionResponse( content=[{"type": "text", "text": str(result)}], isError=False ).model_dump() except asyncio.TimeoutError: raise ToolError(f"Tool execution timed out after {timeout_ms}ms") except KeyError as e: # Server not found raise ToolError(f"Server '{server}' not found in configured servers") except RuntimeError as e: # Server unavailable or tool execution failed error_msg = str(e) if "not found" in error_msg.lower() or "unavailable" in error_msg.lower(): raise ToolError(error_msg) else: raise ToolError(f"Tool execution failed: {error_msg}") except Exception as e: # Other errors raise ToolError(f"Tool execution failed: {str(e)}") async def get_gateway_status( agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None ) -> dict: """Get gateway status, configuration state, and hot reload diagnostics. NOTE: Only available when debug mode is enabled.""" # Defensive check (middleware should have resolved agent_id) if agent_id is None: raise ToolError("Internal error: agent_id not resolved by middleware") # Get reload status if available reload_status = None if _get_reload_status_fn: try: reload_status = _get_reload_status_fn() # Convert datetime objects to ISO strings for JSON serialization if reload_status: for config_type in ["mcp_config", "gateway_rules"]: if config_type in reload_status: for key in ["last_attempt", "last_success"]: if reload_status[config_type].get(key): reload_status[config_type][key] = reload_status[config_type][key].isoformat() except Exception: reload_status = {"error": "Failed to retrieve reload status"} # Get PolicyEngine state policy_state = {} if _policy_engine: try: policy_state = { "total_agents": len(_policy_engine.agents), "agent_ids": list(_policy_engine.agents.keys()), "defaults": _policy_engine.defaults, } except Exception: policy_state = {"error": "Failed to retrieve policy state"} # Get available servers from ProxyManager (reflects hot-reload changes) available_servers = [] if _proxy_manager: try: available_servers = _proxy_manager.get_all_servers() except Exception: available_servers = [] # Get config file paths from src/config.py config_paths = {} try: from src.config import get_stored_config_paths mcp_path, rules_path = get_stored_config_paths() config_paths = { "mcp_config": mcp_path, "gateway_rules": rules_path, } except Exception: config_paths = {"error": "Failed to retrieve config paths"} return GatewayStatusResponse( reload_status=reload_status, policy_state=policy_state, available_servers=available_servers, config_paths=config_paths, message="Gateway is operational. Check reload_status for hot reload health." ).model_dump()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/roddutra/agent-mcp-gateway'

If you have feedback or need assistance with the MCP directory API, please join our Discord server