"""ComfyUI backend for production image generation with SDXL support."""
import os
import asyncio
import json
import uuid
import random
import sys
from typing import Optional, Any, Literal
from io import BytesIO
import httpx
import websockets
from .base import BaseBackend
# Model type detection keywords
SDXL_KEYWORDS = ["sdxl", "sd_xl", "xl_base", "xl-base", "playground", "juggernaut_xl", "realvis_xl"]
SD15_KEYWORDS = ["v1-5", "v1.5", "sd15", "sd_1.5", "dreamshaper", "deliberate", "realistic_vision"]
class ComfyUIBackend(BaseBackend):
"""Backend that connects to ComfyUI for image generation with SDXL optimization."""
def __init__(
self,
host: str = "127.0.0.1",
port: int = 8188,
client_id: Optional[str] = None,
checkpoint: Optional[str] = None
):
self.host = host
self.port = port
self.client_id = client_id or str(uuid.uuid4())
self.base_url = f"http://{host}:{port}"
self.ws_url = f"ws://{host}:{port}/ws?clientId={self.client_id}"
self.name = "comfyui"
self.checkpoint = checkpoint or os.getenv("COMFYUI_CHECKPOINT", "sd_xl_base_1.0.safetensors")
self._available_checkpoints = None
self._model_type_cache = {} # Cache model type detection
async def _get_available_checkpoints(self) -> list:
"""Get list of available checkpoint models from ComfyUI."""
if self._available_checkpoints is not None:
return self._available_checkpoints
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.base_url}/object_info/CheckpointLoaderSimple", timeout=5.0)
if response.status_code == 200:
data = response.json()
self._available_checkpoints = data.get("CheckpointLoaderSimple", {}).get("input", {}).get("required", {}).get("ckpt_name", [[]])[0]
return self._available_checkpoints
except Exception:
pass
return []
async def _get_checkpoint(self) -> str:
"""Get checkpoint to use, auto-detecting if needed."""
if self.checkpoint:
return self.checkpoint
checkpoints = await self._get_available_checkpoints()
if checkpoints:
return checkpoints[0]
return "sd_xl_base_1.0.safetensors"
def _detect_model_type(self, checkpoint: str) -> Literal["sdxl", "sd15"]:
"""Detect if checkpoint is SDXL or SD1.5 based on filename."""
if checkpoint in self._model_type_cache:
return self._model_type_cache[checkpoint]
ckpt_lower = checkpoint.lower()
# Check for SDXL indicators
for keyword in SDXL_KEYWORDS:
if keyword in ckpt_lower:
self._model_type_cache[checkpoint] = "sdxl"
return "sdxl"
# Check for SD1.5 indicators
for keyword in SD15_KEYWORDS:
if keyword in ckpt_lower:
self._model_type_cache[checkpoint] = "sd15"
return "sd15"
# Default to SDXL for unknown models (safer for modern checkpoints)
self._model_type_cache[checkpoint] = "sdxl"
return "sdxl"
def _build_workflow(
self,
prompt: str,
negative_prompt: str = "",
width: int = 1024,
height: int = 1024,
seed: Optional[int] = None,
steps: int = 25,
cfg_scale: float = 7.0,
checkpoint: Optional[str] = None,
sampler: str = "euler",
scheduler: str = "normal",
**kwargs
) -> dict:
"""Build a ComfyUI workflow for image generation.
Automatically selects SDXL or SD1.5 workflow based on checkpoint.
"""
if seed is None:
seed = random.randint(0, 2**32 - 1)
# Use provided checkpoint or fallback to instance default
ckpt = checkpoint or self.checkpoint
model_type = self._detect_model_type(ckpt)
print(f"[ComfyUI] Using {model_type.upper()} workflow for {ckpt}", file=sys.stderr)
if model_type == "sdxl":
return self._build_sdxl_workflow(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
checkpoint=ckpt,
sampler=sampler,
scheduler=scheduler
)
else:
return self._build_sd15_workflow(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
checkpoint=ckpt,
sampler=sampler,
scheduler=scheduler
)
def _build_sdxl_workflow(
self,
prompt: str,
negative_prompt: str = "",
width: int = 1024,
height: int = 1024,
seed: int = 0,
steps: int = 25,
cfg_scale: float = 7.0,
checkpoint: str = "sd_xl_base_1.0.safetensors",
sampler: str = "euler",
scheduler: str = "normal"
) -> dict:
"""Build SDXL-optimized workflow with proper settings."""
# SDXL works best with these resolutions
# Ensure dimensions are multiples of 8 and reasonable for SDXL
width = max(512, min(2048, (width // 8) * 8))
height = max(512, min(2048, (height // 8) * 8))
# SDXL-optimized settings
if sampler == "euler":
sampler = "euler_ancestral" # Better for SDXL
if scheduler == "normal":
scheduler = "karras" # Better quality for SDXL
workflow = {
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": checkpoint
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": height,
"width": width
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": prompt
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": negative_prompt
}
},
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg_scale,
"denoise": 1.0,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": sampler,
"scheduler": scheduler,
"seed": seed,
"steps": steps
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "mcp_sdxl",
"images": ["8", 0]
}
}
}
return workflow
def _build_sd15_workflow(
self,
prompt: str,
negative_prompt: str = "",
width: int = 512,
height: int = 512,
seed: int = 0,
steps: int = 20,
cfg_scale: float = 7.0,
checkpoint: str = "v1-5-pruned-emaonly.safetensors",
sampler: str = "euler",
scheduler: str = "normal"
) -> dict:
"""Build SD1.5 workflow with standard settings."""
# SD1.5 works best at 512x512
width = max(256, min(768, (width // 8) * 8))
height = max(256, min(768, (height // 8) * 8))
workflow = {
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": checkpoint
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": height,
"width": width
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": prompt
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": negative_prompt
}
},
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg_scale,
"denoise": 1.0,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": sampler,
"scheduler": scheduler,
"seed": seed,
"steps": steps
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "mcp_sd15",
"images": ["8", 0]
}
}
}
return workflow
async def _queue_prompt(self, workflow: dict) -> str:
"""Queue a prompt in ComfyUI and return the prompt ID."""
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{self.base_url}/prompt",
json={
"prompt": workflow,
"client_id": self.client_id
}
)
response.raise_for_status()
return response.json()["prompt_id"]
async def _wait_for_completion(self, prompt_id: str, timeout: float = 600.0) -> dict:
"""Wait for the prompt to complete via WebSocket."""
async with websockets.connect(self.ws_url) as ws:
loop = asyncio.get_running_loop()
start_time = loop.time()
while True:
if loop.time() - start_time > timeout:
raise TimeoutError(f"Generation timed out after {timeout}s")
try:
message = await asyncio.wait_for(ws.recv(), timeout=5.0)
data = json.loads(message)
if data.get("type") == "executing":
exec_data = data.get("data", {})
if exec_data.get("prompt_id") == prompt_id:
if exec_data.get("node") is None:
# Execution complete
break
elif data.get("type") == "execution_error":
error_data = data.get("data", {})
if error_data.get("prompt_id") == prompt_id:
raise RuntimeError(f"ComfyUI error: {error_data}")
except asyncio.TimeoutError:
continue
# Get the history to find the output
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(f"{self.base_url}/history/{prompt_id}")
response.raise_for_status()
return response.json()
async def _get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
"""Download a generated image from ComfyUI."""
async with httpx.AsyncClient(timeout=60.0) as client:
params = {
"filename": filename,
"subfolder": subfolder,
"type": folder_type
}
response = await client.get(f"{self.base_url}/view", params=params)
response.raise_for_status()
return response.content
async def _upload_image(self, image_bytes: bytes, filename: str = "input.png") -> dict:
"""Upload an image to ComfyUI for img2img."""
async with httpx.AsyncClient(timeout=60.0) as client:
files = {"image": (filename, image_bytes, "image/png")}
data = {"overwrite": "true"}
response = await client.post(
f"{self.base_url}/upload/image",
files=files,
data=data
)
response.raise_for_status()
return response.json()
def _build_img2img_workflow(
self,
prompt: str,
negative_prompt: str = "",
image_filename: str = "input.png",
denoise: float = 0.35,
seed: Optional[int] = None,
steps: int = 25,
cfg_scale: float = 6.0,
checkpoint: Optional[str] = None,
sampler: str = "euler_ancestral",
scheduler: str = "karras"
) -> dict:
"""Build img2img workflow for character animation consistency."""
if seed is None:
seed = random.randint(0, 2**32 - 1)
ckpt = checkpoint or self.checkpoint
print(f"[ComfyUI] Img2Img workflow: denoise={denoise}, seed={seed}", file=sys.stderr)
workflow = {
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": ckpt
}
},
"10": {
"class_type": "LoadImage",
"inputs": {
"image": image_filename
}
},
"11": {
"class_type": "VAEEncode",
"inputs": {
"pixels": ["10", 0],
"vae": ["4", 2]
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": prompt
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": negative_prompt
}
},
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg_scale,
"denoise": denoise,
"latent_image": ["11", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": sampler,
"scheduler": scheduler,
"seed": seed,
"steps": steps
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "mcp_img2img",
"images": ["8", 0]
}
}
}
return workflow
def _build_controlnet_workflow(
self,
prompt: str,
negative_prompt: str = "",
control_image_filename: str = "control.png",
controlnet_model: str = "diffusers_xl_depth_full.safetensors",
control_strength: float = 0.8,
width: int = 1024,
height: int = 1024,
seed: Optional[int] = None,
steps: int = 30,
cfg_scale: float = 6.0,
checkpoint: Optional[str] = None,
sampler: str = "euler_ancestral",
scheduler: str = "karras"
) -> dict:
"""Build ControlNet workflow for precise viewpoint/structure control.
Args:
control_image_filename: Uploaded control image (depth/canny/etc)
controlnet_model: ControlNet model name
control_strength: How strongly to follow control image (0.0-1.0)
"""
if seed is None:
seed = random.randint(0, 2**32 - 1)
ckpt = checkpoint or self.checkpoint
# Ensure dimensions are multiples of 8
width = max(512, min(2048, (width // 8) * 8))
height = max(512, min(2048, (height // 8) * 8))
print(f"[ComfyUI] ControlNet workflow: model={controlnet_model}, strength={control_strength}", file=sys.stderr)
workflow = {
# Load checkpoint
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": ckpt
}
},
# Load ControlNet model
"12": {
"class_type": "ControlNetLoader",
"inputs": {
"control_net_name": controlnet_model
}
},
# Load control image
"13": {
"class_type": "LoadImage",
"inputs": {
"image": control_image_filename
}
},
# Empty latent
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": height,
"width": width
}
},
# Positive prompt
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": prompt
}
},
# Negative prompt
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": ["4", 1],
"text": negative_prompt
}
},
# Apply ControlNet to positive conditioning
"14": {
"class_type": "ControlNetApply",
"inputs": {
"conditioning": ["6", 0],
"control_net": ["12", 0],
"image": ["13", 0],
"strength": control_strength
}
},
# Apply ControlNet to negative conditioning (helps lock viewpoint vs CFG pulling back to priors)
"15": {
"class_type": "ControlNetApply",
"inputs": {
"conditioning": ["7", 0],
"control_net": ["12", 0],
"image": ["13", 0],
"strength": control_strength
}
},
# KSampler
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg_scale,
"denoise": 1.0,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["15", 0],
"positive": ["14", 0], # Use ControlNet-applied conditioning
"sampler_name": sampler,
"scheduler": scheduler,
"seed": seed,
"steps": steps
}
},
# VAE Decode
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
}
},
# Save Image
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "mcp_controlnet",
"images": ["8", 0]
}
}
}
return workflow
async def generate_with_controlnet(
self,
prompt: str,
control_image: bytes,
controlnet_model: str = "diffusers_xl_depth_full.safetensors",
control_strength: float = 0.8,
negative_prompt: str = "",
width: int = 1024,
height: int = 1024,
seed: Optional[int] = None,
steps: int = 30,
cfg_scale: float = 6.0,
**kwargs
) -> bytes:
"""Generate an image with ControlNet guidance.
Args:
prompt: Text description
control_image: Control image (depth map, canny edges, etc.) as PNG bytes
controlnet_model: ControlNet model to use
control_strength: How strongly to follow control (0.0-1.0)
negative_prompt: Things to avoid
width: Output width
height: Output height
seed: Random seed
steps: Generation steps
cfg_scale: CFG scale
Returns:
Generated PNG image bytes
"""
# Upload control image
filename = f"ctrl_{uuid.uuid4().hex[:8]}.png"
upload_result = await self._upload_image(control_image, filename)
uploaded_name = upload_result.get("name") or upload_result.get("filename") or filename
uploaded_subfolder = upload_result.get("subfolder") or ""
load_image_name = uploaded_name
if uploaded_subfolder:
load_image_name = f"{uploaded_subfolder}/{uploaded_name}"
print(f"[ComfyUI] Uploaded control image: {load_image_name}", file=sys.stderr)
workflow = self._build_controlnet_workflow(
prompt=prompt,
negative_prompt=negative_prompt,
control_image_filename=load_image_name,
controlnet_model=controlnet_model,
control_strength=control_strength,
width=width,
height=height,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
**kwargs
)
prompt_id = await self._queue_prompt(workflow)
history = await self._wait_for_completion(prompt_id)
# Extract the output image
outputs = history.get(prompt_id, {}).get("outputs", {})
for node_id, node_output in outputs.items():
if "images" in node_output:
image_info = node_output["images"][0]
return await self._get_image(
filename=image_info["filename"],
subfolder=image_info.get("subfolder", ""),
folder_type=image_info.get("type", "output")
)
raise RuntimeError("No image output found in ComfyUI response")
async def generate_image(
self,
prompt: str,
negative_prompt: str = "",
width: int = 512,
height: int = 512,
seed: Optional[int] = None,
steps: int = 20,
cfg_scale: float = 7.0,
**kwargs
) -> bytes:
"""Generate an image using ComfyUI."""
workflow = self._build_workflow(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
**kwargs
)
prompt_id = await self._queue_prompt(workflow)
history = await self._wait_for_completion(prompt_id)
# Extract the output image
outputs = history.get(prompt_id, {}).get("outputs", {})
for node_id, node_output in outputs.items():
if "images" in node_output:
image_info = node_output["images"][0]
return await self._get_image(
filename=image_info["filename"],
subfolder=image_info.get("subfolder", ""),
folder_type=image_info.get("type", "output")
)
raise RuntimeError("No image output found in ComfyUI response")
async def generate_img2img(
self,
reference_image: bytes,
prompt: str,
negative_prompt: str = "",
denoise: float = 0.35,
seed: Optional[int] = None,
steps: int = 25,
cfg_scale: float = 6.0,
**kwargs
) -> bytes:
"""Generate an image using img2img from a reference image."""
# Upload reference image
filename = f"ref_{uuid.uuid4().hex[:8]}.png"
upload_result = await self._upload_image(reference_image, filename)
uploaded_name = upload_result.get("name") or upload_result.get("filename") or filename
uploaded_subfolder = upload_result.get("subfolder") or ""
load_image_name = uploaded_name
if uploaded_subfolder:
load_image_name = f"{uploaded_subfolder}/{uploaded_name}"
print(f"[ComfyUI] Uploaded reference: {load_image_name}", file=sys.stderr)
workflow = self._build_img2img_workflow(
prompt=prompt,
negative_prompt=negative_prompt,
image_filename=load_image_name,
denoise=denoise,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
**kwargs
)
prompt_id = await self._queue_prompt(workflow)
history = await self._wait_for_completion(prompt_id)
# Extract the output image
outputs = history.get(prompt_id, {}).get("outputs", {})
for node_id, node_output in outputs.items():
if "images" in node_output:
image_info = node_output["images"][0]
return await self._get_image(
filename=image_info["filename"],
subfolder=image_info.get("subfolder", ""),
folder_type=image_info.get("type", "output")
)
raise RuntimeError("No image output found in ComfyUI response")
async def health_check(self) -> bool:
"""Check if ComfyUI is running and accessible."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.base_url}/system_stats", timeout=5.0)
return response.status_code == 200
except Exception:
return False
def get_name(self) -> str:
return self.name