"""Rate limiting middleware."""
import asyncio
import time
from collections import defaultdict, deque
from typing import Any, Callable, Dict, List
import mcp.types as types
from .base import BaseMiddleware
from ..utils.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""Token bucket rate limiter."""
def __init__(self, max_tokens: int, refill_rate: float, window_size: int = 60) -> None:
"""Initialize rate limiter.
Args:
max_tokens: Maximum number of tokens in bucket
refill_rate: Tokens per second refill rate
window_size: Time window in seconds for tracking
"""
self.max_tokens = max_tokens
self.refill_rate = refill_rate
self.window_size = window_size
self.tokens = max_tokens
self.last_refill = time.time()
self.requests = deque()
async def acquire(self) -> bool:
"""Acquire a token.
Returns:
True if token acquired, False if rate limited
"""
now = time.time()
# Refill tokens
time_passed = now - self.last_refill
self.tokens = min(self.max_tokens, self.tokens + time_passed * self.refill_rate)
self.last_refill = now
# Clean old requests
while self.requests and self.requests[0] < now - self.window_size:
self.requests.popleft()
# Check if we can proceed
if self.tokens >= 1:
self.tokens -= 1
self.requests.append(now)
return True
return False
def get_stats(self) -> Dict[str, Any]:
"""Get rate limiter statistics.
Returns:
Statistics dictionary
"""
now = time.time()
# Clean old requests for accurate count
while self.requests and self.requests[0] < now - self.window_size:
self.requests.popleft()
return {
"current_tokens": self.tokens,
"max_tokens": self.max_tokens,
"refill_rate": self.refill_rate,
"requests_in_window": len(self.requests),
"window_size": self.window_size,
}
class RateLimitMiddleware(BaseMiddleware):
"""Middleware for rate limiting requests."""
def __init__(
self,
tool_limit: int = 100,
resource_limit: int = 200,
prompt_limit: int = 50,
refill_rate: float = 1.0,
window_size: int = 60,
per_client: bool = False,
) -> None:
"""Initialize rate limiting middleware.
Args:
tool_limit: Maximum tool calls per window
resource_limit: Maximum resource reads per window
prompt_limit: Maximum prompt gets per window
refill_rate: Tokens per second refill rate
window_size: Time window in seconds
per_client: Whether to apply limits per client (requires client identification)
"""
super().__init__("rate_limit")
self.tool_limit = tool_limit
self.resource_limit = resource_limit
self.prompt_limit = prompt_limit
self.refill_rate = refill_rate
self.window_size = window_size
self.per_client = per_client
self._logger = get_logger(f"{__name__}.RateLimitMiddleware")
# Rate limiters
if per_client:
self.tool_limiters: Dict[str, RateLimiter] = defaultdict(
lambda: RateLimiter(tool_limit, refill_rate, window_size)
)
self.resource_limiters: Dict[str, RateLimiter] = defaultdict(
lambda: RateLimiter(resource_limit, refill_rate, window_size)
)
self.prompt_limiters: Dict[str, RateLimiter] = defaultdict(
lambda: RateLimiter(prompt_limit, refill_rate, window_size)
)
else:
self.tool_limiter = RateLimiter(tool_limit, refill_rate, window_size)
self.resource_limiter = RateLimiter(resource_limit, refill_rate, window_size)
self.prompt_limiter = RateLimiter(prompt_limit, refill_rate, window_size)
def _get_client_id(self) -> str:
"""Get client identifier.
Returns:
Client identifier (for now, returns 'default')
"""
# In a real implementation, this would extract client info
# from request context, headers, authentication, etc.
return "default"
def _get_limiter(self, limiter_type: str) -> RateLimiter:
"""Get appropriate rate limiter.
Args:
limiter_type: Type of limiter ('tool', 'resource', 'prompt')
Returns:
Rate limiter instance
"""
if self.per_client:
client_id = self._get_client_id()
if limiter_type == "tool":
return self.tool_limiters[client_id]
elif limiter_type == "resource":
return self.resource_limiters[client_id]
elif limiter_type == "prompt":
return self.prompt_limiters[client_id]
else:
if limiter_type == "tool":
return self.tool_limiter
elif limiter_type == "resource":
return self.resource_limiter
elif limiter_type == "prompt":
return self.prompt_limiter
raise ValueError(f"Unknown limiter type: {limiter_type}")
def get_stats(self) -> Dict[str, Any]:
"""Get rate limiting statistics.
Returns:
Statistics dictionary
"""
if self.per_client:
stats = {
"per_client": True,
"tool_clients": len(self.tool_limiters),
"resource_clients": len(self.resource_limiters),
"prompt_clients": len(self.prompt_limiters),
}
else:
stats = {
"per_client": False,
"tool_limiter": self.tool_limiter.get_stats(),
"resource_limiter": self.resource_limiter.get_stats(),
"prompt_limiter": self.prompt_limiter.get_stats(),
}
return stats
async def process_tool_call(
self, name: str, arguments: Dict[str, Any], next_handler: Callable[[str, Dict[str, Any]], Any]
) -> List[types.ContentBlock]:
"""Process tool call with rate limiting."""
limiter = self._get_limiter("tool")
if not await limiter.acquire():
self._logger.warning(f"Rate limit exceeded for tool call: {name}")
raise Exception(f"Rate limit exceeded for tool calls. Max: {self.tool_limit} per {self.window_size}s")
return await next_handler(name, arguments)
async def process_resource_read(self, uri: Any, next_handler: Callable[[Any], Any]) -> str:
"""Process resource read with rate limiting."""
limiter = self._get_limiter("resource")
if not await limiter.acquire():
self._logger.warning(f"Rate limit exceeded for resource read: {uri}")
raise Exception(
f"Rate limit exceeded for resource reads. Max: {self.resource_limit} per {self.window_size}s"
)
return await next_handler(uri)
async def process_prompt_get(
self, name: str, arguments: Dict[str, str] | None, next_handler: Callable[[str, Dict[str, str] | None], Any]
) -> types.GetPromptResult:
"""Process prompt get with rate limiting."""
limiter = self._get_limiter("prompt")
if not await limiter.acquire():
self._logger.warning(f"Rate limit exceeded for prompt get: {name}")
raise Exception(f"Rate limit exceeded for prompt gets. Max: {self.prompt_limit} per {self.window_size}s")
return await next_handler(name, arguments)