"""Unit tests for the model registry."""
from Imagen_MCP.services.model_registry import (
ModelRegistry,
ModelInfo,
ModelCapabilities,
RateLimitInfo,
get_model_registry,
)
class TestModelRegistryInitialization:
"""Tests for model registry initialization."""
def test_load_model_definitions(self):
"""Registry should load all model definitions on initialization."""
registry = ModelRegistry()
models = registry.get_all_models()
assert len(models) == 5
model_ids = [m.id for m in models]
assert "imagen-4" in model_ids
assert "imagen-4-fast" in model_ids
assert "imagen-4-ultra" in model_ids
assert "dall-e-3" in model_ids
assert "gpt-image-1" in model_ids
class TestGetModel:
"""Tests for getting models by ID."""
def test_get_model_by_id(self):
"""Registry should return correct model info by ID."""
registry = ModelRegistry()
model = registry.get_model("imagen-4")
assert model is not None
assert model.id == "imagen-4"
assert model.name == "Imagen 4"
assert model.provider == "Google"
assert isinstance(model.capabilities, ModelCapabilities)
assert isinstance(model.rate_limit, RateLimitInfo)
def test_get_model_nonexistent(self):
"""Registry should return None for non-existent model."""
registry = ModelRegistry()
model = registry.get_model("nonexistent-model")
assert model is None
def test_is_valid_model(self):
"""Registry should correctly validate model IDs."""
registry = ModelRegistry()
assert registry.is_valid_model("imagen-4") is True
assert registry.is_valid_model("imagen-4-fast") is True
assert registry.is_valid_model("nonexistent") is False
class TestGetAllModels:
"""Tests for getting all models."""
def test_get_all_models(self):
"""Registry should return list of all available models."""
registry = ModelRegistry()
models = registry.get_all_models()
assert isinstance(models, list)
assert len(models) == 5
assert all(isinstance(m, ModelInfo) for m in models)
class TestFilterByCapability:
"""Tests for filtering models by capability."""
def test_filter_models_by_hd_support(self):
"""Registry should filter models by HD quality support."""
registry = ModelRegistry()
hd_models = registry.filter_by_capability(supports_hd=True)
assert len(hd_models) > 0
assert all(m.capabilities.supports_hd_quality for m in hd_models)
non_hd_models = registry.filter_by_capability(supports_hd=False)
assert len(non_hd_models) > 0
assert all(not m.capabilities.supports_hd_quality for m in non_hd_models)
def test_filter_models_by_style_support(self):
"""Registry should filter models by style parameter support."""
registry = ModelRegistry()
style_models = registry.filter_by_capability(supports_style=True)
assert len(style_models) > 0
assert all(m.capabilities.supports_style_parameter for m in style_models)
def test_filter_models_by_min_images(self):
"""Registry should filter models by minimum images per request."""
registry = ModelRegistry()
# All models should support at least 1 image
models = registry.filter_by_capability(min_images_per_request=1)
assert len(models) == 5
# Only some models support 4 images
models = registry.filter_by_capability(min_images_per_request=4)
assert len(models) >= 3
def test_filter_models_by_size(self):
"""Registry should filter models by supported size."""
registry = ModelRegistry()
models = registry.filter_by_capability(supported_size="1024x1024")
assert len(models) == 5 # All models support 1024x1024
models = registry.filter_by_capability(supported_size="2048x2048")
assert len(models) == 1 # Only ultra supports 2048x2048
assert models[0].id == "imagen-4-ultra"
def test_filter_models_multiple_criteria(self):
"""Registry should filter models by multiple criteria."""
registry = ModelRegistry()
models = registry.filter_by_capability(
supports_hd=True,
supports_style=True,
min_images_per_request=4,
)
assert len(models) >= 1
for model in models:
assert model.capabilities.supports_hd_quality
assert model.capabilities.supports_style_parameter
assert model.capabilities.max_images_per_request >= 4
class TestModelInfoValidation:
"""Tests for model info validation."""
def test_model_info_validation(self):
"""Model info should validate against schema."""
registry = ModelRegistry()
model = registry.get_model("imagen-4")
assert model is not None
# Check all required fields are present
assert model.id is not None
assert model.name is not None
assert model.provider is not None
assert model.description is not None
assert model.capabilities is not None
assert model.rate_limit is not None
# Check capabilities
caps = model.capabilities
assert isinstance(caps.max_images_per_request, int)
assert isinstance(caps.supported_sizes, list)
assert isinstance(caps.supports_hd_quality, bool)
assert isinstance(caps.supports_style_parameter, bool)
assert isinstance(caps.max_prompt_length, int)
# Check rate limit
rate = model.rate_limit
assert isinstance(rate.messages_per_period, int)
assert isinstance(rate.period_hours, int)
assert isinstance(rate.category, str)
def test_model_info_to_dict(self):
"""Model info should serialize to dictionary."""
registry = ModelRegistry()
model = registry.get_model("imagen-4")
assert model is not None
model_dict = model.model_dump()
assert isinstance(model_dict, dict)
assert model_dict["id"] == "imagen-4"
assert "capabilities" in model_dict
assert "rate_limit" in model_dict
class TestToCatalogDict:
"""Tests for catalog dictionary conversion."""
def test_to_catalog_dict(self):
"""Registry should convert to catalog dictionary."""
registry = ModelRegistry()
catalog = registry.to_catalog_dict()
assert isinstance(catalog, dict)
assert "models" in catalog
assert "default_model" in catalog
assert "total_count" in catalog
assert len(catalog["models"]) == 5
assert catalog["default_model"] == "imagen-4"
assert catalog["total_count"] == 5
class TestGetModelRegistry:
"""Tests for global registry instance."""
def test_get_model_registry_singleton(self):
"""get_model_registry should return the same instance."""
registry1 = get_model_registry()
registry2 = get_model_registry()
# Note: Due to module-level caching, these should be the same instance
assert registry1 is registry2
def test_get_model_registry_returns_valid_registry(self):
"""get_model_registry should return a valid registry."""
registry = get_model_registry()
assert isinstance(registry, ModelRegistry)
assert len(registry.get_all_models()) == 5
class TestDefaultModel:
"""Tests for default model."""
def test_get_default_model(self):
"""Registry should return the default model ID."""
registry = ModelRegistry()
default = registry.get_default_model()
assert default == "imagen-4"
assert registry.is_valid_model(default)