mcp-oauth-auth.e2e.spec.ts•20 kB
import { INestApplication, Injectable } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import request from 'supertest';
import { randomBytes, createHash } from 'crypto';
import { z } from 'zod';
import { Tool, Context, McpModule } from '../src';
import jwt from 'jsonwebtoken';
import { McpAuthModule } from '../src/authz/mcp-oauth.module';
import { McpAuthJwtGuard } from '../src/authz/guards/jwt-auth.guard';
import {
  OAuthProviderConfig,
  OAuthUserProfile,
} from '../src/authz/providers/oauth-provider.interface';
import {
  IOAuthStore,
  OAuthClient,
  AuthorizationCode,
  ClientRegistrationDto,
} from '../src/authz/stores/oauth-store.interface';
import { OAuthSession } from '../src/authz/providers/oauth-provider.interface';
import { createSseClient } from './utils';
// Mock OAuth Provider for testing
const MockOAuthProvider: OAuthProviderConfig = {
  name: 'mock',
  displayName: 'Mock Provider',
  strategy: class MockStrategy {
    _verify: any;
    name: string = 'mock';
    constructor(options: any, verify: any) {
      this._verify = verify;
    }
    authenticate(req: any, options?: any) {
      // Simulate immediate redirect to OAuth provider
      // In a real test, this would redirect to the provider's OAuth page
      // For our test, we'll just simulate the redirect
      this.redirect(
        `https://mock-oauth-provider.com/authorize?client_id=test&redirect_uri=${encodeURIComponent('http://localhost:3000/auth/callback')}`,
      );
    }
    redirect(url: string) {
      // This would be called by Passport to redirect the user
      // For testing, we simulate this behavior
      throw { redirect: url };
    }
  },
  strategyOptions: (options) => ({
    clientID: options.clientId,
    clientSecret: options.clientSecret,
    callbackURL: `${options.serverUrl}/auth/callback`,
  }),
  profileMapper: (profile: any): OAuthUserProfile => ({
    id: profile.id,
    username: profile.username,
    email: profile.emails?.[0]?.value,
    displayName: profile.displayName,
  }),
};
// Mock store for testing
@Injectable()
class MockOAuthStore implements IOAuthStore {
  private clients = new Map<string, OAuthClient>();
  private authCodes = new Map<string, AuthorizationCode>();
  private oauthSessions = new Map<string, OAuthSession>();
  private profilesById = new Map<
    string,
    OAuthUserProfile & { profile_id: string; provider: string }
  >();
  private providerUserKeyToId = new Map<string, string>();
  async storeClient(client: OAuthClient): Promise<OAuthClient> {
    this.clients.set(client.client_id, client);
    return client;
  }
  async getClient(client_id: string): Promise<OAuthClient | undefined> {
    return this.clients.get(client_id);
  }
  async findClient(client_name: string): Promise<OAuthClient | undefined> {
    for (const client of this.clients.values()) {
      if (client.client_name === client_name) {
        return client;
      }
    }
    return undefined;
  }
  generateClientId(client: OAuthClient): string {
    const normalizedName = client.client_name
      .toLowerCase()
      .replace(/[^a-z0-9]/g, '');
    const timestamp = Date.now().toString(36);
    return `${normalizedName}_${timestamp}`;
  }
  async storeAuthCode(code: AuthorizationCode): Promise<void> {
    this.authCodes.set(code.code, code);
  }
  async getAuthCode(code: string): Promise<AuthorizationCode | undefined> {
    return this.authCodes.get(code);
  }
  async removeAuthCode(code: string): Promise<void> {
    this.authCodes.delete(code);
  }
  async storeOAuthSession(
    sessionId: string,
    session: OAuthSession,
  ): Promise<void> {
    this.oauthSessions.set(sessionId, session);
  }
  async getOAuthSession(sessionId: string): Promise<OAuthSession | undefined> {
    return this.oauthSessions.get(sessionId);
  }
  async removeOAuthSession(sessionId: string): Promise<void> {
    this.oauthSessions.delete(sessionId);
  }
  async upsertUserProfile(
    profile: OAuthUserProfile,
    provider: string,
  ): Promise<string> {
    const key = `${provider}:${profile.id}`;
    let profileId = this.providerUserKeyToId.get(key);
    if (!profileId) {
      profileId = `${provider}_${profile.id}`;
      this.providerUserKeyToId.set(key, profileId);
    }
    this.profilesById.set(profileId, {
      ...profile,
      profile_id: profileId,
      provider,
    });
    return profileId;
  }
  async getUserProfileById(
    profileId: string,
  ): Promise<
    (OAuthUserProfile & { profile_id: string; provider: string }) | undefined
  > {
    return this.profilesById.get(profileId);
  }
}
// Test tool for protected endpoints
@Injectable()
export class TestProtectedTool {
  @Tool({
    name: 'protected-hello',
    description: 'A protected tool that requires authentication',
    parameters: z.object({
      message: z.string().default('Hello'),
    }),
  })
  async protectedHello({ message }, context: Context, request: any) {
    return {
      content: [
        {
          type: 'text',
          text: `${message} from authenticated user: ${request.user?.sub}`,
        },
      ],
    };
  }
}
describe('E2E: McpAuthModule OAuth Flow', () => {
  let app: INestApplication;
  let testPort: number;
  let mockStore: MockOAuthStore;
  const testJwtSecret = 'test-jwt-secret-that-is-at-least-32-characters-long';
  const testServerUrl = 'http://localhost:3000';
  const testClientId = 'test-client-id';
  const testClientSecret = 'test-client-secret';
  const normalizeJwtPayload = (payload: any, kind: 'access' | 'refresh') => {
    const clone: any = { ...payload };
    delete clone.iat;
    delete clone.exp;
    delete clone.nbf;
    delete clone.jti;
    if (kind === 'refresh') {
      // Align client binding and token type to access token semantics for comparison
      clone.azp = clone.client_id;
      delete clone.client_id;
      clone.type = 'access';
    }
    return clone;
  };
  beforeAll(async () => {
    mockStore = new MockOAuthStore();
    const moduleFixture: TestingModule = await Test.createTestingModule({
      imports: [
        McpAuthModule.forRoot({
          provider: MockOAuthProvider,
          clientId: testClientId,
          clientSecret: testClientSecret,
          jwtSecret: testJwtSecret,
          serverUrl: testServerUrl,
          apiPrefix: 'auth',
          cookieSecure: false, // For testing
          storeConfiguration: {
            type: 'custom',
            store: mockStore,
          },
        }),
        McpModule.forRoot({
          name: 'test-oauth-mcp-server',
          version: '0.0.1',
          guards: [McpAuthJwtGuard],
        }),
      ],
      providers: [TestProtectedTool],
    }).compile();
    app = moduleFixture.createNestApplication();
    await app.listen(0);
    const server = app.getHttpServer();
    testPort = server.address().port;
  });
  afterAll(async () => {
    await app.close();
  });
  describe('OAuth Well-Known Endpoint', () => {
    it('should return authorization server metadata', async () => {
      const response = await request(app.getHttpServer())
        .get('/.well-known/oauth-authorization-server')
        .expect(200);
      expect(response.body).toMatchObject({
        issuer: testServerUrl,
        authorization_endpoint: expect.stringContaining('/auth/authorize'),
        token_endpoint: expect.stringContaining('/auth/token'),
        registration_endpoint: expect.stringContaining('/auth/register'),
        response_types_supported: ['code'],
        grant_types_supported: ['authorization_code', 'refresh_token'],
        code_challenge_methods_supported: ['plain', 'S256'],
      });
    });
  });
  describe('Client Registration', () => {
    it('should register a new OAuth client', async () => {
      const clientData: ClientRegistrationDto = {
        client_name: 'Test Client',
        client_description: 'A test OAuth client',
        redirect_uris: ['http://localhost:8080/callback'],
        grant_types: ['authorization_code'],
        response_types: ['code'],
      };
      const response = await request(app.getHttpServer())
        .post('/auth/register')
        .send(clientData)
        .expect(201);
      expect(response.body).toMatchObject({
        client_id: expect.any(String),
        client_name: 'Test Client',
        client_description: 'A test OAuth client',
        redirect_uris: ['http://localhost:8080/callback'],
        grant_types: ['authorization_code'],
        response_types: ['code'],
      });
      // Verify client was stored
      const storedClient = await mockStore.getClient(response.body.client_id);
      expect(storedClient).toBeDefined();
      expect(storedClient!.client_name).toBe('Test Client');
    });
  });
  describe('Authorization Flow', () => {
    let registeredClient: OAuthClient;
    beforeEach(async () => {
      // Register a test client
      const clientData: ClientRegistrationDto = {
        client_name: 'Test Flow Client',
        redirect_uris: ['http://localhost:8080/callback'],
        grant_types: ['authorization_code'],
        response_types: ['code'],
      };
      const response = await request(app.getHttpServer())
        .post('/auth/register')
        .send(clientData);
      registeredClient = response.body;
    });
    it('should initiate authorization flow with valid parameters', async () => {
      const codeVerifier = randomBytes(32).toString('base64url');
      const codeChallenge = createHash('sha256')
        .update(codeVerifier)
        .digest('base64url');
      const authUrl = `/auth/authorize?response_type=code&client_id=${registeredClient.client_id}&redirect_uri=${encodeURIComponent(registeredClient.redirect_uris[0])}&code_challenge=${codeChallenge}&code_challenge_method=S256&state=test-state`;
      // The mock strategy will throw an error with redirect info, which results in a 500
      // In a real scenario, this would be a 302 redirect to the OAuth provider
      const response = await request(app.getHttpServer()).get(authUrl);
      // For our mock, we expect either a 302 redirect or a 500 (due to mock limitations)
      // What's important is that the session was created before the redirect attempt
      expect([302, 500]).toContain(response.status);
    });
    it('should reject authorization request with invalid client_id', async () => {
      const authUrl = `/auth/authorize?response_type=code&client_id=invalid-client&redirect_uri=http://localhost:8080/callback`;
      await request(app.getHttpServer()).get(authUrl).expect(400);
    });
    it('should reject authorization request with invalid redirect_uri', async () => {
      const authUrl = `/auth/authorize?response_type=code&client_id=${registeredClient.client_id}&redirect_uri=http://evil.com/callback`;
      await request(app.getHttpServer()).get(authUrl).expect(400);
    });
  });
  describe('Token Exchange', () => {
    let registeredClient: OAuthClient;
    let authCode: string;
    let codeVerifier: string;
    beforeEach(async () => {
      // Register a test client
      const clientData: ClientRegistrationDto = {
        client_name: 'Token Test Client',
        redirect_uris: ['http://localhost:8080/callback'],
        grant_types: ['authorization_code'],
        response_types: ['code'],
      };
      const clientResponse = await request(app.getHttpServer())
        .post('/auth/register')
        .send(clientData);
      registeredClient = clientResponse.body;
      // Create a test authorization code
      codeVerifier = randomBytes(32).toString('base64url');
      const codeChallenge = createHash('sha256')
        .update(codeVerifier)
        .digest('base64url');
      authCode = randomBytes(32).toString('base64url');
      await mockStore.storeAuthCode({
        code: authCode,
        user_id: 'testuser',
        client_id: registeredClient.client_id,
        redirect_uri: registeredClient.redirect_uris[0],
        code_challenge: codeChallenge,
        code_challenge_method: 'S256',
        expires_at: Date.now() + 600000, // 10 minutes
        resource: `${testServerUrl}/mcp`,
        scope: '',
      });
    });
    it('should exchange authorization code for tokens', async () => {
      const tokenRequest = {
        grant_type: 'authorization_code',
        code: authCode,
        code_verifier: codeVerifier,
        redirect_uri: registeredClient.redirect_uris[0],
        client_id: registeredClient.client_id,
      };
      const response = await request(app.getHttpServer())
        .post('/auth/token')
        .send(tokenRequest)
        .expect(200);
      expect(response.body).toMatchObject({
        access_token: expect.any(String),
        refresh_token: expect.any(String),
        token_type: 'bearer',
        expires_in: expect.any(Number),
      });
      // Verify authorization code was removed
      const removedCode = await mockStore.getAuthCode(authCode);
      expect(removedCode).toBeUndefined();
    });
    it('should reject token exchange with invalid authorization code', async () => {
      const tokenRequest = {
        grant_type: 'authorization_code',
        code: 'invalid-code',
        code_verifier: codeVerifier,
        redirect_uri: registeredClient.redirect_uris[0],
        client_id: registeredClient.client_id,
      };
      await request(app.getHttpServer())
        .post('/auth/token')
        .send(tokenRequest)
        .expect(400);
    });
    it('should reject token exchange with invalid PKCE verifier', async () => {
      const tokenRequest = {
        grant_type: 'authorization_code',
        code: authCode,
        code_verifier: 'invalid-verifier',
        redirect_uri: registeredClient.redirect_uris[0],
        client_id: registeredClient.client_id,
      };
      await request(app.getHttpServer())
        .post('/auth/token')
        .send(tokenRequest)
        .expect(400);
    });
  });
  describe('JWT Guard Protection', () => {
    let validAccessToken: string;
    beforeEach(async () => {
      // Get a valid token for testing
      const clientData: ClientRegistrationDto = {
        client_name: 'Guard Test Client',
        redirect_uris: ['http://localhost:8080/callback'],
        grant_types: ['authorization_code'],
        response_types: ['code'],
      };
      const clientResponse = await request(app.getHttpServer())
        .post('/auth/register')
        .send(clientData);
      const registeredClient = clientResponse.body;
      const codeVerifier = randomBytes(32).toString('base64url');
      const codeChallenge = createHash('sha256')
        .update(codeVerifier)
        .digest('base64url');
      const authCode = randomBytes(32).toString('base64url');
      await mockStore.storeAuthCode({
        code: authCode,
        user_id: 'testuser',
        client_id: registeredClient.client_id,
        redirect_uri: registeredClient.redirect_uris[0],
        code_challenge: codeChallenge,
        code_challenge_method: 'S256',
        expires_at: Date.now() + 600000,
        resource: `${testServerUrl}/mcp`,
        scope: '',
      });
      const tokenResponse = await request(app.getHttpServer())
        .post('/auth/token')
        .send({
          grant_type: 'authorization_code',
          code: authCode,
          code_verifier: codeVerifier,
          redirect_uri: registeredClient.redirect_uris[0],
          client_id: registeredClient.client_id,
        });
      validAccessToken = tokenResponse.body.access_token;
    });
    it('should allow access to protected MCP endpoints with valid token', async () => {
      const client = await createSseClient(testPort, {
        requestInit: {
          headers: {
            Authorization: `Bearer ${validAccessToken}`,
          },
        },
      });
      const tools = await client.listTools();
      expect(tools.tools).toHaveLength(1);
      expect(tools.tools[0].name).toBe('protected-hello');
      const result: any = await client.callTool({
        name: 'protected-hello',
        arguments: { message: 'Hello' },
      });
      expect(result.content[0].text).toContain(
        'Hello from authenticated user: testuser',
      );
      await client.close();
    });
    it('should reject access to protected MCP endpoints without token', async () => {
      await expect(
        createSseClient(testPort, {
          requestInit: {
            headers: {},
          },
        }),
      ).rejects.toThrow();
    });
    it('should reject access to protected MCP endpoints with invalid token', async () => {
      await expect(
        createSseClient(testPort, {
          requestInit: {
            headers: {
              Authorization: 'Bearer invalid-token',
            },
          },
        }),
      ).rejects.toThrow();
    });
  });
  describe('Refresh Token Flow', () => {
    let refreshToken: string;
    let initialAccessToken: string;
    beforeEach(async () => {
      // Get tokens for testing
      const clientData: ClientRegistrationDto = {
        client_name: 'Refresh Test Client',
        redirect_uris: ['http://localhost:8080/callback'],
        grant_types: ['authorization_code', 'refresh_token'],
        response_types: ['code'],
      };
      const clientResponse = await request(app.getHttpServer())
        .post('/auth/register')
        .send(clientData);
      const registeredClient = clientResponse.body;
      const codeVerifier = randomBytes(32).toString('base64url');
      const codeChallenge = createHash('sha256')
        .update(codeVerifier)
        .digest('base64url');
      const authCode = randomBytes(32).toString('base64url');
      await mockStore.storeAuthCode({
        code: authCode,
        user_id: 'testuser',
        client_id: registeredClient.client_id,
        redirect_uri: registeredClient.redirect_uris[0],
        code_challenge: codeChallenge,
        code_challenge_method: 'S256',
        expires_at: Date.now() + 600000,
        resource: `${testServerUrl}/mcp`,
        scope: '',
      });
      const tokenResponse = await request(app.getHttpServer())
        .post('/auth/token')
        .send({
          grant_type: 'authorization_code',
          code: authCode,
          code_verifier: codeVerifier,
          redirect_uri: registeredClient.redirect_uris[0],
          client_id: registeredClient.client_id,
        });
      initialAccessToken = tokenResponse.body.access_token;
      refreshToken = tokenResponse.body.refresh_token;
    });
    it('should refresh access token with valid refresh token', async () => {
      const response = await request(app.getHttpServer())
        .post('/auth/token')
        .send({
          grant_type: 'refresh_token',
          refresh_token: refreshToken,
        })
        .expect(200);
      expect(response.body).toMatchObject({
        access_token: expect.any(String),
        refresh_token: expect.any(String),
        token_type: 'bearer',
        expires_in: expect.any(Number),
      });
      // Compare claims between initial and refreshed access tokens after normalizing
      const initialAccessPayload: any = jwt.verify(
        initialAccessToken,
        testJwtSecret,
      );
      const refreshedAccessPayload: any = jwt.verify(
        response.body.access_token,
        testJwtSecret,
      );
      expect(initialAccessPayload.type).toBe('access');
      expect(refreshedAccessPayload.type).toBe('access');
      const normalizedInitial = normalizeJwtPayload(
        initialAccessPayload,
        'access',
      );
      const normalizedRefreshed = normalizeJwtPayload(
        refreshedAccessPayload,
        'access',
      );
      expect(normalizedRefreshed).toEqual(normalizedInitial);
    });
    it('should reject refresh with invalid refresh token', async () => {
      await request(app.getHttpServer())
        .post('/auth/token')
        .send({
          grant_type: 'refresh_token',
          refresh_token: 'invalid-refresh-token',
        })
        .expect(400);
    });
  });
});