We provide all the information about MCP servers via our MCP API.
curl -X GET 'https://glama.ai/api/mcp/v1/servers/krzemienski/shannon-mcp'
If you have feedback or need assistance with the MCP directory API, please join our Discord server
"""
Tests for Session Manager functionality.
"""
import pytest
import asyncio
from pathlib import Path
from datetime import datetime, timezone
import json
from unittest.mock import Mock, AsyncMock, patch
from shannon_mcp.models.session import Session, SessionStatus
from shannon_mcp.managers.session import SessionManager
from tests.fixtures.session_fixtures import SessionFixtures
class TestSessionCreation:
"""Test session creation functionality."""
@pytest.mark.asyncio
async def test_create_session(self, session_manager):
"""Test creating a new session."""
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt",
model="claude-3-opus",
temperature=0.7,
max_tokens=4096
)
assert session.id is not None
assert session.project_path == "/test/project"
assert session.prompt == "Test prompt"
assert session.model == "claude-3-opus"
assert session.temperature == 0.7
assert session.max_tokens == 4096
assert session.status == SessionStatus.CREATED
assert session.created_at is not None
@pytest.mark.asyncio
async def test_create_session_with_defaults(self, session_manager):
"""Test creating session with default values."""
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
assert session.model == "claude-3-opus" # Default model
assert session.temperature == 0.7 # Default temperature
assert session.max_tokens == 4096 # Default max tokens
@pytest.mark.asyncio
async def test_create_session_validation(self, session_manager):
"""Test session creation validation."""
# Test invalid temperature
with pytest.raises(ValueError):
await session_manager.create_session(
project_path="/test",
prompt="Test",
temperature=2.0 # Invalid: > 1.0
)
# Test invalid max_tokens
with pytest.raises(ValueError):
await session_manager.create_session(
project_path="/test",
prompt="Test",
max_tokens=-100 # Invalid: negative
)
class TestSessionLifecycle:
"""Test session lifecycle management."""
@pytest.mark.asyncio
async def test_start_session(self, session_manager, mock_claude_binary):
"""Test starting a session."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Mock subprocess
with patch('asyncio.create_subprocess_exec') as mock_subprocess:
mock_process = AsyncMock()
mock_process.pid = 12345
mock_process.returncode = None
mock_subprocess.return_value = mock_process
# Start session
process_info = await session_manager.start_session(session.id)
assert process_info["pid"] == 12345
assert process_info["session_id"] == session.id
# Verify session status updated
updated_session = await session_manager.get_session(session.id)
assert updated_session.status == SessionStatus.RUNNING
assert updated_session.started_at is not None
@pytest.mark.asyncio
async def test_complete_session(self, session_manager):
"""Test completing a session."""
# Create and mock start session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Mock as running
session.status = SessionStatus.RUNNING
session.started_at = datetime.now(timezone.utc)
await session_manager._update_session(session)
# Complete session
await session_manager.complete_session(
session.id,
success=True,
metadata={"tokens_used": 1500}
)
# Verify completion
completed = await session_manager.get_session(session.id)
assert completed.status == SessionStatus.COMPLETED
assert completed.completed_at is not None
assert completed.metadata["tokens_used"] == 1500
@pytest.mark.asyncio
async def test_fail_session(self, session_manager):
"""Test failing a session."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Fail session
await session_manager.fail_session(
session.id,
error="Test error occurred"
)
# Verify failure
failed = await session_manager.get_session(session.id)
assert failed.status == SessionStatus.FAILED
assert failed.completed_at is not None
assert failed.metadata["error"] == "Test error occurred"
@pytest.mark.asyncio
async def test_cancel_session(self, session_manager):
"""Test canceling a running session."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Mock running process
session_manager._processes[session.id] = {
"process": AsyncMock(),
"pid": 12345
}
# Cancel session
result = await session_manager.cancel_session(session.id)
assert result == True
# Verify canceled
canceled = await session_manager.get_session(session.id)
assert canceled.status == SessionStatus.CANCELLED
class TestSessionStreaming:
"""Test session JSONL streaming functionality."""
@pytest.mark.asyncio
async def test_stream_session_output(self, session_manager):
"""Test streaming session output."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Create mock messages
messages = SessionFixtures.create_streaming_messages(count=5)
# Mock process with stdout
mock_process = AsyncMock()
mock_stdout = AsyncMock()
async def mock_readline():
for msg in messages:
yield msg.encode() + b'\n'
mock_stdout.readline = AsyncMock(side_effect=mock_readline())
mock_process.stdout = mock_stdout
session_manager._processes[session.id] = {
"process": mock_process,
"pid": 12345
}
# Stream output
collected = []
async for message in session_manager.stream_output(session.id):
collected.append(message)
if len(collected) >= len(messages):
break
assert len(collected) == len(messages)
# Verify message types
message_data = [json.loads(msg) for msg in collected]
assert message_data[0]["type"] == "session_start"
assert message_data[-1]["type"] == "session_complete"
@pytest.mark.asyncio
async def test_stream_with_errors(self, session_manager):
"""Test streaming with error handling."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Create error message
error_msg = SessionFixtures.create_error_message("timeout")
# Mock process
mock_process = AsyncMock()
mock_stdout = AsyncMock()
mock_stdout.readline = AsyncMock(return_value=error_msg.encode() + b'\n')
mock_process.stdout = mock_stdout
session_manager._processes[session.id] = {
"process": mock_process,
"pid": 12345
}
# Stream should handle error
async for message in session_manager.stream_output(session.id):
data = json.loads(message)
assert data["type"] == "error"
assert data["error"] == "TimeoutError"
break
class TestSessionQuerying:
"""Test session querying functionality."""
@pytest.mark.asyncio
async def test_list_sessions(self, session_manager):
"""Test listing sessions."""
# Create multiple sessions
sessions = []
for i in range(5):
session = await session_manager.create_session(
project_path=f"/test/project_{i}",
prompt=f"Test prompt {i}"
)
sessions.append(session)
# List all sessions
all_sessions = await session_manager.list_sessions()
assert len(all_sessions) == 5
# List by status
created_sessions = await session_manager.list_sessions(
status=SessionStatus.CREATED
)
assert len(created_sessions) == 5
@pytest.mark.asyncio
async def test_list_sessions_by_project(self, session_manager):
"""Test listing sessions by project."""
# Create sessions for different projects
project1_sessions = []
for i in range(3):
session = await session_manager.create_session(
project_path="/project1",
prompt=f"Test {i}"
)
project1_sessions.append(session)
project2_sessions = []
for i in range(2):
session = await session_manager.create_session(
project_path="/project2",
prompt=f"Test {i}"
)
project2_sessions.append(session)
# List by project
p1_list = await session_manager.list_sessions(project_path="/project1")
assert len(p1_list) == 3
p2_list = await session_manager.list_sessions(project_path="/project2")
assert len(p2_list) == 2
@pytest.mark.asyncio
async def test_get_session_stats(self, session_manager):
"""Test getting session statistics."""
# Create sessions with different statuses
sessions = SessionFixtures.create_batch_sessions(count=10)
for session in sessions:
await session_manager._save_session(session)
# Get stats
stats = await session_manager.get_session_stats()
assert stats["total"] == 10
assert "by_status" in stats
assert "by_model" in stats
assert "average_duration" in stats
class TestSessionCaching:
"""Test session caching functionality."""
@pytest.mark.asyncio
async def test_session_cache(self, session_manager):
"""Test session caching mechanism."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# First get should hit database
session1 = await session_manager.get_session(session.id)
# Second get should use cache
session2 = await session_manager.get_session(session.id)
assert session1.id == session2.id
assert session1.created_at == session2.created_at
@pytest.mark.asyncio
async def test_cache_invalidation(self, session_manager):
"""Test cache invalidation on updates."""
# Create session
session = await session_manager.create_session(
project_path="/test/project",
prompt="Test prompt"
)
# Cache it
cached = await session_manager.get_session(session.id)
assert cached.status == SessionStatus.CREATED
# Update status
await session_manager.complete_session(session.id)
# Should get updated version
updated = await session_manager.get_session(session.id)
assert updated.status == SessionStatus.COMPLETED
class TestConcurrentSessions:
"""Test concurrent session handling."""
@pytest.mark.asyncio
async def test_concurrent_session_limit(self, session_manager):
"""Test concurrent session limits."""
# Set max concurrent to 3
session_manager._config._config["session"]["max_concurrent"] = 3
# Create 3 running sessions
running_sessions = []
for i in range(3):
session = await session_manager.create_session(
project_path=f"/test/project_{i}",
prompt=f"Test {i}"
)
session.status = SessionStatus.RUNNING
await session_manager._update_session(session)
running_sessions.append(session)
# Fourth session should fail
with pytest.raises(RuntimeError, match="concurrent session limit"):
session = await session_manager.create_session(
project_path="/test/project_4",
prompt="Test 4"
)
session.status = SessionStatus.RUNNING
await session_manager._update_session(session)
@pytest.mark.asyncio
async def test_session_cleanup_on_exit(self, session_manager):
"""Test session cleanup on manager exit."""
# Create running sessions
sessions = []
for i in range(3):
session = await session_manager.create_session(
project_path=f"/test/project_{i}",
prompt=f"Test {i}"
)
# Mock as running
session_manager._processes[session.id] = {
"process": AsyncMock(),
"pid": 12345 + i
}
sessions.append(session)
# Stop manager (should cleanup)
await session_manager.stop()
# Verify processes terminated
assert len(session_manager._processes) == 0