Skip to main content
Glama
manager.py13.5 kB
"""Transport connection manager for MCP protocol""" import asyncio from typing import Dict, List, Optional, Any, Union, Callable from datetime import datetime from contextlib import asynccontextmanager from .base import Transport, TransportError, ConnectionState from .stdio import StdioTransport from .process_stdio import ProcessStdioTransport from .sse import SSETransport from ..utils.logging import get_logger from ..utils.notifications import NotificationCenter, NotificationType logger = get_logger(__name__) class TransportManager: """Manages transport connections and routing Supports multiple simultaneous transport connections and provides unified interface for message handling. """ def __init__(self, notification_center: Optional[NotificationCenter] = None): self._transports: Dict[str, Transport] = {} self._primary_transport: Optional[str] = None self._message_handlers: Dict[str, Callable] = {} self._notification_center = notification_center or NotificationCenter() self._stats = { "total_connections": 0, "active_connections": 0, "total_messages_sent": 0, "total_messages_received": 0, "total_errors": 0 } async def add_stdio_transport(self, name: str = "stdio", **kwargs) -> StdioTransport: """Add a STDIO transport""" if name in self._transports: raise TransportError(f"Transport '{name}' already exists") transport = StdioTransport(name=name, **kwargs) await self._register_transport(name, transport) return transport async def add_process_stdio_transport(self, name: str, command: str, args: Optional[List[str]] = None, env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None) -> ProcessStdioTransport: """Add a process-based STDIO transport""" if name in self._transports: raise TransportError(f"Transport '{name}' already exists") transport = ProcessStdioTransport( command=command, args=args, env=env, cwd=cwd, name=name ) await self._register_transport(name, transport) return transport async def add_sse_transport(self, name: str, base_url: str, **kwargs) -> SSETransport: """Add an SSE transport""" if name in self._transports: raise TransportError(f"Transport '{name}' already exists") transport = SSETransport(base_url=base_url, name=name, **kwargs) await self._register_transport(name, transport) return transport async def _register_transport(self, name: str, transport: Transport) -> None: """Register a transport with the manager""" # Set up transport handlers transport.on_error(self._handle_transport_error) transport.on_connect(lambda: self._handle_transport_connect(name)) transport.on_close(lambda: self._handle_transport_close(name)) # Register message handlers for method, handler in self._message_handlers.items(): transport.on_message(method, handler) self._transports[name] = transport # Set as primary if first transport if self._primary_transport is None: self._primary_transport = name logger.info(f"Registered transport: {name}") async def remove_transport(self, name: str) -> None: """Remove a transport""" if name not in self._transports: raise TransportError(f"Transport '{name}' not found") transport = self._transports[name] # Disconnect if connected if transport.state == ConnectionState.CONNECTED: await transport.disconnect() del self._transports[name] # Update primary if needed if self._primary_transport == name: self._primary_transport = next(iter(self._transports), None) logger.info(f"Removed transport: {name}") async def connect(self, name: Optional[str] = None) -> None: """Connect transport(s)""" if name: # Connect specific transport if name not in self._transports: raise TransportError(f"Transport '{name}' not found") await self._transports[name].connect() else: # Connect all transports tasks = [t.connect() for t in self._transports.values()] if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def disconnect(self, name: Optional[str] = None) -> None: """Disconnect transport(s)""" if name: # Disconnect specific transport if name not in self._transports: raise TransportError(f"Transport '{name}' not found") await self._transports[name].disconnect() else: # Disconnect all transports tasks = [t.disconnect() for t in self._transports.values()] if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def send_message(self, message: Dict[str, Any], transport: Optional[str] = None) -> None: """Send a message through transport(s)""" if transport: # Send through specific transport if transport not in self._transports: raise TransportError(f"Transport '{transport}' not found") await self._transports[transport].send_message(message) else: # Send through primary transport if not self._primary_transport: raise TransportError("No primary transport available") await self._transports[self._primary_transport].send_message(message) self._stats["total_messages_sent"] += 1 async def broadcast_message(self, message: Dict[str, Any]) -> Dict[str, Any]: """Broadcast message to all connected transports""" results = {} for name, transport in self._transports.items(): if transport.state == ConnectionState.CONNECTED: try: await transport.send_message(message) results[name] = {"success": True} except Exception as e: results[name] = {"success": False, "error": str(e)} return results async def request(self, method: str, params: Optional[Dict[str, Any]] = None, transport: Optional[str] = None, timeout: Optional[float] = None) -> Any: """Send request and wait for response""" if transport: if transport not in self._transports: raise TransportError(f"Transport '{transport}' not found") return await self._transports[transport].request(method, params, timeout) else: if not self._primary_transport: raise TransportError("No primary transport available") return await self._transports[self._primary_transport].request(method, params, timeout) async def notify(self, method: str, params: Optional[Dict[str, Any]] = None, transport: Optional[str] = None) -> None: """Send notification""" if transport: if transport not in self._transports: raise TransportError(f"Transport '{transport}' not found") await self._transports[transport].notify(method, params) else: if not self._primary_transport: raise TransportError("No primary transport available") await self._transports[self._primary_transport].notify(method, params) def on_message(self, method: str, handler: Callable) -> None: """Register a message handler for all transports""" self._message_handlers[method] = handler # Register with existing transports for transport in self._transports.values(): transport.on_message(method, handler) async def receive_messages(self, transport: Optional[str] = None): """Receive messages from transport(s)""" if transport: # Receive from specific transport if transport not in self._transports: raise TransportError(f"Transport '{transport}' not found") async for message in self._transports[transport].receive_messages(): self._stats["total_messages_received"] += 1 yield (transport, message) else: # Receive from all transports queues = { name: asyncio.create_task(self._queue_messages(name, t)) for name, t in self._transports.items() if t.state == ConnectionState.CONNECTED } while queues: done, pending = await asyncio.wait( queues.values(), return_when=asyncio.FIRST_COMPLETED ) for task in done: try: name, message = await task if message: self._stats["total_messages_received"] += 1 yield (name, message) # Create new task for this transport if name in self._transports: queues[name] = asyncio.create_task( self._queue_messages(name, self._transports[name]) ) except Exception as e: logger.error(f"Error receiving message: {e}") async def _queue_messages(self, name: str, transport: Transport): """Queue messages from a transport""" async for message in transport.receive_messages(): return (name, message) return (name, None) def set_primary_transport(self, name: str) -> None: """Set the primary transport""" if name not in self._transports: raise TransportError(f"Transport '{name}' not found") self._primary_transport = name logger.info(f"Set primary transport: {name}") def get_transport(self, name: str) -> Optional[Transport]: """Get a specific transport""" return self._transports.get(name) def list_transports(self) -> List[Dict[str, Any]]: """List all transports with their status""" return [ { "name": name, "type": transport.__class__.__name__, "state": transport.state.value, "is_primary": name == self._primary_transport, "stats": transport.get_stats() } for name, transport in self._transports.items() ] async def _handle_transport_error(self, error: Exception) -> None: """Handle transport error""" self._stats["total_errors"] += 1 await self._notification_center.notify( NotificationType.ERROR, f"Transport error: {error}", {"error": str(error)} ) async def _handle_transport_connect(self, name: str) -> None: """Handle transport connection""" self._stats["total_connections"] += 1 self._stats["active_connections"] = sum( 1 for t in self._transports.values() if t.state == ConnectionState.CONNECTED ) await self._notification_center.notify( NotificationType.LIFECYCLE, f"Transport connected: {name}", {"transport": name} ) async def _handle_transport_close(self, name: str) -> None: """Handle transport disconnection""" self._stats["active_connections"] = sum( 1 for t in self._transports.values() if t.state == ConnectionState.CONNECTED ) await self._notification_center.notify( NotificationType.LIFECYCLE, f"Transport disconnected: {name}", {"transport": name} ) @asynccontextmanager async def session(self, transports: Optional[List[str]] = None): """Context manager for transport session""" # Connect specified transports or all if transports: for name in transports: await self.connect(name) else: await self.connect() try: yield self finally: # Disconnect if transports: for name in transports: await self.disconnect(name) else: await self.disconnect() def get_stats(self) -> Dict[str, Any]: """Get manager statistics""" return { **self._stats, "transports": self.list_transports() }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/krzemienski/shannon-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server