test_rate_limiter.py•12 kB
"""Unit tests for rate limiter."""
import asyncio
import pytest
import time
from src.utils.rate_limiter import RateLimiter
class TestRateLimiter:
"""Test suite for rate limiter."""
def test_initialization(self):
"""Test rate limiter initialization."""
limiter = RateLimiter(requests_per_minute=60, requests_per_hour=1000)
assert limiter.requests_per_minute == 60
assert limiter.requests_per_hour == 1000
assert limiter.buckets == {}
@pytest.mark.asyncio
async def test_first_request_allowed(self, rate_limiter):
"""Test that first request is always allowed."""
result = await rate_limiter.check_limit("test_tool")
assert result is True
assert "test_tool" in rate_limiter.buckets
@pytest.mark.asyncio
async def test_bucket_initialization(self, rate_limiter):
"""Test that buckets are initialized correctly."""
await rate_limiter.check_limit("test_tool")
bucket = rate_limiter.buckets["test_tool"]
assert "minute" in bucket
assert "hour" in bucket
assert "tokens" in bucket["minute"]
assert "last_update" in bucket["minute"]
assert "tokens" in bucket["hour"]
assert "last_update" in bucket["hour"]
@pytest.mark.asyncio
async def test_multiple_requests_within_limit(self, rate_limiter):
"""Test multiple requests within rate limit."""
tool_name = "test_tool"
# Make 10 requests
for _ in range(10):
result = await rate_limiter.check_limit(tool_name)
assert result is True
@pytest.mark.asyncio
async def test_minute_rate_limit_exceeded(self):
"""Test that minute rate limit is enforced."""
limiter = RateLimiter(requests_per_minute=5, requests_per_hour=1000)
tool_name = "test_tool"
# Make requests up to the limit
for _ in range(5):
result = await limiter.check_limit(tool_name)
assert result is True
# Next request should be denied
result = await limiter.check_limit(tool_name)
assert result is False
@pytest.mark.asyncio
async def test_hour_rate_limit_exceeded(self):
"""Test that hour rate limit is enforced."""
limiter = RateLimiter(requests_per_minute=1000, requests_per_hour=10)
tool_name = "test_tool"
# Make requests up to the limit
for _ in range(10):
result = await limiter.check_limit(tool_name)
assert result is True
# Next request should be denied
result = await limiter.check_limit(tool_name)
assert result is False
@pytest.mark.asyncio
async def test_token_refill_over_time(self):
"""Test that tokens refill over time."""
limiter = RateLimiter(requests_per_minute=60, requests_per_hour=1000)
tool_name = "test_tool"
# Use up some tokens
for _ in range(5):
await limiter.check_limit(tool_name)
# Wait for tokens to refill (simulate time passing)
bucket = limiter.buckets[tool_name]["minute"]
bucket["last_update"] = time.time() - 10 # Simulate 10 seconds ago
# Should be able to make more requests
result = await limiter.check_limit(tool_name)
assert result is True
@pytest.mark.asyncio
async def test_different_tools_independent_limits(self, rate_limiter):
"""Test that different tools have independent rate limits."""
tool1 = "tool1"
tool2 = "tool2"
# Make requests for tool1
for _ in range(5):
await rate_limiter.check_limit(tool1)
# tool2 should still have full capacity
result = await rate_limiter.check_limit(tool2)
assert result is True
# Verify separate buckets
assert tool1 in rate_limiter.buckets
assert tool2 in rate_limiter.buckets
assert rate_limiter.buckets[tool1] != rate_limiter.buckets[tool2]
@pytest.mark.asyncio
async def test_token_count_decreases(self):
"""Test that token count decreases with each request."""
limiter = RateLimiter(requests_per_minute=60, requests_per_hour=1000)
tool_name = "test_tool"
# First request
await limiter.check_limit(tool_name)
tokens_after_first = limiter.buckets[tool_name]["minute"]["tokens"]
# Second request
await limiter.check_limit(tool_name)
tokens_after_second = limiter.buckets[tool_name]["minute"]["tokens"]
# Tokens should decrease
assert tokens_after_second < tokens_after_first
@pytest.mark.asyncio
async def test_both_buckets_checked(self):
"""Test that both minute and hour buckets are checked."""
limiter = RateLimiter(requests_per_minute=5, requests_per_hour=5)
tool_name = "test_tool"
# Use up all tokens
for _ in range(5):
result = await limiter.check_limit(tool_name)
assert result is True
# Both buckets should be depleted
minute_tokens = limiter.buckets[tool_name]["minute"]["tokens"]
hour_tokens = limiter.buckets[tool_name]["hour"]["tokens"]
assert minute_tokens < 1.0
assert hour_tokens < 1.0
# Next request should fail
result = await limiter.check_limit(tool_name)
assert result is False
def test_get_retry_after_no_bucket(self, rate_limiter):
"""Test retry_after for non-existent tool."""
retry_after = rate_limiter.get_retry_after("nonexistent_tool")
assert retry_after == 0
@pytest.mark.asyncio
async def test_get_retry_after_with_depleted_bucket(self):
"""Test retry_after calculation for depleted bucket."""
limiter = RateLimiter(requests_per_minute=5, requests_per_hour=1000)
tool_name = "test_tool"
# Deplete minute bucket
for _ in range(5):
await limiter.check_limit(tool_name)
retry_after = limiter.get_retry_after(tool_name)
# Should return a positive number of seconds
assert retry_after > 0
assert retry_after <= 60 # Should be within a minute
@pytest.mark.asyncio
async def test_get_retry_after_with_available_tokens(self, rate_limiter):
"""Test retry_after when tokens are available."""
tool_name = "test_tool"
# Make one request
await rate_limiter.check_limit(tool_name)
retry_after = rate_limiter.get_retry_after(tool_name)
# Should return default value when tokens available
assert retry_after == 60
@pytest.mark.asyncio
async def test_concurrent_requests_same_tool(self):
"""Test concurrent requests for the same tool."""
limiter = RateLimiter(requests_per_minute=10, requests_per_hour=100)
tool_name = "test_tool"
# Make concurrent requests
results = await asyncio.gather(*[
limiter.check_limit(tool_name) for _ in range(5)
])
# All should succeed (within limit)
assert all(results)
@pytest.mark.asyncio
async def test_token_cap_at_maximum(self):
"""Test that tokens don't exceed maximum capacity."""
limiter = RateLimiter(requests_per_minute=60, requests_per_hour=1000)
tool_name = "test_tool"
# Make a request to initialize bucket
await limiter.check_limit(tool_name)
# Simulate a lot of time passing
bucket = limiter.buckets[tool_name]["minute"]
bucket["last_update"] = time.time() - 3600 # 1 hour ago
# Check limit again
await limiter.check_limit(tool_name)
# Tokens should be capped at maximum
tokens = limiter.buckets[tool_name]["minute"]["tokens"]
assert tokens <= limiter.requests_per_minute
@pytest.mark.asyncio
async def test_fractional_token_consumption(self):
"""Test that fractional tokens are handled correctly."""
limiter = RateLimiter(requests_per_minute=60, requests_per_hour=1000)
tool_name = "test_tool"
# Make requests
for _ in range(3):
await limiter.check_limit(tool_name)
# Check token counts are decremented by 1.0 each time
minute_tokens = limiter.buckets[tool_name]["minute"]["tokens"]
hour_tokens = limiter.buckets[tool_name]["hour"]["tokens"]
# Should have consumed 3 tokens from each bucket
assert minute_tokens < 60
assert hour_tokens < 1000
@pytest.mark.asyncio
async def test_rapid_successive_requests(self):
"""Test rapid successive requests."""
limiter = RateLimiter(requests_per_minute=100, requests_per_hour=1000)
tool_name = "test_tool"
# Make 50 rapid requests
results = []
for _ in range(50):
result = await limiter.check_limit(tool_name)
results.append(result)
# All should succeed (within limit)
assert all(results)
@pytest.mark.asyncio
async def test_edge_case_exactly_at_limit(self):
"""Test behavior when exactly at the rate limit."""
limiter = RateLimiter(requests_per_minute=5, requests_per_hour=100)
tool_name = "test_tool"
# Make exactly 5 requests
for i in range(5):
result = await limiter.check_limit(tool_name)
assert result is True, f"Request {i+1} should succeed"
# 6th request should fail
result = await limiter.check_limit(tool_name)
assert result is False
@pytest.mark.asyncio
async def test_multiple_tools_concurrent(self):
"""Test multiple tools making concurrent requests."""
limiter = RateLimiter(requests_per_minute=10, requests_per_hour=100)
# Make concurrent requests for different tools
results = await asyncio.gather(*[
limiter.check_limit(f"tool_{i}") for i in range(5)
])
# All should succeed
assert all(results)
# Verify separate buckets created
assert len(limiter.buckets) == 5
@pytest.mark.asyncio
async def test_token_refill_rate_minute(self):
"""Test that minute bucket refills at correct rate."""
limiter = RateLimiter(requests_per_minute=60, requests_per_hour=1000)
tool_name = "test_tool"
# Initialize bucket
await limiter.check_limit(tool_name)
# Manually adjust last_update to simulate 1 second passing
bucket = limiter.buckets[tool_name]["minute"]
initial_tokens = bucket["tokens"]
bucket["last_update"] = time.time() - 1.0
# Check limit again
await limiter.check_limit(tool_name)
# Should have refilled approximately 1 token (60 per minute = 1 per second)
# Account for the token consumed by the check
final_tokens = limiter.buckets[tool_name]["minute"]["tokens"]
refilled = (final_tokens + 1) - initial_tokens # +1 for consumed token
# Should be close to 1 token refilled
assert 0.5 < refilled < 1.5
@pytest.mark.asyncio
async def test_zero_tokens_blocks_request(self):
"""Test that request is blocked when tokens reach zero."""
limiter = RateLimiter(requests_per_minute=1, requests_per_hour=100)
tool_name = "test_tool"
# First request should succeed
result1 = await limiter.check_limit(tool_name)
assert result1 is True
# Second request should fail (no tokens left)
result2 = await limiter.check_limit(tool_name)
assert result2 is False