"""Model management tools for ComfyUI MCP Server
Thin MCP wrapper for model upload/download functionality.
NO business logic - delegates to route functions.
"""
from fastmcp import FastMCP
from src.routes.model_upload import (
ModelType,
delete_model,
download_model_from_url,
list_installed_models,
)
from src.utils import get_global_logger
logger = get_global_logger("MCP_Server.tools.model_management")
def register_model_management_tools(
mcp: FastMCP,
comfyui_path: str | None = None,
):
"""Register model management tools with the MCP server
Args:
mcp: FastMCP server instance
comfyui_path: Optional default ComfyUI path (can be overridden per call)
"""
@mcp.tool()
async def download_model(
url: str,
filename: str | None = None,
model_type: str = "checkpoint",
overwrite: bool = False,
comfyui_path_override: str | None = None,
) -> dict:
"""Download a model from a URL and install it to ComfyUI.
Downloads models from URLs (e.g., HuggingFace, CivitAI) and places them
in the appropriate ComfyUI models directory.
Args:
url: URL to download the model from (e.g., "https://huggingface.co/.../model.safetensors")
filename: Optional custom filename. If not provided, extracted from URL.
model_type: Type of model - one of: "checkpoint", "lora", "vae", "controlnet",
"upscaler", "embedding", "hypernetwork", "clip", "clip_vision", "unet", "style_models"
overwrite: If True, overwrite existing file. If False, error if file exists. Default: False
comfyui_path_override: Optional ComfyUI installation path. If not provided, uses
COMFYUI_PATH or COMFYUI_MODELS_PATH env var, or auto-detects.
Returns:
Dict with download results including path, size, and model type.
Examples:
Download SD XL base model:
>>> download_model(
... url="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors",
... model_type="checkpoint"
... )
Download a LoRA:
>>> download_model(
... url="https://civitai.com/api/download/models/12345",
... filename="custom_lora.safetensors",
... model_type="lora"
... )
"""
try:
# Map string model_type to enum
model_type_enum = ModelType(model_type.lower())
except ValueError:
valid_types = [t.value for t in ModelType]
return {
"error": f"Invalid model_type '{model_type}'. Must be one of: {', '.join(valid_types)}"
}
try:
result = await download_model_from_url(
url=url,
filename=filename,
model_type=model_type_enum,
overwrite=overwrite,
comfyui_path=comfyui_path_override or comfyui_path,
)
return result
except Exception as e:
logger.error(f"Failed to download model: {e}", exc_info=True)
return {"error": str(e)}
@mcp.tool()
async def list_models_by_type(
model_type: str | None = None,
comfyui_path_override: str | None = None,
) -> dict:
"""List all installed models in ComfyUI, optionally filtered by type.
Scans the ComfyUI models directory and returns information about
installed models including filename, size, and modification time.
Args:
model_type: Optional filter by model type - one of: "checkpoint", "lora", "vae",
"controlnet", "upscaler", "embedding", "hypernetwork", "clip",
"clip_vision", "unet", "style_models". If not provided, lists all types.
comfyui_path_override: Optional ComfyUI installation path. If not provided, uses
COMFYUI_PATH or COMFYUI_MODELS_PATH env var, or auto-detects.
Returns:
Dict with models organized by type, total count, and list of types with models.
Examples:
List all checkpoints:
>>> list_models_by_type(model_type="checkpoint")
List all installed models:
>>> list_models_by_type()
"""
model_type_enum = None
if model_type:
try:
model_type_enum = ModelType(model_type.lower())
except ValueError:
valid_types = [t.value for t in ModelType]
return {
"error": f"Invalid model_type '{model_type}'. Must be one of: {', '.join(valid_types)}"
}
try:
result = await list_installed_models(
model_type=model_type_enum,
comfyui_path=comfyui_path_override or comfyui_path,
)
return result
except Exception as e:
logger.error(f"Failed to list models: {e}", exc_info=True)
return {"error": str(e)}
@mcp.tool()
async def remove_model(
filename: str,
model_type: str,
comfyui_path_override: str | None = None,
) -> dict:
"""Delete a model file from ComfyUI.
Removes the specified model file from the ComfyUI models directory.
Use with caution - this action cannot be undone.
Args:
filename: Name of the file to delete (e.g., "old_model.safetensors")
model_type: Type of model - one of: "checkpoint", "lora", "vae", "controlnet",
"upscaler", "embedding", "hypernetwork", "clip", "clip_vision", "unet", "style_models"
comfyui_path_override: Optional ComfyUI installation path. If not provided, uses
COMFYUI_PATH or COMFYUI_MODELS_PATH env var, or auto-detects.
Returns:
Dict with deletion status and deleted file path.
Example:
>>> remove_model(
... filename="old_checkpoint.safetensors",
... model_type="checkpoint"
... )
"""
try:
model_type_enum = ModelType(model_type.lower())
except ValueError:
valid_types = [t.value for t in ModelType]
return {
"error": f"Invalid model_type '{model_type}'. Must be one of: {', '.join(valid_types)}"
}
try:
result = await delete_model(
filename=filename,
model_type=model_type_enum,
comfyui_path=comfyui_path_override or comfyui_path,
)
return result
except Exception as e:
logger.error(f"Failed to delete model: {e}", exc_info=True)
return {"error": str(e)}
@mcp.tool()
def get_model_types() -> dict:
"""Get a list of supported model types.
Returns all supported model types that can be used with the model management tools.
Returns:
Dict with list of supported model types and their descriptions.
"""
types_info = {
"checkpoint": "Main Stable Diffusion checkpoint models (.safetensors, .ckpt)",
"lora": "LoRA (Low-Rank Adaptation) models for fine-tuning",
"vae": "VAE (Variational AutoEncoder) models",
"controlnet": "ControlNet models for guided generation",
"upscaler": "Upscaling models (ESRGAN, Real-ESRGAN, etc.)",
"embedding": "Text embeddings (textual inversion)",
"hypernetwork": "Hypernetwork models",
"clip": "CLIP text encoder models",
"clip_vision": "CLIP vision models",
"unet": "U-Net models",
"style_models": "Style transfer models",
}
return {
"supported_types": list(types_info.keys()),
"descriptions": types_info,
"count": len(types_info),
}