import { Inject, Injectable } from '@nestjs/common';
import { randomBytes } from 'crypto';
import * as jwt from 'jsonwebtoken';
import type { OAuthModuleOptions } from '../providers/oauth-provider.interface';
export interface JwtPayload {
sub: string; // user_id
azp?: string; // authorized party (client_id for access tokens)
client_id?: string; // only for refresh tokens
scope?: string;
resource?: string; // MCP server resource identifier
type: 'access' | 'refresh' | 'user';
user_data?: any;
user_profile_id?: string;
iat?: number;
exp?: number;
}
export interface TokenPair {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
scope?: string;
}
@Injectable()
export class JwtTokenService {
private jwtSecret: string;
private issuer: string;
private accessTokenExpiresIn: string;
private refreshTokenExpiresIn: string;
private enableRefreshTokens: boolean;
constructor(@Inject('OAUTH_MODULE_OPTIONS') options: OAuthModuleOptions) {
// Use JWT secret from environment variable
const jwtSecret = options.jwtSecret;
if (!jwtSecret) {
throw new Error('JWT_SECRET must be set in environment variables.');
}
this.jwtSecret = jwtSecret;
this.issuer =
options.jwtIssuer || options.serverUrl || 'https://localhost:3000';
this.accessTokenExpiresIn = options.jwtAccessTokenExpiresIn;
this.refreshTokenExpiresIn = options.jwtRefreshTokenExpiresIn;
this.enableRefreshTokens = options.enableRefreshTokens;
}
generateTokenPair(
userId: string,
clientId: string,
scope = '',
resource?: string,
extras?: { user_profile_id?: string; user_data?: any },
): TokenPair {
if (!resource) {
throw new Error('Resource is required for token generation');
}
const jti = randomBytes(16).toString('hex'); // JWT ID for tracking
const accessTokenPayload: any = {
sub: userId,
azp: clientId, // Use azp instead of client_id
iss: this.issuer,
aud: resource,
resource: resource, // Always include resource
type: 'access' as const,
};
if (extras?.user_profile_id) {
accessTokenPayload.user_profile_id = extras.user_profile_id;
}
if (extras?.user_data) {
accessTokenPayload.user_data = extras.user_data;
}
// Always include scope to ensure parity with refresh token claims
accessTokenPayload.scope = scope || '';
const accessToken = jwt.sign(accessTokenPayload, this.jwtSecret, {
algorithm: 'HS256',
expiresIn: this.accessTokenExpiresIn,
});
let refreshToken: string | undefined = undefined;
if (this.enableRefreshTokens) {
const refreshTokenPayload: any = {
sub: userId,
client_id: clientId,
scope,
resource,
type: 'refresh' as const,
jti: `refresh_${jti}`,
iss: this.issuer,
aud: resource,
};
if (extras?.user_profile_id) {
refreshTokenPayload.user_profile_id = extras.user_profile_id;
}
refreshToken = jwt.sign(refreshTokenPayload, this.jwtSecret, {
algorithm: 'HS256',
expiresIn: this.refreshTokenExpiresIn,
});
}
return {
access_token: accessToken,
token_type: 'bearer',
expires_in: this.parseDurationToSeconds(this.accessTokenExpiresIn),
...(refreshToken ? { refresh_token: refreshToken } : {}),
};
}
validateToken(token: string): JwtPayload | null {
try {
return jwt.verify(token, this.jwtSecret, {
algorithms: ['HS256'],
}) as JwtPayload;
} catch {
return null;
}
}
refreshAccessToken(refreshToken: string): TokenPair | null {
if (!this.enableRefreshTokens) {
return null;
}
const payload = this.validateToken(refreshToken);
if (!payload || payload.type !== 'refresh') {
return null;
}
return this.generateTokenPair(
payload.sub,
payload.client_id!,
payload.scope,
payload.resource,
{
user_profile_id: payload.user_profile_id,
user_data: payload.user_data,
},
);
}
generateUserToken(userId: string, userData: any): string {
const jti = randomBytes(16).toString('hex');
const serverUrl = process.env.SERVER_URL || 'https://localhost:3000';
const payload = {
sub: userId,
type: 'user',
user_data: userData,
jti: `user_${jti}`,
iss: serverUrl,
aud: 'mcp-client',
};
return jwt.sign(payload, this.jwtSecret, {
algorithm: 'HS256',
expiresIn: '24h',
});
}
private parseDurationToSeconds(duration: string): number {
const match = duration.match(/^(\d+)([smhd])$/);
if (!match) {
throw new Error(`Invalid duration format: ${duration}`);
}
const value = parseInt(match[1], 10);
const unit = match[2];
switch (unit) {
case 's':
return value;
case 'm':
return value * 60;
case 'h':
return value * 60 * 60;
case 'd':
return value * 60 * 60 * 24;
default:
throw new Error(`Unsupported duration unit: ${unit}`);
}
}
}