Skip to main content
Glama
mcp_server.py14.1 kB
"""Create a local SSE server that proxies requests to a stdio MCP server.""" import contextlib import logging from collections.abc import AsyncIterator from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Literal import uvicorn from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.server import Server as MCPServerSDK # Renamed to avoid conflict from mcp.server.sse import SseServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import BaseRoute, Mount, Route from starlette.types import ASGIApp, Receive, Scope, Send from .proxy_server import create_proxy_server logger = logging.getLogger(__name__) class AuthorizationMiddleware: """ASGI middleware to handle API token authorization.""" def __init__(self, app: ASGIApp, api_token: str | None = None, public_paths: set[str] | None = None, auth_required_paths: set[str] | None = None): self.app = app self.api_token = api_token # Public endpoints (paths that don't require token verification) self.public_paths = public_paths or { "/status", # Server status check "/docs", # API documentation "/openapi.json", # OpenAPI specification } # Paths that require authentication (if specified, only these paths will require auth) self.auth_required_paths = auth_required_paths async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Process the ASGI request and check authorization if needed.""" if scope["type"] != "http": # Non-HTTP requests (like WebSocket) pass through await self.app(scope, receive, send) return # Skip verification if API token is not configured if not self.api_token: await self.app(scope, receive, send) return path = scope.get("path", "") # Skip token verification for public endpoints if path in self.public_paths: await self.app(scope, receive, send) return # If auth_required_paths is specified, only check auth for those paths if self.auth_required_paths is not None: requires_auth = any(path.startswith(auth_path) for auth_path in self.auth_required_paths) if not requires_auth: await self.app(scope, receive, send) return # Check Authorization header headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization") if not auth_header: logger.warning("Missing Authorization header for %s", path) await self._send_error_response( send, 401, { "error": "unauthorized", "error_description": "Authorization header required", } ) return auth_header_str = auth_header.decode("latin1") if not auth_header_str.startswith("Bearer "): logger.warning("Invalid Authorization header format for %s", path) await self._send_error_response( send, 401, { "error": "invalid_request", "error_description": "Invalid authorization format. Use 'Bearer <token>'", } ) return token = auth_header_str[7:] if token != self.api_token: logger.warning("Invalid API token provided for %s", path) await self._send_error_response( send, 403, { "error": "invalid_token", "error_description": "Invalid API token", } ) return # Proceed with request if token is valid logger.debug("Valid API token provided for %s", path) await self.app(scope, receive, send) async def _send_error_response(self, send: Send, status_code: int, error_data: dict[str, Any]) -> None: """Send an error response.""" import json response_body = json.dumps(error_data).encode("utf-8") headers = [ [b"content-type", b"application/json"], [b"content-length", str(len(response_body)).encode("ascii")], [b"cache-control", b"no-store"], ] # Add WWW-Authenticate header for 401 responses if status_code == 401: headers.append([ b"www-authenticate", b'Bearer realm="MCP Proxy", error="invalid_token", error_description="Token authentication required"' ]) await send({ "type": "http.response.start", "status": status_code, "headers": headers, }) await send({ "type": "http.response.body", "body": response_body, }) @dataclass class MCPServerSettings: """Settings for the MCP server.""" bind_host: str port: int stateless: bool = False allow_origins: list[str] | None = None log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" api_token: str | None = None # To store last activity for multiple servers if needed, though status endpoint is global for now. _global_status: dict[str, Any] = { "api_last_activity": datetime.now(timezone.utc).isoformat(), "server_instances": {}, # Could be used to store per-instance status later } def _update_global_activity() -> None: _global_status["api_last_activity"] = datetime.now(timezone.utc).isoformat() async def _handle_status(_: Request) -> Response: """Global health check and service usage monitoring endpoint.""" return JSONResponse(_global_status) def create_single_instance_routes( mcp_server_instance: MCPServerSDK[object], *, stateless_instance: bool, ) -> tuple[list[BaseRoute], StreamableHTTPSessionManager]: # Return the manager itself """Create Starlette routes and the HTTP session manager for a single MCP server instance.""" logger.debug( "Creating routes for a single MCP server instance (stateless: %s)", stateless_instance, ) sse_transport = SseServerTransport("/messages/") http_session_manager = StreamableHTTPSessionManager( app=mcp_server_instance, event_store=None, json_response=True, stateless=stateless_instance, ) async def handle_sse_instance(request: Request) -> Response: async with sse_transport.connect_sse( request.scope, request.receive, request._send, # noqa: SLF001 ) as (read_stream, write_stream): _update_global_activity() await mcp_server_instance.run( read_stream, write_stream, mcp_server_instance.create_initialization_options(), ) # Return empty response after SSE connection completes return Response(content="", status_code=200) async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: Send) -> None: _update_global_activity() await http_session_manager.handle_request(scope, receive, send) routes = [ Mount("/mcp", app=handle_streamable_http_instance), Route("/sse", endpoint=handle_sse_instance), Mount("/messages/", app=sse_transport.handle_post_message), ] return routes, http_session_manager async def run_mcp_server( mcp_settings: MCPServerSettings, default_server_params: StdioServerParameters | None = None, named_servers: dict[str, tuple[StdioServerParameters, bool]] | None = None, ) -> None: """Run stdio client(s) and expose an MCP server with multiple possible backends. Args: mcp_settings: Server configuration default_server_params: Optional default server parameters named_servers: Dict mapping name to (params, auth_required) tuples """ if named_servers is None: named_servers = {} all_routes: list[BaseRoute] = [ Route("/status", endpoint=_handle_status), # Global status endpoint ] # Use AsyncExitStack to manage lifecycles of multiple components async with contextlib.AsyncExitStack() as stack: # Manage lifespans of all StreamableHTTPSessionManagers @contextlib.asynccontextmanager async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]: logger.info("Main application lifespan starting...") # All http_session_managers' .run() are already entered into the stack yield logger.info("Main application lifespan shutting down...") # Setup default server if configured if default_server_params: logger.info( "Setting up default server: %s %s", default_server_params.command, " ".join(default_server_params.args), ) stdio_streams = await stack.enter_async_context(stdio_client(default_server_params)) session = await stack.enter_async_context(ClientSession(*stdio_streams)) proxy = await create_proxy_server(session) instance_routes, http_manager = create_single_instance_routes( proxy, stateless_instance=mcp_settings.stateless, ) await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run() all_routes.extend(instance_routes) _global_status["server_instances"]["default"] = "configured" # Setup named servers for name, (params, auth_required) in named_servers.items(): logger.info( "Setting up named server '%s': %s %s", name, params.command, " ".join(params.args), ) stdio_streams_named = await stack.enter_async_context(stdio_client(params)) session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named)) proxy_named = await create_proxy_server(session_named) instance_routes_named, http_manager_named = create_single_instance_routes( proxy_named, stateless_instance=mcp_settings.stateless, ) await stack.enter_async_context( http_manager_named.run(), ) # Manage lifespan by calling run() # Mount these routes under /servers/<name>/ server_mount = Mount(f"/servers/{name}", routes=instance_routes_named) all_routes.append(server_mount) _global_status["server_instances"][name] = "configured" if not default_server_params and not named_servers: logger.error("No servers configured to run.") return middleware: list[Middleware] = [] # Add Authorization middleware if API token is configured if mcp_settings.api_token: # Collect paths that require authentication auth_required_paths = {f"/servers/{name}/" for name, (_, auth_required) in named_servers.items() if auth_required} # If we have specific auth-required paths, use them; otherwise use global auth (for backward compatibility) if auth_required_paths: middleware.append( Middleware( AuthorizationMiddleware, api_token=mcp_settings.api_token, auth_required_paths=auth_required_paths, ), ) logger.info("API token authentication enabled for paths: %s", ', '.join(auth_required_paths)) else: middleware.append( Middleware( AuthorizationMiddleware, api_token=mcp_settings.api_token, ), ) logger.info("API token authentication enabled globally") # Add CORS middleware if mcp_settings.allow_origins: middleware.append( Middleware( CORSMiddleware, allow_origins=mcp_settings.allow_origins, allow_methods=["*"], allow_headers=["*"], ), ) starlette_app = Starlette( debug=(mcp_settings.log_level == "DEBUG"), routes=all_routes, middleware=middleware, lifespan=combined_lifespan, ) config = uvicorn.Config( starlette_app, host=mcp_settings.bind_host, port=mcp_settings.port, log_level=mcp_settings.log_level.lower(), ) http_server = uvicorn.Server(config) # Print out the SSE URLs for all configured servers base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}" sse_urls = [] # Add default server if configured if default_server_params: sse_urls.append(f"{base_url}/sse") # Add named servers sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_servers]) # Display the SSE URLs prominently if sse_urls: # Using print directly for user visibility, with noqa to ignore linter warnings logger.info("Serving MCP Servers via SSE:") for url in sse_urls: logger.info(" - %s", url) logger.debug( "Serving incoming MCP requests on %s:%s", mcp_settings.bind_host, mcp_settings.port, ) await http_server.serve()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/gws8820/secure-mcp-proxy'

If you have feedback or need assistance with the MCP directory API, please join our Discord server