manager.py•8.42 kB
"""
Transport manager for handling multiple transport protocols.
"""
import asyncio
import logging
from typing import Dict, List, Optional, Any, Union
import uuid
from .base import BaseTransport, TransportMessage
from .sse import SSETransport
from .websocket import WebSocketTransport
logger = logging.getLogger(__name__)
class TransportManager:
"""Manages multiple transport protocols for Katamari MCP."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
self.transports: Dict[str, BaseTransport] = {}
self.active_transports: List[str] = []
self.message_handlers = {}
self.session_store = {}
self.is_running = False
# Initialize transports based on config
self._initialize_transports()
def _initialize_transports(self):
"""Initialize available transports based on configuration."""
transport_config = self.config.get("transports", {})
# Enable SSE if configured
if transport_config.get("sse", {}).get("enabled", False):
self.transports["sse"] = SSETransport(
transport_config.get("sse", {})
)
# Enable WebSocket if configured
if transport_config.get("websocket", {}).get("enabled", False):
try:
self.transports["websocket"] = WebSocketTransport(
transport_config.get("websocket", {})
)
except ImportError as e:
logger.warning(f"WebSocket transport not available: {e}")
async def start(self, transport_names: Optional[List[str]] = None):
"""Start specified transports or all available transports."""
if transport_names is None:
transport_names = list(self.transports.keys())
for name in transport_names:
if name in self.transports and name not in self.active_transports:
try:
await self.transports[name].start()
self.active_transports.append(name)
logger.info(f"Started transport: {name}")
except Exception as e:
logger.error(f"Failed to start transport {name}: {e}")
self.is_running = True
# Start message processing loop
asyncio.create_task(self._message_processing_loop())
async def stop(self):
"""Stop all active transports."""
self.is_running = False
for name in self.active_transports:
try:
await self.transports[name].stop()
logger.info(f"Stopped transport: {name}")
except Exception as e:
logger.error(f"Error stopping transport {name}: {e}")
self.active_transports.clear()
def register_handler(self, method: str, handler):
"""Register a message handler for all transports."""
self.message_handlers[method] = handler
# Register with all transports
for transport in self.transports.values():
transport.register_handler(method, handler)
async def send_message(self, message: TransportMessage,
transport_names: Optional[List[str]] = None):
"""Send message via specified transports or all active transports."""
if transport_names is None:
transport_names = self.active_transports
for name in transport_names:
if name in self.active_transports:
try:
await self.transports[name].send_message(message)
except Exception as e:
logger.error(f"Failed to send via {name}: {e}")
async def _message_processing_loop(self):
"""Process messages from all transports."""
message_tasks = {}
for name in self.active_transports:
transport = self.transports[name]
task = asyncio.create_task(
self._process_transport_messages(name, transport)
)
message_tasks[name] = task
try:
await asyncio.gather(*message_tasks.values(), return_exceptions=True)
except Exception as e:
logger.error(f"Message processing error: {e}")
async def _process_transport_messages(self, name: str, transport: BaseTransport):
"""Process messages from a specific transport."""
try:
async for message in transport.receive_messages():
# Handle message through registered handlers
response = await transport.handle_message(message)
# Send response back through the same transport
if response.result or response.error:
await transport.send_message(response)
except Exception as e:
logger.error(f"Error processing messages from {name}: {e}")
def get_transport_status(self) -> Dict[str, Any]:
"""Get status of all transports."""
status = {
"running": self.is_running,
"active_transports": self.active_transports,
"available_transports": list(self.transports.keys()),
"transports": {}
}
for name, transport in self.transports.items():
transport_info = {
"running": name in self.active_transports,
"connection_id": transport.connection_id
}
# Add transport-specific info
if hasattr(transport, 'get_connection_info'):
transport_info.update(transport.get_connection_info())
elif hasattr(transport, 'host') and hasattr(transport, 'port'):
transport_info.update({
"host": transport.host,
"port": transport.port
})
status["transports"][name] = transport_info
return status
def get_connection_info(self, transport_name: Optional[str] = None) -> Dict[str, Any]:
"""Get connection information for specific transport or all transports."""
if transport_name and transport_name in self.transports:
transport = self.transports[transport_name]
if hasattr(transport, 'get_connection_info'):
return transport.get_connection_info()
else:
return {
"transport": transport_name,
"connection_id": transport.connection_id,
"running": transport_name in self.active_transports
}
else:
return self.get_transport_status()
async def create_session(self, client_info: Optional[Dict[str, Any]] = None) -> str:
"""Create a new session for tracking client interactions."""
session_id = f"session-{uuid.uuid4().hex[:12]}"
self.session_store[session_id] = {
"created_at": asyncio.get_event_loop().time(),
"client_info": client_info or {},
"message_count": 0,
"last_activity": asyncio.get_event_loop().time()
}
return session_id
async def update_session(self, session_id: str, activity: Dict[str, Any]):
"""Update session activity."""
if session_id in self.session_store:
self.session_store[session_id]["last_activity"] = asyncio.get_event_loop().time()
self.session_store[session_id]["message_count"] += 1
self.session_store[session_id].update(activity)
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session information."""
return self.session_store.get(session_id)
def cleanup_expired_sessions(self, max_age_seconds: int = 3600):
"""Clean up expired sessions."""
current_time = asyncio.get_event_loop().time()
expired_sessions = []
for session_id, session in self.session_store.items():
if current_time - session["last_activity"] > max_age_seconds:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.session_store[session_id]
return len(expired_sessions)