"""Nexos.ai API client for image generation."""
import asyncio
import logging
from typing import Any
import httpx
from Imagen_MCP.config import ImagenMCPConfig
from Imagen_MCP.exceptions import (
AuthenticationError,
ConfigurationError,
GenerationError,
InvalidRequestError,
RateLimitError,
)
from Imagen_MCP.models.generation import GenerateImageRequest, GenerateImageResponse
logger = logging.getLogger(__name__)
class NexosClient:
"""Client for interacting with Nexos.ai API."""
DEFAULT_TIMEOUT = 120.0 # Image generation can take a while
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # Base delay in seconds
def __init__(
self,
api_key: str | None = None,
base_url: str = "https://api.nexos.ai/v1",
timeout: float | None = None,
):
"""Initialize the Nexos.ai client.
Args:
api_key: API key for authentication. If not provided, will be read from config.
base_url: Base URL for the API.
timeout: Request timeout in seconds.
"""
self._api_key = api_key
self._base_url = base_url.rstrip("/")
self._timeout = timeout or self.DEFAULT_TIMEOUT
self._client: httpx.AsyncClient | None = None
@classmethod
def from_config(cls, config: ImagenMCPConfig) -> "NexosClient":
"""Create a client from configuration."""
return cls(
api_key=config.nexos_api_key,
base_url=config.nexos_api_base_url,
)
@classmethod
def from_env(cls) -> "NexosClient":
"""Create a client from environment variables."""
config = ImagenMCPConfig.from_env()
return cls.from_config(config)
@property
def api_key(self) -> str:
"""Get the API key, raising an error if not set."""
if not self._api_key:
raise ConfigurationError("NEXOS_API_KEY is not configured")
return self._api_key
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=self._timeout,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def __aenter__(self) -> "NexosClient":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Async context manager exit."""
await self.close()
def _handle_error_response(self, response: httpx.Response) -> None:
"""Handle error responses from the API."""
status_code = response.status_code
try:
error_data = response.json()
error_message = error_data.get("error", {}).get("message", response.text)
# error_type could be used for more specific error handling in the future
_ = error_data.get("error", {}).get("type", "unknown")
except Exception:
error_message = response.text
if status_code == 400:
raise InvalidRequestError(f"Invalid request: {error_message}")
elif status_code == 401:
raise AuthenticationError(f"Authentication failed: {error_message}")
elif status_code == 403:
raise AuthenticationError(f"Access forbidden: {error_message}")
elif status_code == 429:
# Try to get retry-after header
retry_after = response.headers.get("retry-after")
retry_seconds = int(retry_after) if retry_after else None
raise RateLimitError(
f"Rate limit exceeded: {error_message}", retry_after=retry_seconds
)
elif status_code >= 500:
raise GenerationError(f"Server error ({status_code}): {error_message}")
else:
raise GenerationError(f"API error ({status_code}): {error_message}")
async def _request_with_retry(
self,
method: str,
endpoint: str,
json_data: dict | None = None,
max_retries: int | None = None,
) -> dict[str, Any]:
"""Make a request with retry logic for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint (without base URL)
json_data: JSON data to send
max_retries: Maximum number of retries for transient errors
Returns:
Response JSON data
Raises:
Various exceptions based on error type
"""
retries = max_retries if max_retries is not None else self.MAX_RETRIES
last_exception: Exception | None = None
for attempt in range(retries + 1):
try:
client = await self._get_client()
response = await client.request(
method=method,
url=endpoint,
json=json_data,
)
if response.is_success:
return response.json()
# Handle error responses
self._handle_error_response(response)
except (httpx.ConnectError, httpx.TimeoutException) as e:
last_exception = e
if attempt < retries:
delay = self.RETRY_DELAY * (2**attempt) # Exponential backoff
logger.warning(
f"Request failed (attempt {attempt + 1}/{retries + 1}), "
f"retrying in {delay}s: {e}"
)
await asyncio.sleep(delay)
else:
raise GenerationError(
f"Request failed after {retries + 1} attempts: {e}"
)
except RateLimitError:
# Don't retry rate limit errors
raise
except (AuthenticationError, InvalidRequestError):
# Don't retry auth or validation errors
raise
except GenerationError as e:
# Retry server errors
last_exception = e
if attempt < retries and "Server error" in str(e):
delay = self.RETRY_DELAY * (2**attempt)
logger.warning(
f"Server error (attempt {attempt + 1}/{retries + 1}), "
f"retrying in {delay}s: {e}"
)
await asyncio.sleep(delay)
else:
raise
# Should not reach here, but just in case
if last_exception:
raise last_exception
raise GenerationError("Request failed for unknown reason")
async def generate_image(
self, request: GenerateImageRequest
) -> GenerateImageResponse:
"""Generate an image from a text prompt.
Args:
request: Image generation request
Returns:
Image generation response with generated images
Raises:
ConfigurationError: If API key is not configured
AuthenticationError: If authentication fails
InvalidRequestError: If the request is invalid
RateLimitError: If rate limit is exceeded
GenerationError: If image generation fails
"""
logger.info(
f"Generating image with model={request.model}, "
f"size={request.size}, n={request.n}"
)
response_data = await self._request_with_retry(
method="POST",
endpoint="/images/generations",
json_data=request.to_api_payload(),
)
response = GenerateImageResponse.from_api_response(response_data, request)
logger.info(f"Generated {len(response.images)} image(s)")
return response
async def generate_image_simple(
self,
prompt: str,
model: str = "imagen-4",
size: str = "1024x1024",
quality: str = "standard",
style: str = "vivid",
) -> GenerateImageResponse:
"""Simplified interface for generating a single image.
Args:
prompt: Text description of the image
model: Model to use
size: Image size
quality: Image quality
style: Image style
Returns:
Image generation response
"""
request = GenerateImageRequest(
prompt=prompt,
model=model,
size=size, # type: ignore
quality=quality, # type: ignore
style=style, # type: ignore
n=1,
)
return await self.generate_image(request)
async def list_models(self) -> list[dict[str, Any]]:
"""List available models.
Returns:
List of model information dictionaries
"""
response_data = await self._request_with_retry(
method="GET",
endpoint="/models",
)
return response_data.get("data", [])