"""Node definition orchestration.
Manages ComfyUI node schema discovery, caching, and querying.
Enables AI agents to understand available nodes and their parameters.
Follows FastMCP v3 best practices:
- Async methods for proper Context integration
- Progress reporting during long operations
- Context-aware logging
"""
import time
from typing import Any
from fastmcp import Context
from src.auth.base import ComfyAuth
from src.models.node import NodeDefinition, NodeInputSpec, NodeOutputSpec
from src.routes.models import get_node_definitions
from src.utils import get_global_logger
logger = get_global_logger("MCP_Server.orchestrators.node")
class NodeOrchestrator:
"""Orchestrator for node definition management.
Coordinates node schema discovery, parsing, and caching.
Provides search and query capabilities for workflow construction.
Attributes:
auth: Authentication for ComfyUI API calls
_cache: In-memory cache of node definitions
_cache_timestamp: When cache was last populated
_cache_ttl: Cache time-to-live in seconds (default: 3600 = 1 hour)
"""
def __init__(self, auth: ComfyAuth, cache_ttl: int = 3600):
"""Initialize NodeOrchestrator.
Args:
auth: ComfyAuth instance for API access
cache_ttl: Cache TTL in seconds (default: 1 hour)
"""
self.auth = auth
self._cache: dict[str, NodeDefinition] = {}
self._cache_timestamp: float = 0
self._cache_ttl = cache_ttl
self.logger = get_global_logger(f"{__name__}.{self.__class__.__name__}")
def _is_cache_valid(self) -> bool:
"""Check if cache is still valid."""
if not self._cache:
return False
age = time.time() - self._cache_timestamp
return age < self._cache_ttl
def _parse_node_definition(self, class_type: str, raw_data: dict[str, Any]) -> NodeDefinition:
"""Parse raw ComfyUI node data into NodeDefinition.
Args:
class_type: Node class identifier
raw_data: Raw node data from /object_info endpoint
Returns:
Parsed NodeDefinition object
"""
# Extract basic metadata
display_name = raw_data.get("display_name", class_type)
category = raw_data.get("category", "")
description = raw_data.get("description", "")
python_module = raw_data.get("python_module", "")
# Parse input types
input_types_raw = raw_data.get("input", {})
input_types: dict[str, list[NodeInputSpec]] = {}
for req_type in ["required", "optional"]:
if req_type in input_types_raw:
specs = []
for input_name, input_spec in input_types_raw[req_type].items():
# Parse input specification
# Format: (type_list, {options}) or just type_list
if isinstance(input_spec, tuple | list):
type_info = input_spec[0] if input_spec else []
options_dict = input_spec[1] if len(input_spec) > 1 else {}
else:
type_info = input_spec
options_dict = {}
# Determine type and options
if isinstance(type_info, list):
# Enum/select with predefined options
param_type = "select"
options = type_info
else:
# Single type or type name
param_type = str(type_info)
options = None
# Extract default value
default = options_dict.get("default")
spec = NodeInputSpec(
name=input_name,
type=param_type,
required=(req_type == "required"),
default=default,
options=options,
description=f"Input parameter '{input_name}'",
)
specs.append(spec)
input_types[req_type] = specs
# Parse output types
output_types: list[NodeOutputSpec] = []
return_types = raw_data.get("output", [])
return_names = raw_data.get("output_name", [])
for idx, output_type in enumerate(return_types):
output_name = return_names[idx] if idx < len(return_names) else ""
output_types.append(NodeOutputSpec(index=idx, type=output_type, name=output_name))
return NodeDefinition(
class_type=class_type,
display_name=display_name,
category=category,
description=description,
input_types=input_types,
output_types=output_types,
return_types=return_types,
return_names=return_names,
python_module=python_module,
)
async def _fetch_and_cache_nodes(self, ctx: Context | None = None) -> None:
"""Fetch all nodes from ComfyUI and populate cache.
Args:
ctx: Optional FastMCP Context for progress reporting
"""
self.logger.info("Fetching node definitions from ComfyUI")
if ctx:
await ctx.info("Fetching node definitions from ComfyUI")
await ctx.report_progress(0.0, "Starting node discovery")
start_time = time.time()
try:
if ctx:
await ctx.report_progress(0.1, "Requesting node definitions")
res = await get_node_definitions(auth=self.auth, node_class=None)
if not res.is_success:
error_msg = f"Failed to fetch nodes: HTTP {res.status}"
self.logger.error(error_msg)
if ctx:
await ctx.error(error_msg)
return
if ctx:
await ctx.report_progress(0.3, "Parsing node definitions")
# Parse all node definitions
raw_nodes = res.response
self._cache.clear()
total_nodes = len(raw_nodes)
parsed_count = 0
for class_type, node_data in raw_nodes.items():
try:
definition = self._parse_node_definition(class_type, node_data)
self._cache[class_type] = definition
parsed_count += 1
# Report progress every 50 nodes
if ctx and parsed_count % 50 == 0:
progress = 0.3 + (0.6 * parsed_count / total_nodes)
await ctx.report_progress(
progress, f"Parsed {parsed_count}/{total_nodes} nodes"
)
except Exception as e:
self.logger.warning(f"Failed to parse node '{class_type}': {e}", exc_info=False)
self._cache_timestamp = time.time()
duration = time.time() - start_time
success_msg = f"Cached {len(self._cache)} node definitions in {duration:.2f}s"
self.logger.info(success_msg)
if ctx:
await ctx.report_progress(1.0, "Node discovery complete")
await ctx.info(success_msg)
except Exception as e:
error_msg = f"Error fetching nodes: {e}"
self.logger.error(error_msg, exc_info=True)
if ctx:
await ctx.error(error_msg)
async def get_all_nodes(
self, force_refresh: bool = False, ctx: Context | None = None
) -> dict[str, NodeDefinition]:
"""Get all available node definitions.
Args:
force_refresh: Force cache refresh even if valid
ctx: Optional FastMCP Context for progress reporting
Returns:
Dictionary mapping class_type -> NodeDefinition
"""
if force_refresh or not self._is_cache_valid():
self.logger.debug("Cache miss or refresh requested, fetching nodes")
if ctx:
await ctx.info("Cache miss or refresh requested")
await self._fetch_and_cache_nodes(ctx=ctx)
else:
cache_age = time.time() - self._cache_timestamp
self.logger.debug(f"Cache hit (age: {cache_age:.0f}s)")
if ctx:
await ctx.debug(f"Using cached nodes (age: {cache_age:.0f}s)")
return self._cache.copy()
async def get_node_by_class(
self, class_type: str, ctx: Context | None = None
) -> NodeDefinition | None:
"""Get definition for a specific node class.
Args:
class_type: Node class identifier (e.g., "KSampler")
ctx: Optional FastMCP Context for progress reporting
Returns:
NodeDefinition if found, None otherwise
"""
# Ensure cache is populated
if not self._is_cache_valid():
await self.get_all_nodes(ctx=ctx)
return self._cache.get(class_type)
async def search_nodes(
self,
query: str | None = None,
category: str | None = None,
has_category: bool = False,
ctx: Context | None = None,
) -> list[NodeDefinition]:
"""Search for nodes by query string or category.
Args:
query: Search string (matches class_type, display_name, category)
category: Filter by specific category
has_category: If True, only return nodes with a category set
ctx: Optional FastMCP Context for progress reporting
Returns:
List of matching NodeDefinition objects
"""
# Ensure cache is populated
if not self._is_cache_valid():
await self.get_all_nodes(ctx=ctx)
results = list(self._cache.values())
# Filter by category
if category:
results = [n for n in results if n.category == category]
elif has_category:
results = [n for n in results if n.category]
# Filter by query (case-insensitive substring match)
if query:
query_lower = query.lower()
results = [
n
for n in results
if query_lower in n.class_type.lower()
or query_lower in n.display_name.lower()
or query_lower in n.category.lower()
or query_lower in n.description.lower()
]
self.logger.debug(f"Search returned {len(results)} nodes")
if ctx:
await ctx.debug(f"Search returned {len(results)} nodes")
return results
async def get_categories(self, ctx: Context | None = None) -> list[str]:
"""Get list of all unique node categories.
Args:
ctx: Optional FastMCP Context for progress reporting
Returns:
Sorted list of category names
"""
# Ensure cache is populated
if not self._is_cache_valid():
await self.get_all_nodes(ctx=ctx)
categories = {n.category for n in self._cache.values() if n.category}
return sorted(categories)
async def get_node_inputs(
self, class_type: str, ctx: Context | None = None
) -> dict[str, list[NodeInputSpec]]:
"""Get input specifications for a node.
Args:
class_type: Node class identifier
ctx: Optional FastMCP Context for progress reporting
Returns:
Dictionary with 'required' and 'optional' input lists
"""
node = await self.get_node_by_class(class_type, ctx=ctx)
if not node:
self.logger.warning(f"Node '{class_type}' not found")
if ctx:
await ctx.warning(f"Node '{class_type}' not found")
return {"required": [], "optional": []}
return node.input_types
async def get_nodes_by_output_type(
self, output_type: str, ctx: Context | None = None
) -> list[NodeDefinition]:
"""Find nodes that produce a specific output type.
Args:
output_type: Output type string (e.g., "IMAGE", "LATENT", "MODEL")
ctx: Optional FastMCP Context for progress reporting
Returns:
List of nodes that output the specified type
"""
# Ensure cache is populated
if not self._is_cache_valid():
await self.get_all_nodes(ctx=ctx)
results = [n for n in self._cache.values() if output_type in n.return_types]
self.logger.debug(f"Found {len(results)} nodes with output type '{output_type}'")
if ctx:
await ctx.debug(f"Found {len(results)} nodes with output type '{output_type}'")
return results
def clear_cache(self) -> None:
"""Manually clear the node definition cache."""
self.logger.info("Clearing node definition cache")
self._cache.clear()
self._cache_timestamp = 0