#!/usr/bin/env python3
"""
Test script for Civitai MCP Server
"""
import sys
import os
import pytest
from unittest.mock import MagicMock, patch
from urllib.parse import urlparse, parse_qs
# Add the current directory to Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Import the client directly to test functionality
from civitai_mcp_server import CivitaiClient, search_loras_latest_version
class TestCivitaiClient:
"""Pytest test cases for CivitaiClient"""
def setup_method(self):
"""Setup test client"""
self.client = CivitaiClient()
# Mock the httpx client
self.client.client = MagicMock()
def test_url_encoding_japanese_query(self):
"""Test URL encoding with Japanese characters - reproduces the original error"""
# This is the exact query that caused the 400 Bad Request error
japanese_query = "ああああ(いいいい) LORA"
params = {
"query": japanese_query,
"limit": 20,
"types": ["Lora"]
}
# Build URL using the client's method
url = self.client._build_url("/models", params)
# Parse the URL to check encoding
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
# Verify the query parameter is properly encoded
assert "query" in query_params
decoded_query = query_params["query"][0]
assert decoded_query == japanese_query
# Verify the URL doesn't contain raw Japanese characters
assert "ああああ" not in url
assert "%E3%83%A6%E3%83%BC%E3%83%99%E3%83%AB" in url or url.count("%") > 0
print(f"Original query: {japanese_query}")
print(f"Encoded URL: {url}")
print(f"Decoded query param: {decoded_query}")
def test_url_encoding_special_characters(self):
"""Test URL encoding with special characters like parentheses and spaces"""
query_with_special_chars = "test (model) & more"
params = {
"query": query_with_special_chars,
"limit": 10
}
url = self.client._build_url("/models", params)
# Parse the URL
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
# Verify proper encoding
assert "query" in query_params
decoded_query = query_params["query"][0]
assert decoded_query == query_with_special_chars
# Verify special characters are encoded
assert "(" not in url or "%28" in url
assert ")" not in url or "%29" in url
assert " " not in parsed_url.query # Spaces should be encoded
print(f"Original query: {query_with_special_chars}")
print(f"Encoded URL: {url}")
def test_url_encoding_list_parameters(self):
"""Test URL encoding with list parameters"""
params = {
"types": ["LORA", "Checkpoint"],
"baseModels": ["SD 1.5", "SDXL 1.0"]
}
url = self.client._build_url("/models", params)
# Parse the URL
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
# Verify list parameters are properly handled
assert "types" in query_params
assert len(query_params["types"]) == 2
assert "LORA" in query_params["types"]
assert "Checkpoint" in query_params["types"]
assert "baseModels" in query_params
assert "SD 1.5" in query_params["baseModels"]
assert "SDXL 1.0" in query_params["baseModels"]
print(f"List params URL: {url}")
def test_api_key_encoding(self):
"""Test that API key is properly encoded in URL"""
# Create client with test API key
test_client = CivitaiClient(api_key="test_key_123")
params = {"query": "test"}
url = test_client._build_url("/models", params)
# Parse the URL
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
# Verify API key is included
assert "token" in query_params
assert query_params["token"][0] == "test_key_123"
print(f"URL with API key: {url}")
def test_japanese_query_integration(self):
"""Mocked test: actual API call with Japanese query simulation"""
japanese_query = "アニメ" # "anime" in Japanese
params = {
"query": japanese_query,
"limit": 1,
"nsfw": False
}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"id": 1, "name": "Anime Model"}],
"metadata": {"totalItems": 1}
}
mock_response.raise_for_status.return_value = None
self.client.client.get.return_value = mock_response
try:
# This should not raise a 400 Bad Request error with proper encoding
response = self.client.get_models(**params)
assert "items" in response
print(f"Japanese query integration test passed: {len(response.get('items', []))} results")
# Verify the mock was called with encoded URL
args, _ = self.client.client.get.call_args
url = args[0]
assert "%E3%82%A2%E3%83%8B%E3%83%A1" in url or "query" in url # Verify encoding or presence
except Exception as e:
pytest.fail(f"Test failed: {e}")
def test_search_models(self):
"""Test search_models"""
print("1. Testing search_models...")
params = {
"query": "anime",
"limit": 3,
"sort": "Most Downloaded"
}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"id": 1, "name": "Model 1"}, {"id": 2, "name": "Model 2"}, {"id": 3, "name": "Model 3"}],
"metadata": {"totalItems": 100}
}
self.client.client.get.return_value = mock_response
response = self.client.get_models(**params)
assert len(response["items"]) > 0
metadata = response.get('metadata', {})
print(f"Metadata keys: {metadata.keys()}")
print(f"Found {metadata.get('totalItems', 'Unknown')} models")
print("✓ Search models successful")
def test_get_popular_models(self):
"""Test get_popular_models"""
print("2. Testing get_popular_models...")
params = {
"sort": "Most Downloaded",
"period": "Week",
"limit": 3,
"nsfw": False
}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"id": 1, "name": "Pop Model 1"}, {"id": 2, "name": "Pop Model 2"}],
"metadata": {"totalItems": 50}
}
self.client.client.get.return_value = mock_response
response = self.client.get_models(**params)
assert len(response["items"]) > 0
print("✓ Get popular models successful")
def test_get_latest_models(self):
"""Test get_latest_models"""
print("3. Testing get_latest_models...")
params = {
"sort": "Newest",
"limit": 3,
"nsfw": False
}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"id": 10, "name": "New Model 1"}],
"metadata": {"totalItems": 10}
}
self.client.client.get.return_value = mock_response
response = self.client.get_models(**params)
assert len(response["items"]) > 0
print("✓ Get latest models successful")
def test_browse_images(self):
"""Test browse_images"""
print("4. Testing browse_images...")
params = {
"limit": 3,
"nsfw": "None",
"sort": "Newest"
}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"id": 100, "url": "http://img.com/1.jpg"}],
"metadata": {"totalItems": 100}
}
self.client.client.get.return_value = mock_response
response = self.client.get_images(**params)
items = response.get("items", [])
assert len(items) > 0
print("✓ Browse images successful")
def test_get_creators(self):
"""Test get_creators"""
print("5. Testing get_creators...")
params = {"limit": 3}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"username": "user1"}, {"username": "user2"}],
"metadata": {"totalItems": 20}
}
self.client.client.get.return_value = mock_response
response = self.client.get_creators(**params)
items = response.get("items", [])
assert len(items) > 0
print("✓ Get creators successful")
def test_get_tags(self):
"""Test get_tags"""
print("6. Testing get_tags...")
params = {"limit": 5}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"name": "tag1"}, {"name": "tag2"}],
"metadata": {"totalItems": 50}
}
self.client.client.get.return_value = mock_response
response = self.client.get_tags(**params)
items = response.get("items", [])
assert len(items) > 0
print("✓ Get tags successful")
def test_get_models_by_type(self):
"""Test get_models_by_type"""
print("7. Testing get_models_by_type...")
params = {
"types": ["LORA"],
"limit": 3,
"sort": "Most Downloaded"
}
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [{"id": 3, "name": "Lora Model", "type": "LORA"}]
}
self.client.client.get.return_value = mock_response
response = self.client.get_models(**params)
assert len(response["items"]) > 0
# Check that we actually got LORA models if possible
# Some APIs might return mixed results but we requested LORA
print("✓ Get models by type successful")
def test_search_loras_latest_version(self):
"""Test search_loras_latest_version"""
print("8. Testing search_loras_latest_version...")
params = {
"query": "anime",
"base_models": ["SD 1.5"]
}
# Prepare mock responses (sequential calls)
# 1. get_models
# 2. get_model_version (for each item found, here just 1)
mock_models_response = {
"items": [
{
"id": 101,
"name": "Lora 1",
"modelVersions": [{"id": 1001}, {"id": 1002}] # 1002 is max
}
]
}
mock_version_response = {
"id": 1002,
"downloadUrl": "http://civitai.com/download/1002",
"trainedWords": ["anime style"],
"description": "A great lora",
"images": [
{"meta": {"prompt": "masterpiece, anime girl"}},
{"meta": None} # prompt missing case
]
}
mock_response = MagicMock()
mock_response.json.side_effect = [mock_models_response, mock_version_response]
self.client.client.get.return_value = mock_response
result = self.client.search_loras_latest_version(**params)
assert result is not None
assert len(result["models"]) > 0
model = result["models"][0]
assert model["model_id"] == 101
assert model["model_version_id"] == 1002
assert model["images"][0]["meta"]["prompt"] == "masterpiece, anime girl"
print("✓ Search loras successful")
def test_search_loras_latest_version_debug_mode(self, caplog):
"""Test search_loras_latest_version with debug mode enabled"""
print("9. Testing search_loras_latest_version with debug mode...")
import logging
caplog.set_level(logging.DEBUG)
params = {
"query": "test",
"base_models": ["SDXL"]
}
# Mock responses
mock_models_response = {
"items": [
{
"id": 201,
"name": "Debug Test Model",
"modelVersions": [{"id": 2001}]
}
]
}
mock_version_response = {
"id": 2001,
"downloadUrl": "http://civitai.com/download/2001",
"trainedWords": ["debug"],
"description": "Debug test",
"images": [{"meta": {"prompt": "debug prompt"}}]
}
mock_response = MagicMock()
mock_response.json.side_effect = [mock_models_response, mock_version_response]
self.client.client.get.return_value = mock_response
self.client.search_loras_latest_version(**params)
# Check logs
assert "Searching LoRAs with params" in caplog.text
assert "Found 1 models" in caplog.text
assert "Processing model 'Debug Test Model'" in caplog.text
assert "Selected version ID: 2001" in caplog.text
print("✓ Search loras debug mode successful")
if __name__ == "__main__":
# Run pytest tests
print("\n=== Running pytest cases ===")
pytest.main([__file__ + "::TestCivitaiClient", "-v"])