"""Unit tests for batch generation tools."""
import os
import pytest
from unittest.mock import AsyncMock, MagicMock
from Imagen_MCP.tools.batch_generate import (
start_image_batch,
get_next_image,
get_batch_status,
StartBatchOutput,
GetNextImageOutput,
GetBatchStatusOutput,
)
from Imagen_MCP.services.session_manager import SessionManager
from Imagen_MCP.models.session import GenerationSession
from Imagen_MCP.exceptions import SessionNotFoundError, SessionExpiredError
# Base64 encoded 1x1 PNG for testing
TEST_B64_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
class TestStartImageBatchValidation:
"""Tests for start_image_batch input validation."""
@pytest.mark.asyncio
async def test_invalid_model_returns_error(self, tmp_path):
"""start_image_batch should return error for invalid model."""
result = await start_image_batch(
prompt="A beautiful sunset",
output_dir=str(tmp_path),
model="nonexistent-model",
count=3,
)
assert result.success is False
assert result.error is not None
assert "Invalid model" in result.error
@pytest.mark.asyncio
async def test_invalid_size_returns_error(self, tmp_path):
"""start_image_batch should return error for unsupported size."""
result = await start_image_batch(
prompt="A beautiful sunset",
output_dir=str(tmp_path),
model="imagen-4-fast",
size="1792x1024", # Not supported by fast model
count=3,
)
assert result.success is False
assert result.error is not None
assert "not supported" in result.error
@pytest.mark.asyncio
async def test_hd_quality_unsupported_returns_error(self, tmp_path):
"""start_image_batch should return error for HD on unsupported model."""
result = await start_image_batch(
prompt="A beautiful sunset",
output_dir=str(tmp_path),
model="imagen-4-fast",
quality="hd",
count=3,
)
assert result.success is False
assert result.error is not None
assert "HD quality not supported" in result.error
class TestStartImageBatchSuccess:
"""Tests for successful batch start."""
@pytest.mark.asyncio
async def test_start_batch_success(self, tmp_path):
"""start_image_batch should return session ID and first image file path."""
# Create mock session manager
mock_manager = MagicMock(spec=SessionManager)
mock_session = MagicMock(spec=GenerationSession)
mock_session.id = "test-session-123"
mock_manager.create_session.return_value = mock_session
mock_manager.start_generation = AsyncMock()
mock_manager.get_next_image = AsyncMock(return_value={"b64_json": TEST_B64_PNG})
mock_manager.get_session_status.return_value = {
"session_id": "test-session-123",
"status": "partial",
"completed_count": 1,
"pending_count": 2,
"total_count": 3,
"errors": [],
}
result = await start_image_batch(
prompt="A beautiful sunset",
output_dir=str(tmp_path),
model="imagen-4",
count=3,
session_manager=mock_manager,
)
assert result.success is True
assert result.session_id == "test-session-123"
assert result.first_image_path is not None
assert result.first_image_size_bytes is not None
assert result.first_image_size_bytes > 0
assert result.pending_count == 2
# Verify file was created
assert os.path.exists(result.first_image_path)
@pytest.mark.asyncio
async def test_start_batch_first_image_timeout(self, tmp_path):
"""start_image_batch should handle timeout for first image."""
mock_manager = MagicMock(spec=SessionManager)
mock_session = MagicMock(spec=GenerationSession)
mock_session.id = "test-session-123"
mock_manager.create_session.return_value = mock_session
mock_manager.start_generation = AsyncMock()
mock_manager.get_next_image = AsyncMock(return_value=None)
mock_manager.get_session_status.return_value = {
"session_id": "test-session-123",
"status": "generating",
"completed_count": 0,
"pending_count": 3,
"total_count": 3,
"errors": [],
}
result = await start_image_batch(
prompt="A beautiful sunset",
output_dir=str(tmp_path),
model="imagen-4",
count=3,
session_manager=mock_manager,
)
assert result.success is False
assert result.error is not None
assert "timed out" in result.error
class TestGetNextImage:
"""Tests for get_next_image tool."""
@pytest.mark.asyncio
async def test_get_next_image_success(self, tmp_path):
"""get_next_image should save image to file when available."""
output_path = str(tmp_path / "image_002.png")
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_next_image = AsyncMock(return_value={"b64_json": TEST_B64_PNG})
mock_manager.get_session_status.return_value = {
"session_id": "test-session",
"status": "partial",
"completed_count": 2,
"pending_count": 1,
"total_count": 3,
"errors": [],
}
result = await get_next_image(
session_id="test-session",
output_path=output_path,
timeout=30.0,
session_manager=mock_manager,
)
assert result.success is True
assert result.file_path is not None
assert result.file_size_bytes is not None
assert result.file_size_bytes > 0
assert result.has_more is True
assert result.pending_count == 1
# Verify file was created
assert os.path.exists(result.file_path)
@pytest.mark.asyncio
async def test_get_next_image_no_more_images(self, tmp_path):
"""get_next_image should indicate when no more images."""
output_path = str(tmp_path / "image.png")
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_next_image = AsyncMock(return_value=None)
mock_manager.get_session_status.return_value = {
"session_id": "test-session",
"status": "completed",
"completed_count": 3,
"pending_count": 0,
"total_count": 3,
"errors": [],
}
result = await get_next_image(
session_id="test-session",
output_path=output_path,
session_manager=mock_manager,
)
assert result.success is True
assert result.file_path is None
assert result.has_more is False
@pytest.mark.asyncio
async def test_get_next_image_session_not_found(self, tmp_path):
"""get_next_image should handle session not found."""
output_path = str(tmp_path / "image.png")
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_next_image = AsyncMock(
side_effect=SessionNotFoundError("Session not found")
)
result = await get_next_image(
session_id="nonexistent",
output_path=output_path,
session_manager=mock_manager,
)
assert result.success is False
assert result.error is not None
assert "not found" in result.error
@pytest.mark.asyncio
async def test_get_next_image_session_expired(self, tmp_path):
"""get_next_image should handle session expired."""
output_path = str(tmp_path / "image.png")
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_next_image = AsyncMock(
side_effect=SessionExpiredError("Session expired")
)
result = await get_next_image(
session_id="expired-session",
output_path=output_path,
session_manager=mock_manager,
)
assert result.success is False
assert result.error is not None
assert "expired" in result.error
class TestGetBatchStatus:
"""Tests for get_batch_status tool."""
@pytest.mark.asyncio
async def test_get_batch_status_success(self):
"""get_batch_status should return session status."""
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_session_status.return_value = {
"session_id": "test-session",
"status": "partial",
"completed_count": 2,
"pending_count": 1,
"total_count": 3,
"errors": [],
}
result = await get_batch_status(
session_id="test-session",
session_manager=mock_manager,
)
assert result.success is True
assert result.session_id == "test-session"
assert result.status == "partial"
assert result.completed_count == 2
assert result.pending_count == 1
assert result.total_count == 3
@pytest.mark.asyncio
async def test_get_batch_status_with_errors(self):
"""get_batch_status should include errors."""
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_session_status.return_value = {
"session_id": "test-session",
"status": "completed",
"completed_count": 2,
"pending_count": 0,
"total_count": 3,
"errors": [{"index": 2, "error": "Rate limit exceeded"}],
}
result = await get_batch_status(
session_id="test-session",
session_manager=mock_manager,
)
assert result.success is True
assert len(result.errors) == 1
assert result.errors[0]["error"] == "Rate limit exceeded"
@pytest.mark.asyncio
async def test_get_batch_status_session_not_found(self):
"""get_batch_status should handle session not found."""
mock_manager = MagicMock(spec=SessionManager)
mock_manager.get_session_status.side_effect = SessionNotFoundError(
"Session not found"
)
result = await get_batch_status(
session_id="nonexistent",
session_manager=mock_manager,
)
assert result.success is False
assert result.error is not None
assert "not found" in result.error
class TestOutputModels:
"""Tests for output model schemas."""
def test_start_batch_output_success(self):
"""StartBatchOutput should handle success case."""
output = StartBatchOutput(
success=True,
session_id="session-123",
first_image_path="/path/to/image.png",
first_image_size_bytes=12345,
pending_count=2,
)
assert output.success is True
assert output.session_id == "session-123"
assert output.first_image_path == "/path/to/image.png"
assert output.first_image_size_bytes == 12345
assert output.error is None
def test_start_batch_output_error(self):
"""StartBatchOutput should handle error case."""
output = StartBatchOutput(
success=False,
error="Something went wrong",
)
assert output.success is False
assert output.session_id is None
assert output.first_image_path is None
assert output.error == "Something went wrong"
def test_get_next_image_output(self):
"""GetNextImageOutput should handle all fields."""
output = GetNextImageOutput(
success=True,
file_path="/path/to/image.png",
file_size_bytes=12345,
has_more=True,
pending_count=2,
)
assert output.success is True
assert output.file_path == "/path/to/image.png"
assert output.file_size_bytes == 12345
assert output.has_more is True
def test_get_batch_status_output(self):
"""GetBatchStatusOutput should handle all fields."""
output = GetBatchStatusOutput(
success=True,
session_id="session-123",
status="partial",
completed_count=2,
pending_count=1,
total_count=3,
errors=[],
)
assert output.success is True
assert output.status == "partial"
assert output.total_count == 3