Skip to main content
Glama

Shannon MCP

test_session_manager.py14.3 kB
""" Tests for Session Manager functionality. """ import pytest import asyncio from pathlib import Path from datetime import datetime, timezone import json from unittest.mock import Mock, AsyncMock, patch from shannon_mcp.models.session import Session, SessionStatus from shannon_mcp.managers.session import SessionManager from tests.fixtures.session_fixtures import SessionFixtures class TestSessionCreation: """Test session creation functionality.""" @pytest.mark.asyncio async def test_create_session(self, session_manager): """Test creating a new session.""" session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt", model="claude-3-opus", temperature=0.7, max_tokens=4096 ) assert session.id is not None assert session.project_path == "/test/project" assert session.prompt == "Test prompt" assert session.model == "claude-3-opus" assert session.temperature == 0.7 assert session.max_tokens == 4096 assert session.status == SessionStatus.CREATED assert session.created_at is not None @pytest.mark.asyncio async def test_create_session_with_defaults(self, session_manager): """Test creating session with default values.""" session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) assert session.model == "claude-3-opus" # Default model assert session.temperature == 0.7 # Default temperature assert session.max_tokens == 4096 # Default max tokens @pytest.mark.asyncio async def test_create_session_validation(self, session_manager): """Test session creation validation.""" # Test invalid temperature with pytest.raises(ValueError): await session_manager.create_session( project_path="/test", prompt="Test", temperature=2.0 # Invalid: > 1.0 ) # Test invalid max_tokens with pytest.raises(ValueError): await session_manager.create_session( project_path="/test", prompt="Test", max_tokens=-100 # Invalid: negative ) class TestSessionLifecycle: """Test session lifecycle management.""" @pytest.mark.asyncio async def test_start_session(self, session_manager, mock_claude_binary): """Test starting a session.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Mock subprocess with patch('asyncio.create_subprocess_exec') as mock_subprocess: mock_process = AsyncMock() mock_process.pid = 12345 mock_process.returncode = None mock_subprocess.return_value = mock_process # Start session process_info = await session_manager.start_session(session.id) assert process_info["pid"] == 12345 assert process_info["session_id"] == session.id # Verify session status updated updated_session = await session_manager.get_session(session.id) assert updated_session.status == SessionStatus.RUNNING assert updated_session.started_at is not None @pytest.mark.asyncio async def test_complete_session(self, session_manager): """Test completing a session.""" # Create and mock start session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Mock as running session.status = SessionStatus.RUNNING session.started_at = datetime.now(timezone.utc) await session_manager._update_session(session) # Complete session await session_manager.complete_session( session.id, success=True, metadata={"tokens_used": 1500} ) # Verify completion completed = await session_manager.get_session(session.id) assert completed.status == SessionStatus.COMPLETED assert completed.completed_at is not None assert completed.metadata["tokens_used"] == 1500 @pytest.mark.asyncio async def test_fail_session(self, session_manager): """Test failing a session.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Fail session await session_manager.fail_session( session.id, error="Test error occurred" ) # Verify failure failed = await session_manager.get_session(session.id) assert failed.status == SessionStatus.FAILED assert failed.completed_at is not None assert failed.metadata["error"] == "Test error occurred" @pytest.mark.asyncio async def test_cancel_session(self, session_manager): """Test canceling a running session.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Mock running process session_manager._processes[session.id] = { "process": AsyncMock(), "pid": 12345 } # Cancel session result = await session_manager.cancel_session(session.id) assert result == True # Verify canceled canceled = await session_manager.get_session(session.id) assert canceled.status == SessionStatus.CANCELLED class TestSessionStreaming: """Test session JSONL streaming functionality.""" @pytest.mark.asyncio async def test_stream_session_output(self, session_manager): """Test streaming session output.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Create mock messages messages = SessionFixtures.create_streaming_messages(count=5) # Mock process with stdout mock_process = AsyncMock() mock_stdout = AsyncMock() async def mock_readline(): for msg in messages: yield msg.encode() + b'\n' mock_stdout.readline = AsyncMock(side_effect=mock_readline()) mock_process.stdout = mock_stdout session_manager._processes[session.id] = { "process": mock_process, "pid": 12345 } # Stream output collected = [] async for message in session_manager.stream_output(session.id): collected.append(message) if len(collected) >= len(messages): break assert len(collected) == len(messages) # Verify message types message_data = [json.loads(msg) for msg in collected] assert message_data[0]["type"] == "session_start" assert message_data[-1]["type"] == "session_complete" @pytest.mark.asyncio async def test_stream_with_errors(self, session_manager): """Test streaming with error handling.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Create error message error_msg = SessionFixtures.create_error_message("timeout") # Mock process mock_process = AsyncMock() mock_stdout = AsyncMock() mock_stdout.readline = AsyncMock(return_value=error_msg.encode() + b'\n') mock_process.stdout = mock_stdout session_manager._processes[session.id] = { "process": mock_process, "pid": 12345 } # Stream should handle error async for message in session_manager.stream_output(session.id): data = json.loads(message) assert data["type"] == "error" assert data["error"] == "TimeoutError" break class TestSessionQuerying: """Test session querying functionality.""" @pytest.mark.asyncio async def test_list_sessions(self, session_manager): """Test listing sessions.""" # Create multiple sessions sessions = [] for i in range(5): session = await session_manager.create_session( project_path=f"/test/project_{i}", prompt=f"Test prompt {i}" ) sessions.append(session) # List all sessions all_sessions = await session_manager.list_sessions() assert len(all_sessions) == 5 # List by status created_sessions = await session_manager.list_sessions( status=SessionStatus.CREATED ) assert len(created_sessions) == 5 @pytest.mark.asyncio async def test_list_sessions_by_project(self, session_manager): """Test listing sessions by project.""" # Create sessions for different projects project1_sessions = [] for i in range(3): session = await session_manager.create_session( project_path="/project1", prompt=f"Test {i}" ) project1_sessions.append(session) project2_sessions = [] for i in range(2): session = await session_manager.create_session( project_path="/project2", prompt=f"Test {i}" ) project2_sessions.append(session) # List by project p1_list = await session_manager.list_sessions(project_path="/project1") assert len(p1_list) == 3 p2_list = await session_manager.list_sessions(project_path="/project2") assert len(p2_list) == 2 @pytest.mark.asyncio async def test_get_session_stats(self, session_manager): """Test getting session statistics.""" # Create sessions with different statuses sessions = SessionFixtures.create_batch_sessions(count=10) for session in sessions: await session_manager._save_session(session) # Get stats stats = await session_manager.get_session_stats() assert stats["total"] == 10 assert "by_status" in stats assert "by_model" in stats assert "average_duration" in stats class TestSessionCaching: """Test session caching functionality.""" @pytest.mark.asyncio async def test_session_cache(self, session_manager): """Test session caching mechanism.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # First get should hit database session1 = await session_manager.get_session(session.id) # Second get should use cache session2 = await session_manager.get_session(session.id) assert session1.id == session2.id assert session1.created_at == session2.created_at @pytest.mark.asyncio async def test_cache_invalidation(self, session_manager): """Test cache invalidation on updates.""" # Create session session = await session_manager.create_session( project_path="/test/project", prompt="Test prompt" ) # Cache it cached = await session_manager.get_session(session.id) assert cached.status == SessionStatus.CREATED # Update status await session_manager.complete_session(session.id) # Should get updated version updated = await session_manager.get_session(session.id) assert updated.status == SessionStatus.COMPLETED class TestConcurrentSessions: """Test concurrent session handling.""" @pytest.mark.asyncio async def test_concurrent_session_limit(self, session_manager): """Test concurrent session limits.""" # Set max concurrent to 3 session_manager._config._config["session"]["max_concurrent"] = 3 # Create 3 running sessions running_sessions = [] for i in range(3): session = await session_manager.create_session( project_path=f"/test/project_{i}", prompt=f"Test {i}" ) session.status = SessionStatus.RUNNING await session_manager._update_session(session) running_sessions.append(session) # Fourth session should fail with pytest.raises(RuntimeError, match="concurrent session limit"): session = await session_manager.create_session( project_path="/test/project_4", prompt="Test 4" ) session.status = SessionStatus.RUNNING await session_manager._update_session(session) @pytest.mark.asyncio async def test_session_cleanup_on_exit(self, session_manager): """Test session cleanup on manager exit.""" # Create running sessions sessions = [] for i in range(3): session = await session_manager.create_session( project_path=f"/test/project_{i}", prompt=f"Test {i}" ) # Mock as running session_manager._processes[session.id] = { "process": AsyncMock(), "pid": 12345 + i } sessions.append(session) # Stop manager (should cleanup) await session_manager.stop() # Verify processes terminated assert len(session_manager._processes) == 0

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/krzemienski/shannon-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server