import { Request, Response, NextFunction } from 'express';
import { RateLimiterRedis, RateLimiterMemory } from 'rate-limiter-flexible';
import { createClient } from 'redis';
import { rateLimitConfig, redisConfig } from '../config';
export interface RateLimitRequest extends Request {
userId?: string;
clientIp?: string;
}
export class RateLimitManager {
private rateLimiter: RateLimiterRedis | RateLimiterMemory;
private toolCallLimiter: RateLimiterRedis | RateLimiterMemory;
constructor() {
this.initializeRateLimiters();
}
private async initializeRateLimiters() {
try {
// Try to use Redis if available
const redisClient = createClient({ url: redisConfig.url });
await redisClient.connect();
// General API rate limiter
this.rateLimiter = new RateLimiterRedis({
storeClient: redisClient,
keyPrefix: 'api_limit',
points: rateLimitConfig.maxRequests,
duration: rateLimitConfig.windowMs / 1000, // Convert to seconds
blockDuration: 60, // Block for 1 minute after limit exceeded
});
// Tool call specific rate limiter (more restrictive)
this.toolCallLimiter = new RateLimiterRedis({
storeClient: redisClient,
keyPrefix: 'tool_limit',
points: 50, // 50 tool calls per window
duration: rateLimitConfig.windowMs / 1000,
blockDuration: 300, // Block for 5 minutes
});
console.log('Rate limiting initialized with Redis');
} catch (error) {
console.warn('Redis unavailable for rate limiting, falling back to memory:', error);
// Fallback to memory-based rate limiting
this.rateLimiter = new RateLimiterMemory({
keyPrefix: 'api_limit',
points: rateLimitConfig.maxRequests,
duration: rateLimitConfig.windowMs / 1000,
blockDuration: 60,
});
this.toolCallLimiter = new RateLimiterMemory({
keyPrefix: 'tool_limit',
points: 50,
duration: rateLimitConfig.windowMs / 1000,
blockDuration: 300,
});
console.log('Rate limiting initialized with memory store');
}
}
// General API rate limiting middleware
apiRateLimit = async (req: RateLimitRequest, res: Response, next: NextFunction): Promise<void> => {
try {
const key = this.getRateLimitKey(req);
const resRateLimiter = await this.rateLimiter.consume(key);
// Add rate limit headers
res.set({
'X-RateLimit-Limit': rateLimitConfig.maxRequests.toString(),
'X-RateLimit-Remaining': resRateLimiter.remainingHits?.toString() || '0',
'X-RateLimit-Reset': new Date(Date.now() + resRateLimiter.msBeforeNext).toISOString(),
});
next();
} catch (rejRes: any) {
// Rate limit exceeded
const secs = Math.round(rejRes.msBeforeNext / 1000) || 1;
res.set({
'X-RateLimit-Limit': rateLimitConfig.maxRequests.toString(),
'X-RateLimit-Remaining': '0',
'X-RateLimit-Reset': new Date(Date.now() + rejRes.msBeforeNext).toISOString(),
'Retry-After': secs.toString(),
});
res.status(429).json({
error: 'Too Many Requests',
message: `Rate limit exceeded. Try again in ${secs} seconds.`,
retryAfter: secs
});
}
};
// Tool call specific rate limiting
toolCallRateLimit = async (req: RateLimitRequest, res: Response, next: NextFunction): Promise<void> => {
try {
const key = this.getRateLimitKey(req, 'tool');
const resRateLimiter = await this.toolCallLimiter.consume(key);
// Add tool-specific rate limit headers
res.set({
'X-Tool-RateLimit-Limit': '50',
'X-Tool-RateLimit-Remaining': resRateLimiter.remainingHits?.toString() || '0',
'X-Tool-RateLimit-Reset': new Date(Date.now() + resRateLimiter.msBeforeNext).toISOString(),
});
next();
} catch (rejRes: any) {
const secs = Math.round(rejRes.msBeforeNext / 1000) || 1;
res.set({
'X-Tool-RateLimit-Limit': '50',
'X-Tool-RateLimit-Remaining': '0',
'X-Tool-RateLimit-Reset': new Date(Date.now() + rejRes.msBeforeNext).toISOString(),
'Retry-After': secs.toString(),
});
res.status(429).json({
error: 'Tool Rate Limit Exceeded',
message: `Too many tool calls. Try again in ${secs} seconds.`,
retryAfter: secs
});
}
};
// Get rate limit key based on user or IP
private getRateLimitKey(req: RateLimitRequest, prefix: string = 'api'): string {
// Use user ID if authenticated, otherwise fall back to IP
const identifier = req.userId || req.clientIp || req.ip;
return `${prefix}:${identifier}`;
}
// Method to manually check if user is rate limited
async checkRateLimit(userId: string): Promise<{
limited: boolean;
remainingHits: number;
msBeforeNext: number;
}> {
try {
const res = await this.rateLimiter.get(`api:${userId}`);
if (!res) {
return {
limited: false,
remainingHits: rateLimitConfig.maxRequests,
msBeforeNext: 0
};
}
return {
limited: res.remainingHits <= 0,
remainingHits: res.remainingHits || 0,
msBeforeNext: res.msBeforeNext || 0
};
} catch (error) {
console.error('Rate limit check error:', error);
return {
limited: false,
remainingHits: rateLimitConfig.maxRequests,
msBeforeNext: 0
};
}
}
// Reset rate limit for a user (admin function)
async resetRateLimit(userId: string): Promise<void> {
try {
await this.rateLimiter.delete(`api:${userId}`);
await this.toolCallLimiter.delete(`tool:${userId}`);
} catch (error) {
console.error('Rate limit reset error:', error);
}
}
}
// IP extraction middleware
export const extractClientIp = (req: RateLimitRequest, res: Response, next: NextFunction): void => {
// Get client IP from various headers (handle proxies)
const forwarded = req.headers['x-forwarded-for'] as string;
const realIp = req.headers['x-real-ip'] as string;
const cfConnectingIp = req.headers['cf-connecting-ip'] as string; // Cloudflare
req.clientIp = cfConnectingIp || realIp || (forwarded && forwarded.split(',')[0]) || req.ip;
next();
};
// Middleware to skip rate limiting for successful requests (if configured)
export const skipSuccessfulRequests = (req: Request, res: Response, next: NextFunction): void => {
if (rateLimitConfig.skipSuccessfulRequests) {
// Store original end method
const originalEnd = res.end;
// Override end method to check status
res.end = function(this: Response, ...args: any[]) {
// If response is successful (2xx), don't count against rate limit
if (this.statusCode >= 200 && this.statusCode < 300) {
// Mark request as successful for rate limiter to skip
(req as any).skipRateLimit = true;
}
// Call original end method
return originalEnd.apply(this, args);
};
}
next();
};
// Create and export singleton instance
export const rateLimitManager = new RateLimitManager();