"""Model registry for image generation models."""
import threading
from pydantic import BaseModel, Field
class RateLimitInfo(BaseModel):
"""Rate limit information for a model."""
messages_per_period: int = Field(
description="Number of messages allowed per period"
)
period_hours: int = Field(description="Period duration in hours")
category: str = Field(description="Rate limit category (e.g., 'Category 3')")
class ModelCapabilities(BaseModel):
"""Capabilities of an image generation model."""
max_images_per_request: int = Field(
default=1, description="Maximum images per request"
)
supported_sizes: list[str] = Field(
default_factory=list, description="Supported image sizes"
)
supports_hd_quality: bool = Field(
default=False, description="Whether HD quality is supported"
)
supports_style_parameter: bool = Field(
default=False, description="Whether style parameter is supported"
)
max_prompt_length: int = Field(
default=4000, description="Maximum prompt length in characters"
)
class ModelInfo(BaseModel):
"""Complete information about an image generation model."""
id: str = Field(description="Unique model identifier (internal)")
api_id: str = Field(description="API model identifier for Nexos.ai requests")
name: str = Field(description="Human-readable model name")
provider: str = Field(
description="Model provider (e.g., 'Google', 'Black Forest Labs')"
)
description: str = Field(description="Model description")
capabilities: ModelCapabilities = Field(description="Model capabilities")
rate_limit: RateLimitInfo = Field(description="Rate limit information")
use_cases: list[str] = Field(
default_factory=list, description="Recommended use cases"
)
strengths: list[str] = Field(default_factory=list, description="Model strengths")
weaknesses: list[str] = Field(default_factory=list, description="Model weaknesses")
# Pre-defined model definitions based on Nexos.ai documentation
# All models are in Category 3: 100 messages per 3 hours
# API IDs are the actual model identifiers used in Nexos.ai API requests
_MODEL_DEFINITIONS: list[dict] = [
{
"id": "imagen-4",
"api_id": "Imagen 4 (Public)",
"name": "Imagen 4",
"provider": "Google",
"description": "Google's flagship image generation model with excellent prompt following and photorealistic output.",
"capabilities": {
"max_images_per_request": 4,
"supported_sizes": [
"256x256",
"512x512",
"1024x1024",
"1792x1024",
"1024x1792",
],
"supports_hd_quality": True,
"supports_style_parameter": True,
"max_prompt_length": 4000,
},
"rate_limit": {
"messages_per_period": 100,
"period_hours": 3,
"category": "Category 3",
},
"use_cases": [
"Photorealistic image generation",
"Marketing and advertising visuals",
"Product visualization",
"Concept art and illustrations",
],
"strengths": [
"Excellent prompt adherence",
"High-quality photorealistic output",
"Good text rendering in images",
"Consistent style across generations",
],
"weaknesses": [
"Slower generation time compared to fast models",
"May struggle with very complex multi-object scenes",
],
},
{
"id": "imagen-4-fast",
"api_id": "Imagen 4 Fast (Public)",
"name": "Imagen 4 Fast",
"provider": "Google",
"description": "Faster variant of Imagen 4, optimized for speed while maintaining good quality.",
"capabilities": {
"max_images_per_request": 4,
"supported_sizes": ["256x256", "512x512", "1024x1024"],
"supports_hd_quality": False,
"supports_style_parameter": True,
"max_prompt_length": 4000,
},
"rate_limit": {
"messages_per_period": 100,
"period_hours": 3,
"category": "Category 3",
},
"use_cases": [
"Rapid prototyping and iteration",
"Batch image generation",
"Real-time applications",
"Draft and preview generation",
],
"strengths": [
"Fast generation speed",
"Good quality-to-speed ratio",
"Lower latency for interactive use",
],
"weaknesses": [
"Lower quality than standard Imagen 4",
"Limited size options",
"No HD quality support",
],
},
{
"id": "imagen-4-ultra",
"api_id": "Imagen 4 Ultra (Public)",
"name": "Imagen 4 Ultra",
"provider": "Google",
"description": "Highest quality Imagen model for premium image generation with maximum detail.",
"capabilities": {
"max_images_per_request": 2,
"supported_sizes": ["1024x1024", "1792x1024", "1024x1792", "2048x2048"],
"supports_hd_quality": True,
"supports_style_parameter": True,
"max_prompt_length": 4000,
},
"rate_limit": {
"messages_per_period": 100,
"period_hours": 3,
"category": "Category 3",
},
"use_cases": [
"High-resolution print materials",
"Professional photography replacement",
"Detailed artwork and illustrations",
"Premium marketing content",
],
"strengths": [
"Highest quality output",
"Excellent detail preservation",
"Best text rendering",
"Superior photorealism",
],
"weaknesses": [
"Slowest generation time",
"Fewer images per request",
"Higher resource consumption",
],
},
{
"id": "dall-e-3",
"api_id": "DALL-E 3 (Public)",
"name": "DALL-E 3",
"provider": "OpenAI",
"description": "OpenAI's latest DALL-E model with excellent prompt understanding and artistic capabilities.",
"capabilities": {
"max_images_per_request": 4,
"supported_sizes": [
"256x256",
"512x512",
"1024x1024",
"1792x1024",
"1024x1792",
],
"supports_hd_quality": True,
"supports_style_parameter": True,
"max_prompt_length": 4000,
},
"rate_limit": {
"messages_per_period": 100,
"period_hours": 3,
"category": "Category 3",
},
"use_cases": [
"Artistic and creative imagery",
"Stylized illustrations",
"Fantasy and sci-fi content",
"Abstract and conceptual art",
],
"strengths": [
"Excellent artistic style",
"Creative interpretation of prompts",
"Good at stylized content",
"Strong composition",
],
"weaknesses": [
"May be overly cautious with some prompts",
"Less photorealistic than Imagen",
],
},
{
"id": "gpt-image-1",
"api_id": "GPT-image-1 (Public)",
"name": "GPT Image 1",
"provider": "OpenAI",
"description": "OpenAI's image generation model with strong prompt understanding.",
"capabilities": {
"max_images_per_request": 4,
"supported_sizes": [
"256x256",
"512x512",
"1024x1024",
"1792x1024",
"1024x1792",
],
"supports_hd_quality": True,
"supports_style_parameter": True,
"max_prompt_length": 4000,
},
"rate_limit": {
"messages_per_period": 100,
"period_hours": 3,
"category": "Category 3",
},
"use_cases": [
"General-purpose image generation",
"Content creation",
"Educational materials",
"Social media content",
],
"strengths": [
"Strong prompt understanding",
"Versatile output styles",
"Good at following complex instructions",
"Consistent quality",
],
"weaknesses": [
"May be overly cautious with some prompts",
"Less specialized than domain-specific models",
],
},
]
class ModelRegistry:
"""Registry for image generation models and their capabilities."""
def __init__(self) -> None:
"""Initialize the model registry with pre-defined models."""
self._models: dict[str, ModelInfo] = {}
self._load_model_definitions()
def _load_model_definitions(self) -> None:
"""Load model definitions from the pre-defined list."""
for model_data in _MODEL_DEFINITIONS:
model_info = ModelInfo(
id=model_data["id"],
api_id=model_data["api_id"],
name=model_data["name"],
provider=model_data["provider"],
description=model_data["description"],
capabilities=ModelCapabilities(**model_data["capabilities"]),
rate_limit=RateLimitInfo(**model_data["rate_limit"]),
use_cases=model_data["use_cases"],
strengths=model_data["strengths"],
weaknesses=model_data["weaknesses"],
)
self._models[model_info.id] = model_info
def get_model(self, model_id: str) -> ModelInfo | None:
"""Get model information by ID.
Args:
model_id: The model identifier (internal ID).
Returns:
ModelInfo if found, None otherwise.
"""
return self._models.get(model_id)
def get_api_id(self, model_id: str) -> str | None:
"""Get the API model ID for a given internal model ID.
Args:
model_id: The internal model identifier.
Returns:
The API model ID if found, None otherwise.
"""
model = self._models.get(model_id)
return model.api_id if model else None
def get_model_by_api_id(self, api_id: str) -> ModelInfo | None:
"""Get model information by API ID.
Args:
api_id: The API model identifier.
Returns:
ModelInfo if found, None otherwise.
"""
for model in self._models.values():
if model.api_id == api_id:
return model
return None
def get_all_models(self) -> list[ModelInfo]:
"""Get all available models.
Returns:
List of all model information.
"""
return list(self._models.values())
def filter_by_capability(
self,
supports_hd: bool | None = None,
supports_style: bool | None = None,
min_images_per_request: int | None = None,
supported_size: str | None = None,
) -> list[ModelInfo]:
"""Filter models by capabilities.
Args:
supports_hd: Filter by HD quality support.
supports_style: Filter by style parameter support.
min_images_per_request: Minimum images per request.
supported_size: Required supported size.
Returns:
List of models matching the criteria.
"""
result = []
for model in self._models.values():
caps = model.capabilities
if supports_hd is not None and caps.supports_hd_quality != supports_hd:
continue
if (
supports_style is not None
and caps.supports_style_parameter != supports_style
):
continue
if (
min_images_per_request is not None
and caps.max_images_per_request < min_images_per_request
):
continue
if (
supported_size is not None
and supported_size not in caps.supported_sizes
):
continue
result.append(model)
return result
def is_valid_model(self, model_id: str) -> bool:
"""Check if a model ID is valid.
Args:
model_id: The model identifier to check.
Returns:
True if the model exists, False otherwise.
"""
return model_id in self._models
def get_default_model(self) -> str:
"""Get the default model ID.
Returns:
The default model ID.
"""
return "imagen-4"
def to_catalog_dict(self) -> dict:
"""Convert the registry to a catalog dictionary for the MCP resource.
Returns:
Dictionary representation of all models.
"""
return {
"models": [model.model_dump() for model in self._models.values()],
"default_model": self.get_default_model(),
"total_count": len(self._models),
}
# Global registry instance with thread-safe initialization
_registry: ModelRegistry | None = None
_registry_lock = threading.Lock()
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance.
This function is thread-safe and uses double-checked locking
to ensure only one instance is created even in multi-threaded
environments.
Returns:
The global ModelRegistry instance.
"""
global _registry
if _registry is None:
with _registry_lock:
# Double-check after acquiring lock
if _registry is None:
_registry = ModelRegistry()
return _registry