sse.py•7.76 kB
"""
Server-Sent Events (SSE) transport for web-based MCP clients.
"""
import asyncio
import json
import uuid
from typing import AsyncGenerator, Dict, Set
from aiohttp import web, WSMsgType
import logging
from .base import BaseTransport, TransportMessage
logger = logging.getLogger(__name__)
class SSETransport(BaseTransport):
"""SSE transport for web-based MCP clients."""
def __init__(self, config=None):
super().__init__(config)
self.app = web.Application()
self.connections: Dict[str, web.StreamResponse] = {}
self.pending_messages: Dict[str, asyncio.Queue] = {}
self.site = None
self.host = self.config.get("host", "localhost")
self.port = self.config.get("port", 49152)
self.connection_id = f"sse-{uuid.uuid4().hex[:8]}"
self.server = None # Reference to KatamariServer
# Setup routes
self.app.router.add_get('/mcp', self.handle_sse)
self.app.router.add_post('/mcp', self.handle_post)
self.app.router.add_get('/status', self.handle_status)
async def start(self) -> None:
"""Start SSE server."""
runner = web.AppRunner(self.app)
await runner.setup()
self.site = web.TCPSite(runner, self.host, self.port)
await self.site.start()
self.is_running = True
logger.info(f"SSE transport started on http://{self.host}:{self.port}")
async def stop(self) -> None:
"""Stop SSE server."""
self.is_running = False
if self.site:
await self.site.stop()
# Close all connections
for conn in self.connections.values():
try:
conn.write_eof()
await conn.drain()
except:
pass
self.connections.clear()
self.pending_messages.clear()
async def send_message(self, message: TransportMessage) -> None:
"""Send message via SSE to all connected clients."""
if not self.is_running:
return
data = f"data: {message.model_dump_json(exclude_none=True)}\n\n"
# Send to all connected clients
disconnected = []
for conn_id, response in self.connections.items():
try:
response.write(data.encode('utf-8'))
await response.drain()
except Exception as e:
logger.warning(f"Failed to send to client {conn_id}: {e}")
disconnected.append(conn_id)
# Clean up disconnected clients
for conn_id in disconnected:
self.connections.pop(conn_id, None)
self.pending_messages.pop(conn_id, None)
async def receive_messages(self) -> AsyncGenerator[TransportMessage, None]:
"""Receive messages from POST requests."""
if not self.is_running:
return
# This is handled by handle_post method
# Messages are queued and yielded from here
while self.is_running:
for queue in list(self.pending_messages.values()):
try:
if not queue.empty():
message = await queue.get()
yield message
except asyncio.QueueEmpty:
continue
await asyncio.sleep(0.1)
async def handle_sse(self, request):
"""Handle SSE connection."""
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Headers': 'Content-Type',
}
)
await response.prepare(request)
conn_id = f"sse-{uuid.uuid4().hex[:8]}"
self.connections[conn_id] = response
self.pending_messages[conn_id] = asyncio.Queue()
logger.info(f"SSE client connected: {conn_id}")
try:
# Send initial connection message
welcome_msg = TransportMessage(
id=uuid.uuid4().hex,
method="connection.established",
params={"connection_id": conn_id, "transport": "sse"}
)
await self.send_to_client(conn_id, welcome_msg)
# Keep connection alive
while self.is_running:
await response.write(b": keepalive\n\n")
await response.drain()
await asyncio.sleep(30)
except Exception as e:
logger.warning(f"SSE connection error: {e}")
finally:
self.connections.pop(conn_id, None)
self.pending_messages.pop(conn_id, None)
logger.info(f"SSE client disconnected: {conn_id}")
return response
async def handle_post(self, request):
"""Handle POST requests for bidirectional communication."""
try:
data = await request.json()
# Directly process MCP request if server is available
if self.server:
method = data.get("method")
params = data.get("params", {})
message_id = data.get("id")
if method == "tools/list":
response = await self.server._handle_mcp_request(method, params)
elif method == "tools/call":
response = await self.server._handle_mcp_request(method, params)
else:
response = {"error": f"Unknown method: {method}"}
# Add message ID to response
if message_id:
response["id"] = message_id
return web.json_response(response)
# Fallback to transport message queuing
message = TransportMessage(**data)
conn_id = request.headers.get('X-Connection-ID', f"post-{uuid.uuid4().hex[:8]}")
if conn_id not in self.pending_messages:
self.pending_messages[conn_id] = asyncio.Queue()
await self.pending_messages[conn_id].put(message)
return web.json_response({"status": "received"})
except Exception as e:
logger.error(f"POST request error: {e}")
return web.json_response({"error": str(e)}, status=400)
async def handle_status(self, request):
"""Handle status requests."""
return web.json_response({
"transport": "sse",
"running": self.is_running,
"connections": len(self.connections),
"host": self.host,
"port": self.port
})
async def send_to_client(self, conn_id: str, message: TransportMessage):
"""Send message to specific client."""
if conn_id in self.connections:
try:
data = f"data: {message.model_dump_json(exclude_none=True)}\n\n"
self.connections[conn_id].write(data.encode('utf-8'))
await self.connections[conn_id].drain()
except Exception as e:
logger.warning(f"Failed to send to client {conn_id}: {e}")
# Clean up disconnected client
self.connections.pop(conn_id, None)
self.pending_messages.pop(conn_id, None)
def set_server(self, server):
"""Set server reference for direct MCP handling."""
self.server = server