test_mcp_stdio_concurrent.py•15.9 kB
#!/usr/bin/env python3
"""
Test MCP stdio server concurrent access and parallel function handling
Tests stdio transport with multiple concurrent requests to verify thread safety
"""
import asyncio
import json
import subprocess
import sys
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
class MCPStdioConcurrentTester:
"""Test MCP server stdio transport with concurrent operations"""
def __init__(self):
self.server_process = None
self.test_results = []
self.request_id_counter = 1
async def start_mcp_server(self):
"""Start the MCP server in stdio mode"""
print("Starting MCP stdio server for concurrent testing...")
# Start the hybrid server in stdio mode
self.server_process = subprocess.Popen(
[sys.executable, "mcp_server_hybrid.py"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=Path(__file__).parent
)
# Give it a moment to start
await asyncio.sleep(1)
print("MCP stdio server started")
def send_request_sync(self, request: dict):
"""Send a JSON-RPC request to the MCP server (thread-safe)"""
if not self.server_process:
raise RuntimeError("MCP server not started")
try:
# Send the request
request_line = json.dumps(request) + '\n'
self.server_process.stdin.write(request_line)
self.server_process.stdin.flush()
# Read the response
response_line = self.server_process.stdout.readline()
if not response_line:
return None
return json.loads(response_line.strip())
except Exception as e:
print(f"Error in send_request_sync: {e}")
return {"error": str(e)}
def get_next_request_id(self):
"""Get next request ID in thread-safe manner"""
self.request_id_counter += 1
return self.request_id_counter
def test_concurrent_resource_access(self):
"""Test concurrent access to resources"""
print("\n[CONCURRENT] Testing concurrent resource access...")
# First initialize the server
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"resources": {"subscribe": False},
"tools": {},
"prompts": {}
},
"clientInfo": {
"name": "mcp-stdio-concurrent-test",
"version": "1.0.0"
}
}
}
init_response = self.send_request_sync(init_request)
if not init_response or "result" not in init_response:
print("[FAIL] Server initialization failed for concurrent test")
return
# Send initialized notification
initialized_request = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
self.server_process.stdin.write(json.dumps(initialized_request) + '\n')
self.server_process.stdin.flush()
# Create multiple concurrent resource requests
def make_resource_request(resource_uri, thread_id):
request_id = self.get_next_request_id()
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "resources/read",
"params": {
"uri": resource_uri
}
}
start_time = time.time()
response = self.send_request_sync(request)
end_time = time.time()
return {
"thread_id": thread_id,
"request_id": request_id,
"resource_uri": resource_uri,
"response": response,
"duration": end_time - start_time,
"success": response and "result" in response
}
# Define test resources
test_resources = [
"database://tables",
"database://schemas",
"database://statistics",
"database://pgconfig",
"table://db3.public.accounts"
]
# Execute concurrent requests
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for i, resource_uri in enumerate(test_resources):
future = executor.submit(make_resource_request, resource_uri, i+1)
futures.append(future)
# Collect results
concurrent_results = []
for future in as_completed(futures):
try:
result = future.result(timeout=10)
concurrent_results.append(result)
except Exception as e:
print(f"[ERROR] Concurrent request failed: {e}")
# Analyze results
successful_requests = [r for r in concurrent_results if r["success"]]
failed_requests = [r for r in concurrent_results if not r["success"]]
print(f"[RESULT] Concurrent resource access:")
print(f" Successful: {len(successful_requests)}/{len(test_resources)}")
print(f" Failed: {len(failed_requests)}")
print(f" Average duration: {sum(r['duration'] for r in successful_requests)/len(successful_requests):.3f}s")
if len(successful_requests) == len(test_resources):
print(" [OK] All concurrent resource requests succeeded")
self.test_results.append(("Concurrent Resources", True, f"All {len(test_resources)} concurrent requests succeeded"))
else:
print(" [FAIL] Some concurrent resource requests failed")
self.test_results.append(("Concurrent Resources", False, f"{len(failed_requests)} out of {len(test_resources)} requests failed"))
def test_concurrent_tool_calls(self):
"""Test concurrent tool execution"""
print("\n[CONCURRENT] Testing concurrent tool calls...")
def make_tool_request(tool_name, arguments, thread_id):
request_id = self.get_next_request_id()
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
start_time = time.time()
response = self.send_request_sync(request)
end_time = time.time()
return {
"thread_id": thread_id,
"request_id": request_id,
"tool_name": tool_name,
"response": response,
"duration": end_time - start_time,
"success": response and "result" in response
}
# Define test tool calls (using safe tools)
test_tool_calls = [
("validate_sql_syntax", {"sql_query": "SELECT 1"}),
("validate_sql_syntax", {"sql_query": "SELECT * FROM information_schema.tables LIMIT 1"}),
("validate_sql_syntax", {"sql_query": "SELECT COUNT(*) FROM pg_tables"}),
("get_table_info", {"table_name": "accounts", "database": "db3"}),
("get_table_info", {"table_name": "users", "database": "db3"})
]
# Execute concurrent tool calls
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for i, (tool_name, arguments) in enumerate(test_tool_calls):
future = executor.submit(make_tool_request, tool_name, arguments, i+1)
futures.append(future)
# Collect results
concurrent_results = []
for future in as_completed(futures):
try:
result = future.result(timeout=15)
concurrent_results.append(result)
except Exception as e:
print(f"[ERROR] Concurrent tool call failed: {e}")
# Analyze results
successful_calls = [r for r in concurrent_results if r["success"]]
failed_calls = [r for r in concurrent_results if not r["success"]]
print(f"[RESULT] Concurrent tool calls:")
print(f" Successful: {len(successful_calls)}/{len(test_tool_calls)}")
print(f" Failed: {len(failed_calls)}")
if successful_calls:
print(f" Average duration: {sum(r['duration'] for r in successful_calls)/len(successful_calls):.3f}s")
if len(successful_calls) == len(test_tool_calls):
print(" [OK] All concurrent tool calls succeeded")
self.test_results.append(("Concurrent Tools", True, f"All {len(test_tool_calls)} concurrent tool calls succeeded"))
else:
print(" [FAIL] Some concurrent tool calls failed")
self.test_results.append(("Concurrent Tools", False, f"{len(failed_calls)} out of {len(test_tool_calls)} tool calls failed"))
def test_mixed_concurrent_operations(self):
"""Test mixed concurrent operations (resources, tools, prompts)"""
print("\n[CONCURRENT] Testing mixed concurrent operations...")
def make_mixed_request(operation_type, params, thread_id):
request_id = self.get_next_request_id()
if operation_type == "resource":
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "resources/read",
"params": {"uri": params["uri"]}
}
elif operation_type == "tool":
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "tools/call",
"params": {
"name": params["name"],
"arguments": params["arguments"]
}
}
elif operation_type == "prompt":
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "prompts/get",
"params": {
"name": params["name"],
"arguments": params.get("arguments", {})
}
}
start_time = time.time()
response = self.send_request_sync(request)
end_time = time.time()
return {
"thread_id": thread_id,
"request_id": request_id,
"operation_type": operation_type,
"response": response,
"duration": end_time - start_time,
"success": response and "result" in response
}
# Define mixed operations
mixed_operations = [
("resource", {"uri": "database://tables"}),
("tool", {"name": "validate_sql_syntax", "arguments": {"sql_query": "SELECT 1"}}),
("prompt", {"name": "check_database_health", "arguments": {"database": "db3"}}),
("resource", {"uri": "database://statistics"}),
("tool", {"name": "get_table_info", "arguments": {"table_name": "accounts", "database": "db3"}}),
("prompt", {"name": "analyze_database_schema", "arguments": {"database": "db3"}}),
]
# Execute mixed concurrent operations
with ThreadPoolExecutor(max_workers=6) as executor:
futures = []
for i, (op_type, params) in enumerate(mixed_operations):
future = executor.submit(make_mixed_request, op_type, params, i+1)
futures.append(future)
# Collect results
concurrent_results = []
for future in as_completed(futures):
try:
result = future.result(timeout=20)
concurrent_results.append(result)
except Exception as e:
print(f"[ERROR] Mixed concurrent operation failed: {e}")
# Analyze results by type
resource_results = [r for r in concurrent_results if r["operation_type"] == "resource"]
tool_results = [r for r in concurrent_results if r["operation_type"] == "tool"]
prompt_results = [r for r in concurrent_results if r["operation_type"] == "prompt"]
successful_resources = [r for r in resource_results if r["success"]]
successful_tools = [r for r in tool_results if r["success"]]
successful_prompts = [r for r in prompt_results if r["success"]]
total_successful = len(successful_resources) + len(successful_tools) + len(successful_prompts)
total_operations = len(mixed_operations)
print(f"[RESULT] Mixed concurrent operations:")
print(f" Resources: {len(successful_resources)}/{len(resource_results)} successful")
print(f" Tools: {len(successful_tools)}/{len(tool_results)} successful")
print(f" Prompts: {len(successful_prompts)}/{len(prompt_results)} successful")
print(f" Overall: {total_successful}/{total_operations} successful")
if total_successful == total_operations:
print(" [OK] All mixed concurrent operations succeeded")
self.test_results.append(("Mixed Concurrent Ops", True, f"All {total_operations} mixed operations succeeded"))
else:
print(" [FAIL] Some mixed concurrent operations failed")
self.test_results.append(("Mixed Concurrent Ops", False, f"{total_operations - total_successful} out of {total_operations} operations failed"))
def cleanup(self):
"""Clean up the MCP server process"""
if self.server_process:
print("\n[CLEANUP] Cleaning up MCP server...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
self.server_process.wait()
print("[OK] MCP server stopped")
def print_summary(self):
"""Print test results summary"""
print("\n" + "="*70)
print("[SUMMARY] MCP STDIO CONCURRENT TEST SUMMARY")
print("="*70)
passed = sum(1 for _, success, _ in self.test_results if success)
total = len(self.test_results)
print(f"Tests passed: {passed}/{total}")
print()
for test_name, success, message in self.test_results:
status = "[PASS]" if success else "[FAIL]"
print(f"{status:10} {test_name:<20} {message}")
print("\n" + "="*70)
if passed == total:
print("[SUCCESS] ALL CONCURRENT TESTS PASSED! MCP stdio server handles parallel operations correctly.")
else:
print(f"[WARNING] {total - passed} concurrent test(s) failed. Check stdio thread safety.")
print("="*70)
async def main():
"""Main test runner"""
print("MCP STDIO CONCURRENT ACCESS TEST")
print("Testing parallel function handling via stdio transport")
print("="*70)
tester = MCPStdioConcurrentTester()
try:
# Start the MCP server
await tester.start_mcp_server()
# Run concurrent tests
tester.test_concurrent_resource_access()
tester.test_concurrent_tool_calls()
tester.test_mixed_concurrent_operations()
except KeyboardInterrupt:
print("\n[INTERRUPT] Test interrupted by user")
except Exception as e:
print(f"\n[ERROR] Test runner error: {e}")
finally:
# Always cleanup
tester.cleanup()
tester.print_summary()
if __name__ == "__main__":
asyncio.run(main())