"""FastAPI application with SSE transport for MCP."""
import asyncio
import logging
import signal
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from mcp.server.sse import SseServerTransport
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.types import Receive, Scope, Send
from jana_mcp import __version__
from pydantic import SecretStr
from jana_mcp.client import JanaClient
from jana_mcp.config import Settings, get_settings
from jana_mcp.constants import (
HTTP_SERVICE_UNAVAILABLE,
HTTP_TOO_MANY_REQUESTS,
RATE_LIMIT_DEFAULT,
)
from jana_mcp.server import create_mcp_server, create_mcp_server_with_client
from jana_mcp.utils.logging import setup_logging
if TYPE_CHECKING:
from mcp.server import Server
logger = logging.getLogger(__name__)
# Rate limiter instance
limiter = Limiter(key_func=get_remote_address)
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Handle rate limit exceeded errors."""
logger.warning("Rate limit exceeded for %s on %s", get_remote_address(request), request.url.path)
return JSONResponse(
status_code=HTTP_TOO_MANY_REQUESTS,
content={
"error": "Rate limit exceeded",
"detail": str(exc.detail),
},
)
def _extract_auth_token(scope: Scope) -> tuple[str | None, str]:
"""
Extract authentication token from request.
Checks in order:
1. Authorization header (Token or Bearer prefix)
2. Query parameter (?token=xxx)
Args:
scope: ASGI scope containing request headers and query string
Returns:
Tuple of (token string or None, auth method used)
"""
from urllib.parse import parse_qsl
# First, try Authorization header (preferred, more secure)
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if auth_header.startswith("Token "):
return auth_header[6:], "header" # Remove "Token " prefix
elif auth_header.startswith("Bearer "):
return auth_header[7:], "header" # Remove "Bearer " prefix
# Fallback: check query parameters (less secure but works with all clients)
query_string = scope.get("query_string", b"").decode()
if query_string:
params = dict(parse_qsl(query_string))
token = params.get("token")
if token:
return token, "query_param"
return None, "none"
def _create_session_client(
user_token: str | None,
settings: Settings,
shared_client: JanaClient,
) -> tuple[JanaClient, bool]:
"""
Create a client for the current session.
Args:
user_token: User's authentication token (if provided)
settings: Application settings
shared_client: Shared client for fallback
Returns:
Tuple of (client to use, whether it's a per-session client that needs cleanup)
"""
if user_token:
# Create per-session client with user's token
session_settings = Settings(
jana_backend_url=settings.jana_backend_url,
jana_token=SecretStr(user_token),
jana_timeout=settings.jana_timeout,
jana_host_header=settings.jana_host_header,
)
logger.debug("Creating per-session client with user token")
return JanaClient(session_settings), True
else:
# Use shared client (backward compatibility)
logger.debug("Using shared client (no user token provided)")
return shared_client, False
async def _graceful_shutdown(app: FastAPI, sig: signal.Signals) -> None:
"""Handle graceful shutdown on signal."""
logger.info("Received signal %s, initiating graceful shutdown...", sig.name)
shutdown_event: asyncio.Event | None = getattr(app.state, "shutdown_event", None)
active_connections: set[asyncio.Task[Any]] = getattr(app.state, "active_connections", set())
if shutdown_event:
shutdown_event.set()
# Wait for active connections to drain (max 30 seconds)
if active_connections:
logger.info("Waiting for %d active connections to close...", len(active_connections))
try:
# Give connections 30 seconds to close gracefully
await asyncio.wait_for(
asyncio.gather(*active_connections, return_exceptions=True),
timeout=30.0,
)
except asyncio.TimeoutError:
logger.warning("Timeout waiting for connections, forcing shutdown")
for task in active_connections:
task.cancel()
logger.info("Graceful shutdown complete")
def _create_shutdown_handler(app: FastAPI, sig: signal.Signals) -> Callable[[], None]:
"""Create a shutdown handler for the given signal."""
def handler() -> None:
asyncio.create_task(_graceful_shutdown(app, sig))
return handler
def _is_uvicorn_reload_mode() -> bool:
"""Check if running under uvicorn's reload mode."""
import sys
# Check for --reload in command line args
return any("--reload" in arg for arg in sys.argv)
def _setup_signal_handlers(app: FastAPI, loop: asyncio.AbstractEventLoop) -> None:
"""Setup signal handlers for graceful shutdown.
Skip when running under uvicorn's reloader to avoid interfering
with its reload mechanism.
"""
if _is_uvicorn_reload_mode():
logger.info("Skipping custom signal handlers (uvicorn reload mode)")
return
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, _create_shutdown_handler(app, sig))
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan handler for startup/shutdown."""
settings = get_settings()
# Setup logging
setup_logging(settings.log_level)
logger.info("Starting Jana MCP Server v%s", __version__)
logger.info("Backend URL: %s", settings.jana_backend_url)
# Initialize shutdown event and active connections in app.state (not module-level!)
# This ensures they're properly scoped per app instance and reset on reload
app.state.shutdown_event = asyncio.Event()
app.state.active_connections = set()
# Setup signal handlers for graceful shutdown
try:
loop = asyncio.get_running_loop()
_setup_signal_handlers(app, loop)
logger.info("Signal handlers configured for graceful shutdown")
except (ValueError, RuntimeError) as e:
logger.warning("Could not setup signal handlers: %s", e)
# Create MCP server and client
mcp_server, jana_client = create_mcp_server(settings)
# Create SSE transport
sse_transport = SseServerTransport("/messages")
# Store in app.state for access in endpoints
app.state.mcp_server = mcp_server
app.state.jana_client = jana_client
app.state.sse_transport = sse_transport
logger.info("MCP server initialized and ready")
yield
# Cleanup
logger.info("Shutting down Jana MCP Server")
# Cancel any remaining active connections
active_connections: set[asyncio.Task[Any]] = getattr(app.state, "active_connections", set())
for task in active_connections:
task.cancel()
if jana_client:
await jana_client.close()
def create_app(settings: Settings | None = None) -> FastAPI:
"""
Create the FastAPI application.
Args:
settings: Application settings. If None, uses environment settings.
Returns:
Configured FastAPI application
"""
if settings is None:
settings = get_settings()
app = FastAPI(
title="Jana MCP Server",
description="MCP Server for Jana Environmental Data Platform",
version=__version__,
lifespan=lifespan,
)
# Add rate limiter to app state
app.state.limiter = limiter
# Add rate limit exceeded handler
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore[arg-type]
# Root endpoint
@app.get("/")
@limiter.limit(RATE_LIMIT_DEFAULT)
async def root(request: Request) -> dict[str, Any]:
"""Server information endpoint."""
return {
"name": "Jana MCP Server",
"version": __version__,
"description": "Environmental data access via Model Context Protocol",
"endpoints": {
"sse": "/sse",
"messages": "/messages",
"health": "/health",
},
}
# Health check endpoint (no rate limit - monitoring needs access)
@app.get("/health")
async def health() -> dict[str, Any]:
"""Health check endpoint."""
backend_status = "unknown"
overall_status = "healthy"
jana_client: JanaClient | None = getattr(app.state, "jana_client", None)
if jana_client:
try:
health_result = await jana_client.check_health()
backend_status = health_result.get("status", "unknown")
if backend_status in ("unreachable", "unhealthy", "error"):
overall_status = "degraded"
except httpx.RequestError as e:
logger.warning("Backend health check failed: %s", e, exc_info=True)
backend_status = "unreachable"
overall_status = "degraded"
except (KeyError, ValueError, TypeError) as e:
logger.warning("Backend health check data error: %s", e, exc_info=True)
backend_status = "error"
overall_status = "degraded"
active_connections: set[asyncio.Task[Any]] = getattr(app.state, "active_connections", set())
return {
"status": overall_status,
"version": __version__,
"active_connections": len(active_connections),
"backend": backend_status,
}
# Raw ASGI handler for SSE connections
# This bypasses FastAPI's response handling - MCP's connect_sse handles ASGI directly
async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None:
"""Raw ASGI handler for SSE endpoint.
MCP's SseServerTransport.connect_sse() handles the entire ASGI lifecycle.
We use a raw handler to avoid FastAPI trying to send a second response.
Supports per-user authentication via Authorization header:
- If user provides token: creates per-session client with user's credentials
- If no token: uses shared client (backward compatibility)
"""
logger.info("New SSE connection")
sse_transport: SseServerTransport | None = getattr(app.state, "sse_transport", None)
shared_client: JanaClient | None = getattr(app.state, "jana_client", None)
if not sse_transport or not shared_client:
logger.error("MCP server not initialized")
response = JSONResponse(
status_code=HTTP_SERVICE_UNAVAILABLE,
content={"error": "MCP server not initialized"},
)
await response(scope, receive, send)
return
# Check for shutdown
shutdown_event: asyncio.Event | None = getattr(app.state, "shutdown_event", None)
if shutdown_event and shutdown_event.is_set():
logger.info("Shutdown in progress, rejecting new connection")
response = JSONResponse(
status_code=HTTP_SERVICE_UNAVAILABLE,
content={"error": "Server shutting down"},
)
await response(scope, receive, send)
return
# Extract user token from header or query param (per-user auth)
user_token, auth_method = _extract_auth_token(scope)
if user_token:
if auth_method == "header":
logger.info("Per-user auth via Authorization header")
else:
logger.info("Per-user auth via query parameter (less secure)")
else:
logger.debug("No user token, using shared client")
# Create session-specific client (or use shared)
session_client, is_session_client = _create_session_client(
user_token, settings, shared_client
)
# Create per-session MCP server with this client
session_mcp_server = create_mcp_server_with_client(session_client)
# Track this connection
active_connections: set[asyncio.Task[Any]] = getattr(app.state, "active_connections", set())
current_task = asyncio.current_task()
if current_task:
active_connections.add(current_task)
try:
# MCP's connect_sse handles the entire SSE response lifecycle
# It sends ASGI messages directly - no response object needed
logger.info("Starting MCP SSE connection...")
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
logger.info("SSE connected, starting MCP server run...")
await session_mcp_server.run(
read_stream,
write_stream,
session_mcp_server.create_initialization_options(),
)
logger.info("MCP server run completed")
except asyncio.CancelledError:
logger.info("SSE connection cancelled (graceful shutdown)")
except (httpx.RequestError, ConnectionError, OSError) as e:
logger.warning("SSE connection error: %s", e)
except Exception as e:
logger.exception("Unexpected SSE error: %s", e)
finally:
if current_task:
active_connections.discard(current_task)
# Clean up per-session client (don't close shared client)
if is_session_client:
logger.debug("Closing per-session client")
await session_client.close()
# Raw ASGI handler for MCP messages
async def handle_messages(scope: Scope, receive: Receive, send: Send) -> None:
"""Raw ASGI handler for messages endpoint.
MCP's handle_post_message handles the entire ASGI lifecycle.
We use a raw handler to avoid FastAPI trying to send a second response.
"""
sse_transport: SseServerTransport | None = getattr(app.state, "sse_transport", None)
if not sse_transport:
response = JSONResponse(
status_code=HTTP_SERVICE_UNAVAILABLE,
content={"error": "MCP server not initialized"},
)
await response(scope, receive, send)
return
# MCP's handle_post_message handles the response directly via ASGI send
await sse_transport.handle_post_message(scope, receive, send)
# Store the original ASGI app
original_app = app.router
# Create middleware that intercepts /sse and /messages before FastAPI routing
async def mcp_middleware(scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI middleware that routes MCP endpoints to raw handlers."""
if scope["type"] == "http":
path = scope.get("path", "")
method = scope.get("method", "GET")
if path == "/sse" and method == "GET":
await handle_sse(scope, receive, send)
return
if path == "/messages" and method == "POST":
await handle_messages(scope, receive, send)
return
# All other requests go through FastAPI
await original_app(scope, receive, send)
# Replace the app's router with our middleware
app.router = mcp_middleware # type: ignore[assignment]
return app
# Create default app instance for uvicorn
app = create_app()