"""
Image handler implementations for Fal.ai MCP Server.
Contains: generate_image, generate_image_structured, generate_image_from_image
"""
import asyncio
import json
from typing import Any, Dict, List
from loguru import logger
from mcp.types import TextContent
from fal_mcp_server.model_registry import ModelRegistry
from fal_mcp_server.queue.base import QueueStrategy
async def handle_generate_image(
arguments: Dict[str, Any],
registry: ModelRegistry,
queue_strategy: QueueStrategy,
) -> List[TextContent]:
"""Handle the generate_image tool."""
model_input = arguments.get("model", "flux_schnell")
try:
model_id = await registry.resolve_model_id(model_input)
except ValueError as e:
return [
TextContent(
type="text",
text=f"❌ {e}. Use list_models to see available options.",
)
]
fal_args: Dict[str, Any] = {
"prompt": arguments["prompt"],
"image_size": arguments.get("image_size", "landscape_16_9"),
"num_images": arguments.get("num_images", 1),
}
# Add optional parameters
if "negative_prompt" in arguments:
fal_args["negative_prompt"] = arguments["negative_prompt"]
if "seed" in arguments:
fal_args["seed"] = arguments["seed"]
if "enable_safety_checker" in arguments:
fal_args["enable_safety_checker"] = arguments["enable_safety_checker"]
if "output_format" in arguments:
fal_args["output_format"] = arguments["output_format"]
# Use fast execution (no queue) for image generation
try:
result = await queue_strategy.execute_fast(model_id, fal_args)
except Exception as e:
logger.error("Image generation failed: %s", e)
return [
TextContent(
type="text",
text=f"❌ Image generation failed: {e}",
)
]
# Check for error in response
if "error" in result:
error_msg = result.get("error", "Unknown error")
logger.error("Image generation failed for %s: %s", model_id, error_msg)
return [
TextContent(
type="text",
text=f"❌ Image generation failed: {error_msg}",
)
]
images = result.get("images", [])
if not images:
logger.warning("Image generation returned no images. Model: %s", model_id)
return [
TextContent(
type="text",
text=f"❌ No images were generated by {model_id}. The prompt may have been filtered.",
)
]
# Extract URLs safely
try:
urls = [img["url"] for img in images]
except (KeyError, TypeError) as e:
logger.error("Malformed image response from %s: %s", model_id, e)
return [
TextContent(
type="text",
text=f"❌ Image generation completed but response was malformed: {e}",
)
]
response = f"🎨 Generated {len(urls)} image(s) with {model_id}:\n\n"
for i, url in enumerate(urls, 1):
response += f"Image {i}: {url}\n"
return [TextContent(type="text", text=response)]
async def handle_generate_image_structured(
arguments: Dict[str, Any],
registry: ModelRegistry,
queue_strategy: QueueStrategy,
) -> List[TextContent]:
"""Handle the generate_image_structured tool."""
model_input = arguments.get("model", "flux_schnell")
try:
model_id = await registry.resolve_model_id(model_input)
except ValueError as e:
return [
TextContent(
type="text",
text=f"❌ {e}. Use list_models to see available options.",
)
]
# Build structured JSON prompt from arguments
structured_prompt: Dict[str, Any] = {}
# Required field
structured_prompt["scene"] = arguments["scene"]
# Optional structured fields
for field in [
"subjects",
"style",
"color_palette",
"lighting",
"mood",
"background",
"composition",
"camera",
"effects",
]:
if field in arguments:
structured_prompt[field] = arguments[field]
# Convert structured prompt to JSON string
json_prompt = json.dumps(structured_prompt, indent=2)
fal_args: Dict[str, Any] = {
"prompt": json_prompt,
"image_size": arguments.get("image_size", "landscape_16_9"),
"num_images": arguments.get("num_images", 1),
}
# Add optional generation parameters
if "negative_prompt" in arguments:
fal_args["negative_prompt"] = arguments["negative_prompt"]
if "seed" in arguments:
fal_args["seed"] = arguments["seed"]
if "enable_safety_checker" in arguments:
fal_args["enable_safety_checker"] = arguments["enable_safety_checker"]
if "output_format" in arguments:
fal_args["output_format"] = arguments["output_format"]
# Use fast execution with timeout protection
logger.info("Starting structured image generation with %s", model_id)
try:
result = await asyncio.wait_for(
queue_strategy.execute_fast(model_id, fal_args),
timeout=60,
)
except asyncio.TimeoutError:
logger.error("Structured image generation timed out for %s", model_id)
return [
TextContent(
type="text",
text=f"❌ Image generation timed out after 60 seconds with {model_id}. Please try again.",
)
]
# Check for error in response
if "error" in result:
error_msg = result.get("error", "Unknown error")
logger.error(
"Structured image generation failed for %s: %s", model_id, error_msg
)
return [
TextContent(
type="text",
text=f"❌ Image generation failed: {error_msg}",
)
]
images = result.get("images", [])
if not images:
logger.warning(
"Structured image generation returned no images. Model: %s",
model_id,
)
return [
TextContent(
type="text",
text=f"❌ No images were generated by {model_id}. The prompt may have been filtered or the request format was invalid.",
)
]
# Extract URLs safely
try:
urls = [img["url"] for img in images]
except (KeyError, TypeError) as e:
logger.error("Malformed image response from %s: %s", model_id, e)
return [
TextContent(
type="text",
text=f"❌ Image generation completed but response was malformed: {e}",
)
]
response = (
f"🎨 Generated {len(urls)} image(s) with {model_id} (structured prompt):\n\n"
)
for i, url in enumerate(urls, 1):
response += f"Image {i}: {url}\n"
return [TextContent(type="text", text=response)]
async def handle_generate_image_from_image(
arguments: Dict[str, Any],
registry: ModelRegistry,
queue_strategy: QueueStrategy,
) -> List[TextContent]:
"""Handle the generate_image_from_image tool."""
model_input = arguments.get("model", "fal-ai/flux/dev/image-to-image")
try:
model_id = await registry.resolve_model_id(model_input)
except ValueError as e:
return [
TextContent(
type="text",
text=f"❌ {e}. Use list_models to see available options.",
)
]
# Both image_url and prompt are required
img2img_args: Dict[str, Any] = {
"image_url": arguments["image_url"],
"prompt": arguments["prompt"],
"strength": arguments.get("strength", 0.75),
"num_images": arguments.get("num_images", 1),
}
# Add optional parameters
if "negative_prompt" in arguments:
img2img_args["negative_prompt"] = arguments["negative_prompt"]
if "seed" in arguments:
img2img_args["seed"] = arguments["seed"]
if "enable_safety_checker" in arguments:
img2img_args["enable_safety_checker"] = arguments["enable_safety_checker"]
if "output_format" in arguments:
img2img_args["output_format"] = arguments["output_format"]
logger.info(
"Starting image-to-image transformation with %s from %s",
model_id,
(
arguments["image_url"][:50] + "..."
if len(arguments["image_url"]) > 50
else arguments["image_url"]
),
)
# Use fast execution with timeout protection
try:
result = await asyncio.wait_for(
queue_strategy.execute_fast(model_id, img2img_args),
timeout=60,
)
except asyncio.TimeoutError:
logger.error(
"Image-to-image transformation timed out after 60s. Model: %s",
model_id,
)
return [
TextContent(
type="text",
text=f"❌ Image transformation timed out after 60 seconds with {model_id}. Please try again.",
)
]
except Exception as e:
logger.exception("Image-to-image transformation failed: %s", e)
return [
TextContent(
type="text",
text=f"❌ Image transformation failed: {e}",
)
]
# Check for error in response
if "error" in result:
error_msg = result.get("error", "Unknown error")
logger.error(
"Image-to-image transformation failed for %s: %s",
model_id,
error_msg,
)
return [
TextContent(
type="text",
text=f"❌ Image transformation failed: {error_msg}",
)
]
images = result.get("images", [])
if not images:
logger.warning(
"Image-to-image transformation returned no images. Model: %s",
model_id,
)
return [
TextContent(
type="text",
text=f"❌ No images were generated by {model_id}. The source image may have been filtered.",
)
]
# Extract URLs safely
try:
urls = [img["url"] for img in images]
except (KeyError, TypeError) as e:
logger.error("Malformed image response from %s: %s", model_id, e)
return [
TextContent(
type="text",
text=f"❌ Image transformation completed but response was malformed: {e}",
)
]
response = f"🎨 Transformed image with {model_id}:\n\n"
response += f"**Source**: {arguments['image_url'][:50]}...\n\n"
for i, url in enumerate(urls, 1):
response += f"Result {i}: {url}\n"
return [TextContent(type="text", text=response)]