"""Unit tests for the session manager."""
import pytest
from unittest.mock import AsyncMock, MagicMock
from datetime import datetime, timedelta
from Imagen_MCP.services.session_manager import SessionManager, get_session_manager
from Imagen_MCP.models.session import SessionStatus
from Imagen_MCP.models.generation import GenerateImageResponse
from Imagen_MCP.models.image import GeneratedImage
from Imagen_MCP.exceptions import SessionNotFoundError, SessionExpiredError
class TestSessionManagerInitialization:
"""Tests for session manager initialization."""
def test_session_manager_initialization(self):
"""Session manager should initialize with default values."""
manager = SessionManager()
assert manager._max_concurrent_sessions == 10
assert manager._sessions == {}
def test_session_manager_with_custom_values(self):
"""Session manager should accept custom configuration."""
mock_client = MagicMock()
manager = SessionManager(
client=mock_client,
max_concurrent_sessions=5,
session_ttl_minutes=30,
)
assert manager._client == mock_client
assert manager._max_concurrent_sessions == 5
class TestCreateSession:
"""Tests for creating sessions."""
def test_create_session(self):
"""Session manager should create session with unique ID."""
manager = SessionManager()
session = manager.create_session(
prompt="A beautiful sunset",
model="imagen-4",
count=4,
)
assert session.id is not None
assert session.prompt == "A beautiful sunset"
assert session.model == "imagen-4"
assert session.requested_count == 4
assert session.pending_count == 4
assert session.status == SessionStatus.CREATED
def test_create_session_with_all_parameters(self):
"""Session should accept all parameters."""
manager = SessionManager()
session = manager.create_session(
prompt="A mountain landscape",
model="imagen-4-ultra",
count=3,
size="1792x1024",
quality="hd",
style="natural",
)
assert session.size == "1792x1024"
assert session.quality == "hd"
assert session.style == "natural"
def test_create_session_concurrent_limit(self):
"""Session manager should enforce maximum concurrent sessions."""
manager = SessionManager(max_concurrent_sessions=2)
# Create two sessions
manager.create_session(prompt="Test 1", model="imagen-4", count=2)
manager.create_session(prompt="Test 2", model="imagen-4", count=2)
# Third should fail
with pytest.raises(ValueError, match="Maximum concurrent sessions"):
manager.create_session(prompt="Test 3", model="imagen-4", count=2)
class TestGetSession:
"""Tests for getting sessions."""
def test_get_session_by_id(self):
"""Session manager should retrieve session by ID."""
manager = SessionManager()
created = manager.create_session(prompt="Test", model="imagen-4", count=2)
retrieved = manager.get_session(created.id)
assert retrieved.id == created.id
assert retrieved.prompt == "Test"
def test_get_nonexistent_session_raises(self):
"""Session manager should raise SessionNotFoundError for invalid ID."""
manager = SessionManager()
with pytest.raises(SessionNotFoundError):
manager.get_session("nonexistent-id")
def test_get_expired_session_raises(self):
"""Session manager should raise SessionExpiredError for expired sessions."""
manager = SessionManager(session_ttl_minutes=0) # Immediate expiration
session = manager.create_session(prompt="Test", model="imagen-4", count=2)
# Force expiration by modifying created_at
session.created_at = datetime.now() - timedelta(minutes=1)
with pytest.raises(SessionExpiredError):
manager.get_session(session.id)
class TestStartGeneration:
"""Tests for starting generation."""
@pytest.mark.asyncio
async def test_start_generation_changes_status(self):
"""Starting generation should change session status."""
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[GeneratedImage(b64_json="imagedata")],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
manager = SessionManager(client=mock_client)
session = manager.create_session(prompt="Test", model="imagen-4", count=2)
await manager.start_generation(session.id)
# Give the background task a moment to start
import asyncio
await asyncio.sleep(0.1)
assert session.status in (
SessionStatus.GENERATING,
SessionStatus.PARTIAL,
SessionStatus.COMPLETED,
)
class TestGetNextImage:
"""Tests for getting next image."""
@pytest.mark.asyncio
async def test_get_next_image_available(self):
"""Should return immediately when image is available."""
mock_client = MagicMock()
mock_response = GenerateImageResponse(
created=1234567890,
images=[GeneratedImage(b64_json="imagedata")],
)
mock_client.generate_image = AsyncMock(return_value=mock_response)
manager = SessionManager(client=mock_client)
session = manager.create_session(prompt="Test", model="imagen-4", count=2)
await manager.start_generation(session.id)
# Wait for first image
import asyncio
await asyncio.sleep(0.2)
image = await manager.get_next_image(session.id, timeout=5.0)
assert image is not None
assert image.get("b64_json") == "imagedata"
@pytest.mark.asyncio
async def test_get_next_image_session_not_found(self):
"""Should raise SessionNotFoundError for invalid session."""
manager = SessionManager()
with pytest.raises(SessionNotFoundError):
await manager.get_next_image("nonexistent", timeout=1.0)
class TestGetSessionStatus:
"""Tests for getting session status."""
def test_get_session_status(self):
"""Should return session status dictionary."""
manager = SessionManager()
session = manager.create_session(prompt="Test", model="imagen-4", count=4)
status = manager.get_session_status(session.id)
assert status["session_id"] == session.id
assert status["status"] == "created"
assert status["completed_count"] == 0
assert status["pending_count"] == 4
assert status["total_count"] == 4
class TestListSessions:
"""Tests for listing sessions."""
def test_list_sessions(self):
"""Should list all active sessions."""
manager = SessionManager()
manager.create_session(prompt="Test 1", model="imagen-4", count=2)
manager.create_session(prompt="Test 2", model="imagen-4", count=3)
sessions = manager.list_sessions()
assert len(sessions) == 2
class TestCleanupExpiredSessions:
"""Tests for session cleanup."""
@pytest.mark.asyncio
async def test_cleanup_expired_sessions(self):
"""Expired sessions should be cleaned up."""
manager = SessionManager(session_ttl_minutes=0)
session = manager.create_session(prompt="Test", model="imagen-4", count=2)
# Force expiration
session.created_at = datetime.now() - timedelta(minutes=1)
cleaned = await manager.cleanup_expired_sessions()
assert cleaned == 1
assert manager.get_session_count() == 0
class TestCancelSession:
"""Tests for cancelling sessions."""
@pytest.mark.asyncio
async def test_cancel_session(self):
"""Should cancel session and its background task."""
mock_client = MagicMock()
# Make generate_image slow so we can cancel it
async def slow_generate(*args, **kwargs):
import asyncio
await asyncio.sleep(10)
return GenerateImageResponse(
created=1234567890,
images=[GeneratedImage(b64_json="data")],
)
mock_client.generate_image = slow_generate
manager = SessionManager(client=mock_client)
session = manager.create_session(prompt="Test", model="imagen-4", count=5)
await manager.start_generation(session.id)
# Give it a moment to start
import asyncio
await asyncio.sleep(0.1)
await manager.cancel_session(session.id)
assert session.status == SessionStatus.FAILED
class TestSessionCounts:
"""Tests for session counting."""
def test_get_session_count(self):
"""Should return total number of sessions."""
manager = SessionManager()
assert manager.get_session_count() == 0
manager.create_session(prompt="Test 1", model="imagen-4", count=2)
assert manager.get_session_count() == 1
manager.create_session(prompt="Test 2", model="imagen-4", count=2)
assert manager.get_session_count() == 2
def test_get_active_session_count(self):
"""Should return number of non-complete sessions."""
manager = SessionManager()
session1 = manager.create_session(prompt="Test 1", model="imagen-4", count=2)
_ = manager.create_session(prompt="Test 2", model="imagen-4", count=2)
assert manager.get_active_session_count() == 2
# Mark one as complete
session1.status = SessionStatus.COMPLETED
assert manager.get_active_session_count() == 1
class TestGetSessionManager:
"""Tests for global session manager."""
def test_get_session_manager_singleton(self):
"""get_session_manager should return the same instance."""
# Reset global state
import Imagen_MCP.services.session_manager as sm
sm._session_manager = None
manager1 = get_session_manager()
manager2 = get_session_manager()
assert manager1 is manager2
# Clean up
sm._session_manager = None