#!/usr/bin/env python3
"""
Civitai MCP Server - Python FastMCP implementation
Browse AI models, images, creators, and more from Civitai
"""
import os
import json
from typing import Optional, List, Dict, Any, Union
from datetime import datetime
from urllib.parse import urlencode
import httpx
from pydantic import BaseModel, Field
from fastmcp import FastMCP
import logging
# Initialize logger
logger = logging.getLogger(__name__)
# Initialize FastMCP server
mcp = FastMCP("Civitai MCP Server", host="0.0.0.0", stateless_http=True)
class CivitaiClient:
"""Client for interacting with the Civitai API"""
def __init__(self, api_key: Optional[str] = None):
self.base_url = "https://civitai.com/api/v1"
self.api_key = api_key or os.getenv("CIVITAI_API_KEY")
self.client = httpx.Client(timeout=30.0)
def _build_url(self, endpoint: str, params: Dict[str, Any] = None) -> str:
"""Build URL with parameters"""
url = f"{self.base_url}{endpoint}"
if params:
# Filter out None values
filtered_params = {k: v for k, v in params.items() if v is not None}
if self.api_key:
filtered_params["token"] = self.api_key
if filtered_params:
# Use urlencode for proper URL encoding
url = f"{url}?{urlencode(filtered_params, doseq=True)}"
elif self.api_key:
url = f"{url}?{urlencode({'token': self.api_key})}"
return url
def _make_request(self, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Make HTTP request to Civitai API"""
url = self._build_url(endpoint, params)
try:
response = self.client.get(url)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
raise Exception(f"API request failed: {str(e)}")
except Exception as e:
raise Exception(f"Request error: {str(e)}")
def get_models(self, **params) -> Dict[str, Any]:
"""Get models with optional filters"""
return self._make_request("/models", params)
def get_model(self, model_id: int) -> Dict[str, Any]:
"""Get specific model by ID"""
return self._make_request(f"/models/{model_id}")
def get_model_version(self, version_id: int) -> Dict[str, Any]:
"""Get specific model version"""
return self._make_request(f"/model-versions/{version_id}")
def get_model_version_by_hash(self, hash_value: str) -> Dict[str, Any]:
"""Get model version by file hash"""
return self._make_request(f"/model-versions/by-hash/{hash_value}")
def get_images(self, **params) -> Dict[str, Any]:
"""Get images with optional filters"""
return self._make_request("/images", params)
def get_creators(self, **params) -> Dict[str, Any]:
"""Get creators with optional filters"""
return self._make_request("/creators", params)
def get_tags(self, **params) -> Dict[str, Any]:
"""Get tags with optional filters"""
return self._make_request("/tags", params)
def get_download_url(self, version_id: int) -> str:
"""Get download URL for model version"""
return self._build_url(f"/download/models/{version_id}")
def search_loras_latest_version(self, query: Optional[str] = None, base_models: Optional[List[str]] = None) -> Dict[str, Any]:
params = {
"query": query,
"limit": 10,
"types": ["LORA"],
"nsfw": True,
"baseModels": base_models
}
logger.debug(f"Searching LoRAs with params: {params}")
models = self.get_models(**params)
logger.debug(f"Found {len(models.get('items', []))} models")
response = []
for model in models["items"]:
model_versions = model["modelVersions"]
if not model_versions:
continue
# Select the model version with the largest ID
target_version = max(model_versions, key=lambda x: x['id'])
logger.debug(f"Processing model '{model['name']}' (ID: {model['id']})")
logger.debug(f"Selected version ID: {target_version['id']}")
model_version_details = self.get_model_version(target_version["id"])
# extract prompts from images
images = []
for image in model_version_details["images"]:
images.append({
"url": image.get("url", ""),
"meta": {
"prompt": (image.get("meta", {}) or {}).get("prompt", "")
}
})
logger.debug(f"Found {len(images)} images for version {target_version['id']}")
response.append({
"name": model["name"],
"model_id": model["id"],
"model_version_id": target_version["id"],
"download_url": model_version_details["downloadUrl"],
"trainedWords": model_version_details["trainedWords"],
"description": model_version_details["description"],
"images": images
})
result = {"models": response}
logger.debug(f"Returning response: {json.dumps(result, indent=2)}")
return result
# Initialize client
client = CivitaiClient()
@mcp.tool()
def search_models(
query: Optional[str] = None,
limit: Optional[int] = Field(default=20, ge=1, le=100),
page: Optional[int] = Field(default=None, ge=1),
types: Optional[List[str]] = Field(default=None, pattern="^(Checkpoint|TextualInversion|Hypernetwork|AestheticGradient|LORA|Controlnet|Poses)$"),
sort: Optional[str] = Field(default=None, pattern="^(Highest Rated|Most Downloaded|Newest)$"),
period: Optional[str] = Field(default=None, pattern="^(AllTime|Year|Month|Week|Day)$"),
nsfw: Optional[bool] = None,
base_models: Optional[List[str]] = None
) -> Dict:
"""Search for AI models on Civitai with various filters"""
params = {
"query": query,
"limit": limit,
"page": page,
"types": types,
"sort": sort,
"period": period,
"nsfw": nsfw,
"baseModels": base_models
}
return client.get_models(**params)
@mcp.tool()
def get_model(model_id: int) -> Dict:
"""Get detailed information about a specific model by ID"""
return client.get_model(model_id)
@mcp.tool()
def get_model_version(model_version_id: int) -> Dict:
"""Get detailed information about a specific model version"""
return client.get_model_version(model_version_id)
@mcp.tool()
def get_model_version_by_hash(hash: str) -> Dict:
"""Get model version information by file hash"""
return client.get_model_version_by_hash(hash)
@mcp.tool()
def get_images(
limit: Optional[int] = Field(default=100, ge=1, le=200),
page: Optional[int] = Field(default=None, ge=1),
model_id: Optional[int] = None,
model_version_id: Optional[int] = None,
post_id: Optional[int] = None,
username: Optional[str] = None,
nsfw: Optional[str] = Field(default=None, pattern="^(None|Soft|Mature|X)$"),
sort: Optional[str] = Field(default=None, pattern="^(Most Reactions|Most Comments|Newest)$"),
period: Optional[str] = Field(default=None, pattern="^(AllTime|Year|Month|Week|Day)$")
) -> Dict:
"""Get AI-generated images from Civitai"""
params = {
"limit": limit,
"page": page,
"modelId": model_id,
"modelVersionId": model_version_id,
"postId": post_id,
"username": username,
"nsfw": nsfw,
"sort": sort,
"period": period
}
return client.get_images(**params)
@mcp.tool()
def get_creators(
limit: Optional[int] = Field(default=20, ge=0, le=200),
page: Optional[int] = Field(default=None, ge=1),
query: Optional[str] = None
) -> Dict:
"""Browse and search for model creators on Civitai"""
params = {
"limit": limit,
"page": page,
"query": query
}
return client.get_creators(**params)
@mcp.tool()
def get_tags(
limit: Optional[int] = Field(default=20, ge=1, le=200),
page: Optional[int] = Field(default=None, ge=1),
query: Optional[str] = None
) -> Dict:
"""Browse and search for model tags on Civitai"""
params = {
"limit": limit,
"page": page,
"query": query
}
return client.get_tags(**params)
@mcp.tool()
def get_popular_models(
period: Optional[str] = Field(default="Week", pattern="^(AllTime|Year|Month|Week|Day)$"),
limit: Optional[int] = Field(default=20, ge=1, le=100)
) -> Dict:
"""Get the most popular/downloaded models"""
params = {
"sort": "Most Downloaded",
"period": period,
"limit": limit,
"nsfw": False
}
return client.get_models(**params)
@mcp.tool()
def get_latest_models(
limit: Optional[int] = Field(default=20, ge=1, le=100)
) -> Dict:
"""Get the newest models uploaded to Civitai"""
params = {
"sort": "Newest",
"limit": limit,
"nsfw": False
}
return client.get_models(**params)
@mcp.tool()
def get_top_rated_models(
period: Optional[str] = Field(default="AllTime", pattern="^(AllTime|Year|Month|Week|Day)$"),
limit: Optional[int] = Field(default=20, ge=1, le=100)
) -> Dict:
"""Get the highest rated models"""
params = {
"sort": "Highest Rated",
"period": period,
"limit": limit,
"nsfw": False
}
return client.get_models(**params)
@mcp.tool()
def search_models_by_tag(
tag: str,
limit: Optional[int] = Field(default=20, ge=1, le=100),
sort: Optional[str] = Field(default=None, pattern="^(Highest Rated|Most Downloaded|Newest)$")
) -> Dict:
"""Search for models by a specific tag"""
params = {
"tag": tag,
"limit": limit,
"sort": sort
}
return client.get_models(**params)
@mcp.tool()
def search_models_by_creator(
username: str,
limit: Optional[int] = Field(default=20, ge=1, le=100),
sort: Optional[str] = Field(default=None, pattern="^(Highest Rated|Most Downloaded|Newest)$")
) -> Dict:
"""Search for models by a specific creator"""
params = {
"username": username,
"limit": limit,
"sort": sort
}
return client.get_models(**params)
@mcp.tool()
def get_models_by_type(
type: str = Field(pattern="^(Checkpoint|TextualInversion|Hypernetwork|AestheticGradient|LORA|Controlnet|Poses)$"),
limit: Optional[int] = Field(default=20, ge=1, le=100),
sort: Optional[str] = Field(default=None, pattern="^(Highest Rated|Most Downloaded|Newest)$")
) -> Dict:
"""Get models filtered by type (Checkpoint, LORA, etc.)"""
params = {
"types": [type],
"limit": limit,
"sort": sort
}
return client.get_models(**params)
@mcp.tool()
def get_download_url(model_version_id: int) -> Dict:
"""Get the download URL for a specific model version"""
return client.get_download_url(model_version_id)
@mcp.tool()
def search_loras_latest_version(
query: Optional[str] = None,
base_models: Optional[List[str]] = None
) -> Dict:
"""Search for LoRAs and the prompts for generation"""
return client.search_loras_latest_version(query, base_models)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Civitai MCP Server')
parser.add_argument('--debug', action='store_true', help='Enable debug logging')
args = parser.parse_args()
# Configure root logger to INFO to suppress verbose logs from libraries (httpx, etc.)
logging.basicConfig(level=logging.INFO)
if args.debug:
# Only enable debug logging for this application
logger.setLevel(logging.DEBUG)
logger.debug("Debug logging enabled")
mcp.run(transport="streamable-http")