"""Unit tests for the generate_image tool."""
import os
import pytest
from unittest.mock import AsyncMock, MagicMock
from Imagen_MCP.tools.generate_image import generate_image, GenerateImageOutput
from Imagen_MCP.models.image import GeneratedImage
from Imagen_MCP.models.generation import GenerateImageResponse
from Imagen_MCP.exceptions import (
InvalidRequestError,
RateLimitError,
AuthenticationError,
)
class TestGenerateImageValidation:
"""Tests for input validation."""
@pytest.mark.asyncio
async def test_invalid_model_returns_error(self, tmp_path):
"""generate_image should return error for invalid model."""
output_path = str(tmp_path / "test.png")
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="nonexistent-model",
)
assert result.success is False
assert result.error is not None
assert "Invalid model" in result.error
assert result.model_used == "nonexistent-model"
@pytest.mark.asyncio
async def test_invalid_size_for_model_returns_error(self, tmp_path):
"""generate_image should return error for unsupported size."""
output_path = str(tmp_path / "test.png")
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4-fast", # Doesn't support 1792x1024
size="1792x1024",
)
assert result.success is False
assert result.error is not None
assert "not supported" in result.error
assert "1792x1024" in result.error
@pytest.mark.asyncio
async def test_hd_quality_unsupported_model_returns_error(self, tmp_path):
"""generate_image should return error for HD on unsupported model."""
output_path = str(tmp_path / "test.png")
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4-fast", # Doesn't support HD
quality="hd",
)
assert result.success is False
assert result.error is not None
assert "HD quality not supported" in result.error
# Base64 encoded 1x1 PNG for testing
TEST_B64_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
class TestGenerateImageSuccess:
"""Tests for successful image generation."""
@pytest.mark.asyncio
async def test_generate_image_success_with_b64(self, tmp_path):
"""generate_image should save image to file on success."""
output_path = str(tmp_path / "test.png")
# Create mock client
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[
GeneratedImage(
b64_json=TEST_B64_PNG,
revised_prompt="A beautiful sunset over the ocean",
)
],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is True
assert result.file_path is not None
assert result.file_size_bytes is not None
assert result.file_size_bytes > 0
assert result.revised_prompt == "A beautiful sunset over the ocean"
assert result.model_used == "imagen-4"
assert result.error is None
# Verify file was created
file_path = result.file_path
assert file_path is not None
assert os.path.exists(file_path)
@pytest.mark.asyncio
async def test_generate_image_url_only_returns_error(self, tmp_path):
"""generate_image should return error when only URL is provided (no b64)."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[
GeneratedImage(
url="https://example.com/image.png",
)
],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is False
assert result.error is not None
assert "URL-based responses not supported" in result.error
@pytest.mark.asyncio
async def test_generate_image_with_all_parameters(self, tmp_path):
"""generate_image should pass all parameters to client."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[GeneratedImage(b64_json=TEST_B64_PNG)],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
result = await generate_image(
prompt="A mountain landscape",
output_path=output_path,
model="imagen-4",
size="1792x1024",
quality="hd",
style="natural",
client=mock_client,
)
assert result.success is True
# Verify the request was made with correct parameters
mock_client.generate_image.assert_called_once()
call_args = mock_client.generate_image.call_args[0][0]
assert call_args.prompt == "A mountain landscape"
# The model in the request should be the API ID, not the internal ID
assert call_args.model == "Imagen 4 (Public)"
assert call_args.size == "1792x1024"
assert call_args.quality == "hd"
assert call_args.style == "natural"
@pytest.mark.asyncio
async def test_generate_image_creates_parent_directories(self, tmp_path):
"""generate_image should create parent directories if they don't exist."""
output_path = str(tmp_path / "nested" / "dir" / "test.png")
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[GeneratedImage(b64_json=TEST_B64_PNG)],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is True
file_path = result.file_path
assert file_path is not None
assert os.path.exists(file_path)
class TestGenerateImageErrorHandling:
"""Tests for error handling."""
@pytest.mark.asyncio
async def test_generate_image_invalid_request_error(self, tmp_path):
"""generate_image should handle InvalidRequestError."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_client.generate_image = AsyncMock(
side_effect=InvalidRequestError("Invalid prompt content")
)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is False
assert result.error is not None
assert "Invalid request" in result.error
assert result.model_used == "imagen-4"
@pytest.mark.asyncio
async def test_generate_image_rate_limit_error(self, tmp_path):
"""generate_image should handle RateLimitError."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_client.generate_image = AsyncMock(
side_effect=RateLimitError("Rate limit exceeded")
)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is False
assert result.error is not None
assert "Generation failed" in result.error
assert result.model_used == "imagen-4"
@pytest.mark.asyncio
async def test_generate_image_authentication_error(self, tmp_path):
"""generate_image should handle AuthenticationError."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_client.generate_image = AsyncMock(
side_effect=AuthenticationError("Invalid API key")
)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is False
assert result.error is not None
assert "Generation failed" in result.error
@pytest.mark.asyncio
async def test_generate_image_empty_response(self, tmp_path):
"""generate_image should handle empty response."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is False
assert result.error is not None
assert "No images returned" in result.error
@pytest.mark.asyncio
async def test_generate_image_generic_exception(self, tmp_path):
"""generate_image should handle generic exceptions."""
output_path = str(tmp_path / "test.png")
mock_client = MagicMock()
mock_client.generate_image = AsyncMock(
side_effect=Exception("Unexpected error")
)
result = await generate_image(
prompt="A beautiful sunset",
output_path=output_path,
model="imagen-4",
client=mock_client,
)
assert result.success is False
assert result.error is not None
assert "Generation failed" in result.error
assert "Unexpected error" in result.error
class TestGenerateImageOutput:
"""Tests for GenerateImageOutput model."""
def test_output_with_all_fields(self):
"""GenerateImageOutput should accept all fields."""
output = GenerateImageOutput(
success=True,
file_path="/path/to/image.png",
file_size_bytes=12345,
model_used="imagen-4",
revised_prompt="Revised prompt",
error=None,
)
assert output.success is True
assert output.file_path == "/path/to/image.png"
assert output.file_size_bytes == 12345
assert output.model_used == "imagen-4"
assert output.revised_prompt == "Revised prompt"
assert output.error is None
def test_output_error_case(self):
"""GenerateImageOutput should handle error case."""
output = GenerateImageOutput(
success=False,
model_used="imagen-4",
error="Something went wrong",
)
assert output.success is False
assert output.file_path is None
assert output.file_size_bytes is None
assert output.error == "Something went wrong"