Image Generation MCP Server
- mcp-image-gen
- src
- image_gen
from typing import Any, Optional
import asyncio
import httpx
from mcp.server.models import InitializationOptions
import mcp.types as types
from mcp.server import NotificationOptions, Server
import mcp.server.stdio
import os
TOGETHER_AI_BASE = "https://api.together.xyz/v1/images/generations"
API_KEY = os.getenv("TOGETHER_AI_API_KEY")
DEFAULT_MODEL = "black-forest-labs/FLUX.1-schnell"
server = Server("image-gen")
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""
List available tools.
Each tool specifies its arguments using JSON Schema validation.
"""
return [
types.Tool(
name="generate_image",
description="Generate an image based on the text prompt, model, and optional dimensions",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The text prompt for image generation",
},
"model": {
"type": "string",
"description": "The exact model name as it appears in Together AI. If incorrect, it will fallback to the default model (black-forest-labs/FLUX.1-schnell).",
},
"width": {
"type": "number",
"description": "Optional width for the image",
},
"height": {
"type": "number",
"description": "Optional height for the image",
},
},
"required": ["prompt", "model"],
},
)
]
async def make_together_request(
client: httpx.AsyncClient,
prompt: str,
model: str,
width: Optional[int] = None,
height: Optional[int] = None,
) -> dict[str, Any]:
"""Make a request to the Together API with error handling and fallback for incorrect model."""
request_body = {"model": model, "prompt": prompt, "response_format": "b64_json"}
headers = {"Authorization": f"Bearer {API_KEY}"}
if width is not None:
request_body["width"] = width
if height is not None:
request_body["height"] = height
async def send_request(body: dict) -> (int, dict):
response = await client.post(TOGETHER_AI_BASE, headers=headers, json=body)
try:
data = response.json()
except Exception:
data = {}
return response.status_code, data
# First request with user-provided model
status, data = await send_request(request_body)
# Check if the request failed due to an invalid model error
if status != 200 and "error" in data:
error_info = data["error"]
error_msg = error_info.get("message", "").lower()
error_code = error_info.get("code", "").lower()
if (
"model" in error_msg and "not available" in error_msg
) or error_code == "model_not_available":
# Fallback to the default model
request_body["model"] = DEFAULT_MODEL
status, data = await send_request(request_body)
if status != 200 or "error" in data:
return {
"error": f"Fallback API error: {data.get('error', 'Unknown error')} (HTTP {status})"
}
return data
else:
return {"error": f"Together API error: {data.get('error')}"}
elif status != 200:
return {"error": f"HTTP error {status}"}
return data
@server.call_tool()
async def handle_call_tool(
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""
Handle tool execution requests.
Tools can generate images and notify clients of changes.
"""
if not arguments:
return [
types.TextContent(type="text", text="Missing arguments for the request")
]
if name == "generate_image":
prompt = arguments.get("prompt")
model = arguments.get("model")
width = arguments.get("width")
height = arguments.get("height")
if not prompt or not model:
return [
types.TextContent(type="text", text="Missing prompt or model parameter")
]
async with httpx.AsyncClient() as client:
response_data = await make_together_request(
client=client,
prompt=prompt,
model=model, # User-provided model (or fallback will be used)
width=width,
height=height,
)
if "error" in response_data:
return [types.TextContent(type="text", text=response_data["error"])]
try:
b64_image = response_data["data"][0]["b64_json"]
return [
types.ImageContent(
type="image", data=b64_image, mimeType="image/jpeg"
)
]
except (KeyError, IndexError) as e:
return [
types.TextContent(
type="text", text=f"Failed to parse API response: {e}"
)
]
async def main():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="image-gen",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
if __name__ == "__main__":
asyncio.run(main())