"""Async HTTP client for the ComfyUI REST API."""
import asyncio
import logging
import time
import httpx
from .config import settings
from .models import GenerationResult, ModelInfo, QueueStatus
logger = logging.getLogger(__name__)
class ComfyUIClient:
"""Wraps ComfyUI's REST API with async HTTP calls."""
def __init__(self, base_url: str | None = None) -> None:
self.base_url = (base_url or settings.url).rstrip("/")
self._client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
def _connect_error(self, exc: httpx.ConnectError) -> httpx.ConnectError:
"""Wrap a ConnectError with a user-friendly message, preserving the chain."""
new = httpx.ConnectError(
f"Cannot connect to ComfyUI at {self.base_url}. Is it running?"
)
new.__cause__ = exc
return new
async def health_check(self) -> bool:
"""Check if ComfyUI is running by hitting /system_stats."""
try:
resp = await self._client.get("/system_stats")
return resp.status_code == 200
except httpx.ConnectError:
logger.warning("ComfyUI not reachable at %s", self.base_url)
return False
except httpx.HTTPError as exc:
logger.warning("Health check failed: %s", exc)
return False
async def submit_workflow(self, workflow: dict) -> str:
"""Submit a workflow for execution.
Args:
workflow: The ComfyUI workflow/prompt dictionary.
Returns:
The prompt_id assigned by ComfyUI.
Raises:
httpx.HTTPStatusError: If ComfyUI rejects the workflow.
httpx.ConnectError: If ComfyUI is not reachable.
"""
try:
resp = await self._client.post(
"/prompt",
json={"prompt": workflow, "client_id": "comfyui-mcp"},
)
resp.raise_for_status()
data = resp.json()
return data["prompt_id"]
except httpx.ConnectError as exc:
raise self._connect_error(exc)
async def get_queue(self) -> QueueStatus:
"""Get the current queue status."""
try:
resp = await self._client.get("/queue")
resp.raise_for_status()
data = resp.json()
pending = len(data.get("queue_pending", []))
running = len(data.get("queue_running", []))
return QueueStatus(pending=pending, running=running)
except httpx.ConnectError as exc:
raise self._connect_error(exc)
async def get_history(self, prompt_id: str) -> dict | None:
"""Get generation history for a specific prompt_id.
Returns:
The history entry dict, or None if not found yet.
"""
try:
resp = await self._client.get(f"/history/{prompt_id}")
resp.raise_for_status()
data = resp.json()
return data.get(prompt_id)
except httpx.ConnectError as exc:
raise self._connect_error(exc)
async def poll_until_complete(
self,
prompt_id: str,
timeout: float = 120.0,
interval: float = 1.0,
) -> GenerationResult:
"""Poll /history/{prompt_id} until output appears, error, or timeout.
Args:
prompt_id: The prompt ID to poll for.
timeout: Maximum seconds to wait.
interval: Seconds between polls.
Returns:
GenerationResult with status, images, and elapsed time.
"""
start = time.monotonic()
while (time.monotonic() - start) < timeout:
history = await self.get_history(prompt_id)
if history:
# Check for ComfyUI execution error
status_info = history.get("status", {})
if status_info.get("status_str") == "error":
elapsed = time.monotonic() - start
messages = status_info.get("messages", [])
error_msg = str(messages) if messages else "ComfyUI execution error"
logger.error("Generation %s failed: %s", prompt_id, error_msg)
return GenerationResult(
prompt_id=prompt_id,
status="error",
elapsed_seconds=round(elapsed, 2),
)
# Check for completed outputs
if history.get("outputs"):
elapsed = time.monotonic() - start
images = extract_image_filenames(history["outputs"])
return GenerationResult(
prompt_id=prompt_id,
status="completed",
images=images,
elapsed_seconds=round(elapsed, 2),
)
await asyncio.sleep(interval)
elapsed = time.monotonic() - start
return GenerationResult(
prompt_id=prompt_id,
status="timeout",
elapsed_seconds=round(elapsed, 2),
)
async def list_checkpoints(self) -> list[ModelInfo]:
"""List available checkpoint models from ComfyUI.
Queries /object_info and extracts checkpoints from the
CheckpointLoaderSimple node definition.
"""
try:
resp = await self._client.get("/object_info")
resp.raise_for_status()
data = resp.json()
except httpx.ConnectError as exc:
raise self._connect_error(exc)
models: list[ModelInfo] = []
loader_info = data.get("CheckpointLoaderSimple")
if not loader_info:
return models
inputs = loader_info.get("input", {}).get("required", {})
ckpt_names = inputs.get("ckpt_name", [])
# ckpt_names is typically [list_of_names, {}] or just [list_of_names]
if ckpt_names and isinstance(ckpt_names[0], list):
for name in ckpt_names[0]:
models.append(
ModelInfo(name=name, filename=name, type="checkpoint")
)
return models
async def get_image_data(
self,
filename: str,
subfolder: str = "",
img_type: str = "output",
) -> bytes:
"""Download a generated image from ComfyUI.
Args:
filename: The image filename.
subfolder: Subfolder within the output directory.
img_type: Image type (usually "output").
Returns:
Raw image bytes.
"""
try:
resp = await self._client.get(
"/view",
params={
"filename": filename,
"subfolder": subfolder,
"type": img_type,
},
)
resp.raise_for_status()
return resp.content
except httpx.ConnectError as exc:
raise self._connect_error(exc)
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self._client.aclose()
async def __aenter__(self) -> "ComfyUIClient":
return self
async def __aexit__(self, *args: object) -> None:
await self.close()
def extract_image_filenames(outputs: dict) -> list[str]:
"""Extract image filenames from ComfyUI history outputs.
ComfyUI outputs look like:
{"node_id": {"images": [{"filename": "img.png", "subfolder": "", "type": "output"}, ...]}}
"""
filenames: list[str] = []
for node_output in outputs.values():
for image_info in node_output.get("images", []):
fname = image_info.get("filename")
if fname:
filenames.append(fname)
return filenames