# -*- coding: utf-8 -*-
"""
MCP Session Pool Implementation.
Provides session pooling for MCP ClientSessions to reduce per-request overhead.
Sessions are isolated per user/tenant via identity hashing to prevent session collision.
Performance Impact:
- Without pooling: 20-23ms per tool call (new session each time)
- With pooling: 1-2ms per tool call (10-20x improvement)
Security:
Sessions are isolated by (url, identity_hash, transport_type) to prevent:
- Cross-user session sharing
- Cross-tenant data leakage
- Authentication bypass
Copyright 2026
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti
"""
# flake8: noqa: DAR101, DAR201, DAR401
# Future
from __future__ import annotations
# Standard
import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from enum import Enum
import hashlib
import logging
import os
import re
import socket
import time
from typing import Any, Callable, Dict, Optional, Set, Tuple, TYPE_CHECKING
import uuid
# Third-Party
import anyio
import httpx
from mcp import ClientSession, McpError
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.session import RequestResponder
import mcp.types as mcp_types
import orjson
# First-Party
from mcpgateway.config import settings
from mcpgateway.utils.url_auth import sanitize_url_for_logging
# JSON-RPC standard error code for method not found
METHOD_NOT_FOUND = -32601
# Shared session-id validation (downstream MCP session IDs used for affinity).
# Intentionally strict: protects Redis key/channel construction and log lines.
_MCP_SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
# Worker ID for multi-worker session affinity
# Uses hostname + PID to be unique across Docker containers (each container has PID 1)
# and across gunicorn workers within the same container
WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
def _get_cleanup_timeout() -> float:
"""Get session cleanup timeout from config (lazy import to avoid circular deps).
This timeout controls how long to wait for session/transport __aexit__ calls
when closing sessions. It prevents CPU spin loops when internal tasks don't
respond to cancellation (anyio's _deliver_cancellation issue).
Returns:
Cleanup timeout in seconds (default: 5.0)
"""
try:
# Lazy import to avoid circular dependency during startup
return settings.mcp_session_pool_cleanup_timeout
except Exception:
return 5.0 # Fallback default
if TYPE_CHECKING:
# Standard
from collections.abc import AsyncIterator # pragma: no cover
logger = logging.getLogger(__name__)
class TransportType(Enum):
"""Supported MCP transport types."""
SSE = "sse"
STREAMABLE_HTTP = "streamablehttp"
@dataclass(eq=False) # eq=False makes instances hashable by object identity
class PooledSession:
"""A pooled MCP session with metadata for lifecycle management.
Note: eq=False is required because we store these in Sets for active session
tracking. This makes instances hashable by their object id (identity).
"""
session: ClientSession
transport_context: Any # The transport context manager (kept open)
url: str
transport_type: TransportType
headers: Dict[str, str] # Original headers (for reconnection)
identity_key: str # Identity hash component for headers
user_identity: str = "anonymous" # for user isolation
gateway_id: str = "" # Gateway ID for notification attribution
created_at: float = field(default_factory=time.time)
last_used: float = field(default_factory=time.time)
use_count: int = 0
_closed: bool = field(default=False, repr=False)
@property
def age_seconds(self) -> float:
"""Return session age in seconds.
Returns:
float: Session age in seconds since creation.
"""
return time.time() - self.created_at
@property
def idle_seconds(self) -> float:
"""Return seconds since last use.
Returns:
float: Seconds since last use of this session.
"""
return time.time() - self.last_used
@property
def is_closed(self) -> bool:
"""Return whether this session has been closed.
Returns:
bool: True if session is closed, False otherwise.
"""
return self._closed
def mark_closed(self) -> None:
"""Mark this session as closed."""
self._closed = True
# Type aliases
# Pool key includes transport type and gateway_id to prevent returning wrong transport for same URL
# and to ensure correct notification attribution when notifications are enabled
PoolKey = Tuple[str, str, str, str, str] # (user_identity_hash, url, identity_hash, transport_type, gateway_id)
# Session affinity mapping key: (mcp_session_id, url, transport_type, gateway_id)
SessionMappingKey = Tuple[str, str, str, str]
HttpxClientFactory = Callable[
[Optional[Dict[str, str]], Optional[httpx.Timeout], Optional[httpx.Auth]],
httpx.AsyncClient,
]
# Type alias for identity extractor callback
# Extracts stable identity from headers (e.g., decode JWT to get user_id)
IdentityExtractor = Callable[[Dict[str, str]], Optional[str]]
# Type alias for message handler factory
# Factory that creates message handlers given URL and optional gateway_id
# The handler receives ServerNotification, ServerRequest responders, or Exceptions
MessageHandlerFactory = Callable[
[str, Optional[str]], # (url, gateway_id)
Callable[
[RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception],
Any, # Coroutine
],
]
class MCPSessionPool: # pylint: disable=too-many-instance-attributes
"""
Pool of MCP ClientSessions keyed by (user_identity, server URL, identity hash, transport type, gateway_id).
Thread-Safety:
This pool is designed for asyncio concurrency. It uses asyncio.Lock
for synchronization, which is safe for coroutine-based concurrency
but NOT for multi-threaded access.
Session Isolation:
Sessions are isolated per user/tenant to prevent session collision.
The identity hash is derived from authentication headers ensuring
that different users never share MCP sessions.
Transport Isolation:
Sessions are also isolated by transport type (SSE vs STREAMABLE_HTTP).
The same URL with different transports will use separate pools.
Gateway Isolation:
Sessions are isolated by gateway_id for correct notification attribution.
When notifications are enabled, each gateway gets its own pooled sessions
even if they share the same URL and authentication.
Features:
- Session reuse across requests (10-20x latency improvement)
- Per-user/tenant session isolation (prevents session collision)
- Per-transport session isolation (prevents transport mismatch)
- TTL-based expiration with configurable lifetime
- Health checks on acquire for stale sessions
- Configurable pool size per URL+identity+transport
- Circuit breaker for failing endpoints
- Idle pool key eviction to prevent unbounded growth
- Custom identity extractor for rotating tokens (e.g., JWT decode)
- Metrics for monitoring (hits, misses, evictions)
- Graceful shutdown with close_all()
Usage:
pool = MCPSessionPool()
# Use as context manager for lifecycle management
async with pool:
pooled = await pool.acquire(url, headers)
try:
result = await pooled.session.call_tool("my_tool", {})
finally:
await pool.release(pooled)
# With custom identity extractor for JWT tokens:
def extract_user_id(headers: dict) -> str:
token = headers.get("Authorization", "").replace("Bearer ", "")
claims = jwt.decode(token, options={"verify_signature": False})
return claims.get("sub") or claims.get("user_id")
pool = MCPSessionPool(identity_extractor=extract_user_id)
"""
# Headers that contribute to session identity (case-insensitive)
DEFAULT_IDENTITY_HEADERS: frozenset[str] = frozenset(
[
"authorization",
"x-tenant-id",
"x-user-id",
"x-api-key",
"cookie",
"x-mcp-session-id",
]
)
def __init__(
self,
max_sessions_per_key: int = 10,
session_ttl_seconds: float = 300.0,
health_check_interval_seconds: float = 60.0,
acquire_timeout_seconds: float = 30.0,
session_create_timeout_seconds: float = 30.0,
circuit_breaker_threshold: int = 5,
circuit_breaker_reset_seconds: float = 60.0,
identity_headers: Optional[frozenset[str]] = None,
identity_extractor: Optional[IdentityExtractor] = None,
idle_pool_eviction_seconds: float = 600.0,
default_transport_timeout_seconds: float = 30.0,
health_check_methods: Optional[list[str]] = None,
health_check_timeout_seconds: float = 5.0,
message_handler_factory: Optional[MessageHandlerFactory] = None,
):
"""
Initialize the session pool.
Args:
max_sessions_per_key: Maximum pooled sessions per (URL, identity, transport).
session_ttl_seconds: Session TTL in seconds before forced expiration.
health_check_interval_seconds: Seconds of idle time before health check.
acquire_timeout_seconds: Timeout for waiting when pool is exhausted.
session_create_timeout_seconds: Timeout for creating new sessions.
circuit_breaker_threshold: Consecutive failures before circuit opens.
circuit_breaker_reset_seconds: Seconds before circuit breaker resets.
identity_headers: Headers that contribute to identity hash.
identity_extractor: Optional callback to extract stable identity from headers.
Use this when tokens rotate frequently (e.g., short-lived JWTs).
Should return a stable user/tenant ID string.
idle_pool_eviction_seconds: Evict empty pool keys after this many seconds of no use.
default_transport_timeout_seconds: Default timeout for transport connections.
health_check_methods: Ordered list of health check methods to try.
Options: ping, list_tools, list_prompts, list_resources, skip.
Default: ["ping", "skip"] (try ping, skip if unsupported).
health_check_timeout_seconds: Timeout for each health check attempt.
message_handler_factory: Optional factory for creating message handlers.
Called with (url, gateway_id) to create handlers for
each new session. Enables notification handling.
"""
# Configuration
self._max_sessions = max_sessions_per_key
self._session_ttl = session_ttl_seconds
self._health_check_interval = health_check_interval_seconds
self._acquire_timeout = acquire_timeout_seconds
self._session_create_timeout = session_create_timeout_seconds
self._circuit_breaker_threshold = circuit_breaker_threshold
self._circuit_breaker_reset = circuit_breaker_reset_seconds
self._identity_headers = identity_headers or self.DEFAULT_IDENTITY_HEADERS
self._identity_extractor = identity_extractor
self._idle_pool_eviction = idle_pool_eviction_seconds
self._default_transport_timeout = default_transport_timeout_seconds
self._health_check_methods = health_check_methods or ["ping", "skip"]
self._health_check_timeout = health_check_timeout_seconds
self._message_handler_factory = message_handler_factory
# State - protected by _global_lock for creation, per-key locks for access
self._global_lock = asyncio.Lock()
self._pools: Dict[PoolKey, asyncio.Queue[PooledSession]] = {}
self._active: Dict[PoolKey, Set[PooledSession]] = {}
self._locks: Dict[PoolKey, asyncio.Lock] = {}
self._semaphores: Dict[PoolKey, asyncio.Semaphore] = {}
self._pool_last_used: Dict[PoolKey, float] = {} # Track last use time per pool key
# Circuit breaker state
self._failures: Dict[str, int] = {} # url -> consecutive failure count
self._circuit_open_until: Dict[str, float] = {} # url -> timestamp
# Eviction throttling - only run eviction once per interval
self._last_eviction_run: float = 0.0
self._eviction_run_interval: float = 60.0 # Run eviction at most every 60 seconds
# Metrics
self._hits = 0
self._misses = 0
self._evictions = 0
self._health_check_failures = 0
self._circuit_breaker_trips = 0
self._pool_keys_evicted = 0
self._sessions_reaped = 0 # Sessions closed during background eviction
self._anonymous_identity_count = 0 # Count of requests with no identity headers
# Lifecycle
self._closed = False
# Pre-registered session mappings for session affinity
# Mapping from (mcp_session_id, url, transport_type, gateway_id) -> pool_key
# Set by broadcast() before acquire() is called to enable session affinity lookup
self._mcp_session_mapping: Dict[SessionMappingKey, PoolKey] = {}
self._mcp_session_mapping_lock = asyncio.Lock()
# Multi-worker session affinity via Redis pub/sub
# Track pending responses for forwarded RPC requests
self._rpc_listener_task: Optional[asyncio.Task[None]] = None
self._pending_responses: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
# Session affinity metrics
self._session_affinity_local_hits = 0
self._session_affinity_redis_hits = 0
self._session_affinity_misses = 0
self._forwarded_requests = 0
self._forwarded_request_failures = 0
self._forwarded_request_timeouts = 0
async def __aenter__(self) -> "MCPSessionPool":
"""Async context manager entry.
Returns:
MCPSessionPool: This pool instance.
"""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit - closes all sessions.
Args:
exc_type: Exception type if an exception was raised.
exc_val: Exception value if an exception was raised.
exc_tb: Exception traceback if an exception was raised.
"""
await self.close_all()
def _compute_identity_hash(self, headers: Optional[Dict[str, str]]) -> str:
"""
Compute a hash of identity-relevant headers.
This ensures sessions are isolated per user/tenant. Different users
with different Authorization headers will have different identity hashes
and thus separate session pools.
Identity resolution order:
1. Custom identity_extractor (if configured) - for rotating tokens like JWTs
2. x-mcp-session-id header (if present) - for session affinity, ensures
requests with the same downstream session ID get the same upstream
session even when JWT tokens rotate (different jti values)
3. Configured identity headers - fallback to hashing all identity headers
Args:
headers: Request headers dict.
Returns:
Identity hash string, or "anonymous" if no identity headers present.
"""
if not headers:
self._anonymous_identity_count += 1
logger.debug("Session pool identity collapsed to 'anonymous' (no headers provided). " + "Sessions will be shared. Ensure this is intentional for stateless MCP servers.")
return "anonymous"
# Try custom identity extractor first (for rotating tokens like JWTs)
if self._identity_extractor:
try:
extracted = self._identity_extractor(headers)
if extracted:
return hashlib.sha256(extracted.encode()).hexdigest()
except Exception as e:
logger.debug(f"Identity extractor failed, falling back to header hash: {e}")
# Normalize headers for case-insensitive lookup
headers_lower = {k.lower(): v for k, v in headers.items()}
# Session affinity: prioritize x-mcp-session-id for stable identity
# When present, use ONLY the session ID for identity hash. This ensures
# requests with the same downstream session ID get the same upstream session,
# even when JWT tokens rotate (different jti values per request).
if settings.mcpgateway_session_affinity_enabled:
session_id = headers_lower.get("x-mcp-session-id")
if session_id:
logger.debug(f"Using x-mcp-session-id for session affinity: {session_id[:8]}...")
return hashlib.sha256(session_id.encode()).hexdigest()
# Fallback: extract identity from configured headers
identity_parts = []
for header in sorted(self._identity_headers):
if header in headers_lower:
identity_parts.append(f"{header}:{headers_lower[header]}")
if not identity_parts:
self._anonymous_identity_count += 1
logger.debug(
"Session pool identity collapsed to 'anonymous' (no identity headers found). " + "Expected headers: %s. Sessions will be shared.",
list(self._identity_headers),
)
return "anonymous"
# Create a stable, deterministic hash using JSON serialization
# Prevents delimiter-collision or injection issues present in string joining
serialized_identity = orjson.dumps(identity_parts)
return hashlib.sha256(serialized_identity).hexdigest()
def _make_pool_key(
self,
url: str,
headers: Optional[Dict[str, str]],
transport_type: TransportType,
user_identity: str,
gateway_id: Optional[str] = None,
) -> PoolKey:
"""Create composite pool key from URL, identity, transport type, user identity, and gateway_id.
Including gateway_id ensures correct notification attribution when multiple gateways
share the same URL/auth. Sessions are isolated per gateway for proper event routing.
"""
identity_hash = self._compute_identity_hash(headers)
# Anonymize user identity by hashing it (unless it's commonly "anonymous")
# Use full hash for collision resistance - truncate only for display in logs/metrics
if user_identity == "anonymous":
user_hash = "anonymous"
else:
user_hash = hashlib.sha256(user_identity.encode()).hexdigest()
# Use empty string for None gateway_id to maintain consistent key type
gw_id = gateway_id or ""
return (user_hash, url, identity_hash, transport_type.value, gw_id)
async def _get_or_create_lock(self, pool_key: PoolKey) -> asyncio.Lock:
"""Get or create a lock for the given pool key (thread-safe)."""
async with self._global_lock:
if pool_key not in self._locks:
self._locks[pool_key] = asyncio.Lock()
return self._locks[pool_key]
async def _get_or_create_pool(self, pool_key: PoolKey) -> asyncio.Queue[PooledSession]:
"""Get or create a pool queue for the given key (thread-safe)."""
async with self._global_lock:
if pool_key not in self._pools:
self._pools[pool_key] = asyncio.Queue(maxsize=self._max_sessions)
self._active[pool_key] = set()
self._semaphores[pool_key] = asyncio.Semaphore(self._max_sessions)
return self._pools[pool_key]
def _is_circuit_open(self, url: str) -> bool:
"""Check if circuit breaker is open for a URL."""
if url not in self._circuit_open_until:
return False
if time.time() >= self._circuit_open_until[url]:
# Circuit breaker reset
del self._circuit_open_until[url]
self._failures[url] = 0
logger.info(f"Circuit breaker reset for {sanitize_url_for_logging(url)}")
return False
return True
def _record_failure(self, url: str) -> None:
"""Record a failure and potentially trip circuit breaker."""
self._failures[url] = self._failures.get(url, 0) + 1
if self._failures[url] >= self._circuit_breaker_threshold:
self._circuit_open_until[url] = time.time() + self._circuit_breaker_reset
self._circuit_breaker_trips += 1
logger.warning(f"Circuit breaker opened for {sanitize_url_for_logging(url)} after {self._failures[url]} failures. " f"Will reset in {self._circuit_breaker_reset}s")
def _record_success(self, url: str) -> None:
"""Record a success, resetting failure count."""
self._failures[url] = 0
@staticmethod
def is_valid_mcp_session_id(session_id: str) -> bool:
"""Validate downstream MCP session ID format for affinity.
Used for:
- Redis key construction (ownership + mapping)
- Pub/Sub channel naming
- Avoiding log spam / injection
"""
if not session_id:
return False
return bool(_MCP_SESSION_ID_PATTERN.match(session_id))
def _sanitize_redis_key_component(self, value: str) -> str:
"""Sanitize a value for use in Redis key construction.
Replaces any characters that could cause key collision or injection.
Args:
value: The value to sanitize.
Returns:
Sanitized value safe for Redis key construction.
"""
if not value:
return ""
# Replace problematic characters with underscores
return re.sub(r"[^a-zA-Z0-9_-]", "_", value)
def _session_mapping_redis_key(self, mcp_session_id: str, url: str, transport_type: str, gateway_id: str) -> str:
"""Compute a bounded Redis key for session mapping.
The URL is hashed to keep keys small and avoid special character issues.
"""
sanitized_session_id = self._sanitize_redis_key_component(mcp_session_id)
url_hash = hashlib.sha256(url.encode()).hexdigest()[:16]
return f"mcpgw:session_mapping:{sanitized_session_id}:{url_hash}:{transport_type}:{gateway_id}"
@staticmethod
def _pool_owner_key(mcp_session_id: str) -> str:
"""Return Redis key for session ownership tracking."""
return f"mcpgw:pool_owner:{mcp_session_id}"
async def register_session_mapping(
self,
mcp_session_id: str,
url: str,
gateway_id: str,
transport_type: str,
user_email: Optional[str] = None,
) -> None:
"""Pre-register session mapping for session affinity.
Called from respond() to set up mapping BEFORE acquire() is called.
This ensures acquire() can find the correct pool key for session affinity.
The mapping stores the relationship between an incoming MCP session ID
and the pool key that should be used for upstream connections. This
enables session affinity even when JWT tokens rotate (different jti values
per request).
For multi-worker deployments, the mapping is also stored in Redis with TTL
so that any worker can look it up during acquire().
Args:
mcp_session_id: The downstream MCP session ID from x-mcp-session-id header.
url: The upstream MCP server URL.
gateway_id: The gateway ID.
transport_type: The transport type (sse, streamablehttp).
user_email: The email of the authenticated user (or "system" for unauthenticated).
"""
if not settings.mcpgateway_session_affinity_enabled:
return
# Validate mcp_session_id to prevent Redis key injection
if not self.is_valid_mcp_session_id(mcp_session_id):
logger.warning(f"Invalid mcp_session_id format, skipping session mapping: {mcp_session_id[:20]}...")
return
# Use user email for user_identity, or "anonymous" if not provided
user_identity = user_email or "anonymous"
# Normalize gateway_id to empty string if None for consistent key matching
normalized_gateway_id = gateway_id or ""
mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type, normalized_gateway_id)
# Compute what the pool_key will be for this session
# Use mcp_session_id as the identity basis for affinity
identity_hash = hashlib.sha256(mcp_session_id.encode()).hexdigest()
# Hash user identity for privacy (unless it's "anonymous")
if user_identity == "anonymous":
user_hash = "anonymous"
else:
user_hash = hashlib.sha256(user_identity.encode()).hexdigest()
pool_key: PoolKey = (user_hash, url, identity_hash, transport_type, normalized_gateway_id)
# Store in local memory
async with self._mcp_session_mapping_lock:
self._mcp_session_mapping[mapping_key] = pool_key
logger.debug(f"Session affinity pre-registered (local): {mcp_session_id[:8]}... → {url}, user={user_identity}")
# Store in Redis for multi-worker support AND register ownership atomically
# Registering ownership HERE (during mapping) instead of in acquire() prevents
# a race condition where two workers could both start creating sessions before
# either registers ownership
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if redis:
redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type, normalized_gateway_id)
# Store pool_key as JSON for easy deserialization
pool_key_data = {
"user_hash": user_hash,
"url": url,
"identity_hash": identity_hash,
"transport_type": transport_type,
"gateway_id": normalized_gateway_id,
}
await redis.setex(redis_key, settings.mcpgateway_session_affinity_ttl, orjson.dumps(pool_key_data)) # TTL from config
# CRITICAL: Register ownership atomically with mapping.
# This claims ownership BEFORE any session creation attempt, preventing
# the race condition where two workers both start creating sessions
owner_key = self._pool_owner_key(mcp_session_id)
# Atomic claim with TTL (avoids the SETNX/EXPIRE crash window).
was_set = await redis.set(owner_key, WORKER_ID, nx=True, ex=settings.mcpgateway_session_affinity_ttl)
if was_set:
logger.debug(f"Session ownership claimed (SET NX): {mcp_session_id[:8]}... → worker {WORKER_ID}")
else:
# Another worker already claimed ownership
existing_owner = await redis.get(owner_key)
owner_id = existing_owner.decode() if isinstance(existing_owner, bytes) else existing_owner
logger.debug(f"Session ownership already claimed by {owner_id}: {mcp_session_id[:8]}...")
logger.debug(f"Session affinity pre-registered (Redis): {mcp_session_id[:8]}... TTL={settings.mcpgateway_session_affinity_ttl}s")
except Exception as e:
# Redis failure is non-fatal - local mapping still works for same-worker requests
logger.debug(f"Failed to store session mapping in Redis: {e}")
async def acquire(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
transport_type: TransportType = TransportType.STREAMABLE_HTTP,
httpx_client_factory: Optional[HttpxClientFactory] = None,
timeout: Optional[float] = None,
user_identity: Optional[str] = None,
gateway_id: Optional[str] = None,
) -> PooledSession:
"""
Acquire a session for the given URL, identity, and transport type.
Sessions are isolated by identity (derived from auth headers) AND
transport type. Returns an initialized, healthy session ready for tool calls.
Args:
url: The MCP server URL.
headers: Request headers (used for identity hashing and passed to server).
transport_type: The transport type (SSE or STREAMABLE_HTTP).
httpx_client_factory: Optional factory for creating httpx clients
(for custom SSL/timeout configuration).
timeout: Optional timeout in seconds for transport connection.
gateway_id: Optional gateway ID for notification handler context.
Returns:
PooledSession ready for use.
Raises:
asyncio.TimeoutError: If acquire times out waiting for available session.
RuntimeError: If pool is closed or circuit breaker is open.
Exception: If session creation fails.
"""
if self._closed:
raise RuntimeError("Session pool is closed")
if self._is_circuit_open(url):
raise RuntimeError(f"Circuit breaker open for {url}")
# Use default timeout if not provided
effective_timeout = timeout if timeout is not None else self._default_transport_timeout
user_id = user_identity or "anonymous"
pool_key: Optional[PoolKey] = None
# Check pre-registered mapping first (set by respond() for session affinity)
if settings.mcpgateway_session_affinity_enabled and headers:
headers_lower = {k.lower(): v for k, v in headers.items()}
mcp_session_id = headers_lower.get("x-mcp-session-id")
if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
normalized_gateway_id = gateway_id or ""
mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type.value, normalized_gateway_id)
# Check local memory first (fast path - same worker)
async with self._mcp_session_mapping_lock:
pool_key = self._mcp_session_mapping.get(mapping_key)
if pool_key:
self._session_affinity_local_hits += 1
logger.debug(f"Session affinity hit (local): {mcp_session_id[:8]}...")
# If not in local memory, check Redis (multi-worker support)
if pool_key is None:
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if redis:
redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type.value, normalized_gateway_id)
pool_key_data = await redis.get(redis_key)
if pool_key_data:
# Deserialize pool_key from JSON
data = orjson.loads(pool_key_data)
pool_key = (
data["user_hash"],
data["url"],
data["identity_hash"],
data["transport_type"],
data["gateway_id"],
)
# Cache in local memory for future requests
async with self._mcp_session_mapping_lock:
self._mcp_session_mapping[mapping_key] = pool_key
self._session_affinity_redis_hits += 1
logger.debug(f"Session affinity hit (Redis): {mcp_session_id[:8]}...")
except Exception as e:
logger.debug(f"Failed to check Redis for session mapping: {e}")
# Fallback to normal pool key computation
if pool_key is None:
self._session_affinity_misses += 1
pool_key = self._make_pool_key(url, headers, transport_type, user_id, gateway_id)
pool = await self._get_or_create_pool(pool_key)
# Update pool key last used time IMMEDIATELY after getting pool
# This prevents race with eviction removing keys between awaits
self._pool_last_used[pool_key] = time.time()
lock = await self._get_or_create_lock(pool_key)
# Guard semaphore access - eviction may have removed it between awaits
# If so, re-create the pool structures
if pool_key not in self._semaphores:
pool = await self._get_or_create_pool(pool_key)
self._pool_last_used[pool_key] = time.time()
semaphore = self._semaphores[pool_key]
# Throttled eviction - only run if enough time has passed (inline, not spawned)
await self._maybe_evict_idle_pool_keys()
# Try to get from pool first (quick path, no lock needed for queue get)
while True:
try:
pooled = pool.get_nowait()
except asyncio.QueueEmpty:
break
# Validate the session outside the lock
if await self._validate_session(pooled):
pooled.last_used = time.time()
pooled.use_count += 1
self._hits += 1
async with lock:
self._active[pool_key].add(pooled)
logger.debug(f"Pool hit for {sanitize_url_for_logging(url)} (identity={pool_key[2][:8]}, transport={transport_type.value})")
return pooled
# Session invalid, close it
await self._close_session(pooled)
self._evictions += 1
semaphore.release() # Free up a slot
# No valid session in pool - try to create one or wait
try:
# Use semaphore with timeout to limit concurrent sessions
acquired = await asyncio.wait_for(semaphore.acquire(), timeout=self._acquire_timeout)
if not acquired:
raise asyncio.TimeoutError("Failed to acquire session slot")
except asyncio.TimeoutError:
raise asyncio.TimeoutError(f"Timeout waiting for available session for {sanitize_url_for_logging(url)}") from None
# Create new session (semaphore acquired)
try:
# Verify we own this session before creating (prevents race condition)
# If another worker already claimed ownership, we should not create a new session
# Note: Ownership is registered atomically in register_session_mapping() using SETNX
if settings.mcpgateway_session_affinity_enabled and headers:
headers_lower = {k.lower(): v for k, v in headers.items()}
mcp_session_id = headers_lower.get("x-mcp-session-id")
if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
owner = await self._get_pool_session_owner(mcp_session_id)
if owner and owner != WORKER_ID:
# Another worker claimed ownership - should have been forwarded
# Release semaphore and raise to trigger forwarding
semaphore.release()
logger.warning(f"Session {mcp_session_id[:8]}... owned by worker {owner}, not us ({WORKER_ID})")
raise RuntimeError(f"Session owned by another worker: {owner}")
pooled = await asyncio.wait_for(
self._create_session(url, headers, transport_type, httpx_client_factory, effective_timeout, gateway_id),
timeout=self._session_create_timeout,
)
# Store identity components for key reconstruction
pooled.identity_key = pool_key[2]
pooled.user_identity = user_id
# Note: Ownership is now registered atomically in register_session_mapping()
# before acquire() is called, so we don't need to register it here
self._misses += 1
self._record_success(url)
async with lock:
self._active[pool_key].add(pooled)
logger.debug(f"Pool miss for {sanitize_url_for_logging(url)} - created new session (transport={transport_type.value})")
return pooled
except BaseException as e:
# Release semaphore on ANY failure (including CancelledError)
semaphore.release()
if not isinstance(e, asyncio.CancelledError):
self._record_failure(url)
logger.warning(f"Failed to create session for {sanitize_url_for_logging(url)}: {e}")
raise
async def release(self, pooled: PooledSession) -> None:
"""
Return a session to the pool for reuse.
Args:
pooled: The session to release.
"""
if pooled.is_closed:
logger.warning("Attempted to release already-closed session")
return
# Pool key includes transport type, user identity, and gateway_id
# Re-compute user hash from stored raw identity (full hash for collision resistance)
user_hash = "anonymous"
if pooled.user_identity != "anonymous":
user_hash = hashlib.sha256(pooled.user_identity.encode()).hexdigest()
pool_key = (user_hash, pooled.url, pooled.identity_key, pooled.transport_type.value, pooled.gateway_id)
lock = await self._get_or_create_lock(pool_key)
pool = await self._get_or_create_pool(pool_key)
async with lock:
# Update last-used FIRST to prevent eviction race:
# If eviction runs between removing from _active and putting back in pool,
# it would see key as idle + inactive and evict it. By updating last-used
# while still holding the lock and before removing from _active, we ensure
# eviction sees recent activity.
self._pool_last_used[pool_key] = time.time()
self._active.get(pool_key, set()).discard(pooled)
# Check if session should be returned to pool
if self._closed or pooled.age_seconds > self._session_ttl:
await self._close_session(pooled)
if pool_key in self._semaphores:
self._semaphores[pool_key].release()
if pooled.age_seconds > self._session_ttl:
self._evictions += 1
return
# Return to pool (pool may have been evicted in edge case, recreate if needed)
if pool_key not in self._pools:
pool = await self._get_or_create_pool(pool_key)
self._pool_last_used[pool_key] = time.time()
try:
pool.put_nowait(pooled)
logger.debug(f"Session returned to pool for {sanitize_url_for_logging(pooled.url)}")
except asyncio.QueueFull:
# Pool full (shouldn't happen with semaphore), close session
await self._close_session(pooled)
if pool_key in self._semaphores:
self._semaphores[pool_key].release()
async def _maybe_evict_idle_pool_keys(self) -> None:
"""
Reap stale sessions and evict idle pool keys.
This method is throttled - it only runs eviction if enough time has
passed since the last run (default: 60 seconds). This prevents:
- Unbounded task spawning on every acquire
- Lock contention under high load
Two-phase cleanup:
1. Close expired/stale sessions parked in idle pools (frees connections)
2. Evict pool keys that are now empty and have no active sessions
This prevents unbounded connection and pool key growth when using
rotating tokens (e.g., short-lived JWTs with unique identifiers).
"""
if self._closed:
return
now = time.time()
# Throttle: only run eviction once per interval
if now - self._last_eviction_run < self._eviction_run_interval:
return
self._last_eviction_run = now
# Collect sessions to close and keys to evict (minimize time holding lock)
sessions_to_close: list[PooledSession] = []
keys_to_evict: list[PoolKey] = []
async with self._global_lock:
for pool_key, last_used in list(self._pool_last_used.items()):
# Skip recently-used pools
if now - last_used < self._idle_pool_eviction:
continue
pool = self._pools.get(pool_key)
active = self._active.get(pool_key, set())
# Skip if there are active sessions (in use)
if active:
continue
if pool:
# Phase 1: Drain and collect expired/stale sessions from idle pools
while not pool.empty():
try:
session = pool.get_nowait()
# Close if expired OR idle too long (defense in depth)
if session.age_seconds > self._session_ttl or session.idle_seconds > self._idle_pool_eviction:
sessions_to_close.append(session)
# Release semaphore slot for this session
if pool_key in self._semaphores:
self._semaphores[pool_key].release()
else:
# Session still valid, put it back
pool.put_nowait(session)
break # Stop draining if we find a valid session
except asyncio.QueueEmpty:
break
# Phase 2: Evict pool key if now empty
if pool.empty():
keys_to_evict.append(pool_key)
# Remove evicted keys from all tracking dicts
for pool_key in keys_to_evict:
self._pools.pop(pool_key, None)
self._active.pop(pool_key, None)
self._locks.pop(pool_key, None)
self._semaphores.pop(pool_key, None)
self._pool_last_used.pop(pool_key, None)
self._pool_keys_evicted += 1
logger.debug(f"Evicted idle pool key: {pool_key[0][:8]}|{pool_key[1]}|{pool_key[2][:8]}")
# Close sessions outside the lock (I/O operations)
for session in sessions_to_close:
await self._close_session(session)
self._sessions_reaped += 1
logger.debug(f"Reaped stale session for {sanitize_url_for_logging(session.url)} (age={session.age_seconds:.1f}s)")
async def _validate_session(self, pooled: PooledSession) -> bool:
"""
Validate a session is still usable.
Checks TTL and performs health check if session is stale.
Args:
pooled: The session to validate.
Returns:
True if session is valid, False otherwise.
"""
if pooled.is_closed:
return False
# Check TTL
if pooled.age_seconds > self._session_ttl:
logger.debug(f"Session expired (age={pooled.age_seconds:.1f}s)")
return False
# Health check if stale
if pooled.idle_seconds > self._health_check_interval:
return await self._run_health_check_chain(pooled)
return True
async def _run_health_check_chain(self, pooled: PooledSession) -> bool:
"""
Run health check methods in configured order until one succeeds.
The health check chain allows configuring which methods to try and in what order.
This supports both modern servers (with ping support) and legacy servers
(that may only support list_tools or no health check at all).
Args:
pooled: The session to health check.
Returns:
True if any health check method succeeds, False if all fail.
"""
for method in self._health_check_methods:
try:
if method == "ping":
await asyncio.wait_for(pooled.session.send_ping(), timeout=self._health_check_timeout)
logger.debug(f"Health check passed: ping (url={sanitize_url_for_logging(pooled.url)})")
return True
if method == "list_tools":
await asyncio.wait_for(pooled.session.list_tools(), timeout=self._health_check_timeout)
logger.debug(f"Health check passed: list_tools (url={sanitize_url_for_logging(pooled.url)})")
return True
if method == "list_prompts":
await asyncio.wait_for(pooled.session.list_prompts(), timeout=self._health_check_timeout)
logger.debug(f"Health check passed: list_prompts (url={sanitize_url_for_logging(pooled.url)})")
return True
if method == "list_resources":
await asyncio.wait_for(pooled.session.list_resources(), timeout=self._health_check_timeout)
logger.debug(f"Health check passed: list_resources (url={sanitize_url_for_logging(pooled.url)})")
return True
if method == "skip":
logger.debug(f"Health check skipped per configuration (url={sanitize_url_for_logging(pooled.url)})")
return True
logger.warning(f"Unknown health check method '{method}', skipping")
continue
except McpError as e:
# METHOD_NOT_FOUND (-32601) means the method isn't supported - try next
if e.error.code == METHOD_NOT_FOUND:
logger.debug(f"Health check method '{method}' not supported by server, trying next")
continue
# Other MCP errors are real failures
logger.debug(f"Health check '{method}' failed with MCP error: {e}")
self._health_check_failures += 1
return False
except asyncio.TimeoutError:
logger.debug(f"Health check '{method}' timed out after {self._health_check_timeout}s, trying next")
continue
except Exception as e:
logger.debug(f"Health check '{method}' failed: {e}")
self._health_check_failures += 1
return False
# All methods failed or were unsupported
logger.warning(f"All health check methods failed or unsupported (methods={self._health_check_methods})")
self._health_check_failures += 1
return False
async def _create_session(
self,
url: str,
headers: Optional[Dict[str, str]],
transport_type: TransportType,
httpx_client_factory: Optional[HttpxClientFactory],
timeout: Optional[float] = None,
gateway_id: Optional[str] = None,
) -> PooledSession:
"""
Create a new initialized MCP session.
Args:
url: Server URL.
headers: Request headers.
transport_type: Transport type to use.
httpx_client_factory: Optional factory for httpx clients.
timeout: Optional timeout in seconds for transport connection.
gateway_id: Optional gateway ID for notification handler context.
Returns:
Initialized PooledSession.
Raises:
RuntimeError: If session creation or initialization fails.
asyncio.CancelledError: If cancelled during creation.
"""
# Merge headers with defaults
merged_headers = {"Accept": "application/json, text/event-stream"}
if headers:
merged_headers.update(headers)
# Strip gateway-internal session affinity headers before sending to upstream
# x-mcp-session-id is our internal representation, mcp-session-id is the MCP protocol header
# Neither should be forwarded to upstream servers
keys_to_remove = [k for k in merged_headers if k.lower() in ("x-mcp-session-id", "mcp-session-id")]
for k in keys_to_remove:
del merged_headers[k]
identity_key = self._compute_identity_hash(headers)
transport_ctx = None
session = None
success = False
try:
# Create transport context
if transport_type == TransportType.SSE:
if httpx_client_factory:
transport_ctx = sse_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout)
else:
transport_ctx = sse_client(url=url, headers=merged_headers, timeout=timeout)
# pylint: disable=unnecessary-dunder-call,no-member
streams = await transport_ctx.__aenter__() # Must call directly for manual lifecycle management
read_stream, write_stream = streams[0], streams[1]
else: # STREAMABLE_HTTP
if httpx_client_factory:
transport_ctx = streamablehttp_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout)
else:
transport_ctx = streamablehttp_client(url=url, headers=merged_headers, timeout=timeout)
# pylint: disable=unnecessary-dunder-call,no-member
read_stream, write_stream, _ = await transport_ctx.__aenter__() # Must call directly for manual lifecycle management
# Create message handler if factory is configured
message_handler = None
if self._message_handler_factory:
try:
message_handler = self._message_handler_factory(url, gateway_id)
logger.debug(f"Created message handler for session {sanitize_url_for_logging(url)} (gateway={gateway_id})")
except Exception as e:
logger.warning(f"Failed to create message handler for {sanitize_url_for_logging(url)}: {e}")
# Create and initialize session
session = ClientSession(read_stream, write_stream, message_handler=message_handler)
# pylint: disable=unnecessary-dunder-call
await session.__aenter__() # Must call directly for manual lifecycle management
await session.initialize()
logger.info(f"Created new MCP session for {sanitize_url_for_logging(url)} (transport={transport_type.value})")
success = True
return PooledSession(
session=session,
transport_context=transport_ctx,
url=url,
transport_type=transport_type,
headers=merged_headers,
identity_key=identity_key,
gateway_id=gateway_id or "",
)
except asyncio.CancelledError: # pylint: disable=try-except-raise
# Re-raise CancelledError after cleanup (handled in finally)
raise
except Exception as e:
raise RuntimeError(f"Failed to create MCP session for {url}: {e}") from e
finally:
# Clean up on ANY failure (Exception, CancelledError, etc.)
# Only clean up if we didn't succeed
# Use anyio.move_on_after instead of asyncio.wait_for to properly propagate
# cancellation through anyio's cancel scope system (prevents orphaned spinning tasks)
if not success:
cleanup_timeout = _get_cleanup_timeout()
if session is not None:
with anyio.move_on_after(cleanup_timeout):
try:
await session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
except Exception: # nosec B110 - Best effort cleanup on connection failure
pass
if transport_ctx is not None:
with anyio.move_on_after(cleanup_timeout):
try:
await transport_ctx.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
except Exception: # nosec B110 - Best effort cleanup on connection failure
pass
async def _close_session(self, pooled: PooledSession) -> None:
"""
Close a session and its transport.
Uses timeouts to prevent indefinite blocking if session/transport tasks
don't respond to cancellation. This prevents CPU spin loops in anyio's
_deliver_cancellation which can occur when async iterators or blocking
operations don't properly handle CancelledError.
Args:
pooled: The session to close.
"""
if pooled.is_closed:
return
pooled.mark_closed()
# Use anyio's move_on_after instead of asyncio.wait_for to properly propagate
# cancellation through anyio's cancel scope system. asyncio.wait_for() creates
# orphaned anyio tasks that keep spinning in _deliver_cancellation.
cleanup_timeout = _get_cleanup_timeout()
# Close session with anyio timeout
with anyio.move_on_after(cleanup_timeout) as session_scope:
try:
await pooled.session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
except Exception as e:
logger.debug(f"Error closing session: {e}")
if session_scope.cancelled_caught:
logger.warning(f"Session cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway")
# Close transport with anyio timeout
with anyio.move_on_after(cleanup_timeout) as transport_scope:
try:
await pooled.transport_context.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
except Exception as e:
logger.debug(f"Error closing transport: {e}")
if transport_scope.cancelled_caught:
logger.warning(f"Transport cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway")
logger.debug(f"Closed session for {sanitize_url_for_logging(pooled.url)} (uses={pooled.use_count})")
# Clean up pool_owner key in Redis for session affinity
if settings.mcpgateway_session_affinity_enabled and pooled.headers:
headers_lower = {k.lower(): v for k, v in pooled.headers.items()}
mcp_session_id = headers_lower.get("x-mcp-session-id")
if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
await self._cleanup_pool_session_owner(mcp_session_id)
async def _cleanup_pool_session_owner(self, mcp_session_id: str) -> None:
"""Clean up pool_owner key in Redis when session is closed.
Only deletes the key if this worker owns it (to prevent removing other workers' ownership).
Args:
mcp_session_id: The MCP session ID from x-mcp-session-id header.
"""
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if redis:
key = self._pool_owner_key(mcp_session_id)
# Only delete if we own it
owner = await redis.get(key)
if owner:
owner_id = owner.decode() if isinstance(owner, bytes) else owner
if owner_id == WORKER_ID:
await redis.delete(key)
logger.debug(f"Cleaned up pool session owner: {mcp_session_id[:8]}...")
except Exception as e:
# Cleanup failure is non-fatal
logger.debug(f"Failed to cleanup pool session owner in Redis: {e}")
async def close_all(self) -> None:
"""
Gracefully close all pooled and active sessions.
Should be called during application shutdown.
"""
self._closed = True
logger.info("Closing all pooled sessions...")
async with self._global_lock:
# Close all pooled sessions
for _pool_key, pool in list(self._pools.items()):
while not pool.empty():
try:
pooled = pool.get_nowait()
await self._close_session(pooled)
except asyncio.QueueEmpty:
break
# Close all active sessions
for _pool_key, active_set in list(self._active.items()):
for pooled in list(active_set):
await self._close_session(pooled)
self._pools.clear()
self._active.clear()
self._locks.clear()
self._semaphores.clear()
# Stop RPC listener if running
if self._rpc_listener_task and not self._rpc_listener_task.done():
self._rpc_listener_task.cancel()
try:
await self._rpc_listener_task
except asyncio.CancelledError:
pass
self._rpc_listener_task = None
logger.info("All sessions closed")
async def register_pool_session_owner(self, mcp_session_id: str) -> None:
"""Register this worker as owner of a pool session in Redis.
This enables multi-worker session affinity by tracking which worker owns
which pool session. When a request with x-mcp-session-id arrives at a
different worker, it can forward the request to the owner worker.
Note: This method is now primarily used for refreshing TTL on existing ownership.
Initial ownership is claimed atomically in register_session_mapping() using SETNX.
Args:
mcp_session_id: The MCP session ID from x-mcp-session-id header.
"""
if not settings.mcpgateway_session_affinity_enabled:
return
if not self.is_valid_mcp_session_id(mcp_session_id):
logger.debug("Invalid mcp_session_id for owner registration, skipping")
return
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if redis:
key = self._pool_owner_key(mcp_session_id)
# Do not steal ownership: only claim if missing, or refresh TTL if we already own.
# Lua keeps this atomic.
script = """
local cur = redis.call('GET', KEYS[1])
if not cur then
redis.call('SET', KEYS[1], ARGV[1], 'EX', ARGV[2])
return 1
end
if cur == ARGV[1] then
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 2
end
return 0
"""
ttl = int(settings.mcpgateway_session_affinity_ttl)
outcome = await redis.eval(script, 1, key, WORKER_ID, ttl)
logger.debug(f"Owner registration outcome={outcome} for session {mcp_session_id[:8]}...")
except Exception as e:
# Redis failure is non-fatal - single worker mode still works
logger.debug(f"Failed to register pool session owner in Redis: {e}")
async def _get_pool_session_owner(self, mcp_session_id: str) -> Optional[str]:
"""Get the worker ID that owns a pool session.
Args:
mcp_session_id: The MCP session ID from x-mcp-session-id header.
Returns:
The worker ID that owns this session, or None if not found or Redis unavailable.
"""
if not settings.mcpgateway_session_affinity_enabled:
return None
if not self.is_valid_mcp_session_id(mcp_session_id):
return None
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if redis:
key = self._pool_owner_key(mcp_session_id)
owner = await redis.get(key)
if owner:
decoded = owner.decode() if isinstance(owner, bytes) else owner
return decoded
except Exception as e:
logger.debug(f"Failed to get pool session owner from Redis: {e}")
return None
async def forward_request_to_owner(
self,
mcp_session_id: str,
request_data: Dict[str, Any],
timeout: Optional[float] = None,
) -> Optional[Dict[str, Any]]:
"""Forward RPC request to the worker that owns the pool session.
This method checks Redis to find which worker owns the pool session for
the given mcp_session_id. If owned by another worker, it forwards the
request via Redis pub/sub and waits for the response.
Args:
mcp_session_id: The MCP session ID from x-mcp-session-id header.
request_data: The RPC request data to forward.
timeout: Optional timeout in seconds (default from config).
Returns:
The response from the owner worker, or None if we own the session
(caller should execute locally) or if Redis is unavailable.
Raises:
asyncio.TimeoutError: If the forwarded request times out.
"""
if not settings.mcpgateway_session_affinity_enabled:
return None
if not self.is_valid_mcp_session_id(mcp_session_id):
return None
effective_timeout = timeout if timeout is not None else settings.mcpgateway_pool_rpc_forward_timeout
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if not redis:
return None # Execute locally - no Redis
# Check who owns this session
owner = await redis.get(self._pool_owner_key(mcp_session_id))
method = request_data.get("method", "unknown")
if not owner:
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | No owner → execute locally (new session)")
return None # No owner registered - execute locally (new session)
owner_id = owner.decode() if isinstance(owner, bytes) else owner
if owner_id == WORKER_ID:
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | We own it → execute locally")
return None # We own it - execute locally
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Owner: {owner_id} → forwarding")
# Forward to owner worker via pub/sub
response_id = str(uuid.uuid4())
response_channel = f"mcpgw:pool_rpc_response:{response_id}"
# Subscribe to response channel
pubsub = redis.pubsub()
await pubsub.subscribe(response_channel)
try:
# Prepare request with response channel
forward_data = {
"type": "rpc_forward",
**request_data,
"response_channel": response_channel,
"mcp_session_id": mcp_session_id,
}
# Publish request to owner's channel
await redis.publish(f"mcpgw:pool_rpc:{owner_id}", orjson.dumps(forward_data))
self._forwarded_requests += 1
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Published to worker {owner_id}")
# Wait for response
async with asyncio.timeout(effective_timeout):
while True:
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if msg and msg["type"] == "message":
return orjson.loads(msg["data"])
finally:
await pubsub.unsubscribe(response_channel)
except asyncio.TimeoutError:
self._forwarded_request_timeouts += 1
logger.warning(f"Timeout forwarding request to owner for session {mcp_session_id[:8]}...")
raise
except Exception as e:
self._forwarded_request_failures += 1
logger.debug(f"Error forwarding request to owner: {e}")
return None # Execute locally on error
async def start_rpc_listener(self) -> None:
"""Start listening for forwarded RPC and HTTP requests on this worker's channels.
This method subscribes to Redis pub/sub channels specific to this worker
and processes incoming forwarded requests from other workers:
- mcpgw:pool_rpc:{WORKER_ID} - for SSE transport JSON-RPC forwards
- mcpgw:pool_http:{WORKER_ID} - for Streamable HTTP request forwards
"""
if not settings.mcpgateway_session_affinity_enabled:
return
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if not redis:
logger.debug("Redis not available, RPC listener not started")
return
rpc_channel = f"mcpgw:pool_rpc:{WORKER_ID}"
http_channel = f"mcpgw:pool_http:{WORKER_ID}"
pubsub = redis.pubsub()
await pubsub.subscribe(rpc_channel, http_channel)
logger.info(f"RPC/HTTP listener started for worker {WORKER_ID} on channels: {rpc_channel}, {http_channel}")
while not self._closed:
try:
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if msg and msg["type"] == "message":
request = orjson.loads(msg["data"])
forward_type = request.get("type")
response_channel = request.get("response_channel")
if response_channel:
if forward_type == "rpc_forward":
# Execute forwarded RPC request for SSE transport
response = await self._execute_forwarded_request(request)
await redis.publish(response_channel, orjson.dumps(response))
logger.debug(f"Processed forwarded RPC request, response sent to {response_channel}")
elif forward_type == "http_forward":
# Execute forwarded HTTP request for Streamable HTTP transport
await self._execute_forwarded_http_request(request, redis)
else:
logger.warning(f"Unknown forward type: {forward_type}")
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"Error processing forwarded request: {e}")
await pubsub.unsubscribe(rpc_channel, http_channel)
logger.info(f"RPC/HTTP listener stopped for worker {WORKER_ID}")
except Exception as e:
logger.warning(f"RPC/HTTP listener failed: {e}")
async def _execute_forwarded_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Execute a forwarded RPC request locally via internal HTTP call.
This method handles RPC requests that were forwarded from another worker.
Instead of handling specific methods here, we make an internal HTTP call
to the local /rpc endpoint which reuses ALL existing method handling logic.
The x-forwarded-internally header prevents infinite forwarding loops.
Args:
request: The forwarded RPC request containing method, params, headers, req_id, etc.
Returns:
The JSON-RPC response from the local endpoint.
"""
try:
method = request.get("method")
params = request.get("params", {})
headers = request.get("headers", {})
req_id = request.get("req_id", 1)
mcp_session_id = request.get("mcp_session_id", "unknown")
session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Received forwarded request, executing locally")
# Make internal HTTP call to local /rpc endpoint
# This reuses ALL existing method handling logic without duplication
async with httpx.AsyncClient() as client:
# Build headers for internal request - forward original headers
# but add x-forwarded-internally to prevent infinite loops
internal_headers = dict(headers)
internal_headers["x-forwarded-internally"] = "true"
# Ensure content-type is set
internal_headers["content-type"] = "application/json"
response = await client.post(
f"http://127.0.0.1:{settings.port}/rpc",
json={"jsonrpc": "2.0", "method": method, "params": params, "id": req_id},
headers=internal_headers,
timeout=settings.mcpgateway_pool_rpc_forward_timeout,
)
# Parse response
response_data = response.json()
# Extract result or error from JSON-RPC response
if "error" in response_data:
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error")
return {"error": response_data["error"]}
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed successfully")
return {"result": response_data.get("result", {})}
except httpx.TimeoutException:
logger.warning(f"Timeout executing forwarded request: {request.get('method')}")
return {"error": {"code": -32603, "message": "Internal request timeout"}}
except Exception as e:
logger.warning(f"Error executing forwarded request: {e}")
return {"error": {"code": -32603, "message": str(e)}}
async def _execute_forwarded_http_request(self, request: Dict[str, Any], redis: Any) -> None:
"""Execute a forwarded HTTP request locally and return response via Redis.
This method handles full HTTP requests forwarded from other workers for
Streamable HTTP transport session affinity. It reconstructs the HTTP request,
makes an internal call to the appropriate endpoint, and publishes the response
back through Redis.
Args:
request: Serialized HTTP request data from Redis Pub/Sub containing:
- type: "http_forward"
- response_channel: Redis channel to publish response to
- mcp_session_id: Session identifier
- method: HTTP method (GET, POST, DELETE)
- path: Request path (e.g., /mcp)
- query_string: Query parameters
- headers: Request headers dict
- body: Hex-encoded request body
redis: Redis client for publishing response
"""
response_channel = None
try:
response_channel = request.get("response_channel")
method = request.get("method")
path = request.get("path")
query_string = request.get("query_string", "")
headers = request.get("headers", {})
body_hex = request.get("body", "")
mcp_session_id = request.get("mcp_session_id")
# Decode hex body back to bytes
body = bytes.fromhex(body_hex) if body_hex else b""
session_short = mcp_session_id[:8] if mcp_session_id and len(mcp_session_id) >= 8 else "unknown"
logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Received forwarded HTTP request: {method} {path}")
# Add internal forwarding headers to prevent loops
internal_headers = dict(headers)
internal_headers["x-forwarded-internally"] = "true"
internal_headers["x-original-worker"] = request.get("original_worker", "unknown")
# Make internal HTTP request to local endpoint
url = f"http://127.0.0.1:{settings.port}{path}"
if query_string:
url = f"{url}?{query_string}"
async with httpx.AsyncClient() as client:
response = await client.request(
method=method,
url=url,
headers=internal_headers,
content=body,
timeout=settings.mcpgateway_pool_rpc_forward_timeout,
)
logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Executed locally: {response.status_code}")
# Serialize response for Redis transport
response_data = {
"status": response.status_code,
"headers": dict(response.headers),
"body": response.content.hex(), # Hex encode binary response
}
# Publish response back to requesting worker
if redis and response_channel:
await redis.publish(response_channel, orjson.dumps(response_data))
logger.debug(f"[HTTP_AFFINITY] Published HTTP response to Redis channel: {response_channel}")
except Exception as e:
logger.error(f"Error executing forwarded HTTP request: {e}")
# Try to send error response if possible
if redis and response_channel:
error_response = {
"status": 500,
"headers": {"content-type": "application/json"},
"body": orjson.dumps({"error": "Internal forwarding error"}).hex(),
}
try:
await redis.publish(response_channel, orjson.dumps(error_response))
except Exception as publish_error:
logger.debug(f"Failed to publish error response via Redis: {publish_error}")
async def get_streamable_http_session_owner(self, mcp_session_id: str) -> Optional[str]:
"""Get the worker ID that owns a Streamable HTTP session.
This is a public wrapper around _get_pool_session_owner for use by
streamablehttp_transport to check session ownership before handling requests.
Args:
mcp_session_id: The MCP session ID from mcp-session-id header.
Returns:
Worker ID if found, None otherwise.
"""
return await self._get_pool_session_owner(mcp_session_id)
async def forward_streamable_http_to_owner(
self,
owner_worker_id: str,
mcp_session_id: str,
method: str,
path: str,
headers: Dict[str, str],
body: bytes,
query_string: str = "",
) -> Optional[Dict[str, Any]]:
"""Forward a Streamable HTTP request to the worker that owns the session via Redis Pub/Sub.
This method forwards the entire HTTP request to another worker using Redis
Pub/Sub channels, similar to forward_request_to_owner() for SSE transport.
This ensures session affinity works correctly in single-host multi-worker
deployments where hostname-based routing fails.
Args:
owner_worker_id: The worker ID that owns the session.
mcp_session_id: The MCP session ID.
method: HTTP method (GET, POST, DELETE).
path: Request path (e.g., /mcp).
headers: Request headers.
body: Request body bytes.
query_string: Query string if any.
Returns:
Dict with 'status', 'headers', and 'body' from the owner worker's response,
or None if forwarding fails.
"""
if not settings.mcpgateway_session_affinity_enabled:
return None
if not self.is_valid_mcp_session_id(mcp_session_id):
return None
session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | {method} {path} | Forwarding to worker {owner_worker_id}")
try:
# First-Party
from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
redis = await get_redis_client()
if not redis:
logger.warning("Redis unavailable for HTTP forwarding, executing locally")
return None # Fall back to local execution
# Generate unique response channel for this request
response_uuid = uuid.uuid4().hex
response_channel = f"mcpgw:pool_http_response:{response_uuid}"
# Serialize HTTP request for Redis transport
forward_data = {
"type": "http_forward",
"response_channel": response_channel,
"mcp_session_id": mcp_session_id,
"method": method,
"path": path,
"query_string": query_string,
"headers": headers,
"body": body.hex() if body else "", # Hex encode binary body
"original_worker": WORKER_ID,
"timestamp": time.time(),
}
# Subscribe to response channel BEFORE publishing request (prevent race)
pubsub = redis.pubsub()
await pubsub.subscribe(response_channel)
try:
# Publish forwarded request to owner worker's HTTP channel
owner_channel = f"mcpgw:pool_http:{owner_worker_id}"
await redis.publish(owner_channel, orjson.dumps(forward_data))
logger.debug(f"[HTTP_AFFINITY] Published HTTP request to Redis channel: {owner_channel}")
# Wait for response with timeout
timeout = settings.mcpgateway_pool_rpc_forward_timeout
async with asyncio.timeout(timeout):
while True:
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if msg and msg["type"] == "message":
response_data = orjson.loads(msg["data"])
logger.debug(f"[HTTP_AFFINITY] Received HTTP response via Redis: status={response_data.get('status')}")
# Decode hex body back to bytes
body_hex = response_data.get("body", "")
response_data["body"] = bytes.fromhex(body_hex) if body_hex else b""
self._forwarded_requests += 1
return response_data
finally:
await pubsub.unsubscribe(response_channel)
except asyncio.TimeoutError:
self._forwarded_request_timeouts += 1
logger.warning(f"Timeout forwarding HTTP request to owner {owner_worker_id}")
return None
except Exception as e:
self._forwarded_request_failures += 1
logger.warning(f"Error forwarding HTTP request via Redis: {e}")
return None
def get_metrics(self) -> Dict[str, Any]:
"""
Return pool metrics for monitoring.
Returns:
Dict with hits, misses, evictions, hit_rate, and per-pool stats.
"""
total_requests = self._hits + self._misses
total_affinity_requests = self._session_affinity_local_hits + self._session_affinity_redis_hits + self._session_affinity_misses
return {
"hits": self._hits,
"misses": self._misses,
"evictions": self._evictions,
"health_check_failures": self._health_check_failures,
"circuit_breaker_trips": self._circuit_breaker_trips,
"pool_keys_evicted": self._pool_keys_evicted,
"sessions_reaped": self._sessions_reaped,
"anonymous_identity_count": self._anonymous_identity_count,
"hit_rate": self._hits / total_requests if total_requests > 0 else 0.0,
"pool_key_count": len(self._pools),
# Session affinity metrics
"session_affinity": {
"local_hits": self._session_affinity_local_hits,
"redis_hits": self._session_affinity_redis_hits,
"misses": self._session_affinity_misses,
"hit_rate": (self._session_affinity_local_hits + self._session_affinity_redis_hits) / total_affinity_requests if total_affinity_requests > 0 else 0.0,
"forwarded_requests": self._forwarded_requests,
"forwarded_failures": self._forwarded_request_failures,
"forwarded_timeouts": self._forwarded_request_timeouts,
},
"pools": {
f"{url}|{identity[:8]}|{transport}|{user}|{gw_id[:8] if gw_id else 'none'}": {
"available": pool.qsize(),
"active": len(self._active.get((user, url, identity, transport, gw_id), set())),
"max": self._max_sessions,
}
for (user, url, identity, transport, gw_id), pool in self._pools.items()
},
"circuit_breakers": {
url: {
"failures": self._failures.get(url, 0),
"open_until": self._circuit_open_until.get(url),
}
for url in set(self._failures.keys()) | set(self._circuit_open_until.keys())
},
}
@asynccontextmanager
async def session(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
transport_type: TransportType = TransportType.STREAMABLE_HTTP,
httpx_client_factory: Optional[HttpxClientFactory] = None,
timeout: Optional[float] = None,
user_identity: Optional[str] = None,
gateway_id: Optional[str] = None,
) -> "AsyncIterator[PooledSession]":
"""
Context manager for acquiring and releasing a session.
Usage:
async with pool.session(url, headers) as pooled:
result = await pooled.session.call_tool("my_tool", {})
Args:
url: The MCP server URL.
headers: Request headers.
transport_type: Transport type to use.
httpx_client_factory: Optional factory for httpx clients.
timeout: Optional timeout in seconds for transport connection.
user_identity: Optional user identity for strict isolation.
gateway_id: Optional gateway ID for notification handler context.
Yields:
PooledSession ready for use.
"""
pooled = await self.acquire(url, headers, transport_type, httpx_client_factory, timeout, user_identity, gateway_id)
try:
yield pooled
finally:
await self.release(pooled)
# Global pool instance - initialized by FastAPI lifespan
_mcp_session_pool: Optional[MCPSessionPool] = None
def get_mcp_session_pool() -> MCPSessionPool:
"""Get the global MCP session pool instance.
Returns:
The global MCPSessionPool instance.
Raises:
RuntimeError: If pool has not been initialized.
"""
if _mcp_session_pool is None:
raise RuntimeError("MCP session pool not initialized. Call init_mcp_session_pool() first.")
return _mcp_session_pool
def init_mcp_session_pool(
max_sessions_per_key: int = 10,
session_ttl_seconds: float = 300.0,
health_check_interval_seconds: float = 60.0,
acquire_timeout_seconds: float = 30.0,
session_create_timeout_seconds: float = 30.0,
circuit_breaker_threshold: int = 5,
circuit_breaker_reset_seconds: float = 60.0,
identity_headers: Optional[frozenset[str]] = None,
identity_extractor: Optional[IdentityExtractor] = None,
idle_pool_eviction_seconds: float = 600.0,
default_transport_timeout_seconds: float = 30.0,
health_check_methods: Optional[list[str]] = None,
health_check_timeout_seconds: float = 5.0,
message_handler_factory: Optional[MessageHandlerFactory] = None,
enable_notifications: bool = True,
notification_debounce_seconds: float = 5.0,
) -> MCPSessionPool:
"""Initialize the global MCP session pool.
Args:
See MCPSessionPool.__init__ for argument descriptions.
enable_notifications: Enable automatic notification service for list_changed events.
notification_debounce_seconds: Debounce interval for notification-triggered refreshes.
Returns:
The initialized MCPSessionPool instance.
"""
global _mcp_session_pool # pylint: disable=global-statement
# Auto-create notification service if enabled and no custom handler provided
effective_handler_factory = message_handler_factory
if enable_notifications and message_handler_factory is None:
# First-Party
from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
init_notification_service,
)
# Initialize notification service (will be started during acquire with gateway context)
notification_svc = init_notification_service(debounce_seconds=notification_debounce_seconds)
# Create default handler factory that uses notification service
def default_handler_factory(url: str, gateway_id: Optional[str]):
"""Create a message handler for MCP session notifications.
Args:
url: The MCP server URL for the session.
gateway_id: Optional gateway ID for attribution, falls back to URL if not provided.
Returns:
A message handler that forwards notifications to the notification service.
"""
return notification_svc.create_message_handler(gateway_id or url, url)
effective_handler_factory = default_handler_factory
logger.info("MCP notification service created (debounce=%ss)", notification_debounce_seconds)
_mcp_session_pool = MCPSessionPool(
max_sessions_per_key=max_sessions_per_key,
session_ttl_seconds=session_ttl_seconds,
health_check_interval_seconds=health_check_interval_seconds,
acquire_timeout_seconds=acquire_timeout_seconds,
session_create_timeout_seconds=session_create_timeout_seconds,
circuit_breaker_threshold=circuit_breaker_threshold,
circuit_breaker_reset_seconds=circuit_breaker_reset_seconds,
identity_headers=identity_headers,
identity_extractor=identity_extractor,
idle_pool_eviction_seconds=idle_pool_eviction_seconds,
default_transport_timeout_seconds=default_transport_timeout_seconds,
health_check_methods=health_check_methods,
health_check_timeout_seconds=health_check_timeout_seconds,
message_handler_factory=effective_handler_factory,
)
logger.info("MCP session pool initialized")
return _mcp_session_pool
async def close_mcp_session_pool() -> None:
"""Close the global MCP session pool and notification service."""
global _mcp_session_pool # pylint: disable=global-statement
if _mcp_session_pool is not None:
await _mcp_session_pool.close_all()
_mcp_session_pool = None
logger.info("MCP session pool closed")
# Close notification service if it was initialized
try:
# First-Party
from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
close_notification_service,
)
await close_notification_service()
except (ImportError, RuntimeError):
pass # Notification service not initialized
async def start_pool_notification_service(gateway_service: Any = None) -> None:
"""Start the notification service background worker.
Call this after gateway_service is initialized to enable event-driven refresh.
Args:
gateway_service: Optional GatewayService instance for triggering refreshes.
"""
try:
# First-Party
from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
get_notification_service,
)
notification_svc = get_notification_service()
await notification_svc.initialize(gateway_service)
logger.info("MCP notification service started")
except RuntimeError:
logger.debug("Notification service not configured, skipping start")
def register_gateway_capabilities_for_notifications(gateway_id: str, capabilities: Dict[str, Any]) -> None:
"""Register gateway capabilities for notification handling.
Call this after gateway initialization to enable list_changed notifications.
Args:
gateway_id: The gateway ID.
capabilities: Server capabilities from initialization response.
"""
try:
# First-Party
from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
get_notification_service,
)
notification_svc = get_notification_service()
notification_svc.register_gateway_capabilities(gateway_id, capabilities)
except RuntimeError:
pass # Notification service not initialized
def unregister_gateway_from_notifications(gateway_id: str) -> None:
"""Unregister a gateway from notification handling.
Call this when a gateway is deleted.
Args:
gateway_id: The gateway ID to unregister.
"""
try:
# First-Party
from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
get_notification_service,
)
notification_svc = get_notification_service()
notification_svc.unregister_gateway(gateway_id)
except RuntimeError:
pass # Notification service not initialized