client.py•5.37 kB
"""
HTTP client for LLM API calls with retry logic.
This module handles the transport layer for AI API calls, implementing:
- HTTP-level retries for network issues (timeouts, 5xx errors, connection failures)
- Exponential backoff with jitter to prevent thundering herd
- Proper error classification and handling
This is separate from content-level retries which are handled in tools/cleaner.py
for AI output quality issues (invalid JSON, schema validation failures).
"""
import asyncio
import logging
from typing import List, Dict, Optional
import httpx
from httpx import TimeoutException, HTTPStatusError, RequestError
logger = logging.getLogger(__name__)
class LLMTimeoutError(Exception):
"""Raised when LLM request times out"""
class LLMHttpError(Exception):
"""Raised when LLM returns HTTP error"""
class LLMNetworkError(Exception):
"""Raised when network error occurs"""
class LLMClient:
"""Async HTTP client for LLM API calls with retry logic"""
def __init__(
self,
endpoint: str,
api_key: Optional[str] = None,
model: str = "gpt-4",
timeout: float = 60.0,
max_retries: int = 3,
):
self.endpoint = endpoint
self.api_key = api_key
self.model = model
self.timeout = timeout
self.max_retries = max_retries
async def chat_completions(
self,
messages: List[Dict[str, str]],
temperature: float = 0.2,
max_tokens: int = 600,
request_id: str = "",
) -> str:
"""
Make chat completions request with retry logic.
Args:
messages: List of message dicts with 'role' and 'content'
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
request_id: Request ID for tracking
Returns:
Content of the first choice message
Raises:
LLMTimeoutError: If request times out
LLMHttpError: If HTTP error occurs
LLMNetworkError: If network error occurs
"""
headers = {"Content-Type": "application/json", "X-Request-ID": request_id}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": self.model,
"temperature": temperature,
"max_tokens": max_tokens,
"messages": messages,
}
last_exception = None
for attempt in range(self.max_retries + 1):
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
self.endpoint, headers=headers, json=payload
)
response.raise_for_status()
data = response.json()
if "choices" not in data or not data["choices"]:
raise LLMHttpError("Invalid response format: no choices")
content = data["choices"][0]["message"]["content"]
if not content:
raise LLMHttpError("Empty response content")
return content
except TimeoutException as e:
last_exception = LLMTimeoutError(
f"Request timed out after {self.timeout}s: {e}"
)
logger.warning(
"Request timeout (attempt %d/%d)", attempt + 1, self.max_retries + 1
)
except HTTPStatusError as e:
if e.response.status_code >= 500 and attempt < self.max_retries:
# Retry on 5xx errors
last_exception = LLMHttpError(
f"HTTP {e.response.status_code}: {e.response.text}"
)
logger.warning(
"HTTP error %d (attempt %d/%d)",
e.response.status_code,
attempt + 1,
self.max_retries + 1,
)
else:
# Don't retry on 4xx errors or if max retries reached
raise LLMHttpError(
f"HTTP {e.response.status_code}: {e.response.text}"
) from e
except RequestError as e:
last_exception = LLMNetworkError(f"Network error: {e}")
logger.warning(
"Network error (attempt %d/%d): %s",
attempt + 1,
self.max_retries + 1,
e,
)
except Exception as e:
# Unexpected error, don't retry
raise LLMHttpError(f"Unexpected error: {e}") from e
# Calculate backoff delay for retries
if attempt < self.max_retries:
# Exponential backoff with jitter: 2^attempt + 0.1*attempt seconds
# This prevents thundering herd and adds randomness
delay = 2**attempt + (0.1 * attempt)
logger.info("Retrying in %.1fs...", delay)
await asyncio.sleep(delay)
# If we get here, all retries failed
if last_exception:
raise last_exception
else:
raise LLMNetworkError("All retry attempts failed")