"""Simple image generation tool."""
from pydantic import BaseModel, Field
from Imagen_MCP.constants import (
DEFAULT_MODEL,
DEFAULT_SIZE,
DEFAULT_QUALITY,
DEFAULT_STYLE,
)
from Imagen_MCP.exceptions import InvalidRequestError
from Imagen_MCP.models.generation import (
GenerateImageRequest,
validate_image_params,
)
from Imagen_MCP.services.model_registry import get_model_registry
from Imagen_MCP.services.nexos_client import NexosClient
from Imagen_MCP.utils import generate_temp_path, save_image_to_file
class GenerateImageInput(BaseModel):
"""Input schema for generate_image tool."""
prompt: str = Field(description="Text description of the image to generate")
output_path: str | None = Field(
default=None,
description="File path where the image will be saved. If not provided, a temporary file will be created.",
)
model: str = Field(
default=DEFAULT_MODEL,
description="Model to use for generation (imagen-4, imagen-4-fast, imagen-4-ultra, flux-1.1-pro, gpt-image-1)",
)
size: str = Field(
default=DEFAULT_SIZE,
description="Image size (256x256, 512x512, 1024x1024, 1792x1024, 1024x1792)",
)
quality: str = Field(
default=DEFAULT_QUALITY,
description="Image quality (standard, hd)",
)
style: str = Field(
default=DEFAULT_STYLE,
description="Image style (vivid, natural)",
)
class GenerateImageOutput(BaseModel):
"""Output schema for generate_image tool."""
success: bool = Field(description="Whether the generation was successful")
file_path: str | None = Field(
default=None, description="Path to the saved image file"
)
file_size_bytes: int | None = Field(
default=None, description="Size of the saved image file in bytes"
)
model_used: str = Field(description="Model that was used for generation")
revised_prompt: str | None = Field(
default=None, description="Revised prompt if the model modified it"
)
error: str | None = Field(
default=None, description="Error message if generation failed"
)
async def generate_image(
prompt: str,
output_path: str | None = None,
model: str = DEFAULT_MODEL,
size: str = DEFAULT_SIZE,
quality: str = DEFAULT_QUALITY,
style: str = DEFAULT_STYLE,
client: NexosClient | None = None,
) -> GenerateImageOutput:
"""Generate a single image from a text prompt.
Args:
prompt: Text description of the image to generate.
output_path: File path where the image will be saved. If not provided,
a temporary file will be created.
model: Model to use for generation.
size: Image size.
quality: Image quality (standard or hd).
style: Image style (vivid or natural).
client: Optional NexosClient instance (for testing).
Returns:
GenerateImageOutput with the file path and metadata.
"""
# Generate temp path if not provided
if output_path is None:
output_path = generate_temp_path()
# Validate model
registry = get_model_registry()
if not registry.is_valid_model(model):
return GenerateImageOutput(
success=False,
model_used=model,
error=f"Invalid model: {model}. Available models: {[m.id for m in registry.get_all_models()]}",
)
# Validate size for the model
model_info = registry.get_model(model)
if model_info and size not in model_info.capabilities.supported_sizes:
return GenerateImageOutput(
success=False,
model_used=model,
error=f"Size {size} not supported by {model}. Supported sizes: {model_info.capabilities.supported_sizes}",
)
# Validate quality for the model
if (
quality == "hd"
and model_info
and not model_info.capabilities.supports_hd_quality
):
return GenerateImageOutput(
success=False,
model_used=model,
error=f"HD quality not supported by {model}.",
)
# Create client if not provided
if client is None:
client = NexosClient.from_env()
# Get the API model ID for the request
api_model_id = registry.get_api_id(model)
if not api_model_id:
return GenerateImageOutput(
success=False,
model_used=model,
error=f"Could not find API ID for model: {model}",
)
try:
# Validate and convert parameters to proper types
validated_size, validated_quality, validated_style = validate_image_params(
size, quality, style
)
# Create request with validated parameters
# Use the API model ID for the actual API request
request = GenerateImageRequest(
prompt=prompt,
model=api_model_id,
n=1,
size=validated_size,
quality=validated_quality,
style=validated_style,
)
# Generate image
response = await client.generate_image(request)
if response.images and len(response.images) > 0:
image = response.images[0]
# Save image to file
if image.b64_json:
try:
file_path, file_size = save_image_to_file(
image.b64_json, output_path
)
return GenerateImageOutput(
success=True,
file_path=file_path,
file_size_bytes=file_size,
model_used=model,
revised_prompt=image.revised_prompt,
)
except Exception as e:
return GenerateImageOutput(
success=False,
model_used=model,
error=f"Failed to save image to file: {e}",
)
else:
return GenerateImageOutput(
success=False,
model_used=model,
error="No image data returned from API (only URL-based responses not supported)",
)
else:
return GenerateImageOutput(
success=False,
model_used=model,
error="No images returned from API",
)
except InvalidRequestError as e:
return GenerateImageOutput(
success=False,
model_used=model,
error=f"Invalid request: {e}",
)
except Exception as e:
return GenerateImageOutput(
success=False,
model_used=model,
error=f"Generation failed: {e}",
)