"""Tests for the ComfyUI MCP Server tool handlers."""
import json
from unittest.mock import AsyncMock, patch
import pytest
from src.models import GenerationResult, ModelInfo, QueueStatus
from src.server import generate_image, get_generation, get_queue_status, list_models
@pytest.fixture()
def mock_client() -> AsyncMock:
"""Return a mock ComfyUIClient."""
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
return client
# --- generate_image ---
@patch("src.server._get_client")
async def test_generate_image_calls_client(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.submit_workflow.return_value = "prompt-abc"
client.poll_until_complete.return_value = GenerationResult(
prompt_id="prompt-abc",
status="completed",
images=["output_001.png"],
elapsed_seconds=3.5,
)
mock_get_client.return_value = client
result_str = await generate_image(prompt="a sunset over mountains")
result = json.loads(result_str)
assert result["prompt_id"] == "prompt-abc"
assert result["status"] == "completed"
assert result["images"] == ["output_001.png"]
assert result["elapsed_seconds"] == 3.5
client.submit_workflow.assert_called_once()
client.poll_until_complete.assert_called_once_with("prompt-abc")
@patch("src.server._get_client")
async def test_generate_image_passes_workflow_params(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.submit_workflow.return_value = "prompt-xyz"
client.poll_until_complete.return_value = GenerationResult(
prompt_id="prompt-xyz",
status="completed",
images=[],
elapsed_seconds=1.0,
)
mock_get_client.return_value = client
await generate_image(
prompt="test",
negative_prompt="ugly",
model="custom.safetensors",
width=768,
height=768,
steps=30,
cfg_scale=9.0,
seed=42,
)
# Verify the workflow was submitted (the workflow dict is built by build_txt2img_workflow)
call_args = client.submit_workflow.call_args
workflow = call_args[0][0]
# The workflow should contain our custom model
assert workflow["4"]["inputs"]["ckpt_name"] == "custom.safetensors"
assert workflow["5"]["inputs"]["width"] == 768
assert workflow["3"]["inputs"]["seed"] == 42
# --- list_models ---
@patch("src.server._get_client")
async def test_list_models_returns_formatted(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.list_checkpoints.return_value = [
ModelInfo(name="sd_xl_base.safetensors", filename="sd_xl_base.safetensors", type="checkpoint"),
ModelInfo(name="v1-5-pruned.safetensors", filename="v1-5-pruned.safetensors", type="checkpoint"),
]
mock_get_client.return_value = client
result_str = await list_models()
result = json.loads(result_str)
assert isinstance(result, list)
assert len(result) == 2
assert result[0]["name"] == "sd_xl_base.safetensors"
assert result[1]["name"] == "v1-5-pruned.safetensors"
assert all(m["type"] == "checkpoint" for m in result)
@patch("src.server._get_client")
async def test_list_models_empty(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.list_checkpoints.return_value = []
mock_get_client.return_value = client
result_str = await list_models()
result = json.loads(result_str)
assert result == []
# --- get_queue_status ---
@patch("src.server._get_client")
async def test_get_queue_status_formatted(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.get_queue.return_value = QueueStatus(pending=2, running=1)
mock_get_client.return_value = client
result_str = await get_queue_status()
result = json.loads(result_str)
assert result["pending"] == 2
assert result["running"] == 1
assert "completed" not in result
@patch("src.server._get_client")
async def test_get_queue_status_empty(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.get_queue.return_value = QueueStatus(pending=0, running=0)
mock_get_client.return_value = client
result_str = await get_queue_status()
result = json.loads(result_str)
assert result["pending"] == 0
assert result["running"] == 0
# --- get_generation ---
@patch("src.server._get_client")
async def test_get_generation_found(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.get_history.return_value = {
"outputs": {
"9": {
"images": [
{"filename": "ComfyUI-MCP_00001_.png", "subfolder": "", "type": "output"}
]
}
}
}
mock_get_client.return_value = client
result_str = await get_generation(prompt_id="test-id-123")
result = json.loads(result_str)
assert result["prompt_id"] == "test-id-123"
assert result["status"] == "completed"
assert "ComfyUI-MCP_00001_.png" in result["images"]
@patch("src.server._get_client")
async def test_get_generation_not_found(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.get_history.return_value = None
mock_get_client.return_value = client
result_str = await get_generation(prompt_id="missing-id")
result = json.loads(result_str)
assert result["prompt_id"] == "missing-id"
assert result["status"] == "not_found"
@patch("src.server._get_client")
async def test_get_generation_pending_no_outputs(mock_get_client: AsyncMock) -> None:
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.get_history.return_value = {"outputs": {}}
mock_get_client.return_value = client
result_str = await get_generation(prompt_id="pending-id")
result = json.loads(result_str)
assert result["prompt_id"] == "pending-id"
assert result["status"] == "pending"
assert result["images"] == []