Skip to main content
Glama

MCP Server Replicate

server.py50.9 kB
"""FastMCP server implementation for Replicate API.""" import asyncio import base64 import hashlib import hmac import json import logging import os import webbrowser from collections.abc import Sequence from typing import Any import httpx import jsonschema from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.prompts.base import Message, TextContent, UserMessage from mcp.server.session import ServerSession from mcp.types import ( AnyUrl, BlobResourceContents, EmptyResult, ResourceUpdatedNotification, TextResourceContents, ) from pydantic import BaseModel, Field, field_validator from .models.collection import Collection, CollectionList from .models.hardware import Hardware, HardwareList from .models.model import Model, ModelList from .models.webhook import WebhookPayload from .replicate_client import ReplicateClient from .templates.parameters.common_configs import QUALITY_PRESETS, STYLE_PRESETS, TEMPLATES logger = logging.getLogger(__name__) class SubscriptionRequest(BaseModel): """Request model for subscription operations.""" uri: str = Field(..., description="Resource URI to subscribe to") session_id: str = Field(..., description="ID of the session making the request") class GenerationSubscriptionManager: """Manages subscriptions to generation resources.""" def __init__(self): self._subscriptions: dict[str, set[ServerSession]] = {} self._check_task: asyncio.Task | None = None async def subscribe(self, uri: str, session: ServerSession): """Subscribe a session to generation updates.""" prediction_id = uri.replace("generations://", "") if prediction_id not in self._subscriptions: self._subscriptions[prediction_id] = set() self._subscriptions[prediction_id].add(session) # Start checking if not already running if not self._check_task: self._check_task = asyncio.create_task(self._check_generations()) async def unsubscribe(self, uri: str, session: ServerSession): """Unsubscribe a session from generation updates.""" prediction_id = uri.replace("generations://", "") if prediction_id in self._subscriptions: self._subscriptions[prediction_id].discard(session) if not self._subscriptions[prediction_id]: del self._subscriptions[prediction_id] # Stop checking if no more subscriptions if not self._subscriptions and self._check_task: self._check_task.cancel() self._check_task = None async def _check_generations(self): """Periodically check subscribed generations and notify of updates.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: while True: try: for prediction_id, sessions in list(self._subscriptions.items()): try: result = await client.get_prediction_status(prediction_id) # Notify on completion or failure if result["status"] in ["succeeded", "failed", "canceled"]: # For succeeded generations with image output if result["status"] == "succeeded" and result.get("output"): # For image generation models, output is typically a list with the image URL as first item image_url = ( result["output"][0] if isinstance(result["output"], list) else result["output"] ) # First send a notification with just the URL and metadata notification = ResourceUpdatedNotification( method="notifications/resources/updated", params={"uri": f"generations://{prediction_id}"}, ) # Create text resource with metadata and URL text_resource = TextResourceContents( type="text", uri=f"generations://{prediction_id}", mimeType="application/json", text=json.dumps( { "status": "succeeded", "image_url": image_url, "created_at": result.get("created_at"), "completed_at": result.get("completed_at"), "metrics": result.get("metrics", {}), "urls": result.get("urls", {}), "input": result.get("input", {}), }, indent=2, ), ) # Send notification and text resource to all sessions for session in sessions: await session.send_notification(notification) await session.send_resource(text_resource) # Remove from subscriptions since we've notified del self._subscriptions[prediction_id] else: # For failed or canceled generations, create text resource resource = TextResourceContents( uri=f"generations://{prediction_id}", mimeType="application/json", text=json.dumps( { "status": result["status"], "error": result.get("error"), "created_at": result.get("created_at"), "completed_at": result.get("completed_at"), "metrics": result.get("metrics", {}), "urls": result.get("urls", {}), }, indent=2, ), ) # Send notification with the resource notification = ResourceUpdatedNotification( method="notifications/resources/updated", params={"uri": AnyUrl(f"generations://{prediction_id}")}, ) for session in sessions: await session.send_notification(notification) # Also send the resource directly await session.send_resource(resource) # Remove completed/failed generation from subscriptions del self._subscriptions[prediction_id] except Exception as e: logger.error(f"Error checking generation {prediction_id}: {e}") if not self._subscriptions: break await asyncio.sleep(2.0) # Poll every 2 seconds except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in generation check loop: {e}") await asyncio.sleep(5.0) # Back off on errors async def select_model_for_task( task: str, style: str | None = None, quality: str = "balanced", ) -> tuple[Model, dict[str, Any]]: """Select the best model for a given task and get optimal parameters. Args: task: Task description/prompt style: Optional style preference quality: Quality preset (draft, balanced, quality, extreme) Returns: Tuple of (selected model, optimized parameters) """ # Build search query search_query = task if style: search_query = f"{style} style {search_query}" # Search for models async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.search_models(search_query) if not result["models"]: raise ValueError("No suitable models found for the task") # Score and rank models scored_models = [] for model in result["models"]: score = 0 # Popularity score (0-50) run_count = model.get("run_count", 0) score += min(50, (run_count / 1000) * 50) # Featured bonus if model.get("featured"): score += 20 # Version stability if model.get("latest_version"): score += 10 # Tag matching tags = model.get("tags", []) if style and any(style.lower() in tag.lower() for tag in tags): score += 15 if "image" in tags or "text-to-image" in tags: score += 15 scored_models.append((model, score)) # Sort by score scored_models.sort(key=lambda x: x[1], reverse=True) selected_model = scored_models[0][0] # Get quality preset quality_preset = TEMPLATES["quality-presets"]["presets"].get( quality, TEMPLATES["quality-presets"]["presets"]["balanced"] ) # Get style preset if specified parameters = quality_preset["parameters"].copy() if style: style_preset = TEMPLATES["style-presets"]["presets"].get( style.lower(), TEMPLATES["style-presets"]["presets"].get("photorealistic") ) if style_preset: parameters.update(style_preset["parameters"]) return Model(**selected_model), parameters class TemplateInput(BaseModel): """Input for template-based operations.""" template: str = Field(..., description="Template identifier") parameters: dict[str, Any] = Field(default_factory=dict, description="Template parameters") @field_validator("template") def validate_template(cls, v: str) -> str: """Validate template identifier.""" if v not in TEMPLATES: raise ValueError(f"Unknown template: {v}") return v @field_validator("parameters") def validate_parameters(cls, v: dict[str, Any], values: dict[str, Any]) -> dict[str, Any]: """Validate template parameters.""" if "template" not in values: return v template = TEMPLATES[values["template"]] try: jsonschema.validate(v, template["parameter_schema"]) except jsonschema.exceptions.ValidationError as e: raise ValueError(f"Invalid parameters: {e.message}") from e return v class PredictionInput(BaseModel): """Input for prediction operations.""" version: str = Field(..., description="Model version ID") input: dict[str, Any] = Field(..., description="Model input parameters") webhook: str | None = Field(None, description="Webhook URL for prediction updates") @field_validator("input") def validate_input(cls, v: dict[str, Any]) -> dict[str, Any]: """Validate prediction input.""" if not isinstance(v, dict): raise ValueError("Input must be a dictionary") return v def create_server(*, log_level: int = logging.WARNING) -> FastMCP: """Create and configure the FastMCP server. Args: log_level: The logging level to use. Defaults to WARNING. Returns: Configured FastMCP server instance. """ # Configure logging logging.basicConfig(level=log_level) logger.setLevel(log_level) # Verify API token is available api_token = os.getenv("REPLICATE_API_TOKEN") if not api_token: raise ValueError( "REPLICATE_API_TOKEN environment variable is required. " "Get your token from https://replicate.com/account" ) # Create server instance mcp = FastMCP("Replicate Server") # Add resources @mcp.resource("templates://list") def list_available_templates() -> str: """List all available templates with descriptions.""" template_info = [] for name, template in TEMPLATES.items(): template_info.append( f"Template: {name}\n" f"Description: {template.get('description', 'No description')}\n" f"Version: {template.get('version', '1.0.0')}\n" "---" ) return "\n".join(template_info) @mcp.resource("templates://{name}") def get_template_details(name: str) -> str: """Get detailed information about a specific template.""" if name not in TEMPLATES: raise ValueError(f"Template not found: {name}") template = TEMPLATES[name] return json.dumps( { "name": name, "description": template.get("description", ""), "version": template.get("version", "1.0.0"), "parameter_schema": template["parameter_schema"], "examples": template.get("examples", []), }, indent=2, ) @mcp.resource("generations://{prediction_id}") async def get_generation(prediction_id: str) -> TextResourceContents | BlobResourceContents: """Get a specific image generation result.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.get_prediction_status(prediction_id) # If not succeeded, return status info if result["status"] != "succeeded": return TextResourceContents( uri=f"generations://{prediction_id}", mimeType="application/json", text=json.dumps( { "status": result["status"], "created_at": result.get("created_at"), "started_at": result.get("started_at"), "completed_at": result.get("completed_at"), "error": result.get("error"), "logs": result.get("logs"), "urls": result.get("urls", {}), "metrics": result.get("metrics", {}), } ), ) # For succeeded generations, return image URL and metadata image_url = result["output"][0] if isinstance(result["output"], list) else result["output"] return TextResourceContents( uri=f"generations://{prediction_id}", mimeType="application/json", text=json.dumps( { "status": "succeeded", "image_url": image_url, "created_at": result.get("created_at"), "completed_at": result.get("completed_at"), "metrics": result.get("metrics", {}), "urls": result.get("urls", {}), "input": result.get("input", {}), } ), ) @mcp.resource("generations://list") async def list_generations() -> TextResourceContents: """List all available generations with their details and resource URIs.""" async with ReplicateClient() as client: predictions = await client.list_predictions(limit=100) return TextResourceContents( uri="generations://list", mimeType="application/json", text=json.dumps( { "total_count": len(predictions), "generations": [ { "id": p["id"], "status": p["status"], "created_at": p.get("created_at"), "completed_at": p.get("completed_at"), "prompt": p.get("input", {}).get("prompt"), # Extract prompt for easy reference "style": p.get("input", {}).get("style"), # Extract style for filtering "quality": p.get("input", {}).get("quality", "balanced"), "error": p.get("error"), "resource_uri": f"generations://{p['id']}", "metrics": p.get("metrics", {}), # Include performance metrics "urls": p.get("urls", {}), # Include direct URLs } for p in predictions ], }, indent=2, ), ) @mcp.resource("generations://search/{query}") async def search_generations(query: str) -> TextResourceContents: """Search through available generations by prompt text or metadata.""" async with ReplicateClient() as client: predictions = await client.list_predictions(limit=100) # Improved search - check prompt, style, and quality filtered = [] query_lower = query.lower() for p in predictions: input_data = p.get("input", {}) searchable_text = ( f"{input_data.get('prompt', '')} {input_data.get('style', '')} {input_data.get('quality', '')}" ) if query_lower in searchable_text.lower(): filtered.append(p) return TextResourceContents( uri=f"generations://search/{query}", mimeType="application/json", text=json.dumps( { "query": query, "total_count": len(filtered), "generations": [ { "id": p["id"], "status": p["status"], "created_at": p.get("created_at"), "completed_at": p.get("completed_at"), "prompt": p.get("input", {}).get("prompt"), "style": p.get("input", {}).get("style"), "quality": p.get("input", {}).get("quality", "balanced"), "error": p.get("error"), "resource_uri": f"generations://{p['id']}", "metrics": p.get("metrics", {}), "urls": p.get("urls", {}), } for p in filtered ], }, indent=2, ), ) @mcp.resource("generations://status/{status}") async def filter_generations_by_status(status: str) -> TextResourceContents: """Get generations filtered by status (starting, processing, succeeded, failed, canceled).""" async with ReplicateClient() as client: predictions = await client.list_predictions(status=status, limit=100) return TextResourceContents( uri=f"generations://status/{status}", mimeType="application/json", text=json.dumps( { "status": status, "total_count": len(predictions), "generations": [ { "id": p["id"], "created_at": p.get("created_at"), "completed_at": p.get("completed_at"), "prompt": p.get("input", {}).get("prompt"), "style": p.get("input", {}).get("style"), "quality": p.get("input", {}).get("quality", "balanced"), "error": p.get("error"), "resource_uri": f"generations://{p['id']}", "metrics": p.get("metrics", {}), "urls": p.get("urls", {}), } for p in predictions ], }, indent=2, ), ) @mcp.resource("models://popular") async def get_popular_models() -> str: """Get a list of popular models on Replicate.""" async with ReplicateClient() as client: models = await client.list_models() return json.dumps( { "models": models["models"], "total": models.get("total_models", 0), }, indent=2, ) # Add prompts @mcp.prompt() def text_to_image() -> Sequence[Message]: """Generate an image from text using available models.""" return [ UserMessage( content=TextContent( type="text", text=( "I'll help you create an image using Replicate's SDXL model. " "To get the best results, please tell me:\n\n" "1. What you want to see in the image (be specific)\n" "2. Style preferences (e.g., photorealistic, anime, oil painting)\n" "3. Quality level (draft=fast, balanced=default, quality=better, extreme=best)\n" "4. Any specific requirements (size, aspect ratio, etc.)\n\n" "For example: 'Create a photorealistic mountain landscape at sunset with snow-capped peaks, " "quality level, 16:9 aspect ratio'\n\n" "Once you provide these details:\n" "- I'll start the generation and provide real-time updates\n" "- You can wait for completion or start another generation\n" "- When ready, I can show you the image or open it on your system\n" "- You can also browse, search, or manage your generations" ), ) ) ] @mcp.prompt() def image_to_image() -> Sequence[Message]: """Transform an existing image using various models.""" return [ UserMessage( content=TextContent( type="text", text=( "I'll help you transform an existing image using Replicate's models. " "Please provide:\n\n" "1. The URL or path to your source image\n" "2. The type of transformation you want (e.g., style transfer, upscaling, inpainting)\n" "3. Any specific settings or parameters\n\n" "I'll help you choose the right model and format your request." ), ) ) ] @mcp.prompt() def model_selection(task: str | None = None) -> Sequence[Message]: """Help choose the right model for a specific task.""" base_prompt = ( "I'll help you select the best Replicate model for your needs. " "Please tell me about your task and requirements, including:\n\n" "1. The type of input you have\n" "2. Your desired output\n" "3. Any specific quality or performance requirements\n" "4. Any budget or hardware constraints\n\n" ) if task: base_prompt += f"\nFor {task}, I recommend considering these aspects:\n" if "image" in task.lower(): base_prompt += ( "- Input/output image dimensions\n" "- Style and quality requirements\n" "- Processing speed needs\n" ) elif "text" in task.lower(): base_prompt += ( "- Input length considerations\n" "- Output format requirements\n" "- Specific language needs\n" ) return [UserMessage(content=TextContent(type="text", text=base_prompt))] @mcp.prompt() def parameter_help(template: str | None = None) -> Sequence[Message]: """Get help with model parameters and templates.""" if template and template in TEMPLATES: tmpl = TEMPLATES[template] text = ( f"I'll help you with the {template} template.\n\n" f"Description: {tmpl.get('description', 'No description')}\n" "Required Parameters:\n" + "\n".join( f"- {param}: {schema.get('description', 'No description')}" for param, schema in tmpl["parameter_schema"]["properties"].items() if param in tmpl["parameter_schema"].get("required", []) ) + "\n\nOptional Parameters:\n" + "\n".join( f"- {param}: {schema.get('description', 'No description')}" for param, schema in tmpl["parameter_schema"]["properties"].items() if param not in tmpl["parameter_schema"].get("required", []) ) ) else: text = ( "I'll help you understand and configure model parameters. " "Please provide:\n\n" "1. The model or template you're using\n" "2. Any specific parameters you need help with\n" "3. Your use case or requirements\n\n" "I'll explain the parameters and suggest appropriate values." ) return [UserMessage(content=TextContent(type="text", text=text))] @mcp.prompt() def after_generation() -> Sequence[Message]: """Prompt shown after starting an image generation.""" return [ UserMessage( content=TextContent( type="text", text=( "Your image generation has started! You have several options:\n\n" "1. Wait here - I'll check the progress and show you the image when it's ready\n" "2. Browse your generations - I can show you a list of all your generations\n" "3. Start another generation - We can create more images while this one processes\n\n" "When the image is ready, I can:\n" "- Show you the image directly\n" "- Open it with your system's default image viewer\n" "- Save it or share it\n" "- Create variations or apply transformations\n\n" "What would you like to do?" ), ) ) ] # Model Discovery Tools @mcp.tool() async def list_models(owner: str | None = None) -> ModelList: """List available models on Replicate with optional filtering by owner.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.list_models(owner=owner) return ModelList( models=[Model(**model) for model in result["models"]], next_cursor=result.get("next_cursor"), total_count=result.get("total_models"), ) @mcp.tool() async def search_models(query: str) -> ModelList: """Search for models using semantic search.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.search_models(query) models = [Model(**model) for model in result["models"]] # Sort by run count as a proxy for popularity/reliability models.sort(key=lambda m: m.run_count if m.run_count else 0, reverse=True) return ModelList( models=models, next_cursor=result.get("next_cursor"), total_count=result.get("total_models") ) # Collection Management Tools @mcp.tool() async def list_collections() -> CollectionList: """List available model collections on Replicate.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.list_collections() return CollectionList(collections=[Collection(**collection) for collection in result]) @mcp.tool() async def get_collection_details(collection_slug: str) -> Collection: """Get detailed information about a specific collection.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.get_collection(collection_slug) return Collection(**result) # Hardware Tools @mcp.tool() async def list_hardware() -> HardwareList: """List available hardware options for running models.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.list_hardware() return HardwareList(hardware=[Hardware(**hw) for hw in result]) # Template Tools @mcp.tool() def list_templates() -> dict[str, Any]: """List all available templates with their schemas.""" return { name: { "schema": template["parameter_schema"], "description": template.get("description", ""), "version": template.get("version", "1.0.0"), } for name, template in TEMPLATES.items() } @mcp.tool() def validate_template_parameters(input: dict[str, Any]) -> bool: """Validate parameters against a template schema.""" template_input = TemplateInput(**input) return True # If we get here, validation passed # Prediction Tools @mcp.tool() async def create_prediction(input: dict[str, Any], confirmed: bool = False) -> dict[str, Any]: """Create a new prediction using a specific model version on Replicate. Args: input: Model input parameters including version or model details confirmed: Whether the user has explicitly confirmed the generation Returns: Prediction details if confirmed, or a confirmation request if not """ # If not confirmed, return info about what will be generated if not confirmed: # Extract model info for display model_info = "" if "version" in input: model_info = f"version: {input['version']}" elif "model_owner" in input and "model_name" in input: model_info = f"model: {input['model_owner']}/{input['model_name']}" return { "requires_confirmation": True, "message": ( "⚠️ This will use Replicate credits to generate an image with these parameters:\n\n" f"Model: {model_info}\n" f"Prompt: {input.get('prompt', 'Not specified')}\n" f"Quality: {input.get('quality', 'balanced')}\n\n" "Please confirm if you want to proceed with the generation." ), } async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: # If version is provided directly, use it if "version" in input: version = input.pop("version") # Otherwise, try to find the model and get its latest version elif "model_owner" in input and "model_name" in input: model_id = f"{input.pop('model_owner')}/{input.pop('model_name')}" search_result = await client.search_models(model_id) if not search_result["models"]: raise ValueError(f"Model not found: {model_id}") model = search_result["models"][0] if not model.get("latest_version"): raise ValueError(f"No versions found for model: {model_id}") version = model["latest_version"]["id"] else: raise ValueError("Must provide either 'version' or both 'model_owner' and 'model_name'") # Create prediction with remaining parameters as input result = await client.create_prediction(version=version, input=input, webhook=input.pop("webhook", None)) # Return result with prompt about waiting return { **result, "_next_prompt": "after_generation", # Signal to show the waiting prompt } @mcp.tool() async def get_prediction(prediction_id: str, wait: bool = False, max_retries: int | None = None) -> dict[str, Any]: """Get the status and results of a prediction.""" consecutive_errors = 0 while True: try: async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: response = await client.http_client.get( f"/predictions/{prediction_id}", headers={ "Authorization": f"Bearer {client.api_token}", "Content-Type": "application/json", }, ) response.raise_for_status() data = response.json() # Build URL message with all available links urls_msg = [] # Add streaming URL if available if data.get("urls", {}).get("stream"): urls_msg.append(f"🔄 Stream URL: {data['urls']['stream']}") # Add get URL if available if data.get("urls", {}).get("get"): urls_msg.append(f"📡 Status URL: {data['urls']['get']}") # Add cancel URL if available and still processing if data.get("urls", {}).get("cancel") and data["status"] in ["starting", "processing"]: urls_msg.append(f"🛑 Cancel URL: {data['urls']['cancel']}") # If prediction is complete and has an image output if data["status"] == "succeeded" and data.get("output"): # For image generation models, output is typically a list with the image URL as first item image_url = data["output"][0] if isinstance(data["output"], list) else data["output"] return { "id": data["id"], "status": "succeeded", "image_url": image_url, "resource_uri": f"generations://{prediction_id}", "image_resource_uri": f"generations://{prediction_id}/image", "created_at": data.get("created_at"), "completed_at": data.get("completed_at"), "metrics": data.get("metrics", {}), "urls": data.get("urls", {}), "message": ( "🎨 Generation completed successfully!\n\n" f"You can:\n" f"1. View the image at: {image_url}\n" f"2. Open it with your system viewer (just ask me to open it)\n" f"3. Access image data at: generations://{prediction_id}/image\n" f"4. Get metadata at: generations://{prediction_id}\n" f"5. Create variations or apply transformations\n\n" "Would you like me to open the image for you?\n\n" "Available URLs:\n" + "\n".join(urls_msg) ), } # If prediction failed or was cancelled if data["status"] in ["failed", "canceled"]: error_msg = data.get("error", "No error details available") return { "id": data["id"], "status": data["status"], "error": error_msg, "resource_uri": f"generations://{prediction_id}", "created_at": data.get("created_at"), "completed_at": data.get("completed_at"), "metrics": data.get("metrics", {}), "urls": data.get("urls", {}), "message": ( f"❌ Generation {data['status']}: {error_msg}\n\n" "Available URLs:\n" + "\n".join(urls_msg) ), } # If we're still processing and not waiting, return status if not wait: return { "id": data["id"], "status": data["status"], "resource_uri": f"generations://{prediction_id}", "created_at": data.get("created_at"), "started_at": data.get("started_at"), "metrics": data.get("metrics", {}), "urls": data.get("urls", {}), "message": ( f"⏳ Generation is {data['status']}...\n" "You can:\n" "1. Keep waiting (I'll check again)\n" "2. Use the URLs above to check progress yourself\n" "3. Cancel the generation if needed" ), } # Reset error count on successful request consecutive_errors = 0 # Wait before polling again await asyncio.sleep(2.0) # Increased poll interval to reduce API load except Exception as e: logger.error(f"Error checking prediction status: {str(e)}") consecutive_errors += 1 # Only stop if we have a max_retries set and exceeded it if max_retries is not None and consecutive_errors >= max_retries: return { "id": prediction_id, "status": "error", "error": str(e), "resource_uri": f"generations://{prediction_id}", "message": ( f"⚠️ Having trouble checking the prediction status (tried {consecutive_errors} times).\n\n" f"The prediction might still be running! You can:\n" "1. Try checking again in a few minutes\n" "2. Visit the status URL directly: " f"https://replicate.com/p/{prediction_id}\n" "3. Start a new check with a higher retry limit\n\n" f"Last error: {str(e)}" ), } # Wait with exponential backoff before retrying await asyncio.sleep(min(30, 2.0**consecutive_errors)) # Cap at 30 seconds @mcp.tool() async def cancel_prediction(prediction_id: str) -> dict[str, Any]: """Cancel a running prediction.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: response = await client.cancel_prediction(prediction_id) return await response.json() # Webhook Tools @mcp.tool() async def get_webhook_secret() -> str: """Get the signing secret for verifying webhook requests.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: return await client.get_webhook_secret() @mcp.tool() def verify_webhook(payload: WebhookPayload, signature: str, secret: str) -> bool: """Verify that a webhook request came from Replicate using HMAC-SHA256. Args: payload: The webhook payload to verify signature: The signature from the X-Replicate-Signature header secret: The webhook signing secret from get_webhook_secret Returns: True if signature is valid, False otherwise """ if not signature or not secret: return False # Convert payload to canonical JSON string payload_str = json.dumps(payload.model_dump(), sort_keys=True) # Calculate expected signature expected = hmac.new(secret.encode(), payload_str.encode(), hashlib.sha256).hexdigest() # Compare signatures using constant-time comparison return hmac.compare_digest(signature, expected) @mcp.tool() async def search_available_models( query: str, style: str | None = None, ) -> ModelList: """Search for available models matching the query. Args: query: Search query describing the desired model style: Optional style to filter by Returns: List of matching models with scores """ search_query = query if style: search_query = f"{style} style {search_query}" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.search_models(search_query) models = [Model(**model) for model in result["models"]] # Score models but don't auto-select scored_models = [] for model in models: score = 0 run_count = getattr(model, "run_count", 0) or 0 score += min(50, (run_count / 1000) * 50) if getattr(model, "featured", False): score += 20 if model.latest_version: score += 10 tags = getattr(model, "tags", []) if style and any(style.lower() in tag.lower() for tag in tags): score += 15 if "image" in tags or "text-to-image" in tags: score += 15 scored_models.append((model, score)) # Sort by score but return all for user selection scored_models.sort(key=lambda x: x[1], reverse=True) return ModelList( models=[m[0] for m in scored_models], next_cursor=result.get("next_cursor"), total_count=result.get("total_count"), ) @mcp.tool() async def get_model_details(model_id: str) -> Model: """Get detailed information about a specific model. Args: model_id: Model identifier in format owner/name Returns: Detailed model information """ async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.get_model_details(model_id) return Model(**result) @mcp.tool() async def generate_image( prompt: str, style: str | None = None, quality: str = "balanced", width: int | None = None, height: int | None = None, num_outputs: int = 1, seed: int | None = None, ) -> dict[str, Any]: """Generate an image using the specified parameters.""" # Get quality preset parameters if quality not in QUALITY_PRESETS["presets"]: quality = "balanced" parameters = QUALITY_PRESETS["presets"][quality]["parameters"].copy() # Apply style preset if specified if style: if style in STYLE_PRESETS["presets"]: style_params = STYLE_PRESETS["presets"][style]["parameters"] # Merge prompt prefixes if "prompt_prefix" in style_params: prompt = f"{style_params['prompt_prefix']}, {prompt}" # Copy other parameters for k, v in style_params.items(): if k != "prompt_prefix": parameters[k] = v # Override size if specified if width: parameters["width"] = width if height: parameters["height"] = height # Add other parameters parameters.update( { "prompt": prompt, "num_outputs": num_outputs, } ) if seed is not None: parameters["seed"] = seed # Create prediction with SDXL model async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.create_prediction( version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", # SDXL v1.0 input=parameters, ) # Return resource information return { "resource_uri": f"generations://{result['id']}", "status": result["status"], "message": ( "🎨 Starting your image generation...\n\n" f"Prompt: {prompt}\n" f"Style: {style or 'Default'}\n" f"Quality: {quality}\n\n" "Let me check the status of your generation. I'll use:\n" f"`get_prediction(\"{result['id']}\", wait=true)`\n\n" "This will let me monitor the progress and show you the image as soon as it's ready." ), "next_prompt": "after_generation", "metadata": { # Add metadata for client use "prompt": prompt, "style": style, "quality": quality, "width": width, "height": height, "seed": seed, "model": "SDXL v1.0", "created_at": result.get("created_at"), }, } # Initialize subscription manager subscription_manager = GenerationSubscriptionManager() @mcp.tool() async def subscribe_to_generation(request: SubscriptionRequest) -> EmptyResult: """Handle resource subscription requests.""" if request.uri.startswith("generations://"): session = ServerSession(request.session_id) await subscription_manager.subscribe(request.uri, session) return EmptyResult() @mcp.tool() async def unsubscribe_from_generation(request: SubscriptionRequest) -> EmptyResult: """Handle resource unsubscribe requests.""" if request.uri.startswith("generations://"): session = ServerSession(request.session_id) await subscription_manager.unsubscribe(request.uri, session) return EmptyResult() @mcp.resource("generations://{prediction_id}/image") async def get_generation_image(prediction_id: str) -> BlobResourceContents: """Get the image data for a completed generation.""" async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: result = await client.get_prediction_status(prediction_id) if result["status"] != "succeeded": raise ValueError(f"Generation not completed: {result['status']}") if not result.get("output"): raise ValueError("No image output available") # Get image URL image_url = result["output"][0] if isinstance(result["output"], list) else result["output"] # Download image async with httpx.AsyncClient() as http_client: img_response = await http_client.get(image_url) img_response.raise_for_status() # Determine mime type from URL extension ext = image_url.split(".")[-1].lower() mime_type = { "png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", "gif": "image/gif", "webp": "image/webp", }.get(ext, "image/png") # Return blob contents return BlobResourceContents( type="blob", mimeType=mime_type, uri=image_url, blob=base64.b64encode(img_response.content).decode("ascii"), description="Generated image data", ) @mcp.tool() async def open_image_with_system(image_url: str) -> dict[str, Any]: """Open an image URL with the system's default application. Args: image_url: URL of the image to open Returns: Dict containing status of the operation """ try: # Open URL directly with system default webbrowser.open(image_url) return {"status": "success", "message": "Image opened with system default application", "url": image_url} except Exception as e: logger.error(f"Failed to open image: {str(e)}") return {"status": "error", "message": f"Failed to open image: {str(e)}", "url": image_url} return mcp

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/gerred/mcp-server-replicate'

If you have feedback or need assistance with the MCP directory API, please join our Discord server