"""MCP server for fal.ai image generation."""
import os
import time
from pathlib import Path
from typing import Any
import httpx
from dotenv import load_dotenv
from fastmcp import FastMCP
load_dotenv(Path(__file__).resolve().parents[1] / ".env.local")
mcp = FastMCP("fal")
OUTPUT_DIR = Path(__file__).parent / "output"
OUTPUT_DIR.mkdir(exist_ok=True)
FAL_BASE = "https://queue.fal.run"
FAL_SYNC = "https://fal.run"
POLL_INTERVAL = 2.0
MAX_POLL = 120
def _key() -> str:
key = os.environ.get("FAL_KEY", "")
if not key:
raise RuntimeError("FAL_KEY not set in .env.local")
return key
def _headers() -> dict[str, str]:
return {
"Authorization": f"Key {_key()}",
"Content-Type": "application/json",
}
def _download(url: str, dest: Path) -> Path:
"""Download a URL to a local file."""
with httpx.stream("GET", url, timeout=60, follow_redirects=True) as r:
r.raise_for_status()
with open(dest, "wb") as f:
for chunk in r.iter_bytes():
f.write(chunk)
return dest
@mcp.tool()
def generate_image(
prompt: str,
model: str = "fal-ai/flux/dev",
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 28,
guidance_scale: float = 3.5,
seed: int | None = None,
num_images: int = 1,
output_format: str = "png",
filename: str | None = None,
) -> dict[str, Any]:
"""Generate images from a text prompt using fal.ai.
Common models:
- fal-ai/flux/dev (high quality, 28 steps)
- fal-ai/flux/schnell (fast, 1-4 steps)
- fal-ai/flux-pro/v1.1 (professional, up to 2K)
- fal-ai/flux-general (supports LoRA, ControlNet, IP-Adapter)
- fal-ai/recraft/v3/text-to-image (illustration style)
Returns paths to saved images and metadata.
"""
body: dict[str, Any] = {
"prompt": prompt,
"image_size": {"width": width, "height": height},
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_images": num_images,
"output_format": output_format,
}
if seed is not None:
body["seed"] = seed
result = _submit_and_wait(model, body)
return _save_images(result, filename or _slugify(prompt))
@mcp.tool()
def generate_with_reference(
prompt: str,
reference_image_url: str,
reference_strength: float = 0.65,
model: str = "fal-ai/flux-general",
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 28,
guidance_scale: float = 3.5,
seed: int | None = None,
num_images: int = 1,
output_format: str = "png",
filename: str | None = None,
) -> dict[str, Any]:
"""Generate images with a style/content reference image.
Uses the reference_image feature of flux-general to guide generation
toward a similar style or content as the reference.
Args:
reference_image_url: URL of the reference image for style guidance.
reference_strength: How strongly to follow the reference (0.0-1.0).
"""
body: dict[str, Any] = {
"prompt": prompt,
"image_size": {"width": width, "height": height},
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_images": num_images,
"output_format": output_format,
"reference_image_url": reference_image_url,
"reference_strength": reference_strength,
}
if seed is not None:
body["seed"] = seed
result = _submit_and_wait(model, body)
return _save_images(result, filename or _slugify(prompt))
@mcp.tool()
def generate_with_lora(
prompt: str,
lora_url: str,
lora_scale: float = 1.0,
model: str = "fal-ai/flux-general",
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 28,
guidance_scale: float = 3.5,
seed: int | None = None,
num_images: int = 1,
output_format: str = "png",
filename: str | None = None,
) -> dict[str, Any]:
"""Generate images with a LoRA model applied.
Args:
lora_url: URL to the LoRA safetensors file (e.g. from HuggingFace).
lora_scale: Strength of the LoRA effect (0.0-2.0, default 1.0).
"""
body: dict[str, Any] = {
"prompt": prompt,
"image_size": {"width": width, "height": height},
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_images": num_images,
"output_format": output_format,
"loras": [{"path": lora_url, "scale": lora_scale}],
}
if seed is not None:
body["seed"] = seed
result = _submit_and_wait(model, body)
return _save_images(result, filename or _slugify(prompt))
@mcp.tool()
def edit_image(
prompt: str,
image_url: str,
model: str = "fal-ai/flux-pro/kontext",
seed: int | None = None,
output_format: str = "png",
filename: str | None = None,
) -> dict[str, Any]:
"""Edit an existing image using natural language instructions (FLUX Kontext).
Pass a reference image and describe the changes you want in the prompt.
Great for iterating on generated images.
Args:
image_url: URL of the image to edit.
prompt: Description of the desired changes.
"""
body: dict[str, Any] = {
"prompt": prompt,
"image_url": image_url,
"output_format": output_format,
}
if seed is not None:
body["seed"] = seed
result = _submit_and_wait(model, body)
return _save_images(result, filename or _slugify(prompt))
@mcp.tool()
def raw_generate(
model: str,
body: dict[str, Any],
filename: str | None = None,
) -> dict[str, Any]:
"""Submit an arbitrary request to any fal.ai model endpoint.
Use this for advanced configurations (ControlNet, IP-Adapter, multi-LoRA,
custom parameters) that aren't covered by the other tools.
Args:
model: The fal.ai model endpoint ID (e.g. "fal-ai/flux-general").
body: The full JSON request body to send to the model.
"""
result = _submit_and_wait(model, body)
name = filename or _slugify(body.get("prompt", "raw"))
return _save_images(result, name)
@mcp.tool()
def list_outputs() -> dict[str, Any]:
"""List all images previously saved in the output directory."""
files = sorted(OUTPUT_DIR.glob("*.png")) + sorted(OUTPUT_DIR.glob("*.jpg"))
return {
"output_dir": str(OUTPUT_DIR),
"files": [
{"name": f.name, "path": str(f), "size_kb": round(f.stat().st_size / 1024, 1)}
for f in files
],
}
# --- Internal helpers ---
def _submit_and_wait(model: str, body: dict[str, Any]) -> dict[str, Any]:
"""Submit a job to the queue API and poll until complete."""
with httpx.Client(timeout=30) as client:
# Submit
resp = client.post(f"{FAL_BASE}/{model}", headers=_headers(), json=body)
resp.raise_for_status()
job = resp.json()
request_id = job["request_id"]
status_url = job["status_url"]
result_url = job["response_url"]
# Poll
elapsed = 0.0
while elapsed < MAX_POLL:
time.sleep(POLL_INTERVAL)
elapsed += POLL_INTERVAL
status_resp = client.get(f"{status_url}?logs=1", headers=_headers())
status_resp.raise_for_status()
status = status_resp.json()
if status.get("status") == "COMPLETED":
# Fetch full result
result_resp = client.get(result_url, headers=_headers())
result_resp.raise_for_status()
return result_resp.json()
if status.get("status") not in ("IN_QUEUE", "IN_PROGRESS"):
raise RuntimeError(
f"Job {request_id} failed with status: {status}"
)
raise TimeoutError(
f"Job {request_id} did not complete within {MAX_POLL}s"
)
def _save_images(result: dict[str, Any], base_name: str) -> dict[str, Any]:
"""Download images from result and save locally."""
images = result.get("images", [])
saved = []
for i, img_data in enumerate(images):
url = img_data["url"]
ext = "png" if "png" in img_data.get("content_type", "png") else "jpg"
suffix = f"_{i}" if len(images) > 1 else ""
dest = OUTPUT_DIR / f"{base_name}{suffix}.{ext}"
_download(url, dest)
saved.append({
"path": str(dest),
"width": img_data.get("width"),
"height": img_data.get("height"),
"url": url,
})
return {
"saved_images": saved,
"seed": result.get("seed"),
"timings": result.get("timings"),
"prompt": result.get("prompt"),
}
def _slugify(text: str) -> str:
"""Create a simple filename from text."""
slug = text.lower().strip()
slug = "".join(c if c.isalnum() or c == " " else "" for c in slug)
slug = "_".join(slug.split()[:6])
return slug or "output"
if __name__ == "__main__":
mcp.run()