NestJS MCP Server Module
by rekog-labs
- MCP-Nest
- tests
import { INestApplication, Injectable } from "@nestjs/common";
import { Test, TestingModule } from "@nestjs/testing";
import { z } from "zod";
import { Tool } from "../src";
import { McpModule } from "../src/mcp.module";
import { Context } from "../src/services/mcp-tools.discovery";
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { Progress } from "@modelcontextprotocol/sdk/types.js";
// Mock UserRepository for testing
@Injectable()
class MockUserRepository {
async findOne() {
return {
id: "user123",
name: "Test User",
orgMemberships: [
{
orgId: "org123",
organization: {
name: "Test Org",
},
},
],
};
}
}
@Injectable()
export class GreetingTool {
constructor(private readonly userRepository: MockUserRepository) {}
@Tool({
name: "hello-world",
description: "A sample tool that returns a greeting",
parameters: z.object({
name: z.string().default("World"),
}),
})
async sayHello({ name }, context: Context) {
const user = await this.userRepository.findOne();
const greeting = `Hello, ${name}! I'm ${user.name} from ${
user.orgMemberships[0].organization.name
}.`;
for (let i = 0; i < 5; i++) {
await new Promise((resolve) => setTimeout(resolve, 500));
await context.reportProgress({
progress: (i+1) * 20,
total: 100,
} as Progress);
}
return {
content: [
{
type: "text",
text: greeting,
},
],
};
}
}
describe("E2E: MCP Server via SSE", () => {
let app: INestApplication;
let testPort: number;
beforeAll(async () => {
const moduleFixture: TestingModule = await Test.createTestingModule({
imports: [
McpModule.forRoot({
name: "test-mcp-server",
version: "0.0.1",
capabilities: {
tools: {
"hello-world": {
description: "A sample tool that returns a greeting",
input: {
name: {
type: "string",
default: "World",
},
},
},
"progress-test": {
description: "A tool that simulates progress",
},
},
},
}),
],
providers: [GreetingTool, MockUserRepository],
}).compile();
app = moduleFixture.createNestApplication();
await app.listen(0);
const server = app.getHttpServer();
testPort = server.address().port;
});
afterAll(async () => {
await app.close();
});
it("should list tools", async () => {
const client = new Client(
{ name: "example-client", version: "1.0.0" },
{ capabilities: {} },
);
const sseUrl = new URL(`http://localhost:${testPort}/sse`);
const transport = new SSEClientTransport(sseUrl);
await client.connect(transport);
const tools = await client.listTools();
expect(tools.tools.length).toBeGreaterThan(0);
await client.close();
});
it('should inject dependencies into the tool and call the "hello-world" tool via SSE', async () => {
const client = new Client(
{ name: "example-client", version: "1.0.0" },
{ capabilities: {} },
);
const sseUrl = new URL(`http://localhost:${testPort}/sse`);
const transport = new SSEClientTransport(sseUrl);
await client.connect(transport);
const tools = await client.listTools();
expect(tools.tools.find((t) => t.name == "hello-world")).toBeDefined();
let countProgress=1;
const result = await client.callTool(
{
name: "hello-world",
arguments: { name: "World" },
},
undefined,
{
onprogress: (progress) => {
console.log(progress);
countProgress++;
},
},
);
expect(countProgress).toBe(5)
expect(result).toEqual({
content: [
{
type: "text",
text: "Hello, World! I'm Test User from Test Org.",
},
],
});
await client.close();
});
});