Skip to main content
Glama
websocket-manager.ts24.1 kB
import { Server as SocketIOServer, Socket } from 'socket.io'; // MCP SDK types - stubbed for build compatibility type JSONRPCMessage = any; type JSONRPCRequest = any; type JSONRPCResponse = any; type JSONRPCError = any; type InitializeRequest = any; type InitializeResult = any; type ListToolsRequest = any; type ListToolsResult = any; type CallToolRequest = any; type CallToolResult = any; type ListResourcesRequest = any; type ListResourcesResult = any; type ReadResourceRequest = any; type ReadResourceResult = any; type Tool = any; type Resource = any; type TextContent = any; type ImageContent = any; import { logger } from '../utils/logger'; import { config } from '../config/config'; import { authenticateSocket } from '../auth/middleware'; import { SecurityMiddleware } from '../security/middleware'; import { mcpSecurityValidator } from './mcp-security'; import { redis } from '../database/redis'; import { EventEmitter } from 'events'; import { v4 as uuidv4 } from 'uuid'; interface MCPConnection { id: string; socket: Socket; userId: string; sessionId: string; capabilities: string[]; clientInfo: { name: string; version: string; }; serverInfo: { name: string; version: string; }; initialized: boolean; lastActivity: Date; rateLimitCount: number; rateLimitWindow: Date; } interface MCPTool extends Tool { name: string; description: string; inputSchema: any; handler: (args: any) => Promise<any>; permissions?: string[]; } interface MCPResource extends Resource { uri: string; name: string; description?: string; mimeType?: string; loader: () => Promise<TextContent | ImageContent>; permissions?: string[]; } export class WebSocketManager extends EventEmitter { private io: SocketIOServer; private connections: Map<string, MCPConnection> = new Map(); private tools: Map<string, MCPTool> = new Map(); private resources: Map<string, MCPResource> = new Map(); private connectionCount: number = 0; private messageCount: number = 0; private errorCount: number = 0; constructor(io: SocketIOServer) { super(); this.io = io; this.setupMiddleware(); this.setupEventHandlers(); } public async initialize(): Promise<void> { await this.loadTools(); await this.loadResources(); this.startCleanupInterval(); logger.info('WebSocket manager initialized', { tools: this.tools.size, resources: this.resources.size, }); } private setupMiddleware(): void { // Authentication middleware this.io.use(async (socket, next) => { try { const user = await authenticateSocket(socket); socket.data.user = user; next(); } catch (error) { logger.error('Socket authentication failed', { error, socketId: socket.id }); next(new Error('Authentication failed')); } }); // Rate limiting middleware this.io.use(async (socket, next) => { const clientIp = socket.handshake.address; const rateLimitKey = `ws_rate_limit:${clientIp}`; try { const current = await redis.incr(rateLimitKey); if (current === 1) { await redis.expire(rateLimitKey, 60); // 1 minute window } if (current > config.mcp.rateLimitPerConnection) { logger.warn('WebSocket rate limit exceeded', { clientIp, current }); next(new Error('Rate limit exceeded')); return; } next(); } catch (error) { logger.error('Rate limiting error', { error, clientIp }); next(error); } }); // Connection limit middleware this.io.use((socket, next) => { if (this.connectionCount >= config.websocket.maxConnections) { logger.warn('Connection limit exceeded', { current: this.connectionCount, max: config.websocket.maxConnections }); next(new Error('Connection limit exceeded')); return; } next(); }); } private setupEventHandlers(): void { this.io.on('connection', (socket: Socket) => { this.handleConnection(socket); }); } private handleConnection(socket: Socket): void { const connectionId = uuidv4(); this.connectionCount++; logger.info('New WebSocket connection', { connectionId, socketId: socket.id, userId: socket.data.user.id, ip: socket.handshake.address, userAgent: socket.handshake.headers['user-agent'], }); const connection: MCPConnection = { id: connectionId, socket, userId: socket.data.user.id, sessionId: uuidv4(), capabilities: [], clientInfo: { name: '', version: '' }, serverInfo: { name: 'secure-mcp-server', version: process.env.npm_package_version || '1.0.0', }, initialized: false, lastActivity: new Date(), rateLimitCount: 0, rateLimitWindow: new Date(), }; this.connections.set(connectionId, connection); // Set up message handling socket.on('message', async (data: Buffer | string) => { await this.handleMessage(connectionId, data); }); socket.on('disconnect', (reason: string) => { this.handleDisconnection(connectionId, reason); }); socket.on('error', (error: Error) => { this.handleError(connectionId, error); }); // Send server capabilities this.sendMessage(connection, { jsonrpc: '2.0', method: 'notifications/initialized', params: { protocolVersion: config.mcp.protocolVersion, capabilities: { tools: { listChanged: true }, resources: { listChanged: true, subscribe: true }, logging: {}, }, serverInfo: connection.serverInfo, }, }); } private async handleMessage(connectionId: string, data: Buffer | string): Promise<void> { const connection = this.connections.get(connectionId); if (!connection) { logger.error('Message received for unknown connection', { connectionId }); return; } try { // Update activity timestamp connection.lastActivity = new Date(); // Check rate limiting if (!this.checkRateLimit(connection)) { this.sendError(connection, null, -32000, 'Rate limit exceeded'); return; } // Parse message const messageStr = Buffer.isBuffer(data) ? data.toString('utf8') : data; // Validate message size if (messageStr.length > config.websocket.maxMessageSize) { this.sendError(connection, null, -32600, 'Message too large'); return; } // SECURITY: Enhanced MCP protocol validation const validation = await mcpSecurityValidator.validateMessage(messageStr); if (!validation.valid) { logger.warn('Invalid MCP message', { error: validation.error, connectionId, userId: connection.userId, }); this.sendError(connection, null, -32600, validation.error || 'Invalid Request'); return; } const message = validation.sanitized; // Validate request ID if present if ('id' in message && !mcpSecurityValidator.validateRequestId(message.id)) { this.sendError(connection, null, -32600, 'Invalid request ID'); return; } this.messageCount++; // Handle different message types if ('method' in message) { await this.handleRequest(connection, message as JSONRPCRequest); } else if ('result' in message || 'error' in message) { await this.handleResponse(connection, message as JSONRPCResponse); } } catch (error) { this.errorCount++; logger.error('Error handling WebSocket message', { error, connectionId, userId: connection.userId, }); this.sendError(connection, null, -32603, 'Internal error'); } } private async handleRequest(connection: MCPConnection, request: JSONRPCRequest): Promise<void> { try { switch (request.method) { case 'initialize': await this.handleInitialize(connection, request as InitializeRequest); break; case 'tools/list': await this.handleListTools(connection, request as ListToolsRequest); break; case 'tools/call': await this.handleCallTool(connection, request as CallToolRequest); break; case 'resources/list': await this.handleListResources(connection, request as ListResourcesRequest); break; case 'resources/read': await this.handleReadResource(connection, request as ReadResourceRequest); break; case 'ping': this.sendMessage(connection, { jsonrpc: '2.0', id: request.id, result: { status: 'pong', timestamp: new Date().toISOString() }, }); break; default: this.sendError(connection, request.id, -32601, `Method not found: ${request.method}`); } } catch (error) { logger.error('Error handling request', { error, method: request.method, connectionId: connection.id, }); this.sendError(connection, request.id, -32603, 'Internal error'); } } private async handleInitialize(connection: MCPConnection, request: InitializeRequest): Promise<void> { const params = request.params; // Validate protocol version if (params.protocolVersion !== config.mcp.protocolVersion) { this.sendError(connection, request.id, -32602, `Unsupported protocol version. Expected: ${config.mcp.protocolVersion}, Got: ${params.protocolVersion}`); return; } // Store client info connection.clientInfo = params.clientInfo; connection.capabilities = Object.keys(params.capabilities || {}); connection.initialized = true; const result: InitializeResult = { protocolVersion: config.mcp.protocolVersion, capabilities: { tools: { listChanged: true, }, resources: { listChanged: true, subscribe: true, }, logging: {}, }, serverInfo: connection.serverInfo, }; this.sendMessage(connection, { jsonrpc: '2.0', id: request.id, result, }); logger.info('MCP connection initialized', { connectionId: connection.id, clientInfo: connection.clientInfo, capabilities: connection.capabilities, }); } private async handleListTools(connection: MCPConnection, request: ListToolsRequest): Promise<void> { if (!connection.initialized) { this.sendError(connection, request.id, -32002, 'Connection not initialized'); return; } const userPermissions = await this.getUserPermissions(connection.userId); const availableTools = Array.from(this.tools.values()) .filter(tool => this.hasPermission(userPermissions, tool.permissions)) .map(tool => ({ name: tool.name, description: tool.description, inputSchema: tool.inputSchema, })); const result: ListToolsResult = { tools: availableTools, }; this.sendMessage(connection, { jsonrpc: '2.0', id: request.id, result, }); } private async handleCallTool(connection: MCPConnection, request: CallToolRequest): Promise<void> { if (!connection.initialized) { this.sendError(connection, request.id, -32002, 'Connection not initialized'); return; } const { name, arguments: args } = request.params; // SECURITY: Validate tool execution request const toolValidation = mcpSecurityValidator.validateToolExecution(name, args); if (!toolValidation.valid) { logger.warn('Invalid tool execution request', { error: toolValidation.error, toolName: name, connectionId: connection.id, userId: connection.userId, }); this.sendError(connection, request.id, -32602, toolValidation.error || 'Invalid tool request'); return; } const tool = this.tools.get(name); if (!tool) { this.sendError(connection, request.id, -32601, `Tool not found: ${name}`); return; } // Check permissions const userPermissions = await this.getUserPermissions(connection.userId); if (!this.hasPermission(userPermissions, tool.permissions)) { this.sendError(connection, request.id, -32000, 'Insufficient permissions'); return; } try { // Use sanitized arguments const toolResult = await tool.handler(toolValidation.sanitizedArgs || {}); // SECURITY: Sanitize output before sending const sanitizedResult = mcpSecurityValidator.sanitizeOutput(toolResult); const result: CallToolResult = { content: [ { type: 'text', text: typeof sanitizedResult === 'string' ? sanitizedResult : JSON.stringify(sanitizedResult), }, ], }; this.sendMessage(connection, { jsonrpc: '2.0', id: request.id, result, }); logger.info('Tool executed successfully', { toolName: name, connectionId: connection.id, userId: connection.userId, }); } catch (error) { logger.error('Tool execution failed', { error, toolName: name, connectionId: connection.id, userId: connection.userId, }); this.sendError(connection, request.id, -32000, 'Tool execution failed'); } } private async handleListResources(connection: MCPConnection, request: ListResourcesRequest): Promise<void> { if (!connection.initialized) { this.sendError(connection, request.id, -32002, 'Connection not initialized'); return; } const userPermissions = await this.getUserPermissions(connection.userId); const availableResources = Array.from(this.resources.values()) .filter(resource => this.hasPermission(userPermissions, resource.permissions)) .map(resource => ({ uri: resource.uri, name: resource.name, description: resource.description, mimeType: resource.mimeType, })); const result: ListResourcesResult = { resources: availableResources, }; this.sendMessage(connection, { jsonrpc: '2.0', id: request.id, result, }); } private async handleReadResource(connection: MCPConnection, request: ReadResourceRequest): Promise<void> { if (!connection.initialized) { this.sendError(connection, request.id, -32002, 'Connection not initialized'); return; } const { uri } = request.params; // SECURITY: Validate resource access request const resourceValidation = mcpSecurityValidator.validateResourceAccess(uri); if (!resourceValidation.valid) { logger.warn('Invalid resource access request', { error: resourceValidation.error, uri, connectionId: connection.id, userId: connection.userId, }); this.sendError(connection, request.id, -32602, resourceValidation.error || 'Invalid resource request'); return; } const resource = this.resources.get(uri); if (!resource) { this.sendError(connection, request.id, -32601, `Resource not found: ${uri}`); return; } // Check permissions const userPermissions = await this.getUserPermissions(connection.userId); if (!this.hasPermission(userPermissions, resource.permissions)) { this.sendError(connection, request.id, -32000, 'Insufficient permissions'); return; } try { const content = await resource.loader(); // SECURITY: Sanitize resource content before sending const sanitizedContent = { ...content, text: content.text ? mcpSecurityValidator.sanitizeOutput(content.text) : undefined, }; const result: ReadResourceResult = { contents: [sanitizedContent], }; this.sendMessage(connection, { jsonrpc: '2.0', id: request.id, result, }); logger.info('Resource read successfully', { resourceUri: uri, connectionId: connection.id, userId: connection.userId, }); } catch (error) { logger.error('Resource read failed', { error, resourceUri: uri, connectionId: connection.id, userId: connection.userId, }); this.sendError(connection, request.id, -32000, 'Resource read failed'); } } private async handleResponse(connection: MCPConnection, response: JSONRPCResponse): Promise<void> { // Handle responses to requests sent by the server (if any) logger.debug('Received response from client', { id: response.id, connectionId: connection.id, hasError: 'error' in response, }); } private handleDisconnection(connectionId: string, reason: string): void { const connection = this.connections.get(connectionId); if (connection) { this.connectionCount--; this.connections.delete(connectionId); logger.info('WebSocket disconnected', { connectionId, userId: connection.userId, reason, duration: Date.now() - connection.lastActivity.getTime(), }); this.emit('disconnect', { connection, reason }); } } private handleError(connectionId: string, error: Error): void { this.errorCount++; logger.error('WebSocket error', { error, connectionId, }); this.emit('error', { connectionId, error }); } private sendMessage(connection: MCPConnection, message: any): void { try { const messageStr = JSON.stringify(message); connection.socket.send(messageStr); } catch (error) { logger.error('Failed to send WebSocket message', { error, connectionId: connection.id, messageType: message.method || 'response', }); } } private sendError(connection: MCPConnection, id: any, code: number, message: string): void { const errorResponse: JSONRPCError = { jsonrpc: '2.0', id, error: { code, message, }, }; this.sendMessage(connection, errorResponse); } private isValidJSONRPC(message: any): boolean { if (!message || typeof message !== 'object') return false; if (message.jsonrpc !== '2.0') return false; // Check if it's a request, response, or notification if ('method' in message) { return typeof message.method === 'string'; } else if ('result' in message || 'error' in message) { return message.id !== undefined; } return false; } private checkRateLimit(connection: MCPConnection): boolean { const now = new Date(); const windowStart = new Date(now.getTime() - 60000); // 1 minute window if (connection.rateLimitWindow < windowStart) { // Reset window connection.rateLimitWindow = now; connection.rateLimitCount = 0; } connection.rateLimitCount++; return connection.rateLimitCount <= config.mcp.rateLimitPerConnection; } private async getUserPermissions(userId: string): Promise<string[]> { try { const permissions = await redis.smembers(`user_permissions:${userId}`); return permissions; } catch (error) { logger.error('Failed to get user permissions', { error, userId }); return []; } } private hasPermission(userPermissions: string[], requiredPermissions?: string[]): boolean { if (!requiredPermissions || requiredPermissions.length === 0) return true; return requiredPermissions.some(permission => userPermissions.includes(permission)); } private async loadTools(): Promise<void> { // Example tools - in production, these would be loaded from a configuration or registry this.tools.set('echo', { name: 'echo', description: 'Echo back the input text', inputSchema: { type: 'object', properties: { text: { type: 'string' }, }, required: ['text'], }, handler: async (args: { text: string }) => { return `Echo: ${args.text}`; }, permissions: ['basic'], }); this.tools.set('calculate', { name: 'calculate', description: 'Perform basic mathematical calculations', inputSchema: { type: 'object', properties: { expression: { type: 'string' }, }, required: ['expression'], }, handler: async (args: { expression: string }) => { // SECURITY: Enhanced expression validation if (!args.expression || typeof args.expression !== 'string') { throw new Error('Invalid expression'); } // Only allow basic math operations const allowedPattern = /^[0-9+\-*/().\s]+$/; if (!allowedPattern.test(args.expression)) { throw new Error('Expression contains invalid characters'); } // Limit expression length if (args.expression.length > 100) { throw new Error('Expression too long'); } // Safe math evaluation using Function constructor with strict mode const sanitized = args.expression.replace(/[^0-9+\-*/().\s]/g, ''); try { // Additional safety: wrap in try-catch and timeout const result = Function('"use strict"; return (' + sanitized + ')')(); // Validate result if (!Number.isFinite(result)) { throw new Error('Invalid calculation result'); } return `Result: ${result}`; } catch (error) { throw new Error('Calculation failed'); } }, permissions: ['calculate'], }); } private async loadResources(): Promise<void> { // Example resources - in production, these would be loaded from a configuration or registry this.resources.set('server-info', { uri: 'server-info', name: 'Server Information', description: 'Information about the server', mimeType: 'application/json', loader: async (): Promise<TextContent> => ({ type: 'text', text: JSON.stringify({ name: 'secure-mcp-server', version: process.env.npm_package_version || '1.0.0', environment: config.env, uptime: process.uptime(), connections: this.connectionCount, }, null, 2), }), permissions: ['basic'], }); } private startCleanupInterval(): void { setInterval(() => { const now = Date.now(); const timeout = 300000; // 5 minutes for (const [connectionId, connection] of this.connections) { if (now - connection.lastActivity.getTime() > timeout) { logger.info('Cleaning up inactive connection', { connectionId, lastActivity: connection.lastActivity, }); connection.socket.disconnect(true); this.connections.delete(connectionId); this.connectionCount--; } } }, 60000); // Check every minute } // Public methods for management and monitoring public getConnectionCount(): number { return this.connectionCount; } public getMessageCount(): number { return this.messageCount; } public getErrorCount(): number { return this.errorCount; } public getConnections(): MCPConnection[] { return Array.from(this.connections.values()).map(conn => ({ ...conn, socket: undefined, // Don't expose socket object } as any)); } public async broadcastNotification(method: string, params: any): Promise<void> { const notification = { jsonrpc: '2.0', method, params, }; for (const connection of this.connections.values()) { if (connection.initialized) { this.sendMessage(connection, notification); } } } public async disconnectUser(userId: string): Promise<void> { for (const [connectionId, connection] of this.connections) { if (connection.userId === userId) { connection.socket.disconnect(true); this.connections.delete(connectionId); this.connectionCount--; } } } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/perfecxion-ai/secure-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server