"""Generation request/response models."""
from typing import Literal, get_args
from pydantic import BaseModel, Field
from Imagen_MCP.models.image import GeneratedImage
# Valid size options for image generation
ImageSize = Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
# Valid quality options
ImageQuality = Literal["standard", "hd"]
# Valid style options
ImageStyle = Literal["vivid", "natural"]
# Valid response format options
ResponseFormat = Literal["url", "b64_json"]
def validate_image_params(
size: str, quality: str, style: str
) -> tuple[ImageSize, ImageQuality, ImageStyle]:
"""Validate and convert image parameters to proper types.
Args:
size: Image size string.
quality: Image quality string.
style: Image style string.
Returns:
Tuple of validated (size, quality, style) with proper Literal types.
Raises:
ValueError: If any parameter is invalid.
"""
valid_sizes = get_args(ImageSize)
valid_qualities = get_args(ImageQuality)
valid_styles = get_args(ImageStyle)
if size not in valid_sizes:
raise ValueError(f"Invalid size: {size}. Must be one of: {valid_sizes}")
if quality not in valid_qualities:
raise ValueError(
f"Invalid quality: {quality}. Must be one of: {valid_qualities}"
)
if style not in valid_styles:
raise ValueError(f"Invalid style: {style}. Must be one of: {valid_styles}")
# Type narrowing - we've validated the values, so we can safely cast
return size, quality, style # type: ignore[return-value]
class GenerateImageRequest(BaseModel):
"""Request model for image generation."""
prompt: str = Field(
description="Text description of the image to generate", min_length=1
)
model: str = Field(default="imagen-4", description="Model to use for generation")
n: int = Field(default=1, ge=1, le=10, description="Number of images to generate")
size: ImageSize = Field(
default="1024x1024", description="Size of the generated image"
)
quality: ImageQuality = Field(
default="standard", description="Quality of the generated image"
)
style: ImageStyle = Field(
default="vivid", description="Style of the generated image"
)
response_format: ResponseFormat = Field(
default="b64_json", description="Format of the response"
)
def to_api_payload(self) -> dict:
"""Convert to API request payload."""
return {
"prompt": self.prompt,
"model": self.model,
"n": self.n,
"size": self.size,
"quality": self.quality,
"style": self.style,
"response_format": self.response_format,
}
class GenerateImageResponse(BaseModel):
"""Response model for image generation."""
created: int = Field(description="Unix timestamp of when the images were created")
images: list[GeneratedImage] = Field(
default_factory=list, description="List of generated images"
)
@classmethod
def from_api_response(
cls, response: dict, request: GenerateImageRequest | None = None
) -> "GenerateImageResponse":
"""Create from API response."""
images = []
for item in response.get("data", []):
image = GeneratedImage(
b64_json=item.get("b64_json"),
url=item.get("url"),
revised_prompt=item.get("revised_prompt"),
model=request.model if request else None,
size=request.size if request else None,
quality=request.quality if request else None,
style=request.style if request else None,
)
images.append(image)
return cls(created=response.get("created", 0), images=images)
class StartBatchRequest(BaseModel):
"""Request model for starting a batch image generation."""
prompt: str = Field(
description="Text description of the image to generate", min_length=1
)
count: int = Field(
default=4, ge=2, le=10, description="Number of image variations to generate"
)
model: str = Field(default="imagen-4", description="Model to use for generation")
size: ImageSize = Field(
default="1024x1024", description="Size of the generated images"
)
quality: ImageQuality = Field(
default="standard", description="Quality of the generated images"
)
style: ImageStyle = Field(
default="vivid", description="Style of the generated images"
)
class GetNextImageRequest(BaseModel):
"""Request model for getting the next image from a batch."""
session_id: str = Field(description="Session ID from start_image_batch")
timeout: float = Field(
default=60.0, ge=0, description="Maximum seconds to wait for next image"
)
class BatchStatusResponse(BaseModel):
"""Response model for batch generation status."""
session_id: str = Field(description="Session ID")
status: str = Field(description="Current session status")
completed_count: int = Field(description="Number of completed images")
pending_count: int = Field(description="Number of pending images")
total_count: int = Field(description="Total number of images requested")
errors: list[dict] = Field(
default_factory=list, description="List of errors encountered"
)