"""Workflow generation tools (auto-registered from workflow files)
Thin MCP wrapper around WorkflowOrchestrator for dynamic tool registration.
Handles tool generation, parameter coercion, and regeneration.
"""
import copy
import inspect
import random
from typing import Any
from fastmcp import Context, FastMCP
from src.auth.base import ComfyAuth
from src.models.workflow import WorkflowToolDefinition
from src.orchestrators.asset import AssetOrchestrator
from src.orchestrators.defaults import DefaultsManager
from src.orchestrators.output import OutputManager
from src.orchestrators.workflow import WorkflowOrchestrator
from src.tools.helpers import register_and_build_response
from src.utils import get_global_logger
from src.utils.context import generate_correlation_id, set_correlation_id
logger = get_global_logger("MCP_Server.tools.generation")
def register_workflow_generation_tools(
mcp: FastMCP,
auth: ComfyAuth,
workflow_orchestrator: WorkflowOrchestrator,
defaults_manager: DefaultsManager,
asset_orchestrator: AssetOrchestrator,
output_manager: OutputManager,
):
"""Register workflow-backed generation tools (e.g., generate_image, generate_song)
Args:
mcp: FastMCP server instance
auth: Authentication for ComfyUI API calls
workflow_orchestrator: Orchestrator for workflow operations
defaults_manager: Orchestrator for defaults management
asset_orchestrator: Orchestrator for asset operations
output_manager: Orchestrator for auto-saving workflow outputs
"""
def _register_workflow_tool(definition: WorkflowToolDefinition):
"""Register a single workflow tool from its definition"""
async def _tool_impl(*args, ctx: Context | None = None, **kwargs):
# Generate and set correlation ID for request tracing
correlation_id = generate_correlation_id()
set_correlation_id(correlation_id)
logger.info(f"[{correlation_id}] Workflow generation started: {definition.workflow_id}")
if ctx:
await ctx.info(f"Starting workflow: {definition.workflow_id}")
await ctx.report_progress(0.0, "Initializing workflow")
# Extract return_inline_preview and web_optimize (not workflow parameters)
return_inline_preview = kwargs.pop("return_inline_preview", False)
web_optimize = kwargs.pop("web_optimize", True)
# Session tracking can be added via request context in the future
session_id = None
# Coerce parameter types before signature binding
# MCP/JSON-RPC may pass numbers as strings, so we need to convert them
coerced_kwargs = {}
param_dict = {p.name: p for p in definition.parameters.values()}
for key, value in kwargs.items():
if key in param_dict:
param = param_dict[key]
# Coerce to correct type if needed
if value is not None:
try:
# Handle string representations of numbers
if param.annotation is int:
if isinstance(value, str) and value.strip().isdigit():
coerced_kwargs[key] = int(value)
elif isinstance(value, int | float):
coerced_kwargs[key] = int(value)
else:
coerced_kwargs[key] = value
elif param.annotation is float:
if isinstance(value, str):
coerced_kwargs[key] = float(value)
elif isinstance(value, int | float):
coerced_kwargs[key] = float(value)
else:
coerced_kwargs[key] = value
else:
coerced_kwargs[key] = value
except (ValueError, TypeError) as e:
# If coercion fails, use original value and let validation handle it
logger.warning(
f"Failed to coerce {key}={value!r} to {param.annotation.__name__}: {e}"
)
coerced_kwargs[key] = value
else:
coerced_kwargs[key] = None
else:
# Unknown parameter, pass through
coerced_kwargs[key] = value
bound = _tool_impl.__signature__.bind(*args, ctx=ctx, **coerced_kwargs)
bound.apply_defaults()
try:
if ctx:
await ctx.report_progress(0.1, "Validating parameters")
# Execute workflow using orchestrator (with polling and context support)
result = await workflow_orchestrator.execute_workflow(
auth=auth,
workflow_id=definition.workflow_id,
parameters=dict(bound.arguments),
ctx=ctx, # Pass context for progress reporting
)
if ctx:
await ctx.report_progress(0.9, "Processing results")
if "error" in result:
return result
# Track saved path for workflow snapshot timestamp correlation
saved_path = None
# Track saved path for workflow snapshot timestamp correlation
saved_path = None
# Auto-save workflow output (non-blocking)
try:
success, saved_path, error = await output_manager.save_workflow_output(
asset_info=result,
workflow_id=definition.workflow_id,
web_optimize=web_optimize,
)
if success:
logger.info(f"Auto-saved output to {saved_path}")
if ctx:
await ctx.debug(f"Auto-saved output to {saved_path}")
else:
logger.error(f"Auto-save failed (workflow succeeded): {error}")
if ctx:
await ctx.warning(f"Auto-save failed: {error}")
except Exception as e:
logger.error(f"Auto-save exception (workflow succeeded): {e}", exc_info=True)
if ctx:
await ctx.warning(f"Auto-save exception: {e}")
if ctx:
await ctx.report_progress(1.0, "Complete")
await ctx.info(f"Workflow {definition.workflow_id} completed successfully")
# Register asset and build response
return register_and_build_response(
result,
definition.workflow_id,
asset_orchestrator,
tool_name=definition.tool_name,
return_inline_preview=return_inline_preview,
session_id=session_id,
)
except Exception as exc:
logger.exception("Workflow '%s' failed", definition.workflow_id)
if ctx:
await ctx.error(f"Workflow failed: {exc}")
return {"error": str(exc)}
# Build function signature from workflow definition
# Separate required and optional parameters to ensure correct ordering
required_params = []
optional_params = []
annotations: dict[str, Any] = {}
for param in definition.parameters.values():
# For numeric types, use Any to allow string coercion from JSON-RPC
# FastMCP/Pydantic validation is strict, so we accept Any and validate/coerce ourselves
if param.annotation in (int, float):
annotation_type = Any
else:
annotation_type = param.annotation
if param.required:
parameter = inspect.Parameter(
name=param.name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=annotation_type,
)
required_params.append(parameter)
else:
# Optional parameter with default value
if param.annotation in (int, float):
final_annotation = Any
else:
final_annotation = annotation_type | None
parameter = inspect.Parameter(
name=param.name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=final_annotation,
default=None,
)
optional_params.append(parameter)
annotations[param.name] = param.annotation
# Add Context parameter (FastMCP v3 injects this)
optional_params.append(
inspect.Parameter(
name="ctx",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Context | None,
default=None,
)
)
annotations["ctx"] = Context | None
# Add return_inline_preview and web_optimize as optional parameters
optional_params.append(
inspect.Parameter(
name="return_inline_preview",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=bool,
default=False,
)
)
annotations["return_inline_preview"] = bool
optional_params.append(
inspect.Parameter(
name="web_optimize",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=bool,
default=True,
)
)
annotations["web_optimize"] = bool
# Combine: required parameters first, then optional
parameters = required_params + optional_params
annotations["return"] = dict
_tool_impl.__signature__ = inspect.Signature(parameters, return_annotation=dict)
_tool_impl.__annotations__ = annotations
_tool_impl.__name__ = f"tool_{definition.tool_name}"
_tool_impl.__doc__ = definition.description
# Register with MCP
mcp.tool(name=definition.tool_name, description=definition.description)(_tool_impl)
logger.info(
"Registered MCP tool '%s' for workflow '%s'",
definition.tool_name,
definition.workflow_id,
)
# Register all workflow-backed tools
tool_definitions = workflow_orchestrator.get_tool_definitions()
if tool_definitions:
for tool_definition in tool_definitions:
_register_workflow_tool(tool_definition)
else:
logger.info(
"No workflow placeholders found in %s; add %s markers to enable auto tools",
workflow_orchestrator.workflows_dir,
"PARAM_",
)
def _update_workflow_params(workflow: dict, param_overrides: dict) -> dict:
"""Update workflow node inputs with parameter overrides.
Searches through all nodes to find inputs that match parameter names
and updates them with override values.
Args:
workflow: The workflow dict
param_overrides: Dict of parameter overrides
Returns:
Updated workflow
"""
# Map parameter names to node search patterns
param_mappings = {
"prompt": {"class_type": "CLIPTextEncode", "input_key": "text", "is_negative": False},
"negative_prompt": {
"class_type": "CLIPTextEncode",
"input_key": "text",
"is_negative": True,
},
"steps": {"class_type": "KSampler", "input_key": "steps"},
"cfg": {"class_type": "KSampler", "input_key": "cfg"},
"sampler_name": {"class_type": "KSampler", "input_key": "sampler_name"},
"scheduler": {"class_type": "KSampler", "input_key": "scheduler"},
"denoise": {"class_type": "KSampler", "input_key": "denoise"},
"width": {"class_type": "EmptyLatentImage", "input_key": "width"},
"height": {"class_type": "EmptyLatentImage", "input_key": "height"},
"model": {"class_type": "CheckpointLoaderSimple", "input_key": "ckpt_name"},
# Audio-specific (adjust based on actual node types in workflows)
"tags": {"class_type": None, "input_key": "tags"},
"lyrics": {"class_type": None, "input_key": "lyrics"},
"seconds": {"class_type": None, "input_key": "seconds"},
"lyrics_strength": {"class_type": None, "input_key": "lyrics_strength"},
}
for param_name, override_value in param_overrides.items():
if param_name not in param_mappings:
# Log warning but continue - maybe it's a valid but unknown param
logger.warning(f"Unknown parameter '{param_name}' in regenerate, skipping")
continue
mapping = param_mappings[param_name]
target_class = mapping.get("class_type")
target_input = mapping["input_key"]
is_negative = mapping.get("is_negative", False)
# Search workflow for matching nodes
updated = False
for node_id, node_data in workflow.items():
if not isinstance(node_data, dict):
continue
# Match by class_type if specified
if target_class and node_data.get("class_type") != target_class:
continue
# Check if this node has the target input
inputs = node_data.get("inputs", {})
if target_input not in inputs:
continue
# Special handling for negative prompt
if param_name == "negative_prompt" and is_negative:
if "negative" in str(node_data).lower() or "neg" in str(node_id).lower():
inputs[target_input] = override_value
updated = True
elif param_name == "prompt" and not is_negative:
if "negative" not in str(node_data).lower() and "neg" not in str(node_id).lower():
inputs[target_input] = override_value
updated = True
else:
# Direct parameter update
inputs[target_input] = override_value
updated = True
if not updated:
logger.warning(f"Could not find node to update parameter '{param_name}' in workflow")
return workflow
def _update_seed(workflow: dict, seed: int | None) -> dict:
"""Update the seed in KSampler nodes.
Args:
workflow: The workflow dict
seed: New seed value, or None to generate random, or -1 to keep original
Returns:
Updated workflow
"""
if seed == -1:
# Keep original seed - no changes needed
return workflow
# Generate random seed if not specified
if seed is None:
seed = random.randint(0, 0xFFFFFFFFFFFFFFFF) # nosec B311
# Find and update all KSampler nodes
for _node_id, node_data in workflow.items():
if not isinstance(node_data, dict):
continue
if node_data.get("class_type") == "KSampler":
inputs = node_data.get("inputs", {})
inputs["seed"] = seed
return workflow
def register_regenerate_tool(mcp: FastMCP, auth: ComfyAuth, asset_orchestrator: AssetOrchestrator):
"""Register the regenerate tool for iterating on existing assets.
Args:
mcp: FastMCP server instance
auth: Authentication for ComfyUI API calls
asset_orchestrator: Orchestrator for asset operations
"""
@mcp.tool()
async def regenerate(
asset_id: str,
seed: int | None = None,
return_inline_preview: bool = False,
param_overrides: dict[str, Any] | None = None,
) -> dict:
"""Regenerate an existing asset with optional parameter overrides.
Note: This function generates a new correlation ID for the regeneration request,
but the original correlation ID is preserved in the parent asset metadata.
Retrieves the original workflow and parameters from the asset's provenance
data, applies any overrides, and re-submits to ComfyUI.
Args:
asset_id: ID of the asset to regenerate
seed: New random seed (None = generate new random seed, -1 = use original seed)
return_inline_preview: If True, include a small thumbnail base64 in response
param_overrides: Dict of workflow parameters to override (e.g., {"steps": 30, "cfg": 8.0, "prompt": "new prompt"})
Returns:
dict: New asset information with same structure as generate_* tools
Examples:
# Regenerate with different seed
regenerate(asset_id="abc123")
# Regenerate with higher quality settings
regenerate(asset_id="abc123", param_overrides={"steps": 30, "cfg": 10.0})
# Modify the prompt
regenerate(asset_id="abc123", param_overrides={"prompt": "a beautiful sunset, oil painting style"})
# Use exact same parameters (deterministic)
regenerate(asset_id="abc123", seed=-1)
"""
try:
# Step 1: Retrieve original asset metadata
asset = asset_orchestrator.get_asset_record(asset_id)
if not asset:
return {
"error": f"Asset {asset_id} not found (registry is in-memory and resets on restart). "
"Generate a new asset to regenerate."
}
# Extract the stored workflow
original_workflow = asset.submitted_workflow
if not original_workflow:
return {"error": "No workflow data stored for this asset. Cannot regenerate."}
# Step 2: Deep copy workflow to avoid mutating the stored one
workflow = copy.deepcopy(original_workflow)
# Step 3: Apply parameter overrides
if param_overrides:
workflow = _update_workflow_params(workflow, param_overrides)
# Step 4: Update seed
workflow = _update_seed(workflow, seed)
# Step 5: Determine output preferences from original workflow
output_preferences = None
if asset.workflow_id:
if "image" in asset.workflow_id.lower():
output_preferences = ("images", "image", "gifs", "gif")
elif "audio" in asset.workflow_id.lower() or "song" in asset.workflow_id.lower():
output_preferences = ("audio", "audios", "sound", "files")
elif "video" in asset.workflow_id.lower():
output_preferences = ("videos", "video", "mp4", "mov", "webm")
# Step 6: Submit to ComfyUI using routes
from src.routes.queue import get_prompt_history
from src.routes.workflow import queue_workflow
queue_result = await queue_workflow(auth=auth, workflow=workflow)
if not queue_result.success:
return {"error": queue_result.error or "Failed to queue workflow"}
prompt_id = queue_result.data.get("prompt_id")
if not prompt_id:
return {"error": "No prompt_id returned from workflow submission"}
# Step 7: Poll for completion (simplified, could use workflow orchestrator polling logic)
import asyncio
max_wait = 300 # 5 minutes timeout
poll_interval = 2 # 2 seconds between polls
elapsed = 0
while elapsed < max_wait:
await asyncio.sleep(poll_interval)
elapsed += poll_interval
history_result = await get_prompt_history(auth=auth, prompt_id=prompt_id)
if not history_result.success:
continue
history = history_result.data
if prompt_id in history:
prompt_data = history[prompt_id]
# Check for errors
if "error" in prompt_data:
return {"error": f"Workflow execution failed: {prompt_data['error']}"}
# Check for outputs
if "outputs" in prompt_data and prompt_data["outputs"]:
# Find output asset
outputs = prompt_data["outputs"]
for _node_id, node_outputs in outputs.items():
if output_preferences:
for pref in output_preferences:
if pref in node_outputs:
files = node_outputs[pref]
if files and len(files) > 0:
output_file = files[0]
result = {
"output": output_file,
"prompt_id": prompt_id,
"workflow": workflow,
"history": prompt_data,
}
# Register and return
return register_and_build_response(
result,
asset.workflow_id,
asset_orchestrator,
tool_name="regenerate",
return_inline_preview=return_inline_preview,
session_id=asset.session_id,
)
return {"error": f"Workflow execution timed out after {max_wait} seconds"}
except Exception as e:
logger.exception(f"Failed to regenerate asset {asset_id}")
return {"error": f"Failed to regenerate: {str(e)}"}