"""Session manager for batch image generation."""
import asyncio
import threading
import uuid
from datetime import datetime, timedelta
from typing import cast
from Imagen_MCP.models.session import GenerationSession, SessionStatus
from Imagen_MCP.models.generation import (
GenerateImageRequest,
ImageSize,
ImageQuality,
ImageStyle,
)
from Imagen_MCP.services.nexos_client import NexosClient
from Imagen_MCP.services.model_registry import get_model_registry
from Imagen_MCP.exceptions import SessionNotFoundError, SessionExpiredError
class SessionManager:
"""Manages batch generation sessions with background processing."""
def __init__(
self,
client: NexosClient | None = None,
max_concurrent_sessions: int = 10,
session_ttl_minutes: int = 60,
) -> None:
"""Initialize the session manager.
Args:
client: NexosClient instance for API calls.
max_concurrent_sessions: Maximum number of concurrent sessions.
session_ttl_minutes: Session time-to-live in minutes.
"""
self._client = client
self._sessions: dict[str, GenerationSession] = {}
self._max_concurrent_sessions = max_concurrent_sessions
self._session_ttl = timedelta(minutes=session_ttl_minutes)
self._cleanup_task: asyncio.Task | None = None
@property
def client(self) -> NexosClient:
"""Get or create the NexosClient."""
if self._client is None:
self._client = NexosClient.from_env()
return self._client
def _generate_session_id(self) -> str:
"""Generate a unique session ID."""
return str(uuid.uuid4())
def create_session(
self,
prompt: str,
model: str,
count: int,
size: str = "1024x1024",
quality: str = "standard",
style: str = "vivid",
) -> GenerationSession:
"""Create a new generation session.
Args:
prompt: The image generation prompt.
model: Model to use for generation.
count: Number of images to generate.
size: Image size.
quality: Image quality.
style: Image style.
Returns:
The created GenerationSession.
Raises:
ValueError: If max concurrent sessions reached.
"""
# Check concurrent session limit
active_sessions = sum(1 for s in self._sessions.values() if not s.is_complete)
if active_sessions >= self._max_concurrent_sessions:
raise ValueError(
f"Maximum concurrent sessions ({self._max_concurrent_sessions}) reached"
)
session_id = self._generate_session_id()
session = GenerationSession(
id=session_id,
prompt=prompt,
model=model,
requested_count=count,
pending_count=count,
size=size,
quality=quality,
style=style,
status=SessionStatus.CREATED,
)
self._sessions[session_id] = session
return session
def get_session(self, session_id: str) -> GenerationSession:
"""Get a session by ID.
Args:
session_id: The session identifier.
Returns:
The GenerationSession.
Raises:
SessionNotFoundError: If session doesn't exist.
SessionExpiredError: If session has expired.
"""
session = self._sessions.get(session_id)
if session is None:
raise SessionNotFoundError(f"Session not found: {session_id}")
# Check expiration
if datetime.now() - session.created_at > self._session_ttl:
self._cleanup_session(session_id)
raise SessionExpiredError(f"Session expired: {session_id}")
return session
def _cleanup_session(self, session_id: str) -> None:
"""Clean up a session and cancel its task."""
session = self._sessions.get(session_id)
if session:
if session._generation_task and not session._generation_task.done():
session._generation_task.cancel()
del self._sessions[session_id]
async def start_generation(self, session_id: str) -> None:
"""Start background generation for a session.
Args:
session_id: The session identifier.
Raises:
SessionNotFoundError: If session doesn't exist.
"""
session = self.get_session(session_id)
session.status = SessionStatus.GENERATING
# Start background task
session._generation_task = asyncio.create_task(
self._background_generation(session)
)
async def _background_generation(self, session: GenerationSession) -> None:
"""Generate images in background.
Args:
session: The generation session.
"""
# Get the API model ID for the request
registry = get_model_registry()
api_model_id = registry.get_api_id(session.model)
if not api_model_id:
session.add_error(0, f"Could not find API ID for model: {session.model}")
session.status = SessionStatus.FAILED
return
for i in range(session.requested_count):
try:
# Create request with API model ID
request = GenerateImageRequest(
prompt=session.prompt,
model=api_model_id,
n=1,
size=cast(ImageSize, session.size),
quality=cast(ImageQuality, session.quality),
style=cast(ImageStyle, session.style),
)
# Generate image
response = await self.client.generate_image(request)
if response.images and len(response.images) > 0:
image = response.images[0]
session.add_image(image.model_dump())
else:
session.add_error(i, "No image returned from API")
except asyncio.CancelledError:
# Task was cancelled, stop generation
session.add_error(i, "Generation cancelled")
break
except Exception as e:
session.add_error(i, str(e), type(e).__name__)
# Mark session as complete if not already
if session.pending_count == 0:
if len(session.completed_images) == 0:
session.status = SessionStatus.FAILED
else:
session.status = SessionStatus.COMPLETED
async def get_next_image(
self, session_id: str, timeout: float = 60.0
) -> dict | None:
"""Get the next available image from a session.
Args:
session_id: The session identifier.
timeout: Maximum time to wait for an image.
Returns:
Image data dictionary, or None if no more images.
Raises:
SessionNotFoundError: If session doesn't exist.
"""
session = self.get_session(session_id)
# Check if there's an available image
if session.has_available_images:
return session.get_next_image()
# If session is complete and no more images, return None
if session.is_complete:
return None
# Wait for an image to become available
return await session.wait_for_image(timeout)
def get_session_status(self, session_id: str) -> dict:
"""Get the status of a session.
Args:
session_id: The session identifier.
Returns:
Status dictionary.
Raises:
SessionNotFoundError: If session doesn't exist.
"""
session = self.get_session(session_id)
return session.to_status_dict()
def list_sessions(self) -> list[dict]:
"""List all active sessions.
Returns:
List of session status dictionaries.
"""
return [s.to_status_dict() for s in self._sessions.values()]
async def cleanup_expired_sessions(self) -> int:
"""Clean up expired sessions.
Returns:
Number of sessions cleaned up.
"""
now = datetime.now()
expired = [
session_id
for session_id, session in self._sessions.items()
if now - session.created_at > self._session_ttl
]
for session_id in expired:
self._cleanup_session(session_id)
return len(expired)
async def cancel_session(self, session_id: str) -> None:
"""Cancel a session and its background task.
Args:
session_id: The session identifier.
Raises:
SessionNotFoundError: If session doesn't exist.
"""
session = self.get_session(session_id)
if session._generation_task and not session._generation_task.done():
session._generation_task.cancel()
try:
await session._generation_task
except asyncio.CancelledError:
pass
session.status = SessionStatus.FAILED
def get_session_count(self) -> int:
"""Get the number of active sessions.
Returns:
Number of sessions.
"""
return len(self._sessions)
def get_active_session_count(self) -> int:
"""Get the number of non-complete sessions.
Returns:
Number of active sessions.
"""
return sum(1 for s in self._sessions.values() if not s.is_complete)
# Global session manager instance with thread-safe initialization
_session_manager: SessionManager | None = None
_session_manager_lock = threading.Lock()
def get_session_manager() -> SessionManager:
"""Get the global session manager instance.
This function is thread-safe and uses double-checked locking
to ensure only one instance is created even in multi-threaded
environments.
Returns:
The global SessionManager instance.
"""
global _session_manager
if _session_manager is None:
with _session_manager_lock:
# Double-check after acquiring lock
if _session_manager is None:
_session_manager = SessionManager()
return _session_manager