Codebase MCP

import { ZodLiteral, ZodObject, ZodType, z } from "zod"; import { CancelledNotificationSchema, ClientCapabilities, ErrorCode, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, McpError, Notification, PingRequestSchema, Progress, ProgressNotification, ProgressNotificationSchema, Request, RequestId, Result, ServerCapabilities, } from "../types.js"; import { Transport } from "./transport.js"; /** * Callback for progress notifications. */ export type ProgressCallback = (progress: Progress) => void; /** * Additional initialization options. */ export type ProtocolOptions = { /** * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. * * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. * * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. */ enforceStrictCapabilities?: boolean; }; /** * The default request timeout, in miliseconds. */ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60000; /** * Options that can be given per request. */ export type RequestOptions = { /** * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. */ onprogress?: ProgressCallback; /** * Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request(). */ signal?: AbortSignal; /** * A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request(). * * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. */ timeout?: number; /** * If true, receiving a progress notification will reset the request timeout. * This is useful for long-running operations that send periodic progress updates. * Default: false */ resetTimeoutOnProgress?: boolean; /** * Maximum total time (in milliseconds) to wait for a response. * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; }; /** * Extra data given to request handlers. */ export type RequestHandlerExtra = { /** * An abort signal used to communicate if the request was cancelled from the sender's side. */ signal: AbortSignal; }; /** * Information about a request's timeout state */ type TimeoutInfo = { timeoutId: ReturnType<typeof setTimeout>; startTime: number; timeout: number; maxTotalTimeout?: number; onTimeout: () => void; }; /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ export abstract class Protocol< SendRequestT extends Request, SendNotificationT extends Notification, SendResultT extends Result, > { private _transport?: Transport; private _requestMessageId = 0; private _requestHandlers: Map< string, ( request: JSONRPCRequest, extra: RequestHandlerExtra, ) => Promise<SendResultT> > = new Map(); private _requestHandlerAbortControllers: Map<RequestId, AbortController> = new Map(); private _notificationHandlers: Map< string, (notification: JSONRPCNotification) => Promise<void> > = new Map(); private _responseHandlers: Map< number, (response: JSONRPCResponse | Error) => void > = new Map(); private _progressHandlers: Map<number, ProgressCallback> = new Map(); private _timeoutInfo: Map<number, TimeoutInfo> = new Map(); /** * Callback for when the connection is closed for any reason. * * This is invoked when close() is called as well. */ onclose?: () => void; /** * Callback for when an error occurs. * * Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band. */ onerror?: (error: Error) => void; /** * A handler to invoke for any request types that do not have their own handler installed. */ fallbackRequestHandler?: (request: Request) => Promise<SendResultT>; /** * A handler to invoke for any notification types that do not have their own handler installed. */ fallbackNotificationHandler?: (notification: Notification) => Promise<void>; constructor(private _options?: ProtocolOptions) { this.setNotificationHandler(CancelledNotificationSchema, (notification) => { const controller = this._requestHandlerAbortControllers.get( notification.params.requestId, ); controller?.abort(notification.params.reason); }); this.setNotificationHandler(ProgressNotificationSchema, (notification) => { this._onprogress(notification as unknown as ProgressNotification); }); this.setRequestHandler( PingRequestSchema, // Automatic pong by default. (_request) => ({}) as SendResultT, ); } private _setupTimeout( messageId: number, timeout: number, maxTotalTimeout: number | undefined, onTimeout: () => void ) { this._timeoutInfo.set(messageId, { timeoutId: setTimeout(onTimeout, timeout), startTime: Date.now(), timeout, maxTotalTimeout, onTimeout }); } private _resetTimeout(messageId: number): boolean { const info = this._timeoutInfo.get(messageId); if (!info) return false; const totalElapsed = Date.now() - info.startTime; if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { this._timeoutInfo.delete(messageId); throw new McpError( ErrorCode.RequestTimeout, "Maximum total timeout exceeded", { maxTotalTimeout: info.maxTotalTimeout, totalElapsed } ); } clearTimeout(info.timeoutId); info.timeoutId = setTimeout(info.onTimeout, info.timeout); return true; } private _cleanupTimeout(messageId: number) { const info = this._timeoutInfo.get(messageId); if (info) { clearTimeout(info.timeoutId); this._timeoutInfo.delete(messageId); } } /** * Attaches to the given transport, starts it, and starts listening for messages. * * The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. */ async connect(transport: Transport): Promise<void> { this._transport = transport; this._transport.onclose = () => { this._onclose(); }; this._transport.onerror = (error: Error) => { this._onerror(error); }; this._transport.onmessage = (message) => { if (!("method" in message)) { this._onresponse(message); } else if ("id" in message) { this._onrequest(message); } else { this._onnotification(message); } }; await this._transport.start(); } private _onclose(): void { const responseHandlers = this._responseHandlers; this._responseHandlers = new Map(); this._progressHandlers.clear(); this._transport = undefined; this.onclose?.(); const error = new McpError(ErrorCode.ConnectionClosed, "Connection closed"); for (const handler of responseHandlers.values()) { handler(error); } } private _onerror(error: Error): void { this.onerror?.(error); } private _onnotification(notification: JSONRPCNotification): void { const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; // Ignore notifications not being subscribed to. if (handler === undefined) { return; } // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() .then(() => handler(notification)) .catch((error) => this._onerror( new Error(`Uncaught error in notification handler: ${error}`), ), ); } private _onrequest(request: JSONRPCRequest): void { const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; if (handler === undefined) { this._transport ?.send({ jsonrpc: "2.0", id: request.id, error: { code: ErrorCode.MethodNotFound, message: "Method not found", }, }) .catch((error) => this._onerror( new Error(`Failed to send an error response: ${error}`), ), ); return; } const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() .then(() => handler(request, { signal: abortController.signal })) .then( (result) => { if (abortController.signal.aborted) { return; } return this._transport?.send({ result, jsonrpc: "2.0", id: request.id, }); }, (error) => { if (abortController.signal.aborted) { return; } return this._transport?.send({ jsonrpc: "2.0", id: request.id, error: { code: Number.isSafeInteger(error["code"]) ? error["code"] : ErrorCode.InternalError, message: error.message ?? "Internal error", }, }); }, ) .catch((error) => this._onerror(new Error(`Failed to send response: ${error}`)), ) .finally(() => { this._requestHandlerAbortControllers.delete(request.id); }); } private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); const handler = this._progressHandlers.get(messageId); if (!handler) { this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); return; } const responseHandler = this._responseHandlers.get(messageId); if (this._timeoutInfo.has(messageId) && responseHandler) { try { this._resetTimeout(messageId); } catch (error) { responseHandler(error as Error); return; } } handler(params); } private _onresponse(response: JSONRPCResponse | JSONRPCError): void { const messageId = Number(response.id); const handler = this._responseHandlers.get(messageId); if (handler === undefined) { this._onerror( new Error( `Received a response for an unknown message ID: ${JSON.stringify(response)}`, ), ); return; } this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); this._cleanupTimeout(messageId); if ("result" in response) { handler(response); } else { const error = new McpError( response.error.code, response.error.message, response.error.data, ); handler(error); } } get transport(): Transport | undefined { return this._transport; } /** * Closes the connection. */ async close(): Promise<void> { await this._transport?.close(); } /** * A method to check if a capability is supported by the remote side, for the given method to be called. * * This should be implemented by subclasses. */ protected abstract assertCapabilityForMethod( method: SendRequestT["method"], ): void; /** * A method to check if a notification is supported by the local side, for the given method to be sent. * * This should be implemented by subclasses. */ protected abstract assertNotificationCapability( method: SendNotificationT["method"], ): void; /** * A method to check if a request handler is supported by the local side, for the given method to be handled. * * This should be implemented by subclasses. */ protected abstract assertRequestHandlerCapability(method: string): void; /** * Sends a request and wait for a response. * * Do not use this method to emit notifications! Use notification() instead. */ request<T extends ZodType<object>>( request: SendRequestT, resultSchema: T, options?: RequestOptions, ): Promise<z.infer<T>> { return new Promise((resolve, reject) => { if (!this._transport) { reject(new Error("Not connected")); return; } if (this._options?.enforceStrictCapabilities === true) { this.assertCapabilityForMethod(request.method); } options?.signal?.throwIfAborted(); const messageId = this._requestMessageId++; const jsonrpcRequest: JSONRPCRequest = { ...request, jsonrpc: "2.0", id: messageId, }; if (options?.onprogress) { this._progressHandlers.set(messageId, options.onprogress); jsonrpcRequest.params = { ...request.params, _meta: { progressToken: messageId }, }; } const cancel = (reason: unknown) => { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); this._cleanupTimeout(messageId); this._transport ?.send({ jsonrpc: "2.0", method: "notifications/cancelled", params: { requestId: messageId, reason: String(reason), }, }) .catch((error) => this._onerror(new Error(`Failed to send cancellation: ${error}`)), ); reject(reason); }; this._responseHandlers.set(messageId, (response) => { if (options?.signal?.aborted) { return; } if (response instanceof Error) { return reject(response); } try { const result = resultSchema.parse(response.result); resolve(result); } catch (error) { reject(error); } }); options?.signal?.addEventListener("abort", () => { cancel(options?.signal?.reason); }); const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; const timeoutHandler = () => cancel(new McpError( ErrorCode.RequestTimeout, "Request timed out", { timeout } )); this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler); this._transport.send(jsonrpcRequest).catch((error) => { this._cleanupTimeout(messageId); reject(error); }); }); } /** * Emits a notification, which is a one-way message that does not expect a response. */ async notification(notification: SendNotificationT): Promise<void> { if (!this._transport) { throw new Error("Not connected"); } this.assertNotificationCapability(notification.method); const jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: "2.0", }; await this._transport.send(jsonrpcNotification); } /** * Registers a handler to invoke when this protocol object receives a request with the given method. * * Note that this will replace any previous request handler for the same method. */ setRequestHandler< T extends ZodObject<{ method: ZodLiteral<string>; }>, >( requestSchema: T, handler: ( request: z.infer<T>, extra: RequestHandlerExtra, ) => SendResultT | Promise<SendResultT>, ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); this._requestHandlers.set(method, (request, extra) => Promise.resolve(handler(requestSchema.parse(request), extra)), ); } /** * Removes the request handler for the given method. */ removeRequestHandler(method: string): void { this._requestHandlers.delete(method); } /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ assertCanSetRequestHandler(method: string): void { if (this._requestHandlers.has(method)) { throw new Error( `A request handler for ${method} already exists, which would be overridden`, ); } } /** * Registers a handler to invoke when this protocol object receives a notification with the given method. * * Note that this will replace any previous notification handler for the same method. */ setNotificationHandler< T extends ZodObject<{ method: ZodLiteral<string>; }>, >( notificationSchema: T, handler: (notification: z.infer<T>) => void | Promise<void>, ): void { this._notificationHandlers.set( notificationSchema.shape.method.value, (notification) => Promise.resolve(handler(notificationSchema.parse(notification))), ); } /** * Removes the notification handler for the given method. */ removeNotificationHandler(method: string): void { this._notificationHandlers.delete(method); } } export function mergeCapabilities< T extends ServerCapabilities | ClientCapabilities, >(base: T, additional: T): T { return Object.entries(additional).reduce( (acc, [key, value]) => { if (value && typeof value === "object") { acc[key] = acc[key] ? { ...acc[key], ...value } : value; } else { acc[key] = value; } return acc; }, { ...base }, ); }