workflow.py•11.1 kB
"""Workflow parsing and introspection utilities."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence
@dataclass(slots=True)
class WorkflowNode:
"""Representation of a single ComfyUI node."""
id: int
class_type: str
raw: Mapping[str, Any]
@property
def inputs(self) -> Mapping[str, Any]:
return self.raw.get("inputs", {})
@property
def metadata(self) -> Mapping[str, Any]:
return self.raw.get("meta", {}) or self.raw.get("metadata", {})
@dataclass(slots=True)
class WorkflowLink:
id: int
from_node: int
from_slot: int
to_node: int
to_slot: int
raw: Mapping[str, Any]
@dataclass(slots=True)
class SemanticSummary:
prompts: Dict[str, int] = field(default_factory=dict)
checkpoint_loaders: List[int] = field(default_factory=list)
lora_loaders: List[int] = field(default_factory=list)
vae_loaders: List[int] = field(default_factory=list)
samplers: List[int] = field(default_factory=list)
output_nodes: List[int] = field(default_factory=list)
preview_nodes: List[int] = field(default_factory=list)
save_nodes: List[int] = field(default_factory=list)
text_encoders: List[int] = field(default_factory=list)
resolution_controllers: List[int] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
from dataclasses import asdict
return asdict(self)
@dataclass(slots=True)
class WorkflowGraph:
"""Internal representation of a workflow graph."""
nodes_by_id: Dict[int, WorkflowNode]
nodes_by_class: Dict[str, List[WorkflowNode]]
incoming_links: Dict[int, List[WorkflowLink]]
outgoing_links: Dict[int, List[WorkflowLink]]
semantic_summary: SemanticSummary
@classmethod
def from_payload(cls, payload: Mapping[str, Any]) -> "WorkflowGraph":
validate_workflow(payload)
nodes_by_id: Dict[int, WorkflowNode] = {}
nodes_by_class: Dict[str, List[WorkflowNode]] = {}
incoming_links: Dict[int, List[WorkflowLink]] = {}
outgoing_links: Dict[int, List[WorkflowLink]] = {}
for node_data in payload.get("nodes", []):
node = WorkflowNode(
id=int(node_data["id"]),
class_type=str(node_data.get("class_type") or node_data.get("type")),
raw=node_data,
)
nodes_by_id[node.id] = node
nodes_by_class.setdefault(node.class_type, []).append(node)
for link_data in payload.get("links", []):
link = WorkflowLink(
id=int(link_data["id"]),
from_node=int(link_data["from_node"]),
from_slot=int(link_data["from_slot"]),
to_node=int(link_data["to_node"]),
to_slot=int(link_data["to_slot"]),
raw=link_data,
)
outgoing_links.setdefault(link.from_node, []).append(link)
incoming_links.setdefault(link.to_node, []).append(link)
summary = build_semantic_summary(nodes_by_id, incoming_links, outgoing_links)
return cls(nodes_by_id, nodes_by_class, incoming_links, outgoing_links, summary)
def to_dict(self) -> Dict[str, Any]:
return {
"nodes": [node.raw for node in self.nodes_by_id.values()],
"links": [link.raw for node_links in self.outgoing_links.values() for link in node_links],
}
@dataclass(slots=True)
class WorkflowTemplate:
"""Container for a workflow template loaded from disk."""
name: str
path: Path
payload: Mapping[str, Any]
graph: WorkflowGraph
description: Optional[str] = None
tags: List[str] = field(default_factory=list)
@classmethod
def from_file(cls, path: Path) -> "WorkflowTemplate":
payload = json.loads(path.read_text())
graph = WorkflowGraph.from_payload(payload)
metadata = (payload.get("extra") or {}).get("metadata") or {}
name = metadata.get("name") or path.stem
description = metadata.get("description")
tags = list(metadata.get("tags", [])) if isinstance(metadata.get("tags"), Iterable) else []
return cls(name=name, path=path, payload=payload, graph=graph, description=description, tags=tags)
def summary(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"tags": self.tags,
"nodes": len(self.graph.nodes_by_id),
"semantic_summary": self.graph.semantic_summary.to_dict(),
}
class WorkflowDiscovery:
"""Discovery and caching for workflow templates."""
def __init__(self, directory: Path) -> None:
self.directory = directory
self._templates: Dict[str, WorkflowTemplate] = {}
def load(self) -> None:
self._templates.clear()
if not self.directory.exists():
return
for path in sorted(self.directory.glob("*.json")):
try:
template = WorkflowTemplate.from_file(path)
except Exception as exc: # pragma: no cover - defensive
# Skip invalid workflow files while logging enough context for debugging.
print(f"Failed to load workflow {path}: {exc}")
continue
self._templates[template.name] = template
def refresh(self) -> None:
self.load()
def list_templates(self) -> List[WorkflowTemplate]:
if not self._templates:
self.load()
return sorted(self._templates.values(), key=lambda template: template.name)
def get(self, name: str) -> Optional[WorkflowTemplate]:
if not self._templates:
self.load()
return self._templates.get(name)
def build_semantic_summary(
nodes_by_id: Mapping[int, WorkflowNode],
incoming: Mapping[int, List[WorkflowLink]],
outgoing: Mapping[int, List[WorkflowLink]],
) -> SemanticSummary:
summary = SemanticSummary()
for node in nodes_by_id.values():
class_type = normalize_class_name(node.class_type)
inputs = node.inputs
if is_checkpoint_loader(class_type, inputs):
summary.checkpoint_loaders.append(node.id)
if is_lora_loader(class_type, inputs):
summary.lora_loaders.append(node.id)
if is_vae_loader(class_type, inputs):
summary.vae_loaders.append(node.id)
if is_sampler_node(class_type):
summary.samplers.append(node.id)
if is_text_encoder_node(class_type, inputs):
summary.text_encoders.append(node.id)
label = infer_prompt_role(node, nodes_by_id, incoming, outgoing)
summary.prompts.setdefault(label, node.id)
if is_resolution_controller(class_type, inputs):
summary.resolution_controllers.append(node.id)
if is_output_node(class_type):
summary.output_nodes.append(node.id)
if is_preview_output(class_type):
summary.preview_nodes.append(node.id)
if is_save_output(class_type):
summary.save_nodes.append(node.id)
return summary
def infer_prompt_role(
node: WorkflowNode,
nodes_by_id: Mapping[int, WorkflowNode],
incoming: Mapping[int, List[WorkflowLink]],
outgoing: Mapping[int, List[WorkflowLink]],
) -> str:
"""Infer whether a prompt encoder is positive or negative."""
for link in outgoing.get(node.id, []):
target_slot = link.to_slot
target_node = nodes_by_id.get(link.to_node)
if target_slot == 1 or (
target_node and is_sampler_node(normalize_class_name(target_node.class_type)) and target_slot == 0
):
return "prompt_positive"
if target_slot == 2:
return "prompt_negative"
text = node.inputs.get("text")
if isinstance(text, str) and any(term in text.lower() for term in ("negative", "undesired")):
return "prompt_negative"
return "prompt_positive"
def normalize_class_name(name: str) -> str:
return (name or "").replace(" ", "").lower()
def _matches_keywords(value: str, keywords: Sequence[str]) -> bool:
return any(keyword in value for keyword in keywords)
CHECKPOINT_KEYWORDS = ("checkpointloader", "loadcheckpoint", "checkpointfix", "modelloader")
LORA_KEYWORDS = ("loraloader", "applylora", "lora")
VAE_KEYWORDS = ("vaeloader", "loadvae", "vae")
SAMPLER_KEYWORDS = ("ksampler", "sampler", "samplernode")
TEXT_ENCODER_KEYWORDS = ("cliptextencode", "sdxlclipencode", "textencode", "promptencoder")
PREVIEW_KEYWORDS = ("preview", "display", "viewimage", "showimage")
SAVE_KEYWORDS = ("saveimage", "imagesaver", "imageoutput", "saveoutput")
RESOLUTION_KEYWORDS = ("emptylatentimage", "latent", "resolution", "imageinput", "tiledlatent")
def is_checkpoint_loader(class_type: str, inputs: Mapping[str, Any]) -> bool:
return _matches_keywords(class_type, CHECKPOINT_KEYWORDS) or "ckpt_name" in inputs
def is_lora_loader(class_type: str, inputs: Mapping[str, Any]) -> bool:
return _matches_keywords(class_type, LORA_KEYWORDS) or "lora_name" in inputs
def is_vae_loader(class_type: str, inputs: Mapping[str, Any]) -> bool:
if _matches_keywords(class_type, VAE_KEYWORDS):
return True
return any(key in inputs for key in ("vae_name", "vae"))
def is_sampler_node(class_type: str) -> bool:
return _matches_keywords(class_type, SAMPLER_KEYWORDS)
def is_text_encoder_node(class_type: str, inputs: Mapping[str, Any]) -> bool:
if _matches_keywords(class_type, TEXT_ENCODER_KEYWORDS):
return True
if "text" in inputs:
other_keys = {key for key in inputs if key != "text"}
return bool(other_keys & {"clip", "g", "clip_g", "clip_l", "conditioning"})
return False
def is_resolution_controller(class_type: str, inputs: Mapping[str, Any]) -> bool:
if any(key in inputs for key in ("width", "height", "size")):
return True
return _matches_keywords(class_type, RESOLUTION_KEYWORDS)
def is_output_node(class_type: str) -> bool:
return is_preview_output(class_type) or is_save_output(class_type)
def is_preview_output(class_type: str) -> bool:
return _matches_keywords(class_type, PREVIEW_KEYWORDS)
def is_save_output(class_type: str) -> bool:
return _matches_keywords(class_type, SAVE_KEYWORDS)
def validate_workflow(payload: Mapping[str, Any]) -> None:
if not isinstance(payload, Mapping):
raise ValueError("Workflow payload must be a mapping")
nodes = payload.get("nodes")
links = payload.get("links")
if not isinstance(nodes, list) or not all(isinstance(node, Mapping) for node in nodes):
raise ValueError("Workflow payload must include a list of node mappings")
if links is not None and not isinstance(links, list):
raise ValueError("Workflow links must be a list when provided")
__all__ = [
"WorkflowGraph",
"WorkflowTemplate",
"WorkflowDiscovery",
"WorkflowNode",
"WorkflowLink",
"SemanticSummary",
]