"""
Rate-Limited Gemini API Client.
Wraps the Google Gemini API client with automatic rate limiting and persistent tracking.
"""
import asyncio
import logging
from typing import Optional, Dict, Any, List
import google.generativeai as genai
from google.generativeai import types
from .persistent_rate_limiter import persistent_rate_limiter, RateLimitConfig
logger = logging.getLogger(__name__)
class RateLimitedGeminiClient:
"""Rate-limited wrapper around the Gemini API client."""
def __init__(self, api_key: str, config: Optional[RateLimitConfig] = None):
"""
Initialize the rate-limited Gemini client.
Args:
api_key: Google API key for Gemini
config: Rate limiting configuration
"""
self.api_key = api_key
self.config = config
# Initialize the Gemini client
self.client = genai.GenerativeModel('gemini-2.0-flash-lite')
genai.configure(api_key=api_key)
logger.info("Rate-limited Gemini client initialized")
async def generate_content(
self,
contents: str | List[str],
model: str = "gemini-2.0-flash-lite",
config: Optional[types.GenerationConfig] = None,
**kwargs
):
"""
Generate content with automatic rate limiting.
Args:
contents: The content to generate from
model: The model to use
config: Generation configuration
**kwargs: Additional arguments
Returns:
Response from Gemini API
"""
# Estimate tokens from content
estimated_tokens = self._estimate_tokens(contents)
# Wait if necessary to respect rate limits
await persistent_rate_limiter.wait_if_needed(estimated_tokens)
try:
# Make the API call
response = self.client.generate_content(
contents=contents,
**kwargs
)
# Record the successful request
persistent_rate_limiter.record_request(estimated_tokens)
logger.debug(f"Generated content successfully. Estimated tokens: {estimated_tokens}")
return response
except Exception as e:
logger.error(f"Gemini API call failed: {e}")
# Still record the request attempt to track failures
persistent_rate_limiter.record_request(estimated_tokens)
raise
def _estimate_tokens(self, content) -> int:
"""Estimate token count for content."""
if isinstance(content, str):
return len(content) // 4 # Rough estimation: ~4 characters per token
elif isinstance(content, list):
total_content = ""
for part in content:
if isinstance(part, str):
total_content += part
elif hasattr(part, 'text'):
total_content += part.text
elif hasattr(part, 'parts'):
for subpart in part.parts:
if hasattr(subpart, 'text'):
total_content += subpart.text
return len(total_content) // 4
else:
# Default estimation
return 1000
def get_rate_limit_status(self) -> Dict[str, Any]:
"""Get current rate limiting status."""
return persistent_rate_limiter.get_status()
def get_usage_history(self, hours: int = 24) -> Dict[str, Any]:
"""Get usage history for monitoring."""
return persistent_rate_limiter.get_usage_history(hours)
def can_make_request(self, estimated_tokens: Optional[int] = None) -> tuple[bool, str]:
"""Check if a request can be made without exceeding limits."""
return persistent_rate_limiter.can_make_request(estimated_tokens)
async def wait_if_needed(self, estimated_tokens: Optional[int] = None) -> str:
"""Wait if necessary to respect rate limits."""
return await persistent_rate_limiter.wait_if_needed(estimated_tokens)
# Global rate-limited client instance
_rate_limited_client: Optional[RateLimitedGeminiClient] = None
def get_rate_limited_client(api_key: str, config: Optional[RateLimitConfig] = None) -> RateLimitedGeminiClient:
"""
Get or create a global rate-limited Gemini client.
Args:
api_key: Google API key for Gemini
config: Rate limiting configuration
Returns:
RateLimitedGeminiClient instance
"""
global _rate_limited_client
if _rate_limited_client is None:
_rate_limited_client = RateLimitedGeminiClient(api_key, config)
return _rate_limited_client
def reset_client():
"""Reset the global client instance (useful for testing)."""
global _rate_limited_client
_rate_limited_client = None