import asyncio
import time
import logging
import sqlite3
import json
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
from pathlib import Path
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting.
Official Gemini API Free Tier Limits (Gemini 2.0 Flash-Lite):
- RPM: 30 requests per minute
- TPM: 1,000,000 tokens per minute
- RPD: 200 requests per day
Paid Tier Limits:
- RPM: 60+ requests per minute
- TPM: 1,000,000+ tokens per minute
- RPD: 1,000+ requests per day
"""
rpm_limit: int = 30 # Requests per minute (free tier)
tpm_limit: int = 1_000_000 # Tokens per minute (free tier)
rpd_limit: int = 200 # Requests per day (free tier)
token_estimate_per_request: int = 1000 # Estimated tokens per request
safety_margin: float = 0.8 # Use 80% of limits to be safe
class PersistentRateLimiter:
"""Persistent rate limiter using SQLite for tracking API usage across restarts."""
def __init__(self, config: Optional[RateLimitConfig] = None, db_path: str = "db/rate_limit_tracker.db"):
self.config = config or RateLimitConfig()
self.db_path = db_path
# Calculate safe limits
self.safe_rpm = int(self.config.rpm_limit * self.config.safety_margin)
self.safe_tpm = int(self.config.tpm_limit * self.config.safety_margin)
self.safe_rpd = int(self.config.rpd_limit * self.config.safety_margin)
# Initialize database
self._init_database()
logger.info(f"Persistent rate limiter initialized with safe limits: RPM={self.safe_rpm}, TPM={self.safe_tpm}, RPD={self.safe_rpd}")
def _init_database(self):
"""Initialize SQLite database for tracking API usage."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Create requests table
cursor.execute("""
CREATE TABLE IF NOT EXISTS api_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp REAL NOT NULL,
tokens_used INTEGER DEFAULT 0,
endpoint TEXT DEFAULT 'gemini',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Create index for efficient queries
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_timestamp ON api_requests(timestamp)
""")
# Create index for daily queries
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_created_at ON api_requests(created_at)
""")
conn.commit()
logger.info("Rate limit database initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize rate limit database: {e}")
raise
def _cleanup_old_entries(self):
"""Remove old entries from the database."""
try:
current_time = time.time()
cutoff_time = current_time - 86400 # 24 hours ago
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Delete entries older than 24 hours
cursor.execute("DELETE FROM api_requests WHERE timestamp < ?", (cutoff_time,))
deleted_count = cursor.rowcount
if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} old rate limit entries")
conn.commit()
except Exception as e:
logger.error(f"Failed to cleanup old entries: {e}")
def _get_recent_requests(self, seconds: int) -> list:
"""Get requests from the last N seconds."""
try:
current_time = time.time()
cutoff_time = current_time - seconds
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT timestamp, tokens_used FROM api_requests
WHERE timestamp >= ?
ORDER BY timestamp DESC
""", (cutoff_time,))
return cursor.fetchall()
except Exception as e:
logger.error(f"Failed to get recent requests: {e}")
return []
def _get_daily_requests(self) -> list:
"""Get requests from the last 24 hours."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT timestamp, tokens_used FROM api_requests
WHERE created_at >= datetime('now', '-1 day')
ORDER BY timestamp DESC
""")
return cursor.fetchall()
except Exception as e:
logger.error(f"Failed to get daily requests: {e}")
return []
def _estimate_tokens(self, content: str) -> int:
"""Estimate token count for content (rough approximation)."""
# Simple estimation: ~4 characters per token
return len(content) // 4
def can_make_request(self, estimated_tokens: Optional[int] = None) -> Tuple[bool, str]:
"""
Check if a request can be made without exceeding limits.
Args:
estimated_tokens: Estimated token count for the request
Returns:
(can_proceed, reason)
"""
self._cleanup_old_entries()
# Get recent requests (last minute)
recent_requests = self._get_recent_requests(60)
current_rpm = len(recent_requests)
# Get daily requests
daily_requests = self._get_daily_requests()
current_rpd = len(daily_requests)
# Calculate current TPM
current_tpm = sum(tokens for _, tokens in recent_requests)
# Check RPM limit
if current_rpm >= self.safe_rpm:
if recent_requests:
oldest_request = recent_requests[-1][0] # Oldest timestamp
wait_time = 60 - (time.time() - oldest_request)
return False, f"RPM limit exceeded. Wait {wait_time:.1f} seconds"
else:
return False, "RPM limit exceeded"
# Check RPD limit
if current_rpd >= self.safe_rpd:
return False, "Daily request limit exceeded"
# Check TPM limit
if estimated_tokens and (current_tpm + estimated_tokens) > self.safe_tpm:
return False, f"TPM limit would be exceeded. Current: {current_tpm}, Request: {estimated_tokens}"
return True, "OK"
async def wait_if_needed(self, estimated_tokens: Optional[int] = None) -> str:
"""
Wait if necessary to respect rate limits.
Args:
estimated_tokens: Estimated token count for the request
Returns:
Reason for waiting (if any)
"""
while True:
can_proceed, reason = self.can_make_request(estimated_tokens)
if can_proceed:
break
# Extract wait time if it's a time-based limit
if "Wait" in reason:
try:
wait_time = float(reason.split("Wait ")[1].split(" ")[0])
logger.info(f"Rate limit hit: {reason}. Waiting {wait_time:.1f} seconds...")
await asyncio.sleep(wait_time + 0.1) # Add small buffer
except:
logger.warning(f"Rate limit hit: {reason}")
await asyncio.sleep(2) # Default wait
else:
logger.warning(f"Rate limit hit: {reason}")
await asyncio.sleep(5) # Wait longer for non-time-based limits
return reason
def record_request(self, actual_tokens: Optional[int] = None):
"""
Record a completed request in the database.
Args:
actual_tokens: Actual token count used (if known)
"""
try:
current_time = time.time()
tokens_used = actual_tokens or self.config.token_estimate_per_request
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO api_requests (timestamp, tokens_used, endpoint)
VALUES (?, ?, ?)
""", (current_time, tokens_used, 'gemini'))
conn.commit()
logger.debug(f"Recorded request: tokens={tokens_used}, timestamp={current_time}")
except Exception as e:
logger.error(f"Failed to record request: {e}")
def get_status(self) -> Dict[str, Any]:
"""Get current rate limiting status."""
self._cleanup_old_entries()
recent_requests = self._get_recent_requests(60)
daily_requests = self._get_daily_requests()
current_rpm = len(recent_requests)
current_rpd = len(daily_requests)
current_tpm = sum(tokens for _, tokens in recent_requests)
return {
"current_rpm": current_rpm,
"current_rpd": current_rpd,
"current_tpm": current_tpm,
"safe_rpm": self.safe_rpm,
"safe_rpd": self.safe_rpd,
"safe_tpm": self.safe_tpm,
"rpm_available": self.safe_rpm - current_rpm,
"rpd_available": self.safe_rpd - current_rpd,
"tpm_available": self.safe_tpm - current_tpm,
"database_path": self.db_path
}
def get_usage_history(self, hours: int = 24) -> Dict[str, Any]:
"""Get usage history for monitoring."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get hourly breakdown
cursor.execute("""
SELECT
strftime('%Y-%m-%d %H:00:00', created_at) as hour,
COUNT(*) as requests,
SUM(tokens_used) as tokens
FROM api_requests
WHERE created_at >= datetime('now', '-{} hours')
GROUP BY hour
ORDER BY hour DESC
""".format(hours))
hourly_data = cursor.fetchall()
# Get total usage
cursor.execute("""
SELECT
COUNT(*) as total_requests,
SUM(tokens_used) as total_tokens
FROM api_requests
WHERE created_at >= datetime('now', '-{} hours')
""".format(hours))
total_data = cursor.fetchone()
return {
"hourly_breakdown": [
{"hour": hour, "requests": reqs, "tokens": tokens}
for hour, reqs, tokens in hourly_data
],
"total_requests": total_data[0] if total_data else 0,
"total_tokens": total_data[1] if total_data else 0,
"hours_analyzed": hours
}
except Exception as e:
logger.error(f"Failed to get usage history: {e}")
return {"error": str(e)}
# Global persistent rate limiter instance
persistent_rate_limiter = PersistentRateLimiter()
async def persistent_rate_limited_gemini_call(func, *args, **kwargs):
"""
Decorator function to rate limit Gemini API calls with persistent tracking.
Args:
func: The Gemini API function to call
*args, **kwargs: Arguments for the function
Returns:
The result of the function call
"""
# Estimate tokens from content if available
estimated_tokens = None
if 'contents' in kwargs:
content = kwargs['contents']
if isinstance(content, str):
estimated_tokens = persistent_rate_limiter._estimate_tokens(content)
elif isinstance(content, list):
# Handle list of content parts
total_content = ""
for part in content:
if hasattr(part, 'text'):
total_content += part.text
estimated_tokens = persistent_rate_limiter._estimate_tokens(total_content)
# Wait if necessary
await persistent_rate_limiter.wait_if_needed(estimated_tokens)
# Make the API call
try:
result = func(*args, **kwargs)
# Record the request
persistent_rate_limiter.record_request(estimated_tokens)
return result
except Exception as e:
logger.error(f"Gemini API call failed: {e}")
raise