restartableStdioTransport.test.ts•8.74 kB
import { StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { RestartableStdioTransport } from './restartableStdioTransport.js';
// Mock the StdioClientTransport
const mockStdioClientTransport = vi.hoisted(() =>
vi.fn().mockImplementation(() => ({
start: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
send: vi.fn().mockResolvedValue(undefined),
stderr: null,
pid: 1234,
onclose: undefined,
onerror: undefined,
onmessage: undefined,
})),
);
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({
StdioClientTransport: mockStdioClientTransport,
getDefaultEnvironment: vi.fn().mockReturnValue({ HOME: '/home/user', PATH: '/usr/bin' }),
}));
// Mock logger
vi.mock('@src/logger/logger.js', () => ({
default: {
debug: vi.fn(),
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
},
debugIf: vi.fn(),
}));
describe('RestartableStdioTransport', () => {
const mockServerParams: StdioServerParameters = {
command: 'test-command',
args: ['arg1', 'arg2'],
env: { NODE_ENV: 'test' },
};
let transport: RestartableStdioTransport;
beforeEach(() => {
vi.clearAllMocks();
transport = new RestartableStdioTransport(mockServerParams, {
restartOnExit: true,
maxRestarts: 3,
restartDelay: 100,
});
});
afterEach(() => {
// Ensure transport is closed after each test
if (transport) {
transport.close();
}
});
describe('Constructor', () => {
it('should create RestartableStdioTransport instance', () => {
expect(transport).toBeInstanceOf(RestartableStdioTransport);
});
it('should initialize with zero restart count', () => {
const stats = transport.getRestartStats();
expect(stats.restartCount).toBe(0);
expect(stats.isRestarting).toBe(false);
});
});
describe('Transport Interface', () => {
it('should start transport successfully', async () => {
await expect(transport.start()).resolves.toBeUndefined();
const stats = transport.getRestartStats();
expect(stats.isRestarting).toBe(false);
});
it('should throw error when starting already started transport', async () => {
await transport.start();
await expect(transport.start()).rejects.toThrow('RestartableStdioTransport already started!');
});
it('should send messages through underlying transport', async () => {
await transport.start();
const message = { jsonrpc: '2.0' as const, id: 1, method: 'test' };
await expect(transport.send(message)).resolves.toBeUndefined();
});
it('should throw error when sending message on unstarted transport', async () => {
const message = { jsonrpc: '2.0' as const, id: 1, method: 'test' };
await expect(transport.send(message)).rejects.toThrow('Transport not started');
});
it('should close transport successfully', async () => {
await transport.start();
await expect(transport.close()).resolves.toBeUndefined();
});
});
describe('Restart Functionality', () => {
it('should not restart when restartOnExit is false', async () => {
const noRestartTransport = new RestartableStdioTransport(mockServerParams, {
restartOnExit: false,
});
const onCloseMock = vi.fn();
noRestartTransport.onclose = onCloseMock;
await noRestartTransport.start();
// Simulate transport close
const mockTransport = (noRestartTransport as any)._currentTransport;
mockTransport.onclose();
expect(onCloseMock).toHaveBeenCalled();
await noRestartTransport.close();
});
it('should attempt restart on unexpected close', async () => {
await transport.start();
const initialStats = transport.getRestartStats();
expect(initialStats.restartCount).toBe(0);
// Simulate unexpected transport close
const mockTransport = (transport as any)._currentTransport;
mockTransport.onclose();
// Wait for restart attempt
await new Promise((resolve) => setTimeout(resolve, 150));
const postRestartStats = transport.getRestartStats();
expect(postRestartStats.restartCount).toBe(1);
});
it('should respect max restart limit', async () => {
const onErrorMock = vi.fn();
transport.onerror = onErrorMock;
await transport.start();
// Trigger multiple restarts
const mockTransport = (transport as any)._currentTransport;
for (let i = 0; i < 4; i++) {
mockTransport.onclose();
await new Promise((resolve) => setTimeout(resolve, 150));
}
// Should have hit max restart limit
expect(onErrorMock).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringContaining('Transport failed after 3 restart attempts'),
}),
);
});
it('should not restart when closing intentionally', async () => {
await transport.start();
const restartSpy = vi.spyOn(transport as any, 'attemptRestart');
// Close intentionally
await transport.close();
// Simulate transport close event after intentional close
const mockTransport = (transport as any)._currentTransport;
if (mockTransport) {
mockTransport.onclose?.();
}
expect(restartSpy).not.toHaveBeenCalled();
});
});
describe('Event Forwarding', () => {
it('should forward error events from underlying transport', async () => {
const onErrorMock = vi.fn();
transport.onerror = onErrorMock;
await transport.start();
const testError = new Error('Test transport error');
const mockTransport = (transport as any)._currentTransport;
mockTransport.onerror(testError);
expect(onErrorMock).toHaveBeenCalledWith(testError);
});
it('should forward message events from underlying transport', async () => {
const onMessageMock = vi.fn();
transport.onmessage = onMessageMock;
await transport.start();
const testMessage = { jsonrpc: '2.0' as const, id: 1, result: 'test' };
const mockTransport = (transport as any)._currentTransport;
mockTransport.onmessage(testMessage);
expect(onMessageMock).toHaveBeenCalledWith(testMessage);
});
});
describe('Properties', () => {
it('should return stderr from underlying transport', async () => {
await transport.start();
const stderr = transport.stderr;
expect(stderr).toBeNull(); // Mock returns null
});
it('should return pid from underlying transport', async () => {
await transport.start();
const pid = transport.pid;
expect(pid).toBe(1234); // Mock returns 1234
});
it('should return null for stderr and pid when not started', () => {
expect(transport.stderr).toBeNull();
expect(transport.pid).toBeNull();
});
it('should support setting timeout and tags properties', () => {
transport.timeout = 5000;
transport.tags = ['test', 'mcp'];
expect(transport.timeout).toBe(5000);
expect(transport.tags).toEqual(['test', 'mcp']);
});
});
describe('Edge Cases', () => {
it('should handle restart failure gracefully', async () => {
const onErrorMock = vi.fn();
transport.onerror = onErrorMock;
await transport.start();
// Set up the mock to fail on the next instantiation (restart)
mockStdioClientTransport.mockImplementationOnce(() => ({
start: vi.fn().mockRejectedValue(new Error('Start failed')),
close: vi.fn().mockResolvedValue(undefined),
send: vi.fn().mockResolvedValue(undefined),
stderr: null,
pid: 1234,
onclose: undefined,
onerror: undefined,
onmessage: undefined,
}));
// Trigger restart
const mockTransport = (transport as any)._currentTransport;
mockTransport.onclose();
// Wait for restart attempt
await new Promise((resolve) => setTimeout(resolve, 150));
expect(onErrorMock).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringContaining('Start failed'),
}),
);
});
it('should clear restart timer when closed', async () => {
await transport.start();
// Trigger restart
const mockTransport = (transport as any)._currentTransport;
mockTransport.onclose();
// Close immediately before restart can complete
await transport.close();
// Verify timer was cleared (no way to directly test, but should not crash)
expect(() => transport.close()).not.toThrow();
});
});
});