"""Batch image generation tools."""
from pathlib import Path
from pydantic import BaseModel, Field
from Imagen_MCP.constants import (
DEFAULT_MODEL,
DEFAULT_SIZE,
DEFAULT_QUALITY,
DEFAULT_STYLE,
DEFAULT_BATCH_COUNT,
MIN_BATCH_COUNT,
MAX_BATCH_COUNT,
DEFAULT_GENERATION_TIMEOUT,
MIN_TIMEOUT,
MAX_TIMEOUT,
)
from Imagen_MCP.exceptions import SessionNotFoundError, SessionExpiredError
from Imagen_MCP.services.model_registry import get_model_registry
from Imagen_MCP.services.session_manager import SessionManager, get_session_manager
from Imagen_MCP.utils import generate_temp_dir, generate_temp_path, save_image_to_file
class StartBatchInput(BaseModel):
"""Input schema for start_image_batch tool."""
prompt: str = Field(description="Text description of the image to generate")
output_dir: str | None = Field(
default=None,
description="Directory where images will be saved. If not provided, a temporary directory will be created.",
)
count: int = Field(
default=DEFAULT_BATCH_COUNT,
ge=MIN_BATCH_COUNT,
le=MAX_BATCH_COUNT,
description="Number of images to generate (2-10)",
)
model: str = Field(
default=DEFAULT_MODEL,
description="Model to use for generation",
)
size: str = Field(
default=DEFAULT_SIZE,
description="Image size",
)
quality: str = Field(
default=DEFAULT_QUALITY,
description="Image quality (standard, hd)",
)
style: str = Field(
default=DEFAULT_STYLE,
description="Image style (vivid, natural)",
)
class StartBatchOutput(BaseModel):
"""Output schema for start_image_batch tool."""
success: bool = Field(description="Whether the batch was started successfully")
session_id: str | None = Field(
default=None, description="Session ID for retrieving more images"
)
first_image_path: str | None = Field(
default=None, description="Path to the first generated image file"
)
first_image_size_bytes: int | None = Field(
default=None, description="Size of the first image file in bytes"
)
pending_count: int = Field(
default=0, description="Number of images still being generated"
)
error: str | None = Field(
default=None, description="Error message if batch failed to start"
)
class GetNextImageInput(BaseModel):
"""Input schema for get_next_image tool."""
session_id: str = Field(description="Session ID from start_image_batch")
output_path: str | None = Field(
default=None,
description="File path where the image will be saved. If not provided, a temporary file will be created.",
)
timeout: float = Field(
default=DEFAULT_GENERATION_TIMEOUT,
ge=MIN_TIMEOUT,
le=MAX_TIMEOUT,
description="Maximum time to wait for an image (seconds)",
)
class GetNextImageOutput(BaseModel):
"""Output schema for get_next_image tool."""
success: bool = Field(description="Whether an image was retrieved")
file_path: str | None = Field(
default=None, description="Path to the saved image file"
)
file_size_bytes: int | None = Field(
default=None, description="Size of the saved image file in bytes"
)
has_more: bool = Field(
default=False, description="Whether more images are available or pending"
)
pending_count: int = Field(
default=0, description="Number of images still being generated"
)
error: str | None = Field(
default=None, description="Error message if retrieval failed"
)
class GetBatchStatusInput(BaseModel):
"""Input schema for get_batch_status tool."""
session_id: str = Field(description="Session ID from start_image_batch")
class GetBatchStatusOutput(BaseModel):
"""Output schema for get_batch_status tool."""
success: bool = Field(description="Whether status was retrieved successfully")
session_id: str | None = Field(default=None, description="Session ID")
status: str | None = Field(
default=None,
description="Session status (created, generating, partial, completed, failed)",
)
completed_count: int = Field(default=0, description="Number of completed images")
pending_count: int = Field(default=0, description="Number of pending images")
total_count: int = Field(default=0, description="Total number of requested images")
errors: list[dict] = Field(
default_factory=list, description="List of errors encountered"
)
error: str | None = Field(
default=None, description="Error message if status retrieval failed"
)
async def start_image_batch(
prompt: str,
output_dir: str | None = None,
count: int = DEFAULT_BATCH_COUNT,
model: str = DEFAULT_MODEL,
size: str = DEFAULT_SIZE,
quality: str = DEFAULT_QUALITY,
style: str = DEFAULT_STYLE,
session_manager: SessionManager | None = None,
) -> StartBatchOutput:
"""Start a batch image generation.
This tool starts generating multiple images in the background.
It blocks until the first image is ready, saves it to a file,
and returns the file path along with a session ID for retrieving
the remaining images.
Args:
prompt: Text description of the image to generate.
output_dir: Directory where images will be saved. If not provided,
a temporary directory will be created.
count: Number of images to generate (2-10).
model: Model to use for generation.
size: Image size.
quality: Image quality.
style: Image style.
session_manager: Optional SessionManager instance (for testing).
Returns:
StartBatchOutput with session ID and first image file path.
"""
# Generate temp dir if not provided
if output_dir is None:
output_dir = generate_temp_dir()
# Validate model
registry = get_model_registry()
if not registry.is_valid_model(model):
return StartBatchOutput(
success=False,
error=f"Invalid model: {model}. Available models: {[m.id for m in registry.get_all_models()]}",
)
# Validate size for the model
model_info = registry.get_model(model)
if model_info and size not in model_info.capabilities.supported_sizes:
return StartBatchOutput(
success=False,
error=f"Size {size} not supported by {model}. Supported sizes: {model_info.capabilities.supported_sizes}",
)
# Validate quality for the model
if (
quality == "hd"
and model_info
and not model_info.capabilities.supports_hd_quality
):
return StartBatchOutput(
success=False,
error=f"HD quality not supported by {model}.",
)
# Get session manager
if session_manager is None:
session_manager = get_session_manager()
try:
# Create session
session = session_manager.create_session(
prompt=prompt,
model=model,
count=count,
size=size,
quality=quality,
style=style,
)
# Start background generation
await session_manager.start_generation(session.id)
# Wait for first image
first_image = await session_manager.get_next_image(session.id, timeout=120.0)
if first_image is None:
# Check if there were errors
status = session_manager.get_session_status(session.id)
if status.get("errors"):
return StartBatchOutput(
success=False,
session_id=session.id,
error=f"First image generation failed: {status['errors'][0].get('error', 'Unknown error')}",
)
return StartBatchOutput(
success=False,
session_id=session.id,
error="First image generation timed out",
)
# Save first image to file
b64_data = first_image.get("b64_json")
if not b64_data:
return StartBatchOutput(
success=False,
session_id=session.id,
error="First image has no data",
)
# Generate filename for first image
output_path = Path(output_dir) / "image_001.png"
try:
file_path, file_size = save_image_to_file(b64_data, str(output_path))
except Exception as e:
return StartBatchOutput(
success=False,
session_id=session.id,
error=f"Failed to save first image: {e}",
)
# Get updated status
status = session_manager.get_session_status(session.id)
return StartBatchOutput(
success=True,
session_id=session.id,
first_image_path=file_path,
first_image_size_bytes=file_size,
pending_count=status.get("pending_count", 0),
)
except ValueError as e:
return StartBatchOutput(
success=False,
error=str(e),
)
except Exception as e:
return StartBatchOutput(
success=False,
error=f"Failed to start batch: {e}",
)
async def get_next_image(
session_id: str,
output_path: str | None = None,
timeout: float = DEFAULT_GENERATION_TIMEOUT,
session_manager: SessionManager | None = None,
) -> GetNextImageOutput:
"""Get the next available image from a batch generation session.
This tool retrieves the next image from an ongoing batch generation.
If an image is already available, it saves it to a file and returns immediately.
Otherwise, it blocks until an image becomes available or timeout.
Args:
session_id: Session ID from start_image_batch.
output_path: File path where the image will be saved. If not provided,
a temporary file will be created.
timeout: Maximum time to wait for an image (seconds).
session_manager: Optional SessionManager instance (for testing).
Returns:
GetNextImageOutput with the file path or status.
"""
if session_manager is None:
session_manager = get_session_manager()
try:
# Get next image
image = await session_manager.get_next_image(session_id, timeout)
# Get updated status
status = session_manager.get_session_status(session_id)
pending = status.get("pending_count", 0)
session_status = status.get("status", "unknown")
if image is not None:
# Generate temp path if not provided
actual_output_path = output_path if output_path else generate_temp_path()
# Save image to file
b64_data = image.get("b64_json")
if not b64_data:
return GetNextImageOutput(
success=False,
has_more=pending > 0 or session_status in ("generating", "partial"),
pending_count=pending,
error="Image has no data",
)
try:
file_path, file_size = save_image_to_file(b64_data, actual_output_path)
return GetNextImageOutput(
success=True,
file_path=file_path,
file_size_bytes=file_size,
has_more=pending > 0 or session_status in ("generating", "partial"),
pending_count=pending,
)
except Exception as e:
return GetNextImageOutput(
success=False,
has_more=pending > 0 or session_status in ("generating", "partial"),
pending_count=pending,
error=f"Failed to save image: {e}",
)
else:
# No image available
if session_status in ("completed", "failed"):
return GetNextImageOutput(
success=True,
file_path=None,
has_more=False,
pending_count=0,
)
else:
return GetNextImageOutput(
success=False,
has_more=pending > 0,
pending_count=pending,
error="Timeout waiting for next image",
)
except SessionNotFoundError:
return GetNextImageOutput(
success=False,
error=f"Session not found: {session_id}",
)
except SessionExpiredError:
return GetNextImageOutput(
success=False,
error=f"Session expired: {session_id}",
)
except Exception as e:
return GetNextImageOutput(
success=False,
error=f"Failed to get next image: {e}",
)
async def get_batch_status(
session_id: str,
session_manager: SessionManager | None = None,
) -> GetBatchStatusOutput:
"""Get the current status of a batch generation session.
Args:
session_id: Session ID from start_image_batch.
session_manager: Optional SessionManager instance (for testing).
Returns:
GetBatchStatusOutput with session status.
"""
if session_manager is None:
session_manager = get_session_manager()
try:
status = session_manager.get_session_status(session_id)
return GetBatchStatusOutput(
success=True,
session_id=status.get("session_id"),
status=status.get("status"),
completed_count=status.get("completed_count", 0),
pending_count=status.get("pending_count", 0),
total_count=status.get("total_count", 0),
errors=status.get("errors", []),
)
except SessionNotFoundError:
return GetBatchStatusOutput(
success=False,
error=f"Session not found: {session_id}",
)
except SessionExpiredError:
return GetBatchStatusOutput(
success=False,
error=f"Session expired: {session_id}",
)
except Exception as e:
return GetBatchStatusOutput(
success=False,
error=f"Failed to get batch status: {e}",
)