test_sdk_server.py•6.68 kB
"""
Integration tests for MCP SDK-based server.
Verifies that all 5 tools work correctly via the SDK server implementation.
These tests verify the MCPServerV2 class and its tool handlers directly.
"""
import asyncio
import json
import tempfile
from pathlib import Path
import pytest
from mcp_debug_tool.server import MCPServerV2
@pytest.fixture
def workspace_with_script():
"""Create a temporary workspace with a test script."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir)
# Create a simple test script
script_content = '''
def add(a, b):
result = a + b
return result
def main():
x = 10
y = 20
z = add(x, y)
return z
if __name__ == "__main__":
main()
'''
script_path = workspace / "test_script.py"
script_path.write_text(script_content)
yield workspace
@pytest.fixture
def sdk_server(workspace_with_script):
"""Create an SDK server instance with test workspace."""
return MCPServerV2(workspace_root=workspace_with_script)
class TestSDKServerTools:
"""Test all 5 MCP tools via SDK server"""
@pytest.mark.asyncio
async def test_sessions_create_tool(self, workspace_with_script):
"""Test sessions_create tool creates a session"""
sdk_server = MCPServerV2(workspace_root=workspace_with_script)
result = await sdk_server._handle_sessions_create({
"entry": "test_script.py",
})
assert len(result) == 1
response = json.loads(result[0].text)
assert "sessionId" in response
@pytest.mark.asyncio
async def test_sessions_create_invalid_file(self, sdk_server):
"""Test sessions_create with invalid file returns error"""
result = await sdk_server._handle_sessions_create({
"entry": "nonexistent.py",
})
assert len(result) == 1
response = json.loads(result[0].text)
assert "error" in response
@pytest.mark.asyncio
async def test_sessions_state_tool(self, sdk_server):
"""Test sessions_state tool returns session state"""
# Create a session
create_result = await sdk_server._handle_sessions_create({
"entry": "test_script.py"
})
session_id = json.loads(create_result[0].text)["sessionId"]
# Get session state
state_result = await sdk_server._handle_sessions_state({
"sessionId": session_id
})
assert len(state_result) == 1
state = json.loads(state_result[0].text)
assert "status" in state
@pytest.mark.asyncio
async def test_sessions_state_invalid_session(self, sdk_server):
"""Test sessions_state with invalid session ID returns error"""
result = await sdk_server._handle_sessions_state({
"sessionId": "invalid-id"
})
assert len(result) == 1
response = json.loads(result[0].text)
assert "error" in response
@pytest.mark.asyncio
async def test_sessions_end_tool(self, sdk_server):
"""Test sessions_end tool terminates session"""
# Create a session
create_result = await sdk_server._handle_sessions_create({
"entry": "test_script.py"
})
session_id = json.loads(create_result[0].text)["sessionId"]
# End the session
end_result = await sdk_server._handle_sessions_end({
"sessionId": session_id
})
assert len(end_result) == 1
response = json.loads(end_result[0].text)
assert response.get("ended") is True
@pytest.mark.asyncio
async def test_tool_response_format(self, sdk_server):
"""Test tools return consistent response format"""
result = await sdk_server._handle_sessions_create({
"entry": "test_script.py"
})
assert len(result) == 1
assert result[0].type == "text"
assert isinstance(result[0].text, str)
response = json.loads(result[0].text)
assert isinstance(response, dict)
class TestSDKServerErrorHandling:
"""Test error handling in SDK server"""
@pytest.mark.asyncio
async def test_missing_entry_argument(self, sdk_server):
"""Test missing entry argument returns error"""
result = await sdk_server._handle_sessions_create({})
assert len(result) == 1
response = json.loads(result[0].text)
assert "error" in response
@pytest.mark.asyncio
async def test_invalid_args_type(self, sdk_server):
"""Test invalid args type returns error"""
result = await sdk_server._handle_sessions_create({
"entry": "test_script.py",
"args": "not_a_list",
})
assert len(result) == 1
response = json.loads(result[0].text)
assert "error" in response
class TestSDKServerConcurrency:
"""Test concurrent request handling"""
@pytest.mark.asyncio
async def test_concurrent_session_creation(self, sdk_server):
"""Test creating multiple sessions concurrently"""
tasks = [
sdk_server._handle_sessions_create({
"entry": "test_script.py",
})
for _ in range(3)
]
results = await asyncio.gather(*tasks)
assert len(results) == 3
session_ids = []
for result in results:
assert len(result) == 1
response = json.loads(result[0].text)
if "sessionId" in response:
session_ids.append(response["sessionId"])
assert len(session_ids) == 3
assert len(set(session_ids)) == 3 # All unique
@pytest.mark.asyncio
async def test_concurrent_state_queries(self, sdk_server):
"""Test querying multiple sessions concurrently"""
# Create a session
create_result = await sdk_server._handle_sessions_create({
"entry": "test_script.py"
})
session_id = json.loads(create_result[0].text)["sessionId"]
# Query state 3 times concurrently
tasks = [
sdk_server._handle_sessions_state({
"sessionId": session_id,
})
for _ in range(3)
]
results = await asyncio.gather(*tasks)
assert len(results) == 3
for result in results:
assert len(result) == 1
response = json.loads(result[0].text)
assert "status" in response
if __name__ == "__main__":
pytest.main([__file__, "-v"])