"""Unified task handling for Midjourney operations - task management and scheduling."""
import asyncio
import logging
from client import GPTNBClient
from models import TaskDetail, TaskStatus, TaskResponse
from config import Config
from exceptions import TaskFailedError, TimeoutError, TaskNotFoundError
logger = logging.getLogger(__name__)
# ============================================================================
# Task Manager
# ============================================================================
class TaskManager:
"""Manages Midjourney task lifecycle."""
def __init__(self, client: GPTNBClient, config: Config):
"""Initialize task manager."""
self.client = client
self.config = config
self.poll_interval = 5 # seconds
self.max_poll_time = config.timeout
async def submit_and_wait(self, submit_func, *args, **kwargs) -> TaskDetail:
"""Submit a task and wait for completion."""
logger.info(f"Submitting task with function: {submit_func.__name__}")
response: TaskResponse = await submit_func(*args, **kwargs)
if response.code != 1:
raise TaskFailedError(f"Task submission failed: {response.description}")
if not response.result:
raise TaskFailedError("No task ID returned from submission")
task_id = response.result
logger.info(f"Task submitted successfully with ID: {task_id}")
# Wait for completion
return await self.wait_for_completion(task_id)
async def wait_for_completion(self, task_id: str) -> TaskDetail:
"""Wait for task completion by polling."""
logger.info(f"Waiting for task completion: {task_id}")
start_time = asyncio.get_event_loop().time()
while True:
try:
# Get task status
task = await self.client.get_task(task_id)
logger.debug(f"Task {task_id} status: {task.status}, progress: {task.progress}")
# Check if completed
if task.status == TaskStatus.SUCCESS:
logger.info(f"Task {task_id} completed successfully")
return task
# Check if failed
if task.status == TaskStatus.FAILURE:
error_msg = task.failReason or "Unknown error"
logger.error(f"Task {task_id} failed: {error_msg}")
raise TaskFailedError(f"Task failed: {error_msg}")
# Check timeout
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > self.max_poll_time:
logger.error(f"Task {task_id} timed out after {elapsed:.1f} seconds")
raise TimeoutError(f"Task timed out after {self.max_poll_time} seconds")
# Wait before next poll
await asyncio.sleep(self.poll_interval)
except TaskNotFoundError:
logger.error(f"Task {task_id} not found")
raise
except Exception as e:
logger.error(f"Error polling task {task_id}: {e}")
await asyncio.sleep(self.poll_interval)
async def get_task_status(self, task_id: str) -> TaskDetail:
"""Get current task status."""
# Get from API
task = await self.client.get_task(task_id)
return task
def format_task_result(self, task: TaskDetail) -> str:
"""Format task result for display."""
if task.status == TaskStatus.SUCCESS:
if task.imageUrl:
result = f"β
Task completed successfully!\n\n"
result += f"**Image URL:** {task.imageUrl}\n\n"
result += f"πΌοΈ **Generated Image:**\n"
result += f"\n\n"
result += f"π **Direct Link:** {task.imageUrl}\n\n"
result += f"**Task ID:** {task.id}"
return result
elif task.description:
return f"β
Task completed successfully!\n\n**Result:** {task.description}\n\n**Task ID:** {task.id}"
else:
return f"β
Task completed successfully!\n\n**Task ID:** {task.id}"
elif task.status == TaskStatus.FAILURE:
error_msg = task.failReason or "Unknown error"
return f"β Task failed: {error_msg}\n\n**Task ID:** {task.id}"
elif task.status == TaskStatus.IN_PROGRESS:
progress = task.progress or "Processing"
return f"π Task in progress: {progress}\n\n**Task ID:** {task.id}"
else:
return f"β³ Task status: {task.status}\n\n**Task ID:** {task.id}"