test_concurrent_sessions.py•3.55 kB
"""
Concurrent session tests for SDK server (v2).
Tests async request handling and concurrent session management.
"""
import asyncio
import json
import sys
import tempfile
from pathlib import Path
import pytest
from mcp_debug_tool.server import MCPServerV2
@pytest.fixture
def workspace_with_scripts():
"""Create a temporary workspace with multiple test scripts."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir)
# Create multiple test scripts
for i in range(3):
script_content = f'''
def func_{i}(x):
y = x * {i + 1}
return y
def main():
result = func_{i}(10)
return result
if __name__ == "__main__":
main()
'''
script_path = workspace / f"script_{i}.py"
script_path.write_text(script_content)
yield workspace
@pytest.fixture
def sdk_server(workspace_with_scripts):
"""Create an SDK server instance with test workspace."""
return MCPServerV2(workspace_root=workspace_with_scripts)
class TestConcurrentSessions:
"""Test concurrent session operations"""
@pytest.mark.asyncio
async def test_concurrent_create_and_state(self, sdk_server):
"""Test creating and querying multiple sessions concurrently"""
# Create 5 sessions concurrently
create_tasks = [
sdk_server._handle_sessions_create({
"entry": f"script_{i % 3}.py",
"pythonPath": sys.executable,
})
for i in range(5)
]
create_results = await asyncio.gather(*create_tasks)
session_ids = []
for r in create_results:
response = json.loads(r[0].text)
if "sessionId" in response:
session_ids.append(response["sessionId"])
assert len(session_ids) >= 4
# Query all sessions concurrently
state_tasks = [
sdk_server._handle_sessions_state({
"sessionId": sid,
})
for sid in session_ids
]
state_results = await asyncio.gather(*state_tasks)
assert len(state_results) == len(session_ids)
for result in state_results:
response = json.loads(result[0].text)
assert "status" in response
@pytest.mark.asyncio
async def test_concurrent_session_cleanup(self, sdk_server):
"""Test creating and ending multiple sessions concurrently"""
# Create 3 sessions
create_tasks = [
sdk_server._handle_sessions_create({
"entry": "script_0.py",
"pythonPath": sys.executable,
})
for _ in range(3)
]
create_results = await asyncio.gather(*create_tasks)
session_ids = []
for r in create_results:
response = json.loads(r[0].text)
if "sessionId" in response:
session_ids.append(response["sessionId"])
# End all sessions concurrently
end_tasks = [
sdk_server._handle_sessions_end({
"sessionId": sid,
})
for sid in session_ids
]
end_results = await asyncio.gather(*end_tasks)
assert len(end_results) == len(session_ids)
for result in end_results:
response = json.loads(result[0].text)
assert response.get("ended") is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])