Skip to main content
Glama

MCP Server Replicate

replicate_client.py29.9 kB
"""Replicate API client implementation.""" import logging import os import time from typing import Any, Optional, Dict, AsyncGenerator import asyncio import random import httpx import replicate logger = logging.getLogger(__name__) # Constants REPLICATE_API_BASE = "https://api.replicate.com/v1" DEFAULT_TIMEOUT = 60.0 MAX_RETRIES = 3 MIN_RETRY_DELAY = 1.0 MAX_RETRY_DELAY = 10.0 DEFAULT_RATE_LIMIT = 100 # requests per minute class RateLimitExceeded(Exception): """Raised when rate limit is exceeded.""" def __init__(self, retry_after: float): """Initialize with retry after duration.""" self.retry_after = retry_after super().__init__(f"Rate limit exceeded. Retry after {retry_after} seconds.") class ReplicateClient: """Client for interacting with the Replicate API.""" def __init__(self, api_token: str | None = None) -> None: """Initialize the Replicate client. Args: api_token: Replicate API token for authentication """ self.client = None self.error = None self.api_token = api_token self._rate_limit = DEFAULT_RATE_LIMIT self._request_times: list[float] = [] self._retry_count = 0 self.http_client = None # Initialize to None, will be set up in __aenter__ if not api_token or not api_token.strip(): self.error = "Replicate API token is required" return os.environ["REPLICATE_API_TOKEN"] = api_token self.client = replicate.Client() async def __aenter__(self): """Async context manager entry.""" # Initialize httpx client for direct API calls self.http_client = httpx.AsyncClient( base_url=REPLICATE_API_BASE, headers={ "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", }, timeout=DEFAULT_TIMEOUT ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" if self.http_client: await self.http_client.aclose() self.http_client = None async def _ensure_http_client(self): """Ensure http_client is initialized.""" if not self.http_client: self.http_client = httpx.AsyncClient( base_url=REPLICATE_API_BASE, headers={ "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", }, timeout=DEFAULT_TIMEOUT ) async def _wait_for_rate_limit(self) -> None: """Wait if necessary to comply with rate limiting.""" now = time.time() # Remove request times older than 1 minute self._request_times = [t for t in self._request_times if now - t <= 60] if len(self._request_times) >= self._rate_limit: # Calculate wait time based on oldest request wait_time = 60 - (now - self._request_times[0]) if wait_time > 0: logger.debug(f"Rate limit reached. Waiting {wait_time:.2f} seconds") await asyncio.sleep(wait_time) self._request_times.append(now) async def _handle_response(self, response: httpx.Response) -> None: """Handle rate limits and other response headers. Args: response: The HTTP response to handle Raises: RateLimitExceeded: If rate limit is exceeded """ # Update rate limit from headers if available if "X-RateLimit-Limit" in response.headers: self._rate_limit = int(response.headers["X-RateLimit-Limit"]) # Handle rate limit exceeded if response.status_code == 429: retry_after = float(response.headers.get("Retry-After", 60)) raise RateLimitExceeded(retry_after) async def _make_request( self, method: str, endpoint: str, **kwargs: Any ) -> httpx.Response: """Make an HTTP request with retries and rate limiting. Args: method: HTTP method to use endpoint: API endpoint to call **kwargs: Additional arguments to pass to httpx Returns: HTTP response Raises: Exception: If the request fails after retries """ await self._wait_for_rate_limit() for attempt in range(MAX_RETRIES): try: response = await self.http_client.request(method, endpoint, **kwargs) await self._handle_response(response) response.raise_for_status() self._retry_count = 0 # Reset on success return response except RateLimitExceeded as e: logger.warning(f"Rate limit exceeded. Waiting {e.retry_after} seconds") await asyncio.sleep(e.retry_after) continue except httpx.HTTPError as e: self._retry_count += 1 if attempt == MAX_RETRIES - 1: raise # Calculate exponential backoff with jitter delay = min( MAX_RETRY_DELAY, MIN_RETRY_DELAY * (2 ** attempt) + random.uniform(0, 1) ) logger.warning( f"Request failed: {str(e)}. " f"Retrying in {delay:.2f} seconds " f"(attempt {attempt + 1}/{MAX_RETRIES})" ) await asyncio.sleep(delay) continue def list_models(self, owner: str | None = None, cursor: str | None = None) -> dict[str, Any]: """List available models on Replicate with pagination. Args: owner: Optional owner username to filter models cursor: Pagination cursor from previous response Returns: Dict containing models list, next cursor, and total count Raises: Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: # Build params dict only including cursor if provided params = {} if cursor: params["cursor"] = cursor # Get models collection with pagination models = self.client.models.list(**params) # Get pagination info next_cursor = models.next_cursor if hasattr(models, "next_cursor") else None total_models = models.total if hasattr(models, "total") else None # Filter by owner if specified if owner: models = [m for m in models if m.owner == owner] # Format models with complete structure formatted_models = [] for model in models: model_data = { "id": f"{model.owner}/{model.name}", "owner": model.owner, "name": model.name, "description": model.description, "visibility": model.visibility, "github_url": getattr(model, "github_url", None), "paper_url": getattr(model, "paper_url", None), "license_url": getattr(model, "license_url", None), "run_count": getattr(model, "run_count", None), "cover_image_url": getattr(model, "cover_image_url", None), "default_example": getattr(model, "default_example", None), "featured": getattr(model, "featured", None), "tags": getattr(model, "tags", []), } # Add latest version info if available if model.latest_version: model_data["latest_version"] = { "id": model.latest_version.id, "created_at": model.latest_version.created_at, "cog_version": model.latest_version.cog_version, "openapi_schema": model.latest_version.openapi_schema, "model": f"{model.owner}/{model.name}", "replicate_version": getattr(model.latest_version, "replicate_version", None), "hardware": getattr(model.latest_version, "hardware", None), } formatted_models.append(model_data) return { "models": formatted_models, "next_cursor": next_cursor, "total_count": total_models, } except Exception as err: logger.error(f"Failed to list models: {str(err)}") raise Exception(f"Failed to list models: {str(err)}") from err def get_model_versions(self, model: str) -> list[dict[str, Any]]: """Get available versions for a model. Args: model: Model identifier in format 'owner/model' Returns: List of model versions with their metadata Raises: ValueError: If the model is not found Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: # Get model model_obj = self.client.models.get(model) if not model_obj: raise ValueError(f"Model not found: {model}") # Get versions versions = model_obj.versions.list() # Return minimal version metadata return [ { "id": version.id, "created_at": version.created_at.isoformat() if version.created_at else None, "cog_version": version.cog_version, "openapi_schema": version.openapi_schema, } for version in versions ] except ValueError as err: logger.error(f"Validation error: {str(err)}") raise except Exception as err: logger.error(f"Failed to get model versions: {str(err)}") raise Exception(f"Failed to get model versions: {str(err)}") from err async def predict( self, model: str, input_data: dict[str, Any], version: str | None = None, wait: bool = False, wait_timeout: int | None = None, stream: bool = False, ) -> dict[str, Any]: """Run a prediction using a Replicate model. Args: model: Model identifier in format 'owner/model' input_data: Model-specific input parameters version: Optional model version hash wait: Whether to wait for prediction completion wait_timeout: Max seconds to wait if wait=True (1-60) stream: Whether to request streaming output Returns: Dict containing prediction details and optional stream URL Raises: ValueError: If the model or version is not found Exception: If the prediction fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: # Validate wait_timeout if wait and wait_timeout: if not 1 <= wait_timeout <= 60: raise ValueError("wait_timeout must be between 1 and 60 seconds") # Get model model_obj = self.client.models.get(model) if not model_obj: raise ValueError(f"Model not found: {model}") # Get specific version or latest if version: model_version = model_obj.versions.get(version) if not model_version: raise ValueError(f"Version not found: {version}") else: model_version = model_obj.latest_version # Prepare headers headers = { "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", } if wait: if wait_timeout: headers["Prefer"] = f"wait={wait_timeout}" else: headers["Prefer"] = "wait" # Prepare request body body = { "input": input_data, "stream": stream, } if version: body["version"] = version # Create prediction using rate-limited request response = await self._make_request( "POST", "/predictions", headers=headers, json=body ) data = response.json() # Format response result = { "id": data["id"], "status": data["status"], "input": data["input"], "output": data.get("output"), "error": data.get("error"), "logs": data.get("logs"), "created_at": data.get("created_at"), "started_at": data.get("started_at"), "completed_at": data.get("completed_at"), "urls": data.get("urls", {}), } # Add metrics if available if "metrics" in data: result["metrics"] = data["metrics"] # Add stream URL if requested and available if stream and "urls" in data and "stream" in data["urls"]: result["stream_url"] = data["urls"]["stream"] return result except ValueError as err: logger.error(f"Validation error: {str(err)}") raise except httpx.HTTPError as err: logger.error(f"HTTP error during prediction: {str(err)}") raise Exception(f"Prediction failed: {str(err)}") from err except Exception as err: logger.error(f"Prediction failed: {str(err)}") raise Exception(f"Prediction failed: {str(err)}") from err def get_prediction_status(self, prediction_id: str) -> dict[str, Any]: """Get the status of a prediction. Args: prediction_id: ID of the prediction to check Returns: Dict containing current status and output of the prediction Raises: ValueError: If the prediction is not found Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: # Get prediction prediction = self.client.predictions.get(prediction_id) if not prediction: raise ValueError(f"Prediction not found: {prediction_id}") # Return prediction status and output return { "id": prediction.id, "status": prediction.status, "output": prediction.output, "error": prediction.error, "created_at": prediction.created_at.isoformat() if prediction.created_at else None, "started_at": prediction.started_at.isoformat() if prediction.started_at else None, "completed_at": prediction.completed_at.isoformat() if prediction.completed_at else None, "urls": prediction.urls, "metrics": prediction.metrics, } except ValueError as err: logger.error(f"Validation error: {str(err)}") raise except Exception as err: logger.error(f"Failed to get prediction status: {str(err)}") raise Exception(f"Failed to get prediction status: {str(err)}") from err async def search_models( self, query: str, cursor: Optional[str] = None, ) -> dict[str, Any]: """Search for models using the QUERY endpoint. Args: query: Search query string cursor: Optional pagination cursor Returns: Dict containing search results with pagination info Raises: Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: # Build URL with cursor if provided url = "/models" if cursor: url = f"{url}?cursor={cursor}" # Make QUERY request response = await self.http_client.request( "QUERY", url, content=query, headers={"Content-Type": "text/plain"} ) response.raise_for_status() data = response.json() # Format response with complete model structure return { "models": [ { "id": f"{model['owner']}/{model['name']}", "owner": model["owner"], "name": model["name"], "description": model.get("description"), "visibility": model.get("visibility", "public"), "github_url": model.get("github_url"), "paper_url": model.get("paper_url"), "license_url": model.get("license_url"), "run_count": model.get("run_count"), "cover_image_url": model.get("cover_image_url"), "default_example": model.get("default_example"), "featured": model.get("featured", False), "tags": model.get("tags", []), "latest_version": model.get("latest_version", { "id": model.get("latest_version", {}).get("id"), "created_at": model.get("latest_version", {}).get("created_at"), "cog_version": model.get("latest_version", {}).get("cog_version"), "openapi_schema": model.get("latest_version", {}).get("openapi_schema"), "model": f"{model['owner']}/{model['name']}", "replicate_version": model.get("latest_version", {}).get("replicate_version"), "hardware": model.get("latest_version", {}).get("hardware"), } if model.get("latest_version") else None), } for model in data.get("results", []) ], "next_cursor": data.get("next"), "total_count": data.get("total"), } except httpx.HTTPError as err: logger.error(f"HTTP error during model search: {str(err)}") raise Exception(f"Failed to search models: {str(err)}") from err except Exception as err: logger.error(f"Failed to search models: {str(err)}") raise Exception(f"Failed to search models: {str(err)}") from err async def list_hardware(self) -> list[dict[str, str]]: """Get list of available hardware options for running models. Returns: List of hardware options with name and SKU Raises: Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: response = await self.http_client.get("/hardware") response.raise_for_status() return [ { "name": hw["name"], "sku": hw["sku"], } for hw in response.json() ] except httpx.HTTPError as err: logger.error(f"HTTP error getting hardware options: {str(err)}") raise Exception(f"Failed to get hardware options: {str(err)}") from err except Exception as err: logger.error(f"Failed to get hardware options: {str(err)}") raise Exception(f"Failed to get hardware options: {str(err)}") from err async def list_collections(self) -> list[dict[str, Any]]: """Get list of available model collections. Returns: List of collections with their metadata Raises: Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: response = await self.http_client.get("/collections") response.raise_for_status() data = response.json() return [ { "name": collection["name"], "slug": collection["slug"], "description": collection.get("description"), } for collection in data.get("results", []) ] except httpx.HTTPError as err: logger.error(f"HTTP error listing collections: {str(err)}") raise Exception(f"Failed to list collections: {str(err)}") from err except Exception as err: logger.error(f"Failed to list collections: {str(err)}") raise Exception(f"Failed to list collections: {str(err)}") from err async def get_collection(self, collection_slug: str) -> dict[str, Any]: """Get details of a specific collection including its models. Args: collection_slug: The slug identifier of the collection Returns: Collection details including contained models Raises: ValueError: If the collection is not found Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: response = await self.http_client.get(f"/collections/{collection_slug}") response.raise_for_status() data = response.json() return { "name": data["name"], "slug": data["slug"], "description": data.get("description"), "models": [ { "id": f"{model['owner']}/{model['name']}", "owner": model["owner"], "name": model["name"], "description": model.get("description"), "visibility": model.get("visibility", "public"), "latest_version": model.get("latest_version"), } for model in data.get("models", []) ] } except httpx.HTTPStatusError as err: if err.response.status_code == 404: raise ValueError(f"Collection not found: {collection_slug}") logger.error(f"HTTP error getting collection: {str(err)}") raise Exception(f"Failed to get collection: {str(err)}") from err except Exception as err: logger.error(f"Failed to get collection: {str(err)}") raise Exception(f"Failed to get collection: {str(err)}") from err async def get_webhook_secret(self) -> str: """Get the signing secret for the default webhook endpoint. This secret is used to verify that webhook requests are coming from Replicate. Returns: The webhook signing secret Raises: Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: response = await self.http_client.get("/webhooks/default/secret") response.raise_for_status() data = response.json() return data["key"] except httpx.HTTPError as err: logger.error(f"HTTP error getting webhook secret: {str(err)}") raise Exception(f"Failed to get webhook secret: {str(err)}") from err except Exception as err: logger.error(f"Failed to get webhook secret: {str(err)}") raise Exception(f"Failed to get webhook secret: {str(err)}") from err async def cancel_prediction(self, prediction_id: str) -> dict[str, Any]: """Cancel a running prediction. Args: prediction_id: The ID of the prediction to cancel Returns: Dict containing the updated prediction status Raises: ValueError: If the prediction is not found Exception: If the cancellation fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: response = await self.http_client.post( f"/predictions/{prediction_id}/cancel", headers={ "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", } ) response.raise_for_status() return response.json() except httpx.HTTPStatusError as err: if err.response.status_code == 404: raise ValueError(f"Prediction not found: {prediction_id}") logger.error(f"Failed to cancel prediction: {str(err)}") raise Exception(f"Failed to cancel prediction: {str(err)}") from err except Exception as err: logger.error(f"Failed to cancel prediction: {str(err)}") raise Exception(f"Failed to cancel prediction: {str(err)}") from err async def list_predictions( self, status: str | None = None, limit: int = 10 ) -> list[dict[str, Any]]: """List recent predictions with optional filtering. Args: status: Optional status to filter by (starting|processing|succeeded|failed|canceled) limit: Maximum number of predictions to return (1-100) Returns: List of prediction objects Raises: ValueError: If limit is out of range Exception: If the API request fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") if not 1 <= limit <= 100: raise ValueError("limit must be between 1 and 100") try: params = {"limit": limit} if status: params["status"] = status response = await self.http_client.get( "/predictions", params=params, headers={ "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", } ) response.raise_for_status() return response.json() except Exception as err: logger.error(f"Failed to list predictions: {str(err)}") raise Exception(f"Failed to list predictions: {str(err)}") from err async def create_prediction( self, version: str, input: Dict[str, Any], webhook: Optional[str] = None, ) -> Dict[str, Any]: """Create a new prediction using a model version. Args: version: Model version ID input: Model input parameters webhook: Optional webhook URL for prediction updates Returns: Dict containing prediction details Raises: Exception: If the prediction creation fails """ if not self.client: raise RuntimeError("Client not initialized. Check error property for details.") try: await self._ensure_http_client() # Prepare request body body = { "version": version, "input": input, } if webhook: body["webhook"] = webhook # Create prediction using rate-limited request response = await self._make_request( "POST", "/predictions", json=body ) data = response.json() # Format response result = { "id": data["id"], "status": data["status"], "input": data["input"], "output": data.get("output"), "error": data.get("error"), "logs": data.get("logs"), "created_at": data.get("created_at"), "started_at": data.get("started_at"), "completed_at": data.get("completed_at"), "urls": data.get("urls", {}), } # Add metrics if available if "metrics" in data: result["metrics"] = data["metrics"] return result except Exception as err: logger.error(f"Failed to create prediction: {str(err)}") raise Exception(f"Failed to create prediction: {str(err)}") from err

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/gerred/mcp-server-replicate'

If you have feedback or need assistance with the MCP directory API, please join our Discord server