"""Unit tests for Pydantic models."""
import pytest
from pydantic import ValidationError
from Imagen_MCP.models.image import GeneratedImage, ImageGenerationError
from Imagen_MCP.models.generation import (
GenerateImageRequest,
GenerateImageResponse,
StartBatchRequest,
GetNextImageRequest,
BatchStatusResponse,
)
from Imagen_MCP.models.session import GenerationSession, SessionStatus
class TestGeneratedImage:
"""Tests for GeneratedImage model."""
def test_generated_image_with_b64_json(self):
"""Image should be created with base64 data."""
image = GeneratedImage(
b64_json="base64data",
revised_prompt="A beautiful sunset",
)
assert image.b64_json == "base64data"
assert image.has_data is True
def test_generated_image_with_url(self):
"""Image should be created with URL."""
image = GeneratedImage(
url="https://example.com/image.png",
revised_prompt="A beautiful sunset",
)
assert image.url == "https://example.com/image.png"
assert image.has_data is True
def test_generated_image_without_data(self):
"""Image without data should have has_data=False."""
image = GeneratedImage()
assert image.has_data is False
def test_generated_image_with_metadata(self):
"""Image should store generation metadata."""
image = GeneratedImage(
b64_json="data",
model="imagen-4",
size="1024x1024",
quality="hd",
style="vivid",
)
assert image.model == "imagen-4"
assert image.size == "1024x1024"
assert image.quality == "hd"
assert image.style == "vivid"
class TestImageGenerationError:
"""Tests for ImageGenerationError model."""
def test_error_creation(self):
"""Error should be created with required fields."""
error = ImageGenerationError(
index=0,
error="Generation failed",
error_type="server_error",
)
assert error.index == 0
assert error.error == "Generation failed"
assert error.error_type == "server_error"
class TestGenerateImageRequest:
"""Tests for GenerateImageRequest model."""
def test_request_with_defaults(self):
"""Request should have sensible defaults."""
request = GenerateImageRequest(prompt="A cat")
assert request.prompt == "A cat"
assert request.model == "imagen-4"
assert request.size == "1024x1024"
assert request.quality == "standard"
assert request.style == "vivid"
assert request.n == 1
def test_request_with_custom_values(self):
"""Request should accept custom values."""
request = GenerateImageRequest(
prompt="A dog",
model="imagen-4-ultra",
size="1792x1024",
quality="hd",
style="natural",
n=3,
)
assert request.prompt == "A dog"
assert request.model == "imagen-4-ultra"
assert request.size == "1792x1024"
assert request.quality == "hd"
assert request.style == "natural"
assert request.n == 3
def test_request_validation_empty_prompt(self):
"""Request should reject empty prompt."""
with pytest.raises(ValidationError):
GenerateImageRequest(prompt="")
def test_request_validation_n_range(self):
"""Request should validate n is in valid range."""
with pytest.raises(ValidationError):
GenerateImageRequest(prompt="test", n=0)
with pytest.raises(ValidationError):
GenerateImageRequest(prompt="test", n=11)
def test_request_to_api_payload(self):
"""Request should convert to API payload correctly."""
request = GenerateImageRequest(
prompt="A cat",
model="imagen-4",
size="1024x1024",
quality="standard",
style="vivid",
n=2,
)
payload = request.to_api_payload()
assert payload["prompt"] == "A cat"
assert payload["model"] == "imagen-4"
assert payload["size"] == "1024x1024"
assert payload["quality"] == "standard"
assert payload["style"] == "vivid"
assert payload["n"] == 2
class TestGenerateImageResponse:
"""Tests for GenerateImageResponse model."""
def test_response_from_api_response(self):
"""Response should parse API response correctly."""
api_response = {
"created": 1734800000,
"data": [
{
"b64_json": "imagedata",
"revised_prompt": "A beautiful cat",
}
],
}
request = GenerateImageRequest(prompt="A cat", model="imagen-4")
response = GenerateImageResponse.from_api_response(api_response, request)
assert response.created == 1734800000
assert len(response.images) == 1
assert response.images[0].b64_json == "imagedata"
assert response.images[0].revised_prompt == "A beautiful cat"
assert response.images[0].model == "imagen-4"
def test_response_from_api_response_multiple_images(self):
"""Response should handle multiple images."""
api_response = {
"created": 1734800000,
"data": [
{"b64_json": "image1"},
{"b64_json": "image2"},
{"b64_json": "image3"},
],
}
response = GenerateImageResponse.from_api_response(api_response)
assert len(response.images) == 3
class TestStartBatchRequest:
"""Tests for StartBatchRequest model."""
def test_batch_request_defaults(self):
"""Batch request should have sensible defaults."""
request = StartBatchRequest(prompt="A landscape")
assert request.count == 4
assert request.model == "imagen-4"
def test_batch_request_count_validation(self):
"""Batch request should validate count range."""
with pytest.raises(ValidationError):
StartBatchRequest(prompt="test", count=1) # min is 2
with pytest.raises(ValidationError):
StartBatchRequest(prompt="test", count=11) # max is 10
class TestGetNextImageRequest:
"""Tests for GetNextImageRequest model."""
def test_get_next_image_request(self):
"""Request should be created with session ID."""
request = GetNextImageRequest(session_id="abc123")
assert request.session_id == "abc123"
assert request.timeout == 60.0
def test_get_next_image_custom_timeout(self):
"""Request should accept custom timeout."""
request = GetNextImageRequest(session_id="abc123", timeout=30.0)
assert request.timeout == 30.0
class TestBatchStatusResponse:
"""Tests for BatchStatusResponse model."""
def test_batch_status_response(self):
"""Status response should contain all fields."""
response = BatchStatusResponse(
session_id="abc123",
status="generating",
completed_count=2,
pending_count=3,
total_count=5,
errors=[],
)
assert response.session_id == "abc123"
assert response.status == "generating"
assert response.completed_count == 2
assert response.pending_count == 3
assert response.total_count == 5
class TestSessionStatus:
"""Tests for SessionStatus enum."""
def test_session_status_values(self):
"""Session status should have expected values."""
assert SessionStatus.CREATED.value == "created"
assert SessionStatus.GENERATING.value == "generating"
assert SessionStatus.PARTIAL.value == "partial"
assert SessionStatus.COMPLETED.value == "completed"
assert SessionStatus.FAILED.value == "failed"
class TestGenerationSession:
"""Tests for GenerationSession model."""
def test_session_creation(self):
"""Session should be created with required fields."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=4,
)
assert session.id == "session123"
assert session.prompt == "A mountain"
assert session.model == "imagen-4"
assert session.requested_count == 4
assert session.status == SessionStatus.CREATED
assert session.completed_count == 0
assert session.pending_count == 0
def test_session_add_image(self):
"""Session should track added images."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=4,
pending_count=4,
)
session.add_image({"b64_json": "imagedata"})
assert session.completed_count == 1
assert session.pending_count == 3
assert len(session.completed_images) == 1
def test_session_add_error(self):
"""Session should track errors."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=4,
pending_count=4,
)
session.add_error(0, "Generation failed", "server_error")
assert len(session.errors) == 1
assert session.errors[0]["error"] == "Generation failed"
assert session.pending_count == 3
def test_session_completion_status(self):
"""Session should update status when complete."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=2,
pending_count=2,
)
session.add_image({"b64_json": "image1"})
assert session.status == SessionStatus.PARTIAL
session.add_image({"b64_json": "image2"})
assert session.status == SessionStatus.COMPLETED
def test_session_is_complete(self):
"""Session should report completion status."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=2,
)
assert session.is_complete is False
session.status = SessionStatus.COMPLETED
assert session.is_complete is True
def test_session_get_next_image(self):
"""Session should return images in order."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=2,
pending_count=2,
)
session.add_image({"b64_json": "image1"})
session.add_image({"b64_json": "image2"})
image1 = session.get_next_image()
assert image1 is not None
assert image1["b64_json"] == "image1"
image2 = session.get_next_image()
assert image2 is not None
assert image2["b64_json"] == "image2"
image3 = session.get_next_image()
assert image3 is None
def test_session_to_status_dict(self):
"""Session should convert to status dictionary."""
session = GenerationSession(
id="session123",
prompt="A mountain",
model="imagen-4",
requested_count=4,
pending_count=4,
)
session.add_image({"b64_json": "image1"})
session.add_image({"b64_json": "image2"})
status = session.to_status_dict()
assert status["session_id"] == "session123"
assert status["status"] == "partial"
assert status["completed_count"] == 2
assert status["pending_count"] == 2 # Was decremented by add_image (4-2=2)
assert status["total_count"] == 4