"""GPTNB API client for Midjourney operations."""
import asyncio
import logging
from typing import Optional, Dict, Any, List
import httpx
from config import Config
from models import (
TaskResponse, TaskDetail, ImagineRequest, BlendRequest,
DescribeRequest, ChangeRequest, SwapFaceRequest, ModalRequest
)
from exceptions import (
APIError, AuthenticationError, RateLimitError,
TaskNotFoundError, NetworkError, TimeoutError
)
logger = logging.getLogger(__name__)
class GPTNBClient:
"""GPTNB API client for Midjourney operations."""
def __init__(self, config: Config):
"""Initialize the GPTNB client.
Args:
config: Configuration object
"""
self.config = config
self.base_url = config.gptnb_base_url.rstrip('/')
self.headers = {
"Authorization": f"Bearer {config.gptnb_api_key}",
"Content-Type": "application/json",
"User-Agent": "midjourney-mcp/0.2.0"
}
self._client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
"""Async context manager entry."""
await self._ensure_client()
return self
async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[object]) -> None:
"""Async context manager exit."""
await self.close()
async def _ensure_client(self):
"""Ensure HTTP client is initialized."""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(self.config.timeout),
headers=self.headers
)
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _make_request(
self,
method: str,
endpoint: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Make HTTP request with error handling and retries.
Args:
method: HTTP method
endpoint: API endpoint
data: Request body data
params: Query parameters
Returns:
Response data
Raises:
APIError: For API-related errors
NetworkError: For network-related errors
TimeoutError: For timeout errors
"""
await self._ensure_client()
url = f"{self.base_url}{endpoint}"
for attempt in range(self.config.max_retries + 1):
try:
logger.debug(f"Making {method} request to {url} (attempt {attempt + 1})")
response = await self._client.request(
method=method,
url=url,
json=data,
params=params
)
# Handle different status codes
if response.status_code == 200:
return response.json()
elif response.status_code == 401:
raise AuthenticationError(
"Invalid API key or authentication failed",
status_code=response.status_code,
response_data=response.json() if response.content else {}
)
elif response.status_code == 429:
raise RateLimitError(
"Rate limit exceeded",
status_code=response.status_code,
response_data=response.json() if response.content else {}
)
elif response.status_code == 404:
raise TaskNotFoundError(
"Task not found",
status_code=response.status_code,
response_data=response.json() if response.content else {}
)
else:
response_data = response.json() if response.content else {}
raise APIError(
f"API request failed with status {response.status_code}",
status_code=response.status_code,
response_data=response_data
)
except httpx.TimeoutException as e:
if attempt == self.config.max_retries:
raise TimeoutError(f"Request timed out after {self.config.timeout} seconds") from e
logger.warning(f"Request timeout on attempt {attempt + 1}, retrying...")
except httpx.NetworkError as e:
if attempt == self.config.max_retries:
raise NetworkError(f"Network error: {str(e)}") from e
logger.warning(f"Network error on attempt {attempt + 1}, retrying...")
except (AuthenticationError, RateLimitError, TaskNotFoundError):
# Don't retry these errors
raise
except Exception as e:
if attempt == self.config.max_retries:
raise APIError(f"Unexpected error: {str(e)}") from e
logger.warning(f"Unexpected error on attempt {attempt + 1}, retrying...")
# Wait before retry
if attempt < self.config.max_retries:
wait_time = self.config.retry_delay * (2 ** attempt) # Exponential backoff
logger.debug(f"Waiting {wait_time} seconds before retry...")
await asyncio.sleep(wait_time)
raise APIError("Max retries exceeded")
def _prepare_request_data(self, request: Any) -> Dict[str, Any]:
"""Prepare request data with notify hook if configured.
Args:
request: Request object with model_dump method
Returns:
Prepared request data
"""
data = request.model_dump(exclude_none=True)
if self.config.notify_hook and not data.get('notifyHook'):
data['notifyHook'] = self.config.notify_hook
return data
async def submit_imagine(self, request: ImagineRequest) -> TaskResponse:
"""Submit an imagine task.
Args:
request: Imagine request data
Returns:
Task response
"""
data = self._prepare_request_data(request)
response_data = await self._make_request("POST", "/mj/submit/imagine", data)
return TaskResponse(**response_data)
async def submit_blend(self, request: BlendRequest) -> TaskResponse:
"""Submit a blend task.
Args:
request: Blend request data
Returns:
Task response
"""
data = self._prepare_request_data(request)
response_data = await self._make_request("POST", "/mj/submit/blend", data)
return TaskResponse(**response_data)
async def submit_describe(self, request: DescribeRequest) -> TaskResponse:
"""Submit a describe task.
Args:
request: Describe request data
Returns:
Task response
"""
data = self._prepare_request_data(request)
response_data = await self._make_request("POST", "/mj/submit/describe", data)
return TaskResponse(**response_data)
async def submit_change(self, request: ChangeRequest) -> TaskResponse:
"""Submit a change task (upscale, variation, reroll).
Args:
request: Change request data
Returns:
Task response
"""
data = self._prepare_request_data(request)
response_data = await self._make_request("POST", "/mj/submit/change", data)
return TaskResponse(**response_data)
async def submit_swap_face(self, request: SwapFaceRequest) -> TaskResponse:
"""Submit a swap face task.
Args:
request: Swap face request data
Returns:
Task response
"""
data = self._prepare_request_data(request)
response_data = await self._make_request("POST", "/mj/submit/swap-face", data)
return TaskResponse(**response_data)
async def submit_modal(self, request: ModalRequest) -> TaskResponse:
"""Submit a modal task (zoom, pan, inpainting).
Args:
request: Modal request data
Returns:
Task response
"""
data = self._prepare_request_data(request)
response_data = await self._make_request("POST", "/mj/submit/modal", data)
return TaskResponse(**response_data)
async def get_task(self, task_id: str) -> TaskDetail:
"""Get task details by ID.
Args:
task_id: Task ID
Returns:
Task details
"""
response_data = await self._make_request("GET", f"/mj/task/{task_id}/fetch")
return TaskDetail.from_api_response(response_data)
async def get_tasks(self, task_ids: List[str]) -> List[TaskDetail]:
"""Get multiple tasks by IDs.
Args:
task_ids: List of task IDs
Returns:
List of task details
"""
data = {"ids": task_ids}
response_data = await self._make_request("POST", "/mj/task/list-by-condition", data)
return [TaskDetail(**task) for task in response_data]