llm_client.py•3.6 kB
# /src/llm/lm_client.py
from google import genai
from PIL import Image
import io
import logging
import time # Import time module
import threading # Import threading for lock
from typing import Type, Optional, Union, List, Dict, Any
logger = logging.getLogger(__name__)
import base64
import json
from .clients.gemini_client import GeminiClient
from .clients.azure_openai_client import AzureOpenAIClient
from .clients.openai_client import OpenAIClient
class LLMClient:
"""
Handles interactions with LLM APIs (Google Gemini or any LLM with OpenAI sdk)
with rate limiting.
"""
# Rate limiting parameters (adjust based on the specific API limits)
# Consider making this provider-specific if needed
MIN_REQUEST_INTERVAL_SECONDS = 3.0 # Adjusted slightly, Gemini free is 15 RPM (4s), LLM depends on tier
def __init__(self, provider: str):# 'gemini' or 'LLM'
"""
Initializes the LLM client for the specified provider.
Args:
provider: The LLM provider to use ('gemini' or 'openai' or 'azure').
"""
self.provider = provider.lower()
self.client = None
if self.provider == 'gemini':
self.client = GeminiClient()
elif self.provider == 'openai':
self.client = OpenAIClient()
elif self.provider == 'azure':
self.client = AzureOpenAIClient()
else:
raise ValueError(f"Unsupported provider: {provider}. Choose 'gemini' or 'openai' or 'azure'.")
# Common initialization
self._last_request_time = 0.0
self._lock = threading.Lock() # Lock for rate limiting
logger.info(f"LLMClient initialized for provider '{self.provider}' with {self.MIN_REQUEST_INTERVAL_SECONDS}s request interval.")
def _wait_for_rate_limit(self):
"""Waits if necessary to maintain the minimum request interval."""
with self._lock: # Ensure thread-safe access
now = time.monotonic()
elapsed = now - self._last_request_time
wait_time = self.MIN_REQUEST_INTERVAL_SECONDS - elapsed
if wait_time > 0:
logger.debug(f"Rate limiting: Waiting for {wait_time:.2f} seconds...")
time.sleep(wait_time)
self._last_request_time = time.monotonic() # Update after potential wait
def generate_text(self, prompt: str) -> str:
"""Generates text using the configured LLM provider, respecting rate limits."""
self._wait_for_rate_limit() # Wait before making the API call
return self.client.generate_text(prompt)
def generate_multimodal(self, prompt: str, image_bytes: bytes) -> str:
"""Generates text based on a prompt and an image, respecting rate limits."""
self._wait_for_rate_limit() # Wait before making the API call
return self.client.generate_multimodal(prompt, image_bytes)
def generate_json(self, Schema_Class: Type, prompt: str, image_bytes: Optional[bytes] = None) -> Union[Dict[str, Any], str]:
"""
Generates structured JSON output based on a prompt, an optional image,
and a defined schema, respecting rate limits.
For Gemini, Schema_Class should be a Pydantic BaseModel or compatible type.
For any other LLM, Schema_Class must be a Pydantic BaseModel.
Returns:
A dictionary representing the parsed JSON on success, or an error string.
"""
self._wait_for_rate_limit()
return self.client.generate_json(Schema_Class, prompt, image_bytes)