mutations.py•8.61 kB
"""Workflow mutation helpers."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence
from .workflow import WorkflowGraph, WorkflowTemplate
@dataclass
class MutationDiff:
changes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
def record(self, category: str, before: Any, after: Any) -> None:
self.changes[category] = {"before": before, "after": after}
class WorkflowMutator:
"""Apply high-level mutations to a workflow template."""
def __init__(self, template: WorkflowTemplate) -> None:
self.template = template
self.payload = _deep_copy(template.payload)
self.graph = WorkflowGraph.from_payload(self.payload)
self.diff = MutationDiff()
def update_prompt(self, role: str, text: str) -> None:
node_id = self.graph.semantic_summary.prompts.get(role)
if node_id is None:
raise KeyError(f"Workflow does not define a prompt node for role '{role}'")
node = self._get_node_payload(node_id)
before = node.get("inputs", {}).get("text")
node.setdefault("inputs", {})["text"] = text
self.diff.record(f"prompt.{role}", before, text)
def set_checkpoint(self, checkpoint_name: str) -> None:
loaders = self.graph.semantic_summary.checkpoint_loaders
if not loaders:
raise KeyError("Workflow does not include a checkpoint loader")
for node_id in loaders:
node = self._get_node_payload(node_id)
before = node.get("inputs", {}).get("ckpt_name")
node.setdefault("inputs", {})["ckpt_name"] = checkpoint_name
self.diff.record(f"checkpoint.{node_id}", before, checkpoint_name)
def set_cfg(self, value: float) -> None:
for node_id in self.graph.semantic_summary.samplers:
node = self._get_node_payload(node_id)
before = node.get("inputs", {}).get("cfg")
node.setdefault("inputs", {})["cfg"] = value
self.diff.record(f"sampler.{node_id}.cfg", before, value)
def set_steps(self, value: int) -> None:
for node_id in self.graph.semantic_summary.samplers:
node = self._get_node_payload(node_id)
before = node.get("inputs", {}).get("steps")
node.setdefault("inputs", {})["steps"] = value
self.diff.record(f"sampler.{node_id}.steps", before, value)
def set_seed(self, value: int) -> None:
for node_id in self.graph.semantic_summary.samplers:
node = self._get_node_payload(node_id)
before = node.get("inputs", {}).get("seed")
node.setdefault("inputs", {})["seed"] = value
self.diff.record(f"sampler.{node_id}.seed", before, value)
def set_vae(self, vae_name: str) -> None:
loaders = self.graph.semantic_summary.vae_loaders
if not loaders:
raise KeyError("Workflow does not include a VAE loader")
for node_id in loaders:
node = self._get_node_payload(node_id)
inputs = node.setdefault("inputs", {})
key = _find_existing_key(inputs, ("vae_name", "vae", "name")) or "vae_name"
before = inputs.get(key)
inputs[key] = vae_name
self.diff.record(f"vae.{node_id}.{key}", before, vae_name)
def configure_loras(self, specs: Sequence[Mapping[str, Any]]) -> None:
if not specs:
return
loaders = list(self.graph.semantic_summary.lora_loaders)
if not loaders:
raise KeyError("Workflow does not include a LoRA loader")
remaining = loaders.copy()
for index, spec in enumerate(specs):
if not isinstance(spec, Mapping):
raise TypeError("LoRA configuration entries must be mappings")
target_id = _resolve_target_node(spec, remaining, loaders)
if target_id in remaining:
remaining.remove(target_id)
node = self._get_node_payload(target_id)
inputs = node.setdefault("inputs", {})
name_value = spec.get("name") or spec.get("lora_name")
if name_value is not None:
key = _find_existing_key(inputs, ("lora_name", "name")) or "lora_name"
before = inputs.get(key)
inputs[key] = str(name_value)
self.diff.record(f"lora.{target_id}.{key}", before, inputs[key])
if "strength" in spec:
strength = float(spec["strength"])
before = inputs.get("strength_model")
inputs["strength_model"] = strength
self.diff.record(f"lora.{target_id}.strength_model", before, strength)
clip_strength = _coalesce(spec, ("clip_strength", "strength_clip"))
if clip_strength is not None:
value = float(clip_strength)
before = inputs.get("strength_clip")
inputs["strength_clip"] = value
self.diff.record(f"lora.{target_id}.strength_clip", before, value)
def set_resolution(self, *, width: Optional[int] = None, height: Optional[int] = None) -> None:
if width is None and height is None:
return
targets: Iterable[int]
if self.graph.semantic_summary.resolution_controllers:
targets = list(dict.fromkeys(self.graph.semantic_summary.resolution_controllers))
else:
targets = [
node.id
for node in self.graph.nodes_by_id.values()
if any(key in node.inputs for key in ("width", "height", "size"))
]
if not targets:
raise KeyError("Workflow does not include a node controlling resolution")
for node_id in targets:
node = self._get_node_payload(node_id)
inputs = node.setdefault("inputs", {})
if width is not None:
before = inputs.get("width")
inputs["width"] = width
self.diff.record(f"resolution.{node_id}.width", before, width)
if height is not None:
before = inputs.get("height")
inputs["height"] = height
self.diff.record(f"resolution.{node_id}.height", before, height)
if "size" in inputs and isinstance(inputs["size"], (list, tuple)):
before_size = list(inputs["size"])
new_size = list(before_size)
if width is not None:
_ensure_length(new_size, 2)
new_size[0] = width
if height is not None:
_ensure_length(new_size, 2)
new_size[1] = height
inputs["size"] = new_size
self.diff.record(f"resolution.{node_id}.size", before_size, new_size)
def apply(self) -> WorkflowTemplate:
return WorkflowTemplate(
name=self.template.name,
path=self.template.path,
payload=self.payload,
graph=WorkflowGraph.from_payload(self.payload),
description=self.template.description,
tags=self.template.tags,
)
def _get_node_payload(self, node_id: int) -> MutableMapping[str, Any]:
for node in self.payload["nodes"]:
if int(node["id"]) == node_id:
return node
raise KeyError(f"Node {node_id} not found in workflow payload")
def _deep_copy(payload: Mapping[str, Any]) -> Dict[str, Any]:
import copy
return copy.deepcopy(payload)
def _find_existing_key(inputs: Mapping[str, Any], candidates: Sequence[str]) -> Optional[str]:
for key in candidates:
if key in inputs:
return key
return None
def _resolve_target_node(spec: Mapping[str, Any], remaining: Sequence[int], all_nodes: Sequence[int]) -> int:
target = spec.get("node") or spec.get("node_id")
if target is not None:
node_id = int(target)
if node_id not in all_nodes:
raise KeyError(f"Workflow does not include LoRA node {node_id}")
return node_id
if not remaining:
raise ValueError("More LoRA configurations provided than available LoRA nodes")
return remaining[0]
def _coalesce(spec: Mapping[str, Any], keys: Sequence[str]) -> Optional[Any]:
for key in keys:
if key in spec:
return spec[key]
return None
def _ensure_length(values: list, size: int) -> None:
while len(values) < size:
values.append(0)
__all__ = ["WorkflowMutator", "MutationDiff"]