"""Utility functions for Midjourney MCP server."""
import base64
import re
from typing import List, Optional
from exceptions import ValidationError
def validate_base64_image(base64_str: str) -> bool:
"""Validate base64 image string.
Args:
base64_str: Base64 encoded image string
Returns:
True if valid, False otherwise
"""
try:
# Check if it has data URL prefix
if base64_str.startswith('data:image/'):
# Extract base64 part after comma
if ',' in base64_str:
base64_str = base64_str.split(',', 1)[1]
# Try to decode
base64.b64decode(base64_str, validate=True)
return True
except Exception:
return False
def format_base64_image(base64_str: str, image_type: str = "png") -> str:
"""Format base64 string with proper data URL prefix.
Args:
base64_str: Base64 encoded image string
image_type: Image type (png, jpg, jpeg, webp)
Returns:
Properly formatted base64 data URL
"""
# Remove existing data URL prefix if present
if base64_str.startswith('data:image/'):
return base64_str
# Add data URL prefix
return f"data:image/{image_type};base64,{base64_str}"
def validate_aspect_ratio(aspect_ratio: str) -> bool:
"""Validate aspect ratio format.
Args:
aspect_ratio: Aspect ratio string (e.g., "16:9", "1:1")
Returns:
True if valid, False otherwise
"""
pattern = r'^\d+:\d+$'
return bool(re.match(pattern, aspect_ratio))
def validate_prompt(prompt: str) -> str:
"""Validate and clean prompt text.
Args:
prompt: Input prompt
Returns:
Cleaned prompt
Raises:
ValidationError: If prompt is invalid
"""
if not prompt or not prompt.strip():
raise ValidationError("Prompt cannot be empty")
prompt = prompt.strip()
# Check length (Midjourney has limits)
if len(prompt) > 4000:
raise ValidationError("Prompt is too long (max 4000 characters)")
return prompt
def validate_task_id(task_id: str) -> str:
"""Validate task ID format.
Args:
task_id: Task ID string
Returns:
Validated task ID
Raises:
ValidationError: If task ID is invalid
"""
if not task_id or not task_id.strip():
raise ValidationError("Task ID cannot be empty")
task_id = task_id.strip()
# Basic format validation (adjust based on GPTNB format)
if not task_id.isdigit() and len(task_id) < 10:
raise ValidationError("Invalid task ID format")
return task_id
def validate_image_index(index: int) -> int:
"""Validate image index for variations/upscales.
Args:
index: Image index (1-4)
Returns:
Validated index
Raises:
ValidationError: If index is invalid
"""
if not isinstance(index, int) or index < 1 or index > 4:
raise ValidationError("Image index must be between 1 and 4")
return index
def validate_base64_images(base64_images: List[str], min_count: int = 1, max_count: int = 5) -> List[str]:
"""Validate list of base64 images.
Args:
base64_images: List of base64 image strings
min_count: Minimum number of images required
max_count: Maximum number of images allowed
Returns:
Validated list of base64 images
Raises:
ValidationError: If validation fails
"""
if not base64_images:
if min_count > 0:
raise ValidationError(f"At least {min_count} image(s) required")
return []
if len(base64_images) < min_count:
raise ValidationError(f"At least {min_count} image(s) required")
if len(base64_images) > max_count:
raise ValidationError(f"Maximum {max_count} image(s) allowed")
# Validate each image
validated_images = []
for i, img in enumerate(base64_images):
if not validate_base64_image(img):
raise ValidationError(f"Invalid base64 image at index {i}")
validated_images.append(format_base64_image(img))
return validated_images
def extract_task_id_from_response(response_text: str) -> Optional[str]:
"""Extract task ID from response text.
Args:
response_text: Response text that may contain task ID
Returns:
Extracted task ID or None
"""
# Look for patterns like "Task ID: 1234567890" or similar
patterns = [
r'[Tt]ask\s+ID[:\s]+(\d+)',
r'ID[:\s]+(\d+)',
r'(\d{10,})', # Long numeric IDs
]
for pattern in patterns:
match = re.search(pattern, response_text)
if match:
return match.group(1)
return None
def format_error_message(error: Exception, context: str = "") -> str:
"""Format error message for user display.
Args:
error: Exception object
context: Additional context information
Returns:
Formatted error message
"""
error_type = type(error).__name__
error_msg = str(error)
if context:
return f"Error in {context}: {error_type} - {error_msg}"
else:
return f"{error_type}: {error_msg}"
def truncate_text(text: str, max_length: int = 100) -> str:
"""Truncate text to specified length.
Args:
text: Input text
max_length: Maximum length
Returns:
Truncated text
"""
if len(text) <= max_length:
return text
return text[:max_length - 3] + "..."