#!/usr/bin/env python3
"""Test the list_models and get_model_details tools."""
from __future__ import annotations
import json
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from Imagen_MCP.services.model_registry import get_model_registry # noqa: E402
def test_list_models():
"""Test listing all models."""
print("=" * 60)
print("TEST: list_models tool")
print("=" * 60)
registry = get_model_registry()
models = []
for model in registry.get_all_models():
models.append(
{
"id": model.id,
"name": model.name,
"provider": model.provider,
"description": model.description,
"use_cases": model.use_cases,
"strengths": model.strengths,
"weaknesses": model.weaknesses,
"supported_sizes": model.capabilities.supported_sizes,
"max_images_per_request": model.capabilities.max_images_per_request,
"supports_hd_quality": model.capabilities.supports_hd_quality,
"rate_limit": f"{model.rate_limit.messages_per_period} messages per {model.rate_limit.period_hours} hours",
}
)
result = {
"models": models,
"total_count": len(models),
"default_model": "imagen-4",
"usage_hint": "Use the 'model' parameter in generate_image or start_image_batch to specify which model to use.",
}
print(json.dumps(result, indent=2))
print()
print(f"Total models: {result['total_count']}")
return True
def test_get_model_details():
"""Test getting details for a specific model."""
print("\n" + "=" * 60)
print("TEST: get_model_details tool")
print("=" * 60)
registry = get_model_registry()
# Test with valid model
model_id = "imagen-4-fast"
model = registry.get_model(model_id)
if model:
result = {
"id": model.id,
"name": model.name,
"api_id": model.api_id,
"provider": model.provider,
"description": model.description,
"use_cases": model.use_cases,
"strengths": model.strengths,
"weaknesses": model.weaknesses,
"capabilities": {
"supported_sizes": model.capabilities.supported_sizes,
"max_images_per_request": model.capabilities.max_images_per_request,
"supports_hd_quality": model.capabilities.supports_hd_quality,
"supports_style_parameter": model.capabilities.supports_style_parameter,
"max_prompt_length": model.capabilities.max_prompt_length,
},
"rate_limit": {
"messages_per_period": model.rate_limit.messages_per_period,
"period_hours": model.rate_limit.period_hours,
"category": model.rate_limit.category,
},
}
print(f"Details for '{model_id}':")
print(json.dumps(result, indent=2))
else:
print(f"ERROR: Model not found: {model_id}")
return False
# Test with invalid model
print("\n" + "-" * 40)
invalid_id = "invalid-model-xyz"
invalid_model = registry.get_model(invalid_id)
if invalid_model is None:
print(f"✓ Correctly returns None for invalid model: {invalid_id}")
return True
if __name__ == "__main__":
test_list_models()
test_get_model_details()
print("\n✓ All tests passed!")