"""HTTP transport wrapper for the Stepstone MCP server.
This module exposes a Starlette application that proxies HTTP traffic to the
existing STDIO-based MCP server implementation. The adapter enables browsers or
remote clients to communicate with the server while satisfying CORS
requirements enforced by hosted MCP environments.
"""
from __future__ import annotations
import logging
import os
from contextlib import asynccontextmanager
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from starlette.types import Receive, Scope, Send
import stepstone_server
logger = logging.getLogger("stepstone-http-server")
def _cors_headers(
origin: bytes | None = None,
allow_headers: bytes | None = None,
) -> list[tuple[bytes, bytes]]:
"""Headers applied to every HTTP response."""
headers: list[tuple[bytes, bytes]] = [
(b"access-control-allow-origin", origin or b"*"),
(b"access-control-allow-methods", b"GET,POST,DELETE,OPTIONS"),
(b"access-control-allow-headers", allow_headers or b"*"),
(b"access-control-expose-headers", b"mcp-session-id"),
]
if origin:
headers.append((b"access-control-allow-credentials", b"true"))
headers.append((b"vary", b"origin"))
return headers
def _cors_preflight_headers(
origin: bytes | None,
allow_headers: bytes | None,
) -> list[tuple[bytes, bytes]]:
"""Headers returned for CORS preflight responses."""
headers = list(_cors_headers(origin, allow_headers))
headers.append((b"access-control-max-age", b"86400"))
return headers
def _merge_headers(
existing: list[tuple[bytes, bytes]], additions: list[tuple[bytes, bytes]]
) -> list[tuple[bytes, bytes]]:
"""Merge HTTP headers while overriding duplicates using case-insensitive keys."""
header_map: dict[bytes, tuple[bytes, bytes]] = {
key.lower(): (key, value) for key, value in existing
}
for key, value in additions:
header_map[key.lower()] = (key, value)
return list(header_map.values())
def _update_accept_header(headers: list[tuple[bytes, bytes]]) -> None:
"""Ensure the Accept header allows both JSON and SSE responses."""
required_types = ("application/json", "text/event-stream")
header_index = next(
(idx for idx, (name, _) in enumerate(headers) if name.lower() == b"accept"),
None,
)
if header_index is None:
headers.append((b"accept", b", ".join(value.encode("latin-1") for value in required_types)))
return
name, value = headers[header_index]
media_types = [token.strip() for token in value.decode("latin-1").split(",") if token.strip()]
def satisfies(token: str, target: str) -> bool:
base = token.split(";", 1)[0].strip()
if base == "*/*":
return False
if base.endswith("/*"):
prefix = base.split("/", 1)[0]
return target.startswith(f"{prefix}/")
return base == target
updated = False
for required in required_types:
if not any(satisfies(token, required) for token in media_types):
media_types.append(required)
updated = True
if updated:
new_value = ", ".join(media_types).encode("latin-1")
headers[header_index] = (name, new_value)
def _ensure_required_headers(scope: Scope) -> Scope:
"""Return a scope copy with Accept and Content-Type headers normalised."""
headers = list(scope.get("headers", []))
headers = [(name, value) for name, value in headers]
_update_accept_header(headers)
method = scope.get("method", "").upper()
has_content_type = any(name.lower() == b"content-type" for name, _ in headers)
if method == "POST" and not has_content_type:
headers.append((b"content-type", b"application/json"))
updated_scope = dict(scope)
updated_scope["headers"] = headers
return updated_scope
def create_app() -> Starlette:
"""Create a Starlette application exposing the MCP server over HTTP."""
session_manager = StreamableHTTPSessionManager(
stepstone_server.server,
json_response=True,
stateless=False,
)
class StreamableEndpoint:
def __init__(self, manager: StreamableHTTPSessionManager):
self._manager = manager
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
raise RuntimeError("Unsupported ASGI scope type")
method = scope.get("method", "GET").upper()
origin_header = next(
(
value
for name, value in scope.get("headers", [])
if name.lower() == b"origin"
),
None,
)
request_headers = next(
(
value
for name, value in scope.get("headers", [])
if name.lower() == b"access-control-request-headers"
),
None,
)
if method == "OPTIONS":
logger.debug("Handling CORS preflight for %s", scope.get("path"))
await send(
{
"type": "http.response.start",
"status": 204,
"headers": _cors_preflight_headers(origin_header, request_headers),
}
)
await send({"type": "http.response.body", "body": b""})
return
async def send_with_cors(message):
if message["type"] == "http.response.start":
headers = message.setdefault("headers", [])
message["headers"] = _merge_headers(
headers, _cors_headers(origin_header)
)
await send(message)
normalized_scope = _ensure_required_headers(scope)
await self._manager.handle_request(normalized_scope, receive, send_with_cors)
async def homepage(request: Request):
origin = request.headers.get("origin")
origin_bytes = origin.encode("latin-1") if origin else None
allow_headers = request.headers.get("access-control-request-headers")
allow_headers_bytes = (
allow_headers.encode("latin-1") if allow_headers else None
)
if request.method == "OPTIONS":
response: Response = Response(status_code=204)
cors_headers = _cors_preflight_headers(origin_bytes, allow_headers_bytes)
else:
# Provide an absolute MCP endpoint URL so hosted environments can
# connect without guessing the server's origin. Some scanners only
# consume the absolute URL, so keep the relative path for backwards
# compatibility alongside the fully-qualified endpoint.
mcp_url = str(request.url_for("mcp"))
response = JSONResponse(
{
"status": "ok",
"message": "Stepstone MCP HTTP endpoint",
"endpoints": {"mcp": mcp_url, "mcpPath": "/mcp"},
}
)
cors_headers = _cors_headers(origin_bytes)
for key, value in cors_headers:
response.headers[key.decode("ascii")] = value.decode("latin-1")
return response
async def healthcheck(request: Request):
response = JSONResponse({"status": "ok"})
for key, value in _cors_headers():
response.headers[key.decode("ascii")] = value.decode("ascii")
return response
@asynccontextmanager
async def lifespan(app):
async with session_manager.run():
yield
streamable_endpoint = StreamableEndpoint(session_manager)
routes = [
Route("/", homepage, methods=["GET", "OPTIONS"]),
Route("/health", healthcheck, methods=["GET"]),
Route("/healthz", healthcheck, methods=["GET"]),
Route(
"/mcp",
streamable_endpoint,
methods=["GET", "POST", "DELETE", "OPTIONS"],
name="mcp",
),
Route(
"/mcp/",
streamable_endpoint,
methods=["GET", "POST", "DELETE", "OPTIONS"],
name="mcp_slash",
),
]
return Starlette(routes=routes, lifespan=lifespan)
app = create_app()
def main() -> None:
import uvicorn
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
logger.info("Starting Stepstone MCP HTTP server on %s:%s", host, port)
uvicorn.run(app, host=host, port=port, log_level=os.environ.get("LOG_LEVEL", "info"))
if __name__ == "__main__":
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
main()