"""Defaults management for workflow parameters.
Implements precedence chain: per-call > runtime > config > env > hardcoded
Decoupled from ComfyUIClient using dependency injection for model validation.
"""
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import Any
from src.utils import get_global_logger
logger = get_global_logger("ComfyUI_MCP.orchestrators.defaults")
# Configuration paths
CONFIG_DIR = Path.home() / ".config" / "comfy-mcp"
CONFIG_FILE = CONFIG_DIR / "config.json"
class NullDefaultsManager:
"""Null Object implementation of DefaultsManager.
Used when defaults management is not needed. Always returns the provided
fallback value without any validation or storage.
FastMCP v3 Best Practice: Use Null Object pattern instead of Optional dependencies.
"""
def get_default(self, namespace: str, key: str, fallback: Any = None) -> Any:
"""Always return the fallback value.
Args:
namespace: Parameter namespace (ignored)
key: Parameter key (ignored)
fallback: Value to return
Returns:
The fallback value unchanged
"""
return fallback
def get_all_defaults(self) -> dict[str, dict[str, Any]]:
"""Return empty defaults for all namespaces.
Returns:
Empty dictionaries for image, audio, video namespaces
"""
return {"image": {}, "audio": {}, "video": {}}
def set_defaults(
self, namespace: str, defaults: dict[str, Any], validate_models: bool = True
) -> dict[str, Any]:
"""No-op for set_defaults.
Args:
namespace: Parameter namespace (ignored)
defaults: Defaults to set (ignored)
validate_models: Whether to validate (ignored)
Returns:
Success response
"""
return {"success": True, "updated": {}}
def persist_defaults(self, namespace: str, defaults: dict[str, Any]) -> dict[str, Any]:
"""No-op for persist_defaults.
Args:
namespace: Parameter namespace (ignored)
defaults: Defaults to persist (ignored)
Returns:
Success response
"""
return {"success": True, "persisted": {}}
def validate_default_model(self, namespace: str) -> tuple[bool, str, str]:
"""Always validate as true.
Args:
namespace: Parameter namespace (ignored)
Returns:
Tuple of (True, "", "none")
"""
return (True, "", "none")
def validate_all_defaults(self) -> None:
"""No-op for validate_all_defaults."""
pass
def is_model_valid(self, namespace: str, model: str) -> bool:
"""Always return True.
Args:
namespace: Parameter namespace (ignored)
model: Model name (ignored)
Returns:
True
"""
return True
class DefaultsManager:
"""Manages default parameter values with precedence chain.
Precedence Order (highest to lowest):
1. per-call: Values provided in tool calls
2. runtime: Values set via set_defaults()
3. config: Values from ~/.config/comfy-mcp/config.json
4. env: Values from environment variables
5. hardcoded: Built-in default values
Design: Dependency Inversion Principle
- Depends on model_validator: Callable[[str], bool] instead of ComfyUIClient
- Validator function injected at initialization
- Enables testing without ComfyUI connection
"""
def __init__(self, model_validator: Callable[[str], bool] | None = None):
"""Initialize defaults manager.
Args:
model_validator: Optional function to validate model names.
Should return True if model exists, False otherwise.
If None, model validation is skipped.
"""
self.model_validator = model_validator
# Runtime defaults (highest priority after per-call)
self._runtime_defaults: dict[str, dict[str, Any]] = {"image": {}, "audio": {}, "video": {}}
# Config file defaults
self._config_defaults = self._load_config_defaults()
# Hardcoded defaults (fallback)
self._hardcoded_defaults = {
"image": {
"width": 512,
"height": 512,
"steps": 20,
"cfg": 8.0,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1.0,
"model": "v1-5-pruned-emaonly.ckpt",
"negative_prompt": "text, watermark",
},
"audio": {
"steps": 50,
"cfg": 5.0,
"sampler_name": "euler",
"scheduler": "simple",
"denoise": 1.0,
"seconds": 60,
"lyrics_strength": 0.99,
"model": "ace_step_v1_3.5b.safetensors",
},
"video": {
"width": 1280,
"height": 720,
"steps": 20,
"cfg": 8.0,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1.0,
"negative_prompt": "text, watermark",
"duration": 5,
"fps": 16,
},
}
# Validation state
self._invalid_models: dict[str, str] = {} # namespace -> invalid model name
logger.info("Initialized DefaultsManager")
# Validate defaults at startup if validator provided
if self.model_validator:
self.validate_all_defaults()
def _load_config_defaults(self) -> dict[str, dict[str, Any]]:
"""Load defaults from config file.
Returns:
Dictionary of namespaced defaults
"""
defaults = {"image": {}, "audio": {}, "video": {}}
if CONFIG_FILE.exists():
try:
with open(CONFIG_FILE, encoding="utf-8") as f:
config = json.load(f)
defaults["image"] = config.get("defaults", {}).get("image", {})
defaults["audio"] = config.get("defaults", {}).get("audio", {})
defaults["video"] = config.get("defaults", {}).get("video", {})
logger.debug(f"Loaded config defaults from {CONFIG_FILE}")
except (OSError, json.JSONDecodeError) as e:
logger.warning(f"Failed to load config file {CONFIG_FILE}: {e}")
return defaults
def _get_env_defaults(self) -> dict[str, dict[str, Any]]:
"""Load defaults from environment variables.
Environment variables:
- COMFY_MCP_DEFAULT_IMAGE_MODEL
- COMFY_MCP_DEFAULT_AUDIO_MODEL
- COMFY_MCP_DEFAULT_VIDEO_MODEL
Returns:
Dictionary of namespaced defaults from environment
"""
defaults = {"image": {}, "audio": {}, "video": {}}
image_model = os.getenv("COMFY_MCP_DEFAULT_IMAGE_MODEL")
audio_model = os.getenv("COMFY_MCP_DEFAULT_AUDIO_MODEL")
video_model = os.getenv("COMFY_MCP_DEFAULT_VIDEO_MODEL")
if image_model:
defaults["image"]["model"] = image_model
if audio_model:
defaults["audio"]["model"] = audio_model
if video_model:
defaults["video"]["model"] = video_model
return defaults
def get_default(self, namespace: str, key: str, provided_value: Any = None) -> Any:
"""Get default value with precedence chain.
Precedence: provided > runtime > config > env > hardcoded
Args:
namespace: Parameter namespace ("image", "audio", or "video")
key: Parameter key (e.g., "model", "steps")
provided_value: Value provided by caller (highest priority)
Returns:
Resolved value or None if not found in any source
"""
# 1. Provided value (highest priority)
if provided_value is not None:
return provided_value
# 2. Runtime defaults
if key in self._runtime_defaults.get(namespace, {}):
return self._runtime_defaults[namespace][key]
# 3. Config file defaults
if key in self._config_defaults.get(namespace, {}):
return self._config_defaults[namespace][key]
# 4. Environment variables
env_defaults = self._get_env_defaults()
if key in env_defaults.get(namespace, {}):
return env_defaults[namespace][key]
# 5. Hardcoded defaults (lowest priority)
if key in self._hardcoded_defaults.get(namespace, {}):
return self._hardcoded_defaults[namespace][key]
return None
def get_all_defaults(self) -> dict[str, dict[str, Any]]:
"""Get all effective defaults (merged from all sources).
Returns merged defaults for each namespace with precedence applied.
Returns:
Dictionary with "image", "audio", "video" namespaces
"""
env_defaults = self._get_env_defaults()
result = {"image": {}, "audio": {}, "video": {}}
for namespace in ["image", "audio", "video"]:
# Build from lowest to highest priority
result[namespace] = self._hardcoded_defaults[namespace].copy()
result[namespace].update(env_defaults.get(namespace, {}))
result[namespace].update(self._config_defaults.get(namespace, {}))
result[namespace].update(self._runtime_defaults.get(namespace, {}))
return result
def set_defaults(
self, namespace: str, defaults: dict[str, Any], validate_models: bool = True
) -> dict[str, Any]:
"""Set runtime defaults for a namespace.
Args:
namespace: Parameter namespace ("image", "audio", or "video")
defaults: Dictionary of key-value pairs to set
validate_models: Whether to validate model names (default: True)
Returns:
Result dictionary with success/error status
"""
# Validate namespace
if namespace not in ["image", "audio", "video"]:
error_msg = f"Invalid namespace: {namespace}. Must be 'image', 'audio', or 'video'"
logger.error(error_msg)
return {"error": error_msg}
# Validate model names if requested and validator available
if validate_models and "model" in defaults and self.model_validator:
model_name = defaults["model"]
if not self.model_validator(model_name):
error_msg = f"Model '{model_name}' not found in ComfyUI checkpoints"
logger.warning(error_msg)
return {"error": error_msg}
# Update runtime defaults
if namespace not in self._runtime_defaults:
self._runtime_defaults[namespace] = {}
self._runtime_defaults[namespace].update(defaults)
# Clear invalid model flag if setting valid model
if "model" in defaults and validate_models:
if namespace in self._invalid_models:
del self._invalid_models[namespace]
logger.info(f"Set runtime defaults for {namespace}: {list(defaults.keys())}")
return {"success": True, "updated": defaults}
def persist_defaults(self, namespace: str, defaults: dict[str, Any]) -> dict[str, Any]:
"""Persist defaults to config file.
Saves defaults to ~/.config/comfy-mcp/config.json and reloads.
Args:
namespace: Parameter namespace ("image", "audio", or "video")
defaults: Dictionary of key-value pairs to persist
Returns:
Result dictionary with success/error status
"""
# Ensure config directory exists
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
logger.debug(f"Ensured config directory exists: {CONFIG_DIR}")
# Load existing config
config = {}
if CONFIG_FILE.exists():
try:
with open(CONFIG_FILE, encoding="utf-8") as f:
config = json.load(f)
except (OSError, json.JSONDecodeError) as e:
logger.warning(f"Failed to load existing config: {e}")
config = {}
# Update defaults
if "defaults" not in config:
config["defaults"] = {}
if namespace not in config["defaults"]:
config["defaults"][namespace] = {}
config["defaults"][namespace].update(defaults)
# Save config
try:
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
# Reload config defaults
self._config_defaults = self._load_config_defaults()
logger.info(f"Persisted defaults for {namespace} to {CONFIG_FILE}")
return {"success": True, "persisted": defaults}
except OSError as e:
error_msg = f"Failed to write config file: {e}"
logger.error(error_msg)
return {"error": error_msg}
def validate_default_model(self, namespace: str) -> tuple[bool, str, str]:
"""Validate the default model for a namespace.
Args:
namespace: Parameter namespace ("image", "audio", or "video")
Returns:
Tuple of (is_valid, model_name, source)
"""
model_name = self.get_default(namespace, "model")
if not model_name:
return (True, "", "none") # No model default is valid
source = self._get_default_source(namespace, "model")
# Validate model if validator available
if self.model_validator:
is_valid = self.model_validator(model_name)
return (is_valid, model_name, source)
# No validator, assume valid
return (True, model_name, source)
def validate_all_defaults(self) -> None:
"""Validate all default models at startup.
Logs warnings for invalid models but doesn't raise exceptions.
"""
if not self.model_validator:
logger.debug("No model validator provided, skipping validation")
return
logger.info("Validating default models")
for namespace in ["image", "audio", "video"]:
is_valid, model_name, source = self.validate_default_model(namespace)
if not is_valid and model_name:
logger.warning(
f"Default model '{model_name}' (from {source} defaults) "
f"for {namespace} namespace not found in ComfyUI checkpoints. "
f"Set a valid model via `set_defaults`, config file, or env var."
)
self._invalid_models[namespace] = model_name
def is_model_valid(self, namespace: str, model: str) -> bool:
"""Check if a model is valid for a namespace.
Args:
namespace: Parameter namespace ("image", "audio", or "video")
model: Model name to validate
Returns:
True if valid, False otherwise
"""
if not model:
return True # Empty model is valid (will use default)
# Check invalid models cache
if namespace in self._invalid_models and self._invalid_models[namespace] == model:
return False
# Validate with validator if available
if self.model_validator:
return self.model_validator(model)
# No validator, assume valid
return True
def _get_default_source(self, namespace: str, key: str) -> str:
"""Determine where a default value came from.
Args:
namespace: Parameter namespace
key: Parameter key
Returns:
Source name ("runtime", "config", "env", "hardcoded", or "unknown")
"""
# Check runtime defaults (highest priority)
if key in self._runtime_defaults.get(namespace, {}):
return "runtime"
# Check config file defaults
if key in self._config_defaults.get(namespace, {}):
return "config"
# Check environment variables
env_defaults = self._get_env_defaults()
if key in env_defaults.get(namespace, {}):
return "env"
# Check hardcoded defaults (lowest priority)
if key in self._hardcoded_defaults.get(namespace, {}):
return "hardcoded"
return "unknown"