create_prediction
Generate predictions using AI models on Replicate by specifying model inputs and confirming execution for image creation or inference tasks.
Instructions
Create a new prediction using a specific model version on Replicate.
Args:
input: Model input parameters including version or model details
confirmed: Whether the user has explicitly confirmed the generation
Returns:
Prediction details if confirmed, or a confirmation request if not
Input Schema
TableJSON Schema
| Name | Required | Description | Default |
|---|---|---|---|
| input | Yes | ||
| confirmed | No |
Implementation Reference
- The primary MCP tool handler for 'create_prediction'. Handles user confirmation, model/version resolution, and delegates to ReplicateClient for actual prediction creation.@mcp.tool() async def create_prediction(input: dict[str, Any], confirmed: bool = False) -> dict[str, Any]: """Create a new prediction using a specific model version on Replicate. Args: input: Model input parameters including version or model details confirmed: Whether the user has explicitly confirmed the generation Returns: Prediction details if confirmed, or a confirmation request if not """ # If not confirmed, return info about what will be generated if not confirmed: # Extract model info for display model_info = "" if "version" in input: model_info = f"version: {input['version']}" elif "model_owner" in input and "model_name" in input: model_info = f"model: {input['model_owner']}/{input['model_name']}" return { "requires_confirmation": True, "message": ( "⚠️ This will use Replicate credits to generate an image with these parameters:\n\n" f"Model: {model_info}\n" f"Prompt: {input.get('prompt', 'Not specified')}\n" f"Quality: {input.get('quality', 'balanced')}\n\n" "Please confirm if you want to proceed with the generation." ), } async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: # If version is provided directly, use it if "version" in input: version = input.pop("version") # Otherwise, try to find the model and get its latest version elif "model_owner" in input and "model_name" in input: model_id = f"{input.pop('model_owner')}/{input.pop('model_name')}" search_result = await client.search_models(model_id) if not search_result["models"]: raise ValueError(f"Model not found: {model_id}") model = search_result["models"][0] if not model.get("latest_version"): raise ValueError(f"No versions found for model: {model_id}") version = model["latest_version"]["id"] else: raise ValueError("Must provide either 'version' or both 'model_owner' and 'model_name'") # Create prediction with remaining parameters as input result = await client.create_prediction(version=version, input=input, webhook=input.pop("webhook", None)) # Return result with prompt about waiting return { **result, "_next_prompt": "after_generation", # Signal to show the waiting prompt }
- The ReplicateClient helper method called by the tool handler to perform the actual API call to create the prediction on Replicate.async def create_prediction( self, version: str, input: Dict[str, Any], webhook: Optional[str] = None, ) -> Dict[str, Any]: """Create a new prediction using a model version. Args: version: Model version ID input: Model input parameters webhook: Optional webhook URL for prediction updates Returns: Dict containing prediction details Raises: Exception: If the prediction creation fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: await self._ensure_http_client() # Prepare request body body = { "version": version, "input": input, } if webhook: body["webhook"] = webhook # Create prediction using rate-limited request response = await self._make_request( "POST", "/predictions", json=body ) data = response.json() # Format response result = { "id": data["id"], "status": data["status"], "input": data["input"], "output": data.get("output"), "error": data.get("error"), "logs": data.get("logs"), "created_at": data.get("created_at"), "started_at": data.get("started_at"), "completed_at": data.get("completed_at"), "urls": data.get("urls", {}), } # Add metrics if available if "metrics" in data: result["metrics"] = data["metrics"] return result except Exception as err: logger.error(f"Failed to create prediction: {str(err)}") raise Exception(f"Failed to create prediction: {str(err)}") from err
- src/mcp_server_replicate/server.py:702-702 (registration)The @mcp.tool() decorator registers the create_prediction function as an MCP tool named 'create_prediction' (defaults to function name).@mcp.tool()