sse_server.py•9.52 kB
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Roger Gujord
# https://github.com/gujord/OpenAPI-MCP
import asyncio
import json
import logging
import signal
import sys
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional
from starlette.applications import Starlette
from starlette.responses import StreamingResponse, JSONResponse, Response
from starlette.routing import Route
from starlette.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
import uvicorn
try:
    from .sse_handler import SSEManager, SSEEvent, SSEEventType
    from .exceptions import RequestExecutionError
except ImportError:
    from sse_handler import SSEManager, SSEEvent, SSEEventType
    from exceptions import RequestExecutionError
class SSEHTTPServer:
    """HTTP server for serving SSE endpoints."""
    
    def __init__(self, sse_manager: SSEManager, host: str = "127.0.0.1", port: int = 8000):
        self.sse_manager = sse_manager
        self.host = host
        self.port = port
        self.app = None
        self.server = None
        self._shutdown_event = asyncio.Event()
    
    def create_app(self) -> Starlette:
        """Create the Starlette application."""
        
        @asynccontextmanager
        async def lifespan(app):
            """Application lifespan manager."""
            logging.info(f"SSE HTTP Server starting on {self.host}:{self.port}")
            # Start cleanup task
            cleanup_task = asyncio.create_task(self._cleanup_loop())
            try:
                yield
            finally:
                logging.info("SSE HTTP Server shutting down")
                cleanup_task.cancel()
                # Disconnect all connections
                for connection in list(self.sse_manager.connections.values()):
                    await connection.disconnect()
        
        # Define routes
        routes = [
            Route("/sse/stream/{connection_id}", self.sse_stream_endpoint),
            Route("/sse/connections", self.sse_connections_endpoint),
            Route("/sse/health", self.health_endpoint),
            Route("/sse/broadcast", self.broadcast_endpoint, methods=["POST"]),
        ]
        
        app = Starlette(routes=routes, lifespan=lifespan)
        
        # Add CORS middleware
        app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )
        
        return app
    
    async def sse_stream_endpoint(self, request):
        """SSE streaming endpoint."""
        connection_id = request.path_params["connection_id"]
        
        # Get the connection
        connection = self.sse_manager.connections.get(connection_id)
        if not connection:
            return JSONResponse(
                {"error": f"Connection {connection_id} not found"}, 
                status_code=404
            )
        
        if not connection.connected:
            return JSONResponse(
                {"error": f"Connection {connection_id} already disconnected"}, 
                status_code=410
            )
        
        async def event_generator():
            """Generate SSE events for the client."""
            try:
                async for event_data in connection.event_stream():
                    yield event_data
            except Exception as e:
                logging.error(f"SSE stream error for {connection_id}: {e}")
                # Send error event
                error_event = SSEEvent(
                    type=SSEEventType.ERROR,
                    data={"error": str(e)}
                )
                yield error_event.to_sse_format()
            finally:
                # Clean up connection
                await self.sse_manager.remove_connection(connection_id)
        
        return EventSourceResponse(event_generator())
    
    async def sse_connections_endpoint(self, request):
        """Get information about active SSE connections."""
        connections_info = []
        
        for connection_id, connection in self.sse_manager.connections.items():
            connections_info.append({
                "connection_id": connection_id,
                "connected": connection.connected,
                "last_heartbeat": connection.last_heartbeat,
                "heartbeat_interval": connection.heartbeat_interval
            })
        
        return JSONResponse({
            "active_connections": len(connections_info),
            "connections": connections_info
        })
    
    async def health_endpoint(self, request):
        """Health check endpoint."""
        return JSONResponse({
            "status": "healthy",
            "active_connections": self.sse_manager.get_connection_count(),
            "server": "SSE HTTP Server",
            "version": "1.0.0"
        })
    
    async def broadcast_endpoint(self, request):
        """Broadcast a message to all connected clients."""
        try:
            data = await request.json()
            
            event = SSEEvent(
                type=SSEEventType.DATA,
                data=data.get("data", {}),
                id=data.get("id")
            )
            
            await self.sse_manager.broadcast_to_all(event)
            
            return JSONResponse({
                "success": True,
                "broadcasted_to": self.sse_manager.get_connection_count(),
                "message": "Event broadcasted successfully"
            })
            
        except Exception as e:
            logging.error(f"Broadcast error: {e}")
            return JSONResponse({
                "error": str(e)
            }, status_code=400)
    
    async def _cleanup_loop(self):
        """Periodic cleanup of stale connections."""
        while True:
            try:
                await asyncio.sleep(60)  # Run every minute
                await self.sse_manager.cleanup_stale_connections()
            except asyncio.CancelledError:
                break
            except Exception as e:
                logging.error(f"Cleanup error: {e}")
    
    async def start_server(self):
        """Start the SSE HTTP server."""
        self.app = self.create_app()
        
        config = uvicorn.Config(
            self.app,
            host=self.host,
            port=self.port,
            log_level="info",
            access_log=True
        )
        
        self.server = uvicorn.Server(config)
        
        # Set up signal handlers
        def signal_handler(signum, frame):
            logging.info(f"Received signal {signum}, shutting down...")
            self._shutdown_event.set()
        
        signal.signal(signal.SIGINT, signal_handler)
        signal.signal(signal.SIGTERM, signal_handler)
        
        # Start server
        try:
            await self.server.serve()
        except Exception as e:
            logging.error(f"Server error: {e}")
        finally:
            await self._shutdown()
    
    async def _shutdown(self):
        """Shutdown the server gracefully."""
        if self.server:
            self.server.should_exit = True
        
        # Disconnect all SSE connections
        for connection in list(self.sse_manager.connections.values()):
            await connection.disconnect()
        
        logging.info("SSE HTTP Server shutdown complete")
    
    def run(self):
        """Run the server (blocking)."""
        try:
            asyncio.run(self.start_server())
        except KeyboardInterrupt:
            logging.info("Server interrupted by user")
        except Exception as e:
            logging.error(f"Server failed: {e}")
            sys.exit(1)
class SSEServerManager:
    """Manages the SSE server lifecycle alongside the MCP server."""
    
    def __init__(self, sse_manager: SSEManager, host: str = "127.0.0.1", port: int = 8000):
        self.sse_manager = sse_manager
        self.sse_server = SSEHTTPServer(sse_manager, host, port)
        self.server_task = None
        self.running = False
    
    async def start(self):
        """Start the SSE server in the background."""
        if not self.running:
            self.server_task = asyncio.create_task(self.sse_server.start_server())
            self.running = True
            logging.info(f"SSE server started on {self.sse_server.host}:{self.sse_server.port}")
    
    async def stop(self):
        """Stop the SSE server."""
        if self.running and self.server_task:
            self.server_task.cancel()
            await self.sse_server._shutdown()
            self.running = False
            logging.info("SSE server stopped")
    
    def get_stream_url(self, connection_id: str) -> str:
        """Get the SSE stream URL for a connection."""
        return f"http://{self.sse_server.host}:{self.sse_server.port}/sse/stream/{connection_id}"
    
    def get_connections_url(self) -> str:
        """Get the SSE connections info URL."""
        return f"http://{self.sse_server.host}:{self.sse_server.port}/sse/connections"
    
    def get_health_url(self) -> str:
        """Get the health check URL."""
        return f"http://{self.sse_server.host}:{self.sse_server.port}/sse/health"
def create_sse_enabled_server(
    sse_manager: SSEManager,
    host: str = "127.0.0.1", 
    port: int = 8000
) -> SSEServerManager:
    """Factory function to create an SSE-enabled server manager."""
    return SSEServerManager(sse_manager, host, port)