mcp_transport.py•24.1 kB
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Roger Gujord
# https://github.com/gujord/OpenAPI-MCP
"""
MCP-compliant HTTP Stream Transport implementation.
Follows the official MCP specification for HTTP streaming transport.
"""
import asyncio
import json
import logging
import time
import uuid
from typing import Dict, Any, Optional, AsyncGenerator, Callable, List
from dataclasses import dataclass, field
from enum import Enum
from starlette.applications import Starlette
from starlette.responses import StreamingResponse, JSONResponse
from starlette.requests import Request
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route
import uvicorn
try:
    from .exceptions import RequestExecutionError, ParameterError
except ImportError:
    from exceptions import RequestExecutionError, ParameterError
@dataclass
class MCPSession:
    """Represents an MCP session with unique session ID."""
    session_id: str
    created_at: float = field(default_factory=time.time)
    last_activity: float = field(default_factory=time.time)
    message_history: List[Dict[str, Any]] = field(default_factory=list)
    active: bool = True
    
    def update_activity(self):
        """Update last activity timestamp."""
        self.last_activity = time.time()
    
    def add_message(self, message: Dict[str, Any]):
        """Add message to history."""
        self.message_history.append({
            **message,
            "timestamp": time.time()
        })
    
    def is_expired(self, max_age: int = 3600) -> bool:
        """Check if session is expired."""
        return (time.time() - self.last_activity) > max_age
class MCPTransportMode(Enum):
    """MCP transport response modes."""
    BATCH = "batch"
    STREAMING = "streaming"
class MCPHttpTransport:
    """
    MCP-compliant HTTP Stream Transport.
    
    Implements the official MCP HTTP streaming transport specification:
    - Single HTTP endpoint for all MCP communication
    - JSON-RPC 2.0 message format
    - Session management with unique session IDs
    - Batch and streaming response modes
    - Server-Sent Events for streaming responses
    """
    
    def __init__(
        self,
        mcp_server,
        host: str = "127.0.0.1",
        port: int = 8000,
        cors_origins: List[str] = None,
        message_size_limit: str = "4mb",
        batch_timeout: int = 30,
        session_timeout: int = 3600
    ):
        self.mcp_server = mcp_server
        self.host = host
        self.port = port
        self.cors_origins = cors_origins or ["*"]
        self.message_size_limit = message_size_limit
        self.batch_timeout = batch_timeout
        self.session_timeout = session_timeout
        
        # Session management
        self.sessions: Dict[str, MCPSession] = {}
        self.cleanup_task: Optional[asyncio.Task] = None
        
        # Create Starlette app
        self.app = self._create_app()
        self.server = None
        
        logging.info(f"MCP HTTP Transport initialized on {host}:{port}")
    
    def _create_app(self) -> Starlette:
        """Create the Starlette application with MCP endpoints."""
        
        routes = [
            # Standard MCP endpoints for mcp-remote clients
            Route("/mcp", self._handle_mcp_request, methods=["POST", "OPTIONS"]),
            Route("/sse", self._handle_mcp_sse, methods=["GET"]),  # mcp-remote SSE endpoint
            
            # Session-based endpoints (fallback)
            Route("/mcp/sse/{session_id}", self._handle_sse_stream, methods=["GET"]),
            Route("/mcp/sessions/{session_id}", self._handle_session_delete, methods=["DELETE"]),
            
            # Health and info endpoints
            Route("/mcp/health", self._handle_health, methods=["GET"]),
            Route("/health", self._handle_health, methods=["GET"]),  # Alternative health endpoint
        ]
        
        app = Starlette(routes=routes)
        
        # Add CORS middleware
        app.add_middleware(
            CORSMiddleware,
            allow_origins=self.cors_origins,
            allow_credentials=True,
            allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
            allow_headers=["*"],
        )
        
        return app
    
    async def _handle_mcp_request(self, request: Request):
        """
        Handle MCP JSON-RPC requests via HTTP POST.
        
        This is the main endpoint for all MCP communication.
        Supports both batch and streaming response modes.
        """
        if request.method == "OPTIONS":
            return JSONResponse({"status": "ok"})
        
        try:
            # Get or create session
            session_id = request.headers.get("Mcp-Session-Id")
            if not session_id:
                session_id = str(uuid.uuid4())
                session = MCPSession(session_id=session_id)
                self.sessions[session_id] = session
            else:
                session = self.sessions.get(session_id)
                if not session or not session.active:
                    return JSONResponse(
                        {"error": "Invalid or expired session"},
                        status_code=404
                    )
            
            session.update_activity()
            
            # Parse JSON-RPC request
            body = await request.body()
            if len(body) > self._parse_size_limit():
                return JSONResponse(
                    {"error": "Request too large"},
                    status_code=413
                )
            
            try:
                rpc_request = json.loads(body.decode('utf-8'))
            except json.JSONDecodeError as e:
                return JSONResponse(
                    {
                        "jsonrpc": "2.0",
                        "error": {
                            "code": -32700,
                            "message": "Parse error",
                            "data": str(e)
                        }
                    },
                    status_code=400
                )
            
            # Add to session history
            session.add_message(rpc_request)
            
            # Determine response mode from request headers or query params
            response_mode = self._get_response_mode(request)
            
            if response_mode == MCPTransportMode.STREAMING:
                # Return streaming response info
                return JSONResponse({
                    "jsonrpc": "2.0",
                    "id": rpc_request.get("id"),
                    "result": {
                        "session_id": session_id,
                        "stream_url": f"/mcp/sse/{session_id}",
                        "transport_mode": "streaming"
                    }
                }, headers={"Mcp-Session-Id": session_id})
            else:
                # Handle batch mode - process request immediately
                response = await self._process_mcp_request(rpc_request, session)
                session.add_message(response)
                
                return JSONResponse(
                    response,
                    headers={"Mcp-Session-Id": session_id}
                )
                
        except Exception as e:
            logging.error(f"MCP request handling error: {e}")
            return JSONResponse(
                {
                    "jsonrpc": "2.0",
                    "error": {
                        "code": -32603,
                        "message": "Internal error",
                        "data": str(e)
                    }
                },
                status_code=500
            )
    
    async def _handle_sse_stream(self, request: Request):
        """
        Handle Server-Sent Events stream for MCP responses.
        
        Streams JSON-RPC responses via SSE according to MCP specification.
        """
        session_id = request.path_params["session_id"]
        session = self.sessions.get(session_id)
        
        if not session or not session.active:
            return JSONResponse(
                {"error": "Invalid or expired session"},
                status_code=404
            )
        
        session.update_activity()
        
        async def event_generator():
            """Generate SSE events with MCP JSON-RPC responses."""
            
            # Send connection established event
            yield self._format_sse_event(
                "connected",
                {
                    "session_id": session_id,
                    "transport": "mcp-http-sse",
                    "connected_at": time.time()
                }
            )
            
            try:
                # Wait for requests to process
                last_processed = len(session.message_history)
                
                while session.active:
                    # Check for new messages to process
                    current_messages = len(session.message_history)
                    
                    if current_messages > last_processed:
                        # Process new messages
                        for i in range(last_processed, current_messages):
                            message = session.message_history[i]
                            
                            # Skip responses (only process requests)
                            if "method" in message:
                                try:
                                    response = await self._process_mcp_request(message, session)
                                    session.add_message(response)
                                    
                                    # Send response via SSE
                                    yield self._format_sse_event("message", response)
                                    
                                except Exception as e:
                                    error_response = {
                                        "jsonrpc": "2.0",
                                        "id": message.get("id"),
                                        "error": {
                                            "code": -32603,
                                            "message": "Internal error",
                                            "data": str(e)
                                        }
                                    }
                                    session.add_message(error_response)
                                    yield self._format_sse_event("error", error_response)
                        
                        last_processed = len(session.message_history)
                    
                    # Send heartbeat
                    yield self._format_sse_event(
                        "heartbeat",
                        {"timestamp": time.time(), "session_id": session_id}
                    )
                    
                    # Wait before next check
                    await asyncio.sleep(1)
                    
            except asyncio.CancelledError:
                logging.info(f"SSE stream cancelled for session {session_id}")
            except Exception as e:
                logging.error(f"SSE stream error for session {session_id}: {e}")
                yield self._format_sse_event(
                    "error",
                    {"error": str(e), "session_id": session_id}
                )
            finally:
                # Send disconnect event
                yield self._format_sse_event(
                    "disconnected",
                    {"session_id": session_id, "disconnected_at": time.time()}
                )
        
        return StreamingResponse(
            event_generator(),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Mcp-Session-Id": session_id
            }
        )
    
    async def _handle_session_delete(self, request: Request):
        """Handle session termination via DELETE request."""
        session_id = request.path_params["session_id"]
        
        if session_id in self.sessions:
            self.sessions[session_id].active = False
            del self.sessions[session_id]
            logging.info(f"Session {session_id} terminated")
            return JSONResponse({"status": "terminated"})
        
        return JSONResponse({"error": "Session not found"}, status_code=404)
    
    async def _handle_health(self, request: Request):
        """Health check endpoint."""
        return JSONResponse({
            "status": "healthy",
            "transport": "mcp-http-sse",
            "active_sessions": len([s for s in self.sessions.values() if s.active]),
            "total_sessions": len(self.sessions),
            "uptime": time.time() - getattr(self, '_start_time', time.time())
        })
    
    async def _handle_mcp_sse(self, request: Request):
        """
        Handle MCP SSE endpoint for mcp-remote clients.
        
        This is the standard /sse endpoint that mcp-remote expects for 
        Server-Sent Events communication.
        """
        # Create a new session for this SSE connection
        session_id = str(uuid.uuid4())
        session = MCPSession(session_id=session_id)
        self.sessions[session_id] = session
        session.update_activity()
        
        async def mcp_sse_generator():
            """Generate MCP-compliant SSE events for mcp-remote clients."""
            
            # Send connection established event
            yield self._format_sse_event(
                "connected",
                {
                    "transport": "mcp-http-sse",
                    "session_id": session_id,
                    "server_info": {
                        "name": self.mcp_server.server_name,
                        "version": "1.0.0"
                    },
                    "connected_at": time.time()
                }
            )
            
            try:
                # Wait for incoming JSON-RPC requests via query params or websocket
                # For mcp-remote, requests come via the main /mcp endpoint
                # and this SSE stream returns the responses
                
                last_processed = 0
                heartbeat_counter = 0
                
                while session.active:
                    # Check for new messages to process
                    current_messages = len(session.message_history)
                    
                    if current_messages > last_processed:
                        # Process new messages since last check
                        for i in range(last_processed, current_messages):
                            message = session.message_history[i]
                            
                            # Only send responses (skip requests in history)
                            if "result" in message or "error" in message:
                                yield self._format_sse_event("response", message)
                        
                        last_processed = current_messages
                    
                    # Send heartbeat every 30 seconds
                    heartbeat_counter += 1
                    if heartbeat_counter >= 30:
                        yield self._format_sse_event(
                            "heartbeat",
                            {"timestamp": time.time(), "session_id": session_id}
                        )
                        heartbeat_counter = 0
                    
                    # Wait before next check
                    await asyncio.sleep(1)
                    
            except asyncio.CancelledError:
                logging.info(f"MCP SSE stream cancelled for session {session_id}")
            except Exception as e:
                logging.error(f"MCP SSE stream error for session {session_id}: {e}")
                yield self._format_sse_event(
                    "error",
                    {"error": str(e), "session_id": session_id}
                )
            finally:
                # Clean up session
                if session_id in self.sessions:
                    self.sessions[session_id].active = False
                    del self.sessions[session_id]
                
                # Send disconnect event
                yield self._format_sse_event(
                    "disconnected",
                    {"session_id": session_id, "disconnected_at": time.time()}
                )
        
        return StreamingResponse(
            mcp_sse_generator(),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Headers": "*",
                "Mcp-Session-Id": session_id
            }
        )
    
    async def _process_mcp_request(self, rpc_request: Dict[str, Any], session: MCPSession) -> Dict[str, Any]:
        """
        Process an MCP JSON-RPC request using the MCP server.
        
        Routes the request to the appropriate MCP server method.
        """
        method = rpc_request.get("method")
        params = rpc_request.get("params", {})
        request_id = rpc_request.get("id")
        
        try:
            # Route to appropriate MCP server method
            if method == "initialize":
                return self.mcp_server._initialize_tool(req_id=request_id, **params)
            elif method == "tools/list":
                return self.mcp_server._tools_list_tool(req_id=request_id)
            elif method == "tools/call":
                return self.mcp_server._tools_call_tool(
                    req_id=request_id,
                    name=params.get("name"),
                    arguments=params.get("arguments", {})
                )
            elif method == "resources/list":
                # Return resources list
                resources = []
                if hasattr(self.mcp_server, 'resource_manager') and self.mcp_server.resource_manager:
                    for name, data in self.mcp_server.resource_manager.registered_resources.items():
                        resources.append({
                            "uri": f"/resource/{name}",
                            "name": name,
                            "description": data["metadata"]["description"],
                            "mimeType": "application/json"
                        })
                
                return {
                    "jsonrpc": "2.0",
                    "id": request_id,
                    "result": {"resources": resources}
                }
            elif method == "prompts/list":
                # Return prompts list (if implemented)
                return {
                    "jsonrpc": "2.0",
                    "id": request_id,
                    "result": {"prompts": []}
                }
            else:
                # Unknown method
                return {
                    "jsonrpc": "2.0",
                    "id": request_id,
                    "error": {
                        "code": -32601,
                        "message": f"Method not found: {method}"
                    }
                }
                
        except Exception as e:
            logging.error(f"Error processing MCP request {method}: {e}")
            return {
                "jsonrpc": "2.0",
                "id": request_id,
                "error": {
                    "code": -32603,
                    "message": "Internal error",
                    "data": str(e)
                }
            }
    
    def _get_response_mode(self, request: Request) -> MCPTransportMode:
        """Determine response mode from request headers or query parameters."""
        # Check for streaming preference in headers
        if request.headers.get("Accept") == "text/event-stream":
            return MCPTransportMode.STREAMING
        
        # Check query parameter
        if request.query_params.get("mode") == "streaming":
            return MCPTransportMode.STREAMING
        
        # Default to batch mode
        return MCPTransportMode.BATCH
    
    def _format_sse_event(
        self,
        event_type: str,
        data: Dict[str, Any],
        event_id: Optional[str] = None
    ) -> str:
        """Format data as Server-Sent Event."""
        lines = []
        
        if event_id:
            lines.append(f"id: {event_id}")
        
        lines.append(f"event: {event_type}")
        
        # Format data as JSON
        data_json = json.dumps(data)
        for line in data_json.split('\n'):
            lines.append(f"data: {line}")
        
        lines.append("")  # Empty line ends the event
        return '\n'.join(lines) + '\n'
    
    def _parse_size_limit(self) -> int:
        """Parse message size limit from string format."""
        size_str = self.message_size_limit.lower()
        if size_str.endswith('mb'):
            return int(size_str[:-2]) * 1024 * 1024
        elif size_str.endswith('kb'):
            return int(size_str[:-2]) * 1024
        else:
            return int(size_str)
    
    async def _cleanup_sessions(self):
        """Periodic cleanup of expired sessions."""
        while True:
            try:
                current_time = time.time()
                expired_sessions = []
                
                for session_id, session in self.sessions.items():
                    if session.is_expired(self.session_timeout):
                        expired_sessions.append(session_id)
                
                for session_id in expired_sessions:
                    self.sessions[session_id].active = False
                    del self.sessions[session_id]
                    logging.info(f"Cleaned up expired session: {session_id}")
                
                await asyncio.sleep(300)  # Cleanup every 5 minutes
                
            except asyncio.CancelledError:
                break
            except Exception as e:
                logging.error(f"Session cleanup error: {e}")
                await asyncio.sleep(300)
    
    async def start(self):
        """Start the MCP HTTP transport server."""
        self._start_time = time.time()
        
        # Start session cleanup task
        self.cleanup_task = asyncio.create_task(self._cleanup_sessions())
        
        # Create and start Uvicorn server
        config = uvicorn.Config(
            self.app,
            host=self.host,
            port=self.port,
            log_level="info"
        )
        self.server = uvicorn.Server(config)
        
        logging.info(f"Starting MCP HTTP transport server on {self.host}:{self.port}")
        await self.server.serve()
    
    async def stop(self):
        """Stop the MCP HTTP transport server."""
        if self.cleanup_task:
            self.cleanup_task.cancel()
        
        if self.server:
            self.server.should_exit = True
        
        # Cleanup all sessions
        for session in self.sessions.values():
            session.active = False
        self.sessions.clear()
        
        logging.info("MCP HTTP transport server stopped")
    
    def get_transport_info(self) -> Dict[str, Any]:
        """Get transport information and status."""
        return {
            "type": "mcp-http-sse",
            "host": self.host,
            "port": self.port,
            "endpoints": {
                "mcp": "/mcp",
                "sse": "/sse",  # Standard mcp-remote endpoint
                "sse_session": "/mcp/sse/{session_id}",
                "session_delete": "/mcp/sessions/{session_id}",
                "health": "/health"
            },
            "active_sessions": len([s for s in self.sessions.values() if s.active]),
            "total_sessions": len(self.sessions),
            "message_size_limit": self.message_size_limit,
            "batch_timeout": self.batch_timeout,
            "session_timeout": self.session_timeout
        }