"""Workflow orchestration layer.
Splits the monolithic WorkflowManager into focused classes following SOLID principles:
- WorkflowLoader: File I/O and catalog management
- ParameterExtractor: Parse PARAM_ placeholders from workflows
- WorkflowRenderer: Parameter binding and workflow rendering
- WorkflowOrchestrator: Coordinates execution with route functions
"""
import asyncio
import copy
import json
import random
import time
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from src.orchestrators.defaults import DefaultsManager
try:
from fastmcp import Context
except ImportError:
Context = None # type: ignore
else:
try:
from fastmcp import Context
except ImportError:
Context = None
from src.auth.base import ComfyAuth
from src.models.workflow import WorkflowParameter, WorkflowToolDefinition
from src.routes.workflow import get_prompt_history, queue_workflow
from src.utils import get_global_logger
logger = get_global_logger("ComfyUI_MCP.orchestrators.workflow")
# Constants
PLACEHOLDER_PREFIX = "PARAM_"
PLACEHOLDER_TYPE_HINTS = {
"STR": str,
"STRING": str,
"TEXT": str,
"INT": int,
"FLOAT": float,
"BOOL": bool,
}
PLACEHOLDER_DESCRIPTIONS = {
"prompt": "Main text prompt used inside the workflow.",
"seed": "Random seed for generation. If not provided, a random seed will be generated.",
"width": "Image width in pixels. Default: 512.",
"height": "Image height in pixels. Default: 512.",
"model": "Checkpoint model name. Default: 'v1-5-pruned-emaonly.ckpt'.",
"steps": "Number of sampling steps. Higher = better quality but slower. Default: 20.",
"cfg": "Classifier-free guidance scale. Higher = more adherence to prompt. Default: 8.0.",
"sampler_name": "Sampling method (e.g., 'euler', 'dpmpp_2m'). Default: 'euler'.",
"scheduler": "Scheduler type (e.g., 'normal', 'karras'). Default: 'normal'.",
"denoise": "Denoising strength (0.0-1.0). Default: 1.0.",
"negative_prompt": "Negative prompt to avoid certain elements. Default: 'text, watermark'.",
"tags": "Comma-separated descriptive tags for the audio model.",
"lyrics": "Full lyric text that should drive the audio generation.",
"seconds": "Audio duration in seconds. Default: 60.",
"lyrics_strength": "How strongly lyrics influence audio generation (0.0-1.0). Default: 0.99.",
"duration": "Video duration in seconds. Default: 5.",
"fps": "Frames per second for video output. Default: 16.",
}
DEFAULT_OUTPUT_KEYS = ("images", "image", "gifs", "gif")
AUDIO_OUTPUT_KEYS = ("audio", "audios", "sound", "files")
VIDEO_OUTPUT_KEYS = ("videos", "video", "mp4", "mov", "webm")
class ParameterExtractor:
"""Extracts parameters from workflow templates.
Single Responsibility: Parse PARAM_ placeholders and build parameter definitions.
"""
@staticmethod
def extract_parameters(workflow: dict[str, Any]) -> "OrderedDict[str, WorkflowParameter]":
"""Extract parameters from workflow JSON.
Scans workflow nodes for PARAM_ placeholders and builds parameter definitions.
Args:
workflow: Workflow JSON dictionary
Returns:
OrderedDict of parameter definitions
"""
parameters: OrderedDict[str, WorkflowParameter] = OrderedDict()
for node_id, node in workflow.items():
inputs = node.get("inputs", {})
if not isinstance(inputs, dict):
continue
for input_name, value in inputs.items():
parsed = ParameterExtractor._parse_placeholder(value)
if not parsed:
continue
param_name, annotation, placeholder_value = parsed
description = PLACEHOLDER_DESCRIPTIONS.get(param_name, f"Value for '{param_name}'.")
parameter = parameters.get(param_name)
if not parameter:
# Determine if parameter is required
optional_params = {
"seed",
"width",
"height",
"model",
"steps",
"cfg",
"sampler_name",
"scheduler",
"denoise",
"negative_prompt",
"seconds",
"lyrics_strength",
"duration",
"fps",
}
is_required = param_name not in optional_params
parameter = WorkflowParameter(
name=param_name,
placeholder=placeholder_value,
annotation=annotation,
description=description,
required=is_required,
)
parameters[param_name] = parameter
# Track binding to node input
parameter.bindings.append((node_id, input_name))
logger.debug(f"Extracted {len(parameters)} parameters from workflow")
return parameters
@staticmethod
def _parse_placeholder(value: Any) -> tuple[str, type, str] | None:
"""Parse PARAM_ placeholder to extract name and type hint.
Args:
value: Input value to check
Returns:
Tuple of (param_name, type, placeholder_value) or None
"""
if not isinstance(value, str) or not value.startswith(PLACEHOLDER_PREFIX):
return None
token = value[len(PLACEHOLDER_PREFIX) :]
annotation = str
# Check for type hint (e.g., PARAM_INT_STEPS)
if "_" in token:
type_candidate, remainder = token.split("_", 1)
type_hint = PLACEHOLDER_TYPE_HINTS.get(type_candidate.upper())
if type_hint:
annotation = type_hint
token = remainder
param_name = ParameterExtractor._normalize_name(token)
return param_name, annotation, value
@staticmethod
def _normalize_name(raw: str) -> str:
"""Normalize parameter name (lowercase, alphanumeric + underscore).
Args:
raw: Raw parameter name
Returns:
Normalized name
"""
cleaned = [(char.lower() if char.isalnum() else "_") for char in raw.strip()]
normalized = "".join(cleaned).strip("_")
return normalized or "param"
class WorkflowLoader:
"""Loads workflows from disk with caching.
Single Responsibility: File I/O and workflow catalog management.
"""
def __init__(self, workflows_dir: Path):
"""Initialize workflow loader.
Args:
workflows_dir: Directory containing workflow JSON files
"""
self.workflows_dir = Path(workflows_dir).resolve()
self._workflow_cache: dict[str, dict[str, Any]] = {}
self._tool_names: set[str] = set()
logger.info(f"Initialized WorkflowLoader with directory: {self.workflows_dir}")
def load_workflow(self, workflow_id: str) -> dict[str, Any] | None:
"""Load workflow by ID with caching.
Args:
workflow_id: Workflow identifier
Returns:
Workflow JSON dictionary or None if not found
"""
# Check cache
if workflow_id in self._workflow_cache:
logger.debug(f"Workflow cache hit for {workflow_id}")
return copy.deepcopy(self._workflow_cache[workflow_id])
# Load from disk
workflow_path = self._safe_workflow_path(workflow_id)
if not workflow_path:
logger.warning(f"Workflow {workflow_id} not found")
return None
try:
with open(workflow_path, encoding="utf-8") as f:
workflow = json.load(f)
# Cache it
self._workflow_cache[workflow_id] = workflow
logger.info(f"Loaded workflow {workflow_id} from disk")
return copy.deepcopy(workflow)
except (OSError, json.JSONDecodeError) as e:
logger.error(f"Failed to load workflow {workflow_id}: {e}")
return None
def get_workflow_catalog(self) -> list[dict[str, Any]]:
"""Get catalog of all available workflows.
Returns:
List of workflow metadata dictionaries
"""
catalog = []
if not self.workflows_dir.exists():
logger.warning(f"Workflow directory does not exist: {self.workflows_dir}")
return catalog
for workflow_path in sorted(self.workflows_dir.glob("*.json")):
# Skip metadata files
if workflow_path.name.endswith(".meta.json"):
continue
workflow_id = workflow_path.stem
try:
with open(workflow_path, encoding="utf-8") as f:
workflow = json.load(f)
except (OSError, json.JSONDecodeError) as e:
logger.warning(f"Skipping {workflow_path.name}: {e}")
continue
# Load metadata (sidecar file)
metadata = self._load_workflow_metadata(workflow_path)
# Extract parameters
parameters = ParameterExtractor.extract_parameters(workflow)
available_inputs = {
name: {
"type": param.annotation.__name__,
"required": param.required,
"description": param.description,
}
for name, param in parameters.items()
}
catalog.append(
{
"id": workflow_id,
"name": metadata.get("name", workflow_id.replace("_", " ").title()),
"description": metadata.get(
"description", f"Execute the '{workflow_id}' workflow."
),
"available_inputs": available_inputs,
"defaults": metadata.get("defaults", {}),
"updated_at": metadata.get("updated_at"),
"hash": metadata.get("hash"),
}
)
logger.debug(f"Generated catalog with {len(catalog)} workflows")
return catalog
def load_tool_definitions(self) -> list[WorkflowToolDefinition]:
"""Load all workflows as tool definitions.
Used for automatic tool registration.
Returns:
List of WorkflowToolDefinition objects
"""
definitions: list[WorkflowToolDefinition] = []
if not self.workflows_dir.exists():
logger.info(f"Workflow directory does not exist: {self.workflows_dir}")
return definitions
for workflow_path in sorted(self.workflows_dir.glob("*.json")):
# Skip metadata files
if workflow_path.name.endswith(".meta.json"):
continue
try:
with open(workflow_path, encoding="utf-8") as f:
workflow = json.load(f)
except (OSError, json.JSONDecodeError) as e:
logger.error(f"Skipping workflow {workflow_path.name}: {e}")
continue
# Extract parameters
parameters = ParameterExtractor.extract_parameters(workflow)
if not parameters:
logger.info(
f"Workflow {workflow_path.name} has no PARAM_ placeholders, "
f"skipping auto-tool registration"
)
continue
# Create tool definition
tool_name = self._dedupe_tool_name(self._derive_tool_name(workflow_path.stem))
definition = WorkflowToolDefinition(
workflow_id=workflow_path.stem,
tool_name=tool_name,
description=self._derive_description(workflow_path.stem),
template=workflow,
parameters=parameters,
output_preferences=self._guess_output_preferences(workflow),
)
logger.info(
f"Loaded tool definition '{tool_name}' with params: {list(parameters.keys())}"
)
definitions.append(definition)
return definitions
def _safe_workflow_path(self, workflow_id: str) -> Path | None:
"""Resolve workflow ID to file path with path traversal protection.
Args:
workflow_id: Workflow identifier
Returns:
Validated Path object or None if invalid/not found
"""
# Sanitize workflow_id (prevent path traversal)
safe_id = workflow_id.replace("/", "_").replace("\\", "_").replace("..", "_")
safe_id = "".join(c for c in safe_id if c.isalnum() or c in ("_", "-"))
if not safe_id:
logger.warning(f"Invalid workflow_id after sanitization: {workflow_id}")
return None
workflow_path = (self.workflows_dir / f"{safe_id}.json").resolve()
# Ensure resolved path is within workflows_dir (path traversal protection)
try:
workflow_path.relative_to(self.workflows_dir.resolve())
except ValueError:
logger.warning(f"Path traversal attempt detected: {workflow_id}")
return None
return workflow_path if workflow_path.exists() else None
def _load_workflow_metadata(self, workflow_path: Path) -> dict[str, Any]:
"""Load sidecar metadata file if it exists.
Args:
workflow_path: Path to workflow JSON file
Returns:
Metadata dictionary (empty if file doesn't exist)
"""
metadata_path = workflow_path.with_suffix(".meta.json")
if metadata_path.exists():
try:
with open(metadata_path, encoding="utf-8") as f:
return json.load(f)
except (OSError, json.JSONDecodeError) as e:
logger.warning(f"Failed to load metadata for {workflow_path.name}: {e}")
return {}
def _derive_tool_name(self, stem: str) -> str:
"""Derive tool name from workflow filename.
Args:
stem: Workflow filename stem
Returns:
Normalized tool name
"""
return ParameterExtractor._normalize_name(stem)
def _dedupe_tool_name(self, base_name: str) -> str:
"""Ensure tool name is unique by appending suffix if needed.
Args:
base_name: Base tool name
Returns:
Unique tool name
"""
name = base_name or "workflow_tool"
if name not in self._tool_names:
self._tool_names.add(name)
return name
suffix = 2
while f"{name}_{suffix}" in self._tool_names:
suffix += 1
deduped = f"{name}_{suffix}"
self._tool_names.add(deduped)
return deduped
def _derive_description(self, stem: str) -> str:
"""Derive tool description from workflow filename.
Args:
stem: Workflow filename stem
Returns:
Human-readable description
"""
readable = stem.replace("_", " ").replace("-", " ").strip()
readable = readable if readable else stem
return f"Execute the '{readable}' ComfyUI workflow."
def _guess_output_preferences(self, workflow: dict[str, Any]) -> tuple[str, ...]:
"""Guess preferred output keys based on workflow node types.
Args:
workflow: Workflow JSON dictionary
Returns:
Tuple of preferred output keys
"""
for node in workflow.values():
class_type = str(node.get("class_type", "")).lower()
if "audio" in class_type:
return AUDIO_OUTPUT_KEYS
if "video" in class_type or "savevideo" in class_type or "videocombine" in class_type:
return VIDEO_OUTPUT_KEYS
return DEFAULT_OUTPUT_KEYS
class WorkflowRenderer:
"""Renders workflow templates with parameter bindings.
Single Responsibility: Parameter binding, type coercion, seed generation.
"""
@staticmethod
def render_workflow(
definition: WorkflowToolDefinition,
provided_params: dict[str, Any],
defaults_provider: Any | None = None,
) -> dict[str, Any]:
"""Render workflow template with parameter bindings.
Args:
definition: WorkflowToolDefinition with template and parameters
provided_params: User-provided parameter values
defaults_provider: Optional defaults manager for value resolution
Returns:
Rendered workflow JSON with all parameters bound
Raises:
ValueError: If required parameters are missing
"""
workflow = copy.deepcopy(definition.template)
# Determine namespace for defaults (image, audio, video)
namespace = WorkflowRenderer._determine_namespace(definition.workflow_id)
logger.debug(
f"Rendering workflow {definition.workflow_id} with {len(provided_params)} params"
)
for param in definition.parameters.values():
# Check required parameters
if param.required and param.name not in provided_params:
raise ValueError(f"Missing required parameter '{param.name}'")
# Resolve parameter value with precedence:
# 1. Provided value
# 2. Generated value (for seed)
# 3. Default value (from defaults manager)
# 4. Skip if optional and no default
raw_value = provided_params.get(param.name)
if raw_value is None:
if param.name == "seed" and param.annotation is int:
# Generate random seed
raw_value = random.randint(0, 2**32 - 1) # nosec B311
elif defaults_provider:
# Get default from defaults manager
raw_value = defaults_provider.get_default(namespace, param.name, None)
if raw_value is not None:
logger.debug(f"Using default value for {param.name}: {raw_value}")
else:
# Skip optional parameter with no default
continue
else:
# No value available, skip
continue
# Coerce to correct type
coerced_value = WorkflowRenderer._coerce_value(raw_value, param.annotation)
# Bind to all node inputs
for node_id, input_name in param.bindings:
if node_id in workflow and "inputs" in workflow[node_id]:
workflow[node_id]["inputs"][input_name] = coerced_value
logger.debug(f"Workflow {definition.workflow_id} rendered successfully")
return workflow
@staticmethod
def _coerce_value(value: Any, annotation: type) -> Any:
"""Coerce value to specified type.
Args:
value: Raw value
annotation: Target type
Returns:
Coerced value
Raises:
ValueError: If coercion fails
"""
try:
if annotation is str:
return str(value)
if annotation is int:
return int(value)
if annotation is float:
return float(value)
if annotation is bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "y"}
return bool(value)
return value
except (ValueError, TypeError) as e:
raise ValueError(f"Cannot convert {value!r} to {annotation.__name__}: {e}") from e
@staticmethod
def _determine_namespace(workflow_id: str) -> str:
"""Determine namespace based on workflow ID.
Args:
workflow_id: Workflow identifier
Returns:
Namespace ("image", "audio", or "video")
"""
if "song" in workflow_id.lower() or "audio" in workflow_id.lower():
return "audio"
elif "video" in workflow_id.lower():
return "video"
else:
return "image" # Default fallback
class WorkflowOrchestrator:
"""Orchestrates workflow execution with polling and asset extraction.
Single Responsibility: Execute workflows via routes, poll for completion, extract results.
Coordinates WorkflowLoader, ParameterExtractor, and WorkflowRenderer.
FastMCP v3 Best Practices:
- Explicit dependency injection (no optional dependencies)
- Context support for progress reporting and logging
- Clear separation of concerns
"""
def __init__(
self,
auth: ComfyAuth,
workflow_loader: "WorkflowLoader",
defaults_manager: "DefaultsManager",
):
"""Initialize workflow orchestrator.
Args:
auth: ComfyAuth instance for API access
workflow_loader: WorkflowLoader for catalog management (required)
defaults_manager: DefaultsManager for parameter defaults (required)
"""
self.auth = auth
self.workflow_loader = workflow_loader
self.defaults_manager = defaults_manager
logger.info("Initialized WorkflowOrchestrator")
async def execute_and_wait(
self,
workflow: dict[str, Any],
max_attempts: int = 300,
poll_interval: float = 1.0,
output_preferences: tuple[str, ...] | None = None,
ctx: Any = None, # FastMCP Context (use Any to avoid import issues)
) -> dict[str, Any]:
"""Execute workflow and wait for completion with polling.
Args:
workflow: Rendered workflow JSON
max_attempts: Maximum polling attempts (default: 300 = 5 minutes)
poll_interval: Seconds between polls (default: 1.0)
output_preferences: Preferred output node keys for asset extraction
ctx: Optional FastMCP Context for progress reporting and logging
Returns:
Result dictionary with asset info, prompt_id, comfy_history
Raises:
TimeoutError: If workflow doesn't complete within max_attempts
RuntimeError: If workflow execution fails
"""
# Use Context for structured logging if available
if ctx and Context:
await ctx.info("Queueing workflow for execution")
else:
logger.info("Queueing workflow for execution")
start_time = time.time()
# Queue workflow via route
res = await queue_workflow(auth=self.auth, workflow=workflow)
if not res.is_success:
error_msg = f"Failed to queue workflow: HTTP {res.status}"
if ctx and Context:
await ctx.error(error_msg)
else:
logger.error(error_msg)
raise RuntimeError(error_msg)
prompt_id = res.response["prompt_id"]
if ctx and Context:
await ctx.info(f"Workflow queued: {prompt_id}")
await ctx.report_progress(0.0, "Workflow queued")
else:
logger.info(f"Workflow queued: {prompt_id}")
# Poll for completion
iterations = 0
while iterations < max_attempts:
iterations += 1
await asyncio.sleep(poll_interval)
# Report progress milestones
if ctx and Context and iterations % 30 == 0:
progress = min(0.9, iterations / max_attempts)
elapsed = time.time() - start_time
await ctx.report_progress(progress, f"Processing... ({elapsed:.0f}s)")
elif iterations % 30 == 0:
elapsed = time.time() - start_time
logger.debug(f"Still polling... {elapsed:.0f}s elapsed")
# Get history
history_res = await get_prompt_history(auth=self.auth, prompt_id=prompt_id)
if history_res.is_success and prompt_id in history_res.response:
history_entry = history_res.response[prompt_id]
# Check if completed
if "outputs" in history_entry:
duration = time.time() - start_time
if ctx and Context:
await ctx.report_progress(1.0, "Complete")
await ctx.info(
f"Workflow completed in {duration:.2f}s after {iterations} polls"
)
else:
logger.info(
f"Workflow completed in {duration:.2f}s after {iterations} polls"
)
# Extract asset info
return self._extract_asset_info(
history_entry,
prompt_id,
workflow,
output_preferences or DEFAULT_OUTPUT_KEYS,
)
# Timeout
error_msg = f"Workflow timeout after {max_attempts} polls"
if ctx and Context:
await ctx.error(error_msg)
else:
logger.error(error_msg)
raise TimeoutError(f"Workflow {prompt_id} did not complete within {max_attempts} polls")
def _extract_asset_info(
self,
history_entry: dict[str, Any],
prompt_id: str,
submitted_workflow: dict[str, Any],
output_preferences: tuple[str, ...],
) -> dict[str, Any]:
"""Extract asset information from workflow history.
Args:
history_entry: Workflow history entry from /history/{prompt_id}
prompt_id: Prompt ID
submitted_workflow: Original workflow submitted
output_preferences: Preferred output node keys
Returns:
Asset info dictionary with filename, subfolder, folder_type, etc.
Raises:
RuntimeError: If no asset found in outputs
"""
outputs = history_entry.get("outputs", {})
# Try to find asset in preferred output keys
for output_key in output_preferences:
for node_id, node_output in outputs.items():
if output_key in node_output:
asset_list = node_output[output_key]
if asset_list and len(asset_list) > 0:
asset = asset_list[0]
logger.debug(f"Found asset in node {node_id}, key {output_key}")
return {
"filename": asset.get("filename", ""),
"subfolder": asset.get("subfolder", ""),
"folder_type": asset.get("type", "output"),
"prompt_id": prompt_id,
"comfy_history": history_entry,
"submitted_workflow": submitted_workflow,
}
# No asset found
error_msg = f"No asset found in workflow outputs (checked keys: {output_preferences})"
logger.error(error_msg)
raise RuntimeError(error_msg)
# Delegating methods to WorkflowLoader
@property
def workflows_dir(self) -> Path:
"""Get workflows directory from loader."""
return self.workflow_loader.workflows_dir
def get_workflow_catalog(self) -> list[dict[str, Any]]:
"""Get workflow catalog from loader.
Returns:
List of workflow definitions
"""
return self.workflow_loader.get_workflow_catalog()
def load_workflow(self, workflow_id: str) -> dict[str, Any] | None:
"""Load workflow by ID.
Args:
workflow_id: Workflow identifier
Returns:
Workflow JSON or None if not found
"""
return self.workflow_loader.load_workflow(workflow_id)
def get_tool_definitions(self) -> list[WorkflowToolDefinition]:
"""Get tool definitions from workflow loader.
Returns:
List of workflow tool definitions
"""
return self.workflow_loader.load_tool_definitions()
async def execute_workflow(
self,
auth: ComfyAuth,
workflow_id: str,
parameters: dict[str, Any],
ctx: Any = None, # FastMCP Context (use Any to avoid import issues)
) -> dict[str, Any]:
"""High-level workflow execution with parameter rendering.
Loads workflow, renders with parameters, executes, and waits for completion.
Args:
auth: ComfyAuth instance
workflow_id: Workflow identifier
parameters: Parameter overrides
ctx: Optional FastMCP Context for progress reporting and logging
Returns:
Result dictionary with asset info
"""
if ctx and Context:
await ctx.info(f"Loading workflow: {workflow_id}")
await ctx.report_progress(0.1, "Loading workflow")
# Load workflow
workflow_template = self.workflow_loader.load_workflow(workflow_id)
if not workflow_template:
error_msg = f"Workflow '{workflow_id}' not found"
if ctx and Context:
await ctx.error(error_msg)
return {"error": error_msg}
if ctx and Context:
await ctx.report_progress(0.2, "Extracting parameters")
# Extract parameters
workflow_params = ParameterExtractor.extract_parameters(workflow_template)
# Create tool definition
definition = WorkflowToolDefinition(
workflow_id=workflow_id,
tool_name=workflow_id,
description=workflow_template.get("_meta", {}).get("description", ""),
template=workflow_template,
parameters=workflow_params,
output_preferences=self.workflow_loader._guess_output_preferences(workflow_template),
)
if ctx and Context:
await ctx.report_progress(0.3, "Rendering workflow")
# Render workflow with parameters using instance defaults_manager
rendered = WorkflowRenderer.render_workflow(
definition=definition,
provided_params=parameters,
defaults_provider=self.defaults_manager,
)
if ctx and Context:
await ctx.report_progress(0.4, "Executing workflow")
# Execute and wait with context passthrough
result = await self.execute_and_wait(
workflow=rendered,
max_attempts=300,
poll_interval=1.0,
output_preferences=definition.output_preferences,
ctx=ctx,
)
return result