Skip to main content
Glama
async_utils.py15.5 kB
"""Async utilities and patterns for better performance and reliability.""" import asyncio import time from contextlib import asynccontextmanager from typing import Any, Callable, Dict, List, Optional, TypeVar from fastmcp import Client from .errors import ( ConnectionPoolError, ConnectionPoolExhaustedError, ServerConnectionError, ) from .logging import get_logger, with_correlation_id from .types import OperationStatus, Result logger = get_logger("AsyncUtils") T = TypeVar("T") class ConnectionPool: """Manages persistent connections to MCP servers.""" def __init__( self, max_connections_per_server: int = 5, connection_timeout: float = 30.0, idle_timeout: float = 300.0, # 5 minutes ): self.max_connections = max_connections_per_server self.connection_timeout = connection_timeout self.idle_timeout = idle_timeout # Pool storage: server_name -> Queue of available connections self._pools: Dict[str, asyncio.Queue] = {} self._active_connections: Dict[str, int] = {} self._connection_configs: Dict[str, Dict[str, Any]] = {} self._last_used: Dict[str, Dict[Client, float]] = {} # Cleanup task self._cleanup_task: Optional[asyncio.Task] = None self._shutdown = False async def initialize(self): """Initialize the connection pool.""" if not self._cleanup_task: self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections()) async def register_server( self, server_name: str, connection_config: Dict[str, Any] ): """Register a server configuration for connection pooling.""" self._connection_configs[server_name] = connection_config if server_name not in self._pools: self._pools[server_name] = asyncio.Queue(maxsize=self.max_connections) self._active_connections[server_name] = 0 self._last_used[server_name] = {} @asynccontextmanager async def get_connection(self, server_name: str): """Get a connection from the pool or create a new one.""" if server_name not in self._connection_configs: raise ConnectionPoolError(f"Server {server_name} not registered") connection = None try: connection = await self._acquire_connection(server_name) # Use the FastMCP client's context manager to ensure proper connection async with connection as connected_client: yield connected_client finally: if connection: await self._release_connection(server_name, connection) async def _acquire_connection(self, server_name: str) -> Client: """Acquire a connection from the pool.""" pool = self._pools[server_name] # Try to get an existing connection from the pool try: connection = pool.get_nowait() self._last_used[server_name][connection] = time.time() logger.debug(f"Reused connection for {server_name}") return connection except asyncio.QueueEmpty: pass # Check if we can create a new connection if self._active_connections[server_name] >= self.max_connections: # Wait for a connection to become available try: connection = await asyncio.wait_for( pool.get(), timeout=self.connection_timeout ) self._last_used[server_name][connection] = time.time() logger.debug(f"Got pooled connection for {server_name}") return connection except asyncio.TimeoutError: raise ConnectionPoolExhaustedError( f"No connections available for {server_name} within timeout" ) # Create a new connection connection = await self._create_connection(server_name) self._active_connections[server_name] += 1 self._last_used[server_name][connection] = time.time() logger.debug(f"Created new connection for {server_name}") return connection async def _create_connection(self, server_name: str) -> Client: """Create a new connection to the server.""" config = self._connection_configs[server_name] try: if config.get("type") == "stdio": # For stdio connections import shlex import urllib.parse command = config.get("command", "") args = config.get("args", []) if args: full_command = shlex.join([command] + args) encoded_command = urllib.parse.quote(full_command, safe="") client = Client(f"stdio://{encoded_command}") else: encoded_command = urllib.parse.quote(command, safe="") client = Client(f"stdio://{encoded_command}") else: # For HTTP/SSE connections url = config.get("url", "") client = Client(url) # Don't test the connection here - let the context manager handle it # The client will be connected when used in the context manager return client except Exception as e: raise ServerConnectionError( f"Failed to create connection to {server_name}", server_name=server_name, cause=e, ) async def _release_connection(self, server_name: str, connection: Client): """Release a connection back to the pool.""" pool = self._pools[server_name] try: # Put the connection back in the pool if there's space pool.put_nowait(connection) logger.debug(f"Released connection for {server_name}") except asyncio.QueueFull: # Pool is full, close the connection await self._close_connection(server_name, connection) async def _close_connection(self, server_name: str, connection: Client): """Close a connection and update counters.""" try: # FastMCP clients don't have a direct close method # They are managed through context managers # Just clean up our tracking pass except Exception as e: logger.warning(f"Error closing connection for {server_name}: {e}") finally: self._active_connections[server_name] -= 1 if connection in self._last_used[server_name]: del self._last_used[server_name][connection] async def _cleanup_idle_connections(self): """Periodically clean up idle connections.""" while not self._shutdown: try: await asyncio.sleep(60) # Check every minute current_time = time.time() for server_name, last_used_times in self._last_used.items(): pool = self._pools[server_name] connections_to_close = [] # Check for idle connections for connection, last_used in last_used_times.items(): if current_time - last_used > self.idle_timeout: connections_to_close.append(connection) # Close idle connections for connection in connections_to_close: try: # Remove from pool if present temp_connections = [] while not pool.empty(): try: conn = pool.get_nowait() if conn != connection: temp_connections.append(conn) except asyncio.QueueEmpty: break # Put back non-idle connections for conn in temp_connections: try: pool.put_nowait(conn) except asyncio.QueueFull: await self._close_connection(server_name, conn) # Close the idle connection await self._close_connection(server_name, connection) logger.debug(f"Closed idle connection for {server_name}") except Exception as e: logger.warning( f"Error during cleanup for {server_name}: {e}" ) except Exception as e: logger.error(f"Error in connection cleanup: {e}") async def close_all(self): """Close all connections and shutdown the pool.""" self._shutdown = True if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass # Close all connections for server_name, pool in self._pools.items(): connections = [] while not pool.empty(): try: connections.append(pool.get_nowait()) except asyncio.QueueEmpty: break for connection in connections: try: await self._close_connection(server_name, connection) except Exception as e: logger.warning(f"Error closing connection for {server_name}: {e}") logger.info("Connection pool closed") class RetryManager: """Manages retry logic with exponential backoff.""" def __init__( self, max_attempts: int = 3, base_delay: float = 1.0, max_delay: float = 60.0, exponential_base: float = 2.0, jitter: bool = True, ): self.max_attempts = max_attempts self.base_delay = base_delay self.max_delay = max_delay self.exponential_base = exponential_base self.jitter = jitter async def execute_with_retry( self, operation: Callable[[], Any], retryable_exceptions: tuple = (Exception,), operation_name: str = "operation", ) -> Result[Any]: """Execute an operation with retry logic.""" last_error = None for attempt in range(1, self.max_attempts + 1): try: logger.debug( f"Executing {operation_name}, attempt {attempt}/{self.max_attempts}" ) result = await operation() if attempt > 1: logger.info(f"{operation_name} succeeded on attempt {attempt}") return Result(status=OperationStatus.SUCCESS, data=result) except retryable_exceptions as e: last_error = e if attempt == self.max_attempts: logger.error( f"{operation_name} failed after {attempt} attempts", error=last_error, ) break # Calculate delay with exponential backoff delay = min( self.base_delay * (self.exponential_base ** (attempt - 1)), self.max_delay, ) # Add jitter to prevent thundering herd if self.jitter: import random delay *= 0.5 + random.random() * 0.5 logger.warning( f"{operation_name} failed on attempt {attempt}, retrying in {delay:.2f}s", error=e, ) await asyncio.sleep(delay) except Exception as e: # Non-retryable exception logger.error( f"{operation_name} failed with non-retryable error", error=e ) return Result( status=OperationStatus.FAILED, error=str(e), error_code=type(e).__name__, ) # All retries exhausted return Result( status=OperationStatus.FAILED, error=f"Operation failed after {self.max_attempts} attempts: {last_error}", error_code="RETRY_EXHAUSTED", ) @with_correlation_id async def run_concurrent_operations( operations: List[Callable[[], Any]], max_concurrent: int = 5, operation_names: Optional[List[str]] = None, ) -> List[Result[Any]]: """Run multiple operations concurrently with controlled concurrency.""" if operation_names is None: operation_names = [f"operation_{i}" for i in range(len(operations))] # Handle edge case where max_concurrent is 0 or negative - use unlimited concurrency if max_concurrent <= 0: max_concurrent = len(operations) or 1 semaphore = asyncio.Semaphore(max_concurrent) async def run_single_operation(operation: Callable, name: str) -> Result[Any]: async with semaphore: try: start_time = time.time() result = await operation() duration = (time.time() - start_time) * 1000 return Result( status=OperationStatus.SUCCESS, data=result, duration_ms=duration ) except Exception as e: logger.error(f"Operation {name} failed", error=e) return Result( status=OperationStatus.FAILED, error=str(e), error_code=type(e).__name__, ) tasks = [ run_single_operation(op, name) for op, name in zip(operations, operation_names) ] results = await asyncio.gather(*tasks, return_exceptions=True) # Convert any exceptions to Result objects final_results: List[Result[Any]] = [] for i, result in enumerate(results): if isinstance(result, Exception): final_results.append( Result( status=OperationStatus.FAILED, error=str(result), error_code=type(result).__name__, ) ) elif isinstance(result, Result): final_results.append(result) else: # This shouldn't happen, but handle it gracefully final_results.append( Result( status=OperationStatus.FAILED, error=f"Unexpected result type: {type(result)}", error_code="UNEXPECTED_RESULT_TYPE", ) ) return final_results # Global connection pool instance _connection_pool: Optional[ConnectionPool] = None async def get_connection_pool() -> ConnectionPool: """Get the global connection pool instance.""" global _connection_pool if _connection_pool is None: _connection_pool = ConnectionPool() await _connection_pool.initialize() return _connection_pool async def shutdown_connection_pool(): """Shutdown the global connection pool.""" global _connection_pool if _connection_pool: await _connection_pool.close_all() _connection_pool = None

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/lightfastai/lightfast-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server