"""Session state models for batch generation."""
import asyncio
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field, ConfigDict
if TYPE_CHECKING:
pass
class SessionStatus(str, Enum):
"""Status of a generation session."""
CREATED = "created"
GENERATING = "generating"
PARTIAL = "partial"
COMPLETED = "completed"
FAILED = "failed"
class GenerationSession(BaseModel):
"""Represents a batch generation session."""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Session identification
id: str = Field(description="Unique session identifier")
created_at: datetime = Field(
default_factory=datetime.now, description="When the session was created"
)
# Request parameters
prompt: str = Field(description="The prompt used for generation")
model: str = Field(description="Model used for generation")
requested_count: int = Field(description="Number of images requested")
size: str = Field(default="1024x1024", description="Image size")
quality: str = Field(default="standard", description="Image quality")
style: str = Field(default="vivid", description="Image style")
# Session state
status: SessionStatus = Field(
default=SessionStatus.CREATED, description="Current session status"
)
# Image storage - using list of dicts to avoid circular imports
completed_images: list[dict] = Field(
default_factory=list, description="Completed images"
)
pending_count: int = Field(default=0, description="Number of pending images")
# Tracking
next_image_index: int = Field(
default=0, description="Index of next image to return"
)
errors: list[dict] = Field(
default_factory=list, description="Errors encountered during generation"
)
# Background task reference (not serialized)
_generation_task: asyncio.Task | None = None
_image_ready_event: asyncio.Event | None = None
_lock: asyncio.Lock | None = None
def model_post_init(self, __context) -> None:
"""Initialize async primitives after model creation."""
self._image_ready_event = asyncio.Event()
self._lock = asyncio.Lock()
@property
def completed_count(self) -> int:
"""Get the number of completed images."""
return len(self.completed_images)
@property
def is_complete(self) -> bool:
"""Check if all images have been generated."""
return self.status in (SessionStatus.COMPLETED, SessionStatus.FAILED)
@property
def has_available_images(self) -> bool:
"""Check if there are images available to retrieve."""
return self.next_image_index < len(self.completed_images)
def add_image(self, image_data: dict) -> None:
"""Add a completed image to the session."""
self.completed_images.append(image_data)
self.pending_count = max(0, self.pending_count - 1)
if self.pending_count == 0:
self.status = SessionStatus.COMPLETED
elif len(self.completed_images) > 0:
self.status = SessionStatus.PARTIAL
# Signal that an image is ready
if self._image_ready_event:
self._image_ready_event.set()
def add_error(self, index: int, error: str, error_type: str | None = None) -> None:
"""Add an error to the session."""
self.errors.append(
{
"index": index,
"error": error,
"error_type": error_type,
"timestamp": datetime.now().isoformat(),
}
)
self.pending_count = max(0, self.pending_count - 1)
if self.pending_count == 0:
if len(self.completed_images) == 0:
self.status = SessionStatus.FAILED
else:
self.status = SessionStatus.COMPLETED
def get_next_image(self) -> dict | None:
"""Get the next available image, if any."""
if self.next_image_index < len(self.completed_images):
image = self.completed_images[self.next_image_index]
self.next_image_index += 1
return image
return None
async def wait_for_image(self, timeout: float = 60.0) -> dict | None:
"""Wait for the next image to become available."""
if self.has_available_images:
return self.get_next_image()
if self.is_complete:
return None
# Wait for an image to be ready
if self._image_ready_event:
try:
await asyncio.wait_for(self._image_ready_event.wait(), timeout=timeout)
self._image_ready_event.clear()
return self.get_next_image()
except asyncio.TimeoutError:
return None
return None
def to_status_dict(self) -> dict:
"""Convert session to status dictionary."""
return {
"session_id": self.id,
"status": self.status.value,
"completed_count": self.completed_count,
"pending_count": self.pending_count,
"total_count": self.requested_count,
"errors": self.errors,
}