"""Create a local SSE server that proxies requests to a stdio MCP server."""
import contextlib
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal
import uvicorn
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.server import Server as MCPServerSDK # Renamed to avoid conflict
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute, Mount, Route
from starlette.types import ASGIApp, Receive, Scope, Send
from .proxy_server import create_proxy_server
logger = logging.getLogger(__name__)
class AuthorizationMiddleware:
"""ASGI middleware to handle API token authorization."""
def __init__(self, app: ASGIApp, api_token: str | None = None, public_paths: set[str] | None = None, auth_required_paths: set[str] | None = None):
self.app = app
self.api_token = api_token
# Public endpoints (paths that don't require token verification)
self.public_paths = public_paths or {
"/status", # Server status check
"/docs", # API documentation
"/openapi.json", # OpenAPI specification
}
# Paths that require authentication (if specified, only these paths will require auth)
self.auth_required_paths = auth_required_paths
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process the ASGI request and check authorization if needed."""
if scope["type"] != "http":
# Non-HTTP requests (like WebSocket) pass through
await self.app(scope, receive, send)
return
# Skip verification if API token is not configured
if not self.api_token:
await self.app(scope, receive, send)
return
path = scope.get("path", "")
# Skip token verification for public endpoints
if path in self.public_paths:
await self.app(scope, receive, send)
return
# If auth_required_paths is specified, only check auth for those paths
if self.auth_required_paths is not None:
requires_auth = any(path.startswith(auth_path) for auth_path in self.auth_required_paths)
if not requires_auth:
await self.app(scope, receive, send)
return
# Check Authorization header
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization")
if not auth_header:
logger.warning("Missing Authorization header for %s", path)
await self._send_error_response(
send, 401, {
"error": "unauthorized",
"error_description": "Authorization header required",
}
)
return
auth_header_str = auth_header.decode("latin1")
if not auth_header_str.startswith("Bearer "):
logger.warning("Invalid Authorization header format for %s", path)
await self._send_error_response(
send, 401, {
"error": "invalid_request",
"error_description": "Invalid authorization format. Use 'Bearer <token>'",
}
)
return
token = auth_header_str[7:]
if token != self.api_token:
logger.warning("Invalid API token provided for %s", path)
await self._send_error_response(
send, 403, {
"error": "invalid_token",
"error_description": "Invalid API token",
}
)
return
# Proceed with request if token is valid
logger.debug("Valid API token provided for %s", path)
await self.app(scope, receive, send)
async def _send_error_response(self, send: Send, status_code: int, error_data: dict[str, Any]) -> None:
"""Send an error response."""
import json
response_body = json.dumps(error_data).encode("utf-8")
headers = [
[b"content-type", b"application/json"],
[b"content-length", str(len(response_body)).encode("ascii")],
[b"cache-control", b"no-store"],
]
# Add WWW-Authenticate header for 401 responses
if status_code == 401:
headers.append([
b"www-authenticate",
b'Bearer realm="MCP Proxy", error="invalid_token", error_description="Token authentication required"'
])
await send({
"type": "http.response.start",
"status": status_code,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": response_body,
})
@dataclass
class MCPServerSettings:
"""Settings for the MCP server."""
bind_host: str
port: int
stateless: bool = False
allow_origins: list[str] | None = None
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
api_token: str | None = None
# To store last activity for multiple servers if needed, though status endpoint is global for now.
_global_status: dict[str, Any] = {
"api_last_activity": datetime.now(timezone.utc).isoformat(),
"server_instances": {}, # Could be used to store per-instance status later
}
def _update_global_activity() -> None:
_global_status["api_last_activity"] = datetime.now(timezone.utc).isoformat()
async def _handle_status(_: Request) -> Response:
"""Global health check and service usage monitoring endpoint."""
return JSONResponse(_global_status)
def create_single_instance_routes(
mcp_server_instance: MCPServerSDK[object],
*,
stateless_instance: bool,
) -> tuple[list[BaseRoute], StreamableHTTPSessionManager]: # Return the manager itself
"""Create Starlette routes and the HTTP session manager for a single MCP server instance."""
logger.debug(
"Creating routes for a single MCP server instance (stateless: %s)",
stateless_instance,
)
sse_transport = SseServerTransport("/messages/")
http_session_manager = StreamableHTTPSessionManager(
app=mcp_server_instance,
event_store=None,
json_response=True,
stateless=stateless_instance,
)
async def handle_sse_instance(request: Request) -> Response:
async with sse_transport.connect_sse(
request.scope,
request.receive,
request._send, # noqa: SLF001
) as (read_stream, write_stream):
_update_global_activity()
await mcp_server_instance.run(
read_stream,
write_stream,
mcp_server_instance.create_initialization_options(),
)
# Return empty response after SSE connection completes
return Response(content="", status_code=200)
async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: Send) -> None:
_update_global_activity()
await http_session_manager.handle_request(scope, receive, send)
routes = [
Mount("/mcp", app=handle_streamable_http_instance),
Route("/sse", endpoint=handle_sse_instance),
Mount("/messages/", app=sse_transport.handle_post_message),
]
return routes, http_session_manager
async def run_mcp_server(
mcp_settings: MCPServerSettings,
default_server_params: StdioServerParameters | None = None,
named_servers: dict[str, tuple[StdioServerParameters, bool]] | None = None,
) -> None:
"""Run stdio client(s) and expose an MCP server with multiple possible backends.
Args:
mcp_settings: Server configuration
default_server_params: Optional default server parameters
named_servers: Dict mapping name to (params, auth_required) tuples
"""
if named_servers is None:
named_servers = {}
all_routes: list[BaseRoute] = [
Route("/status", endpoint=_handle_status), # Global status endpoint
]
# Use AsyncExitStack to manage lifecycles of multiple components
async with contextlib.AsyncExitStack() as stack:
# Manage lifespans of all StreamableHTTPSessionManagers
@contextlib.asynccontextmanager
async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]:
logger.info("Main application lifespan starting...")
# All http_session_managers' .run() are already entered into the stack
yield
logger.info("Main application lifespan shutting down...")
# Setup default server if configured
if default_server_params:
logger.info(
"Setting up default server: %s %s",
default_server_params.command,
" ".join(default_server_params.args),
)
stdio_streams = await stack.enter_async_context(stdio_client(default_server_params))
session = await stack.enter_async_context(ClientSession(*stdio_streams))
proxy = await create_proxy_server(session)
instance_routes, http_manager = create_single_instance_routes(
proxy,
stateless_instance=mcp_settings.stateless,
)
await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run()
all_routes.extend(instance_routes)
_global_status["server_instances"]["default"] = "configured"
# Setup named servers
for name, (params, auth_required) in named_servers.items():
logger.info(
"Setting up named server '%s': %s %s",
name,
params.command,
" ".join(params.args),
)
stdio_streams_named = await stack.enter_async_context(stdio_client(params))
session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named))
proxy_named = await create_proxy_server(session_named)
instance_routes_named, http_manager_named = create_single_instance_routes(
proxy_named,
stateless_instance=mcp_settings.stateless,
)
await stack.enter_async_context(
http_manager_named.run(),
) # Manage lifespan by calling run()
# Mount these routes under /servers/<name>/
server_mount = Mount(f"/servers/{name}", routes=instance_routes_named)
all_routes.append(server_mount)
_global_status["server_instances"][name] = "configured"
if not default_server_params and not named_servers:
logger.error("No servers configured to run.")
return
middleware: list[Middleware] = []
# Add Authorization middleware if API token is configured
if mcp_settings.api_token:
# Collect paths that require authentication
auth_required_paths = {f"/servers/{name}/" for name, (_, auth_required) in named_servers.items() if auth_required}
# If we have specific auth-required paths, use them; otherwise use global auth (for backward compatibility)
if auth_required_paths:
middleware.append(
Middleware(
AuthorizationMiddleware,
api_token=mcp_settings.api_token,
auth_required_paths=auth_required_paths,
),
)
logger.info("API token authentication enabled for paths: %s", ', '.join(auth_required_paths))
else:
middleware.append(
Middleware(
AuthorizationMiddleware,
api_token=mcp_settings.api_token,
),
)
logger.info("API token authentication enabled globally")
# Add CORS middleware
if mcp_settings.allow_origins:
middleware.append(
Middleware(
CORSMiddleware,
allow_origins=mcp_settings.allow_origins,
allow_methods=["*"],
allow_headers=["*"],
),
)
starlette_app = Starlette(
debug=(mcp_settings.log_level == "DEBUG"),
routes=all_routes,
middleware=middleware,
lifespan=combined_lifespan,
)
config = uvicorn.Config(
starlette_app,
host=mcp_settings.bind_host,
port=mcp_settings.port,
log_level=mcp_settings.log_level.lower(),
)
http_server = uvicorn.Server(config)
# Print out the SSE URLs for all configured servers
base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}"
sse_urls = []
# Add default server if configured
if default_server_params:
sse_urls.append(f"{base_url}/sse")
# Add named servers
sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_servers])
# Display the SSE URLs prominently
if sse_urls:
# Using print directly for user visibility, with noqa to ignore linter warnings
logger.info("Serving MCP Servers via SSE:")
for url in sse_urls:
logger.info(" - %s", url)
logger.debug(
"Serving incoming MCP requests on %s:%s",
mcp_settings.bind_host,
mcp_settings.port,
)
await http_server.serve()