replicate-flux-mcp

by awkoy
Verified
  • src
#!/usr/bin/env node import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { ErrorCode, McpError } from "@modelcontextprotocol/sdk/types.js"; import { z } from "zod"; import Replicate from "replicate"; // Configuration const CONFIG = { serverName: "replicate-flux-mcp", serverVersion: "0.0.1", modelId: "black-forest-labs/flux-schnell", pollingAttempts: 5, pollingInterval: 2000, // ms }; // Initialize MCP server const server = new McpServer({ name: CONFIG.serverName, version: CONFIG.serverVersion, }); // Environment validation function getReplicateApiToken(): string { const token = process.env.REPLICATE_API_TOKEN; if (!token) { console.error( "Error: REPLICATE_API_TOKEN environment variable is required" ); process.exit(1); } return token; } // Initialize Replicate client const replicate = new Replicate({ auth: getReplicateApiToken(), }); // Schema definitions const imageGenerationSchema = { prompt: z.string().min(1).describe("Prompt for generated image"), seed: z .number() .int() .optional() .describe("Random seed. Set for reproducible generation"), go_fast: z .boolean() .default(true) .describe( "Run faster predictions with model optimized for speed (currently fp8 quantized); disable to run in original bf16" ), megapixels: z .enum(["1", "0.25"]) .default("1") .describe("Approximate number of megapixels for generated image"), num_outputs: z .number() .int() .min(1) .max(4) .default(1) .describe("Number of outputs to generate"), aspect_ratio: z .enum([ "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21", ]) .default("1:1") .describe("Aspect ratio for the generated image"), output_format: z .enum(["webp", "jpg", "png"]) .default("webp") .describe("Format of the output images"), output_quality: z .number() .int() .min(0) .max(100) .default(80) .describe( "Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs" ), num_inference_steps: z .number() .int() .min(1) .max(4) .default(4) .describe( "Number of denoising steps. 4 is recommended, and lower number of steps produce lower quality outputs, faster." ), disable_safety_checker: z .boolean() .default(false) .describe("Disable safety checker for generated images."), }; // Helper functions async function pollForCompletion(predictionId: string) { for (let i = 0; i < CONFIG.pollingAttempts; i++) { const latest = await replicate.predictions.get(predictionId); if (latest.status !== "starting" && latest.status !== "processing") { return latest; } await new Promise((resolve) => setTimeout(resolve, CONFIG.pollingInterval)); } return null; } function handleError(error: unknown): never { if (error instanceof Error) { throw new McpError(ErrorCode.InternalError, error.message); } throw new McpError(ErrorCode.InternalError, String(error)); } // Register tools server.tool( "generate_image", "Generate an image from a text prompt using Flux Schnell model", imageGenerationSchema, async (input) => { try { const prediction = await replicate.predictions.create({ model: CONFIG.modelId, input, }); // Initial check await replicate.predictions.get(prediction.id); // Poll for completion const completed = await pollForCompletion(prediction.id); return { content: [ { type: "text", text: JSON.stringify(completed || "Processing timed out", null, 2), }, ], }; } catch (error) { handleError(error); } } ); server.tool( "prediction_list", "Get a list of recent predictions from Replicate", { limit: z .number() .int() .min(1) .max(100) .default(50) .describe("Maximum number of predictions to return"), }, async ({ limit }) => { try { const predictions = []; for await (const page of replicate.paginate(replicate.predictions.list)) { predictions.push(...page); if (predictions.length >= limit) { break; } } // Trim to exact limit const limitedPredictions = predictions.slice(0, limit); const totalPages = Math.ceil(predictions.length / limit); return { content: [ { type: "text", text: `Found ${limitedPredictions.length} predictions (showing ${limitedPredictions.length} of ${predictions.length} total, page 1 of ${totalPages})`, }, { type: "text", text: JSON.stringify(limitedPredictions, null, 2), }, ], }; } catch (error) { handleError(error); } } ); server.tool( "get_prediction", "Get details of a specific prediction by ID", { predictionId: z .string() .min(1) .describe("ID of the prediction to retrieve"), }, async ({ predictionId }) => { try { const prediction = await replicate.predictions.get(predictionId); return { content: [ { type: "text", text: JSON.stringify(prediction, null, 2), }, ], }; } catch (error) { handleError(error); } } ); // Server initialization async function main() { try { const transport = new StdioServerTransport(); await server.connect(transport); console.error( `${CONFIG.serverName} v${CONFIG.serverVersion} running on stdio` ); } catch (error) { console.error( "Server initialization error:", error instanceof Error ? error.message : String(error) ); process.exit(1); } } main().catch((error: unknown) => { console.error( "Unhandled server error:", error instanceof Error ? error.message : String(error) ); process.exit(1); });