import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import { setupCapabilities } from '@src/core/capabilities/capabilityManager.js';
import { LazyLoadingOrchestrator } from '@src/core/capabilities/lazyLoadingOrchestrator.js';
import { byCapabilities } from '@src/core/filtering/clientFiltering.js';
import { FilteringService } from '@src/core/filtering/filteringService.js';
import type { OutboundConnections } from '@src/core/types/client.js';
import { InboundConnection, InboundConnectionConfig, OperationOptions, ServerStatus } from '@src/core/types/index.js';
import {
type ClientConnection,
PresetNotificationService,
} from '@src/domains/preset/services/presetNotificationService.js';
import logger, { debugIf } from '@src/logger/logger.js';
import { enhanceServerWithLogging } from '@src/logger/mcpLoggingEnhancer.js';
import type { ContextData } from '@src/types/context.js';
import { executeOperation } from '@src/utils/core/operationExecution.js';
/**
* Manages transport connection lifecycle and inbound connections
*/
export class ConnectionManager {
private inboundConns: Map<string, InboundConnection> = new Map();
private connectionSemaphore: Map<string, Promise<void>> = new Map();
private disconnectingIds: Set<string> = new Set();
private lazyLoadingOrchestrator?: LazyLoadingOrchestrator;
constructor(
private serverConfig: { name: string; version: string },
private serverCapabilities: { capabilities: Record<string, unknown> },
private outboundConns: OutboundConnections,
lazyLoadingOrchestrator?: LazyLoadingOrchestrator,
) {
this.lazyLoadingOrchestrator = lazyLoadingOrchestrator;
}
/**
* Set the lazy loading orchestrator
*/
public setLazyLoadingOrchestrator(orchestrator: LazyLoadingOrchestrator): void {
this.lazyLoadingOrchestrator = orchestrator;
}
/**
* Get the lazy loading orchestrator
*/
public getLazyLoadingOrchestrator(): LazyLoadingOrchestrator | undefined {
return this.lazyLoadingOrchestrator;
}
/**
* Connect a transport with the given session ID and configuration
*/
public async connectTransport(
transport: Transport,
sessionId: string,
opts: InboundConnectionConfig,
context?: ContextData,
filteredInstructions?: string,
): Promise<void> {
// Check if a connection is already in progress for this session
const existingConnection = this.connectionSemaphore.get(sessionId);
if (existingConnection) {
logger.warn(`Connection already in progress for session ${sessionId}, waiting...`);
await existingConnection;
return;
}
// Check if transport is already connected
if (this.inboundConns.has(sessionId)) {
logger.warn(`Transport already connected for session ${sessionId}`);
return;
}
// Create connection promise to prevent race conditions
const connectionPromise = this.performConnection(transport, sessionId, opts, context, filteredInstructions);
this.connectionSemaphore.set(sessionId, connectionPromise);
try {
await connectionPromise;
} finally {
// Clean up the semaphore entry
this.connectionSemaphore.delete(sessionId);
}
}
/**
* Disconnect a transport by session ID
*/
public async disconnectTransport(sessionId: string, forceClose: boolean = false): Promise<void> {
// Prevent recursive disconnection calls
if (this.disconnectingIds.has(sessionId)) {
return;
}
const server = this.inboundConns.get(sessionId);
if (server) {
this.disconnectingIds.add(sessionId);
try {
// Update status to Disconnected
server.status = ServerStatus.Disconnected;
// Only close the transport if explicitly requested
if (forceClose && server.server.transport) {
try {
server.server.transport.close();
} catch (error) {
logger.error(`Error closing transport for session ${sessionId}:`, error);
}
}
// Untrack client from preset notification service
const notificationService = PresetNotificationService.getInstance();
notificationService.untrackClient(sessionId);
debugIf(() => ({ message: 'Untracked client from preset notifications', meta: { sessionId } }));
this.inboundConns.delete(sessionId);
logger.info(`Disconnected transport for session ${sessionId}`);
} finally {
this.disconnectingIds.delete(sessionId);
}
}
}
/**
* Get transport by session ID
*/
public getTransport(sessionId: string): Transport | undefined {
return this.inboundConns.get(sessionId)?.server.transport;
}
/**
* Get all active transports
*/
public getTransports(): Map<string, Transport> {
const transports = new Map<string, Transport>();
for (const [id, server] of this.inboundConns.entries()) {
if (server.server.transport) {
transports.set(id, server.server.transport);
}
}
return transports;
}
/**
* Get server connection by session ID
*/
public getServer(sessionId: string): InboundConnection | undefined {
return this.inboundConns.get(sessionId);
}
/**
* Get all inbound connections
*/
public getInboundConnections(): Map<string, InboundConnection> {
return this.inboundConns;
}
/**
* Get count of active transports
*/
public getActiveTransportsCount(): number {
return this.inboundConns.size;
}
/**
* Execute a server operation with error handling
*/
public async executeServerOperation<T>(
inboundConn: InboundConnection,
operation: (inboundConn: InboundConnection) => Promise<T>,
options: OperationOptions = {},
): Promise<T> {
// Check connection status before executing operation
if (inboundConn.status !== ServerStatus.Connected || !inboundConn.server.transport) {
throw new Error(`Cannot execute operation: server status is ${inboundConn.status}`);
}
return executeOperation(() => operation(inboundConn), 'server', options);
}
/**
* Perform the actual connection
*/
private async performConnection(
transport: Transport,
sessionId: string,
opts: InboundConnectionConfig,
context?: ContextData,
filteredInstructions?: string,
): Promise<void> {
// Set connection timeout
const connectionTimeoutMs = 30000; // 30 seconds
const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout(() => reject(new Error(`Connection timeout for session ${sessionId}`)), connectionTimeoutMs);
});
try {
await Promise.race([this.doConnect(transport, sessionId, opts, context, filteredInstructions), timeoutPromise]);
} catch (error) {
// Update status to Error if connection exists
const connection = this.inboundConns.get(sessionId);
if (connection) {
connection.status = ServerStatus.Error;
connection.lastError = error instanceof Error ? error : new Error(String(error));
}
logger.error(`Failed to connect transport for session ${sessionId}:`, error);
throw error;
}
}
/**
* Do the actual connection work
*/
private async doConnect(
transport: Transport,
sessionId: string,
opts: InboundConnectionConfig,
context?: ContextData,
filteredInstructions?: string,
): Promise<void> {
// Create server capabilities with filtered instructions
const serverOptionsWithInstructions = {
...this.serverCapabilities,
instructions: filteredInstructions || undefined,
};
// Create a new server instance for this transport
const server = new Server(this.serverConfig, serverOptionsWithInstructions);
// Create server info object, merging context if provided
// CRITICAL: Ensure sessionId is always set in context for session-scoped filtering
// This is needed for lazy loading to store and retrieve session filters correctly
// Priority: opts.context.sessionId > context.sessionId > transport sessionId (fallback)
const mergedContext = {
...(context || {}),
...(opts.context || {}),
// Use opts.context.sessionId if provided, otherwise fall back to transport sessionId
sessionId: opts.context?.sessionId || context?.sessionId || sessionId,
};
const serverInfo: InboundConnection = {
server,
status: ServerStatus.Connecting,
connectedAt: new Date(),
...opts,
context: mergedContext,
};
// Enhance server with logging middleware
enhanceServerWithLogging(server);
// Set up capabilities for this server instance
await setupCapabilities(this.outboundConns, serverInfo, this.lazyLoadingOrchestrator);
// Store the server instance
this.inboundConns.set(sessionId, serverInfo);
// Connect the transport to the new server instance
await server.connect(transport);
// Update status to Connected after successful connection
serverInfo.status = ServerStatus.Connected;
serverInfo.lastConnected = new Date();
// Initialize session filter for lazy loading if enabled
if (this.lazyLoadingOrchestrator?.isEnabled()) {
await this.initializeSessionFilter(sessionId, serverInfo);
}
// Register client with preset notification service if preset is used
if (opts.presetName) {
await this.registerClientForPresets(sessionId, opts.presetName, serverInfo);
}
logger.info(`Connected transport for session ${sessionId}`);
}
/**
* Initialize session filter for lazy loading
* This applies tag/preset filtering to the session so that tool_list meta-tool
* returns only tools from filtered servers
*/
private async initializeSessionFilter(sessionId: string, serverInfo: InboundConnection): Promise<void> {
if (!this.lazyLoadingOrchestrator) {
return;
}
// Apply the same filtering chain as used in tools/list handler
const sessionFilteredConns = this.outboundConns;
const capabilityFilteredConns = byCapabilities({ tools: {} })(sessionFilteredConns);
const filteredConns = FilteringService.getFilteredConnections(capabilityFilteredConns, serverInfo);
// Get the set of filtered server names using clean names from connection.name
// This ensures consistency with how the tool registry is built
// Template servers use hash-suffixed keys (e.g., "serena:abc123") but
// connection.name contains the clean name (e.g., "serena")
const filteredServerNames = new Set(Array.from(filteredConns.values()).map((conn) => conn.name));
debugIf(() => ({
message: 'Initializing session filter for lazy loading',
meta: {
sessionId,
totalServers: this.outboundConns.size,
filteredServers: filteredServerNames.size,
serverNames: Array.from(filteredServerNames),
mapKeys: Array.from(filteredConns.keys()), // For debugging key mismatch
tagFilterMode: serverInfo.tagFilterMode,
tags: serverInfo.tags,
},
}));
// Store the session filter so tool_list returns filtered results
await this.lazyLoadingOrchestrator.getCapabilitiesForFilteredServers(filteredServerNames, sessionId);
}
/**
* Register client with preset notification service
*/
private async registerClientForPresets(
sessionId: string,
presetName: string,
serverInfo: InboundConnection,
): Promise<void> {
const notificationService = PresetNotificationService.getInstance();
const clientConnection: ClientConnection = {
id: sessionId,
presetName,
sendNotification: async (method: string, params?: Record<string, unknown>) => {
try {
if (serverInfo.status === ServerStatus.Connected && serverInfo.server.transport) {
await serverInfo.server.notification({ method, params: params || {} });
debugIf(() => ({ message: 'Sent notification to client', meta: { sessionId, method } }));
} else {
logger.warn('Cannot send notification to disconnected client', { sessionId, method });
}
} catch (error) {
logger.error('Failed to send notification to client', {
sessionId,
method,
error: error instanceof Error ? error.message : 'Unknown error',
});
throw error;
}
},
isConnected: () => serverInfo.status === ServerStatus.Connected && !!serverInfo.server.transport,
};
notificationService.trackClient(clientConnection, presetName);
logger.info('Registered client for preset notifications', {
sessionId,
presetName,
});
}
/**
* Clean up all connections (for shutdown)
*/
public async cleanup(): Promise<void> {
// Clean up existing connections with forced close
for (const [sessionId] of this.inboundConns) {
await this.disconnectTransport(sessionId, true);
}
this.inboundConns.clear();
this.connectionSemaphore.clear();
this.disconnectingIds.clear();
}
}