rate_limiter.py•3.47 kB
"""Rate limiting implementation using token bucket algorithm."""
import time
from typing import Dict
from src.utils.logger import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""
Token bucket rate limiter for per-tool rate limiting.
"""
def __init__(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests per minute
requests_per_hour: Maximum requests per hour
"""
self.requests_per_minute = requests_per_minute
self.requests_per_hour = requests_per_hour
# Store buckets per tool: {tool_name: {'minute': {...}, 'hour': {...}}}
self.buckets: Dict[str, Dict[str, Dict[str, float]]] = {}
async def check_limit(self, tool_name: str) -> bool:
"""
Check if request is allowed under rate limits.
Args:
tool_name: Name of the tool being called
Returns:
True if request is allowed, False if rate limit exceeded
"""
current_time = time.time()
# Initialize buckets for tool if not exists
if tool_name not in self.buckets:
self.buckets[tool_name] = {
"minute": {"tokens": float(self.requests_per_minute), "last_update": current_time},
"hour": {"tokens": float(self.requests_per_hour), "last_update": current_time},
}
# Check minute bucket
minute_bucket = self.buckets[tool_name]["minute"]
time_passed = current_time - minute_bucket["last_update"]
minute_bucket["tokens"] = min(
self.requests_per_minute,
minute_bucket["tokens"] + time_passed * (self.requests_per_minute / 60.0),
)
minute_bucket["last_update"] = current_time
# Check hour bucket
hour_bucket = self.buckets[tool_name]["hour"]
time_passed = current_time - hour_bucket["last_update"]
hour_bucket["tokens"] = min(
self.requests_per_hour,
hour_bucket["tokens"] + time_passed * (self.requests_per_hour / 3600.0),
)
hour_bucket["last_update"] = current_time
# Check if both buckets have tokens
if minute_bucket["tokens"] >= 1.0 and hour_bucket["tokens"] >= 1.0:
minute_bucket["tokens"] -= 1.0
hour_bucket["tokens"] -= 1.0
logger.debug(
"rate_limit_check_passed",
tool_name=tool_name,
minute_tokens=minute_bucket["tokens"],
hour_tokens=hour_bucket["tokens"],
)
return True
logger.warning(
"rate_limit_exceeded",
tool_name=tool_name,
minute_tokens=minute_bucket["tokens"],
hour_tokens=hour_bucket["tokens"],
)
return False
def get_retry_after(self, tool_name: str) -> int:
"""
Get seconds to wait before retrying.
Args:
tool_name: Name of the tool
Returns:
Seconds to wait
"""
if tool_name not in self.buckets:
return 0
minute_bucket = self.buckets[tool_name]["minute"]
if minute_bucket["tokens"] < 1.0:
# Calculate time needed to refill 1 token
return int((1.0 - minute_bucket["tokens"]) / (self.requests_per_minute / 60.0)) + 1
return 60 # Default to 1 minute