#!/usr/bin/env python3
"""
MCP Protocol Integration Test
This script implements a proper MCP client that:
1. Starts the MCP chain server in a subprocess
2. Follows the exact MCP protocol from mcp_protocol.md
3. Sends proper initialization sequence and tool calls
4. Validates responses according to the protocol
5. Checks logs to verify middleware logging works
6. Properly terminates the subprocess
"""
import os
import sys
import json
import time
import logging
import subprocess
import pytest
import select
import tempfile
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional
import jsonschema
# Setup
project_root = Path(__file__).parent.parent.absolute()
log_dir = Path(tempfile.gettempdir()) / "mcp_chain_integration"
log_dir.mkdir(exist_ok=True, parents=True)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - MCP-CLIENT - %(levelname)s - %(message)s",
handlers=[logging.FileHandler(log_dir / "mcp_client.log"), logging.StreamHandler()],
)
logger = logging.getLogger("MCPProtocolTest")
# MCP Protocol Schemas for validation
TOOL_SCHEMA = {
"type": "object",
"required": ["name", "description"],
"properties": {
"name": {"type": "string"},
"description": {"type": "string"},
"inputSchema": {"type": "object"},
},
}
INITIALIZE_RESPONSE_SCHEMA = {
"type": "object",
"required": ["protocolVersion", "capabilities", "serverInfo"],
"properties": {
"protocolVersion": {"type": "string"},
"capabilities": {
"type": "object",
"properties": {
"tools": {"type": "object"},
"resources": {"type": "object"},
"prompts": {"type": "object"},
},
},
"serverInfo": {
"type": "object",
"required": ["name", "version"],
"properties": {"name": {"type": "string"}, "version": {"type": "string"}},
},
},
}
class MCPProtocolClient:
"""MCP client that follows the exact protocol specification."""
def __init__(self, server_process: subprocess.Popen):
self.process = server_process
self.message_id = 0
self.protocol_version = "2025-03-26"
def _next_id(self) -> int:
"""Get next message ID."""
self.message_id += 1
return self.message_id
def send_message(self, message: Dict[str, Any]) -> None:
"""Send a JSON-RPC message to the server via STDIN."""
message_json = json.dumps(message) + "\n"
logger.debug(f"Sending: {message_json.strip()}")
self.process.stdin.write(message_json)
self.process.stdin.flush()
def read_response(self, timeout: int = 5) -> Optional[Dict[str, Any]]:
"""Read a JSON-RPC response from the server via STDOUT with proper timeout."""
start_time = time.time()
while time.time() - start_time < timeout:
# Check if process is still alive
if self.process.poll() is not None:
logger.error("Server process died")
return None
# Use select to check if data is available with timeout
try:
ready, _, _ = select.select([self.process.stdout], [], [], 0.1)
if ready:
line = self.process.stdout.readline()
if line:
line = line.strip()
if line: # Only process non-empty lines
response = json.loads(line)
logger.debug(f"Received: {json.dumps(response, indent=2)}")
return response
except select.error as e:
logger.debug(f"Select error: {e}")
time.sleep(0.05)
continue
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON response: {line}")
logger.error(f"JSON error: {e}")
return None
except Exception as e:
logger.debug(f"Read error: {e}")
time.sleep(0.05)
continue
time.sleep(0.05) # Small sleep to prevent busy waiting
logger.warning(f"No response received within {timeout} seconds")
return None
def send_request_and_get_response(
self, method: str, params: Dict[str, Any] = None, timeout: int = 5
) -> Optional[Dict[str, Any]]:
"""Send a request and wait for response."""
request_id = self._next_id()
request = {"jsonrpc": "2.0", "id": request_id, "method": method}
if params:
request["params"] = params
self.send_message(request)
response = self.read_response(timeout)
# Validate response structure and ID
if response:
if not self._is_valid_jsonrpc_response(response):
logger.error(f"Invalid JSON-RPC response structure: {response}")
return None
if response.get("id") != request_id:
logger.error(
f"Response ID mismatch: expected {request_id}, got {response.get('id')}"
)
return None
return response
def send_notification(self, method: str, params: Dict[str, Any] = None) -> None:
"""Send a notification (no response expected)."""
notification = {"jsonrpc": "2.0", "method": method}
if params:
notification["params"] = params
self.send_message(notification)
def send_invalid_json(self) -> None:
"""Send invalid JSON to test error handling."""
try:
self.process.stdin.write("invalid json\n")
self.process.stdin.flush()
except BrokenPipeError:
logger.error("Server process stdin closed")
def _is_valid_jsonrpc_response(self, response: Dict[str, Any]) -> bool:
"""Validate JSON-RPC response format."""
if response.get("jsonrpc") != "2.0":
return False
if "id" not in response:
return False
if "result" not in response and "error" not in response:
return False
if "result" in response and "error" in response:
return False
return True
def start_mcp_server() -> subprocess.Popen:
"""Start the MCP chain server subprocess."""
logger.info("Starting MCP chain server...")
server_script = project_root / "examples" / "protocol_testing_server.py"
if not server_script.exists():
raise FileNotFoundError(f"Server script not found: {server_script}")
env = os.environ.copy()
env["PYTHONPATH"] = str(project_root / "src")
process = subprocess.Popen(
[sys.executable, str(server_script)],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=env,
cwd=str(project_root),
)
# Give server time to start
time.sleep(2)
if process.poll() is not None:
stdout, stderr = process.communicate()
logger.error("Server failed to start:")
logger.error(f"STDOUT: {stdout}")
logger.error(f"STDERR: {stderr}")
raise RuntimeError("Failed to start MCP server")
logger.info(f"MCP server started successfully (PID: {process.pid})")
return process
@pytest.fixture
def mcp_server():
"""Fixture to start and stop MCP server for tests."""
server_process = None
try:
server_process = start_mcp_server()
yield server_process
finally:
# Cleanup: kill the server subprocess
if server_process:
logger.info("Terminating MCP server subprocess...")
try:
server_process.terminate()
server_process.wait(timeout=2) # Reduced timeout
logger.info("Server terminated cleanly")
except subprocess.TimeoutExpired:
logger.warning("Server didn't terminate quickly, killing...")
server_process.kill()
try:
server_process.wait(timeout=1)
logger.info("Server killed")
except subprocess.TimeoutExpired:
logger.error("Server failed to die even after kill signal")
except Exception as e:
logger.error(f"Error terminating server: {e}")
# Force kill if cleanup fails
try:
server_process.kill()
except:
pass
@pytest.mark.integration
def test_mcp_protocol_initialization(mcp_server):
"""Test MCP protocol initialization sequence."""
client = MCPProtocolClient(mcp_server)
# Step 1: Send initialize request
init_response = client.send_request_and_get_response(
"initialize",
{
"protocolVersion": client.protocol_version,
"capabilities": {"roots": {"listChanged": False}},
"clientInfo": {"name": "mcp-protocol-test-client", "version": "1.0.0"},
},
timeout=10,
)
# Enhanced validation with schema
assert init_response is not None, "Initialize response should not be None"
assert "result" in init_response, "Initialize response should contain 'result'"
# Validate against schema
try:
jsonschema.validate(init_response["result"], INITIALIZE_RESPONSE_SCHEMA)
except jsonschema.ValidationError as e:
pytest.fail(f"Initialize response schema validation failed: {e}")
# Additional semantic validation
result = init_response["result"]
assert result["protocolVersion"] is not None, (
"Server should return protocol version"
)
assert isinstance(result["capabilities"], dict), "Capabilities should be an object"
# Step 2: Send initialized notification
client.send_notification("notifications/initialized")
time.sleep(0.5) # Give server time to process
@pytest.mark.integration
def test_mcp_protocol_tools_discovery(mcp_server):
"""Test MCP tools discovery."""
client = MCPProtocolClient(mcp_server)
# Initialize first
init_response = client.send_request_and_get_response(
"initialize",
{
"protocolVersion": client.protocol_version,
"capabilities": {"roots": {"listChanged": False}},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
timeout=10,
)
assert init_response is not None
client.send_notification("notifications/initialized")
time.sleep(0.5)
# Discover tools
tools_response = client.send_request_and_get_response("tools/list", timeout=10)
assert tools_response is not None, "Tools response should not be None"
assert "result" in tools_response, "Tools response should contain 'result'"
assert "tools" in tools_response["result"], "Tools response should contain 'tools'"
tools = tools_response["result"]["tools"]
assert isinstance(tools, list), "Tools should be a list"
assert len(tools) > 0, "Should have at least one tool available"
# Enhanced tool validation with schema
for tool in tools:
try:
jsonschema.validate(tool, TOOL_SCHEMA)
except jsonschema.ValidationError as e:
pytest.fail(
f"Tool schema validation failed for {tool.get('name', 'unknown')}: {e}"
)
# Additional semantic validation
assert len(tool["name"]) > 0, "Tool name should not be empty"
assert len(tool["description"]) > 0, "Tool description should not be empty"
@pytest.mark.integration
def test_mcp_protocol_tool_execution(mcp_server):
"""Test MCP tool execution."""
client = MCPProtocolClient(mcp_server)
# Initialize and discover tools
init_response = client.send_request_and_get_response(
"initialize",
{
"protocolVersion": client.protocol_version,
"capabilities": {"roots": {"listChanged": False}},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
timeout=10,
)
assert init_response is not None
client.send_notification("notifications/initialized")
time.sleep(0.5)
tools_response = client.send_request_and_get_response("tools/list", timeout=10)
assert tools_response is not None
assert len(tools_response["result"]["tools"]) > 0
# Execute the first available tool
tool_name = tools_response["result"]["tools"][0]["name"]
call_response = client.send_request_and_get_response(
"tools/call",
{"name": tool_name, "arguments": {"_args": ["Hello from MCP protocol test!"]}},
timeout=10,
)
assert call_response is not None, "Tool call response should not be None"
assert "result" in call_response, "Tool call response should contain 'result'"
assert "content" in call_response["result"], (
"Tool call response should contain 'content'"
)
content = call_response["result"]["content"]
assert len(content) > 0, "Tool call should return content"
@pytest.mark.integration
def test_mcp_protocol_middleware_logging(mcp_server):
"""Test that middleware logging is working."""
client = MCPProtocolClient(mcp_server)
# Initialize and make some requests to generate logs
init_response = client.send_request_and_get_response(
"initialize",
{
"protocolVersion": client.protocol_version,
"capabilities": {"roots": {"listChanged": False}},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
timeout=10,
)
assert init_response is not None
client.send_notification("notifications/initialized")
time.sleep(0.5)
# Make a tools/list request to generate more log entries
tools_response = client.send_request_and_get_response("tools/list", timeout=10)
assert tools_response is not None
# Give some time for logs to be written
time.sleep(1)
# Check if log files were created
messages_log = log_dir / "mcp_messages.jsonl"
server_log = log_dir / "mcp_server.log"
assert messages_log.exists(), f"Messages log should exist at {messages_log}"
assert server_log.exists(), f"Server log should exist at {server_log}"
# Enhanced log content validation
with open(messages_log, "r") as f:
log_lines = f.readlines()
assert len(log_lines) > 0, "Should have log entries"
# Validate each log entry is valid JSON with required fields
for i, line in enumerate(log_lines):
try:
log_entry = json.loads(line.strip())
assert "direction" in log_entry, f"Log entry {i} should have direction"
assert "message" in log_entry, f"Log entry {i} should have message"
assert "timestamp" in log_entry, f"Log entry {i} should have timestamp"
assert log_entry["direction"] in ["request", "response"], (
f"Log entry {i} has invalid direction"
)
except json.JSONDecodeError:
pytest.fail(f"Log entry {i} is not valid JSON: {line}")
logged_requests = sum(1 for line in log_lines if '"direction": "request"' in line)
logged_responses = sum(1 for line in log_lines if '"direction": "response"' in line)
assert logged_requests >= 2, "Should have logged initialize and tools/list requests"
assert logged_responses >= 2, "Should have logged corresponding responses"
@pytest.mark.integration
@pytest.mark.timeout(30) # Add overall test timeout
def test_mcp_protocol_error_handling(mcp_server):
"""Test MCP protocol error handling scenarios."""
client = MCPProtocolClient(mcp_server)
# Initialize connection first with shorter timeout
init_response = client.send_request_and_get_response(
"initialize",
{
"protocolVersion": client.protocol_version,
"capabilities": {"roots": {"listChanged": False}},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
timeout=5, # Reduced timeout
)
assert init_response is not None
client.send_notification("notifications/initialized")
time.sleep(0.2) # Reduced sleep time
# Test 1: Unknown method with shorter timeout
logger.info("Testing unknown method error handling...")
error_response = client.send_request_and_get_response("unknown/method", timeout=3)
# Be more lenient - server might not respond to unknown methods
if error_response is None:
logger.warning("Server did not respond to unknown method - this is acceptable")
else:
assert "error" in error_response, "Response should contain error"
error = error_response["error"]
expected_codes = [-32601, -32600] # Method not found or invalid request
assert error["code"] in expected_codes, (
f"Should return method not found or invalid request error code, got {error['code']}"
)
# Test 2: Invalid tool call (tool not found) with timeout
logger.info("Testing invalid tool call error handling...")
try:
error_response = client.send_request_and_get_response(
"tools/call", {"name": "nonexistent_tool", "arguments": {}}, timeout=3
)
if error_response and "error" in error_response:
error_msg = error_response["error"]["message"].lower()
assert (
"nonexistent" in error_msg
or "not found" in error_msg
or "unknown" in error_msg
)
except Exception as e:
logger.warning(f"Tool call error test failed gracefully: {e}")
# Test 3: Malformed request with timeout
logger.info("Testing malformed request error handling...")
try:
malformed_request = {
"jsonrpc": "2.0",
"id": client._next_id(),
"method": "tools/call",
# Missing required params
}
client.send_message(malformed_request)
response = client.read_response(timeout=3)
if response and "error" in response:
# Server should return an error for missing parameters
expected_codes = [
-32602,
-32600,
-32601,
] # Invalid params, invalid request, or method not found
assert response["error"]["code"] in expected_codes, (
f"Should return appropriate error code, got {response['error']['code']}"
)
except Exception as e:
logger.warning(f"Malformed request test failed gracefully: {e}")
@pytest.mark.integration
def test_mcp_protocol_complete_flow():
"""Test complete MCP protocol flow from start to finish."""
logger.info("=" * 60)
logger.info("MCP PROTOCOL COMPLETE INTEGRATION TEST")
logger.info("=" * 60)
test_results = []
server_process = None
try:
# Step 1: Start MCP server
server_process = start_mcp_server()
client = MCPProtocolClient(server_process)
# Step 2: Initialize connection
logger.info("Step 2: Sending initialize request...")
init_response = client.send_request_and_get_response(
"initialize",
{
"protocolVersion": client.protocol_version,
"capabilities": {"roots": {"listChanged": False}},
"clientInfo": {"name": "mcp-protocol-test-client", "version": "1.0.0"},
},
timeout=10,
)
init_success = (
init_response is not None
and "result" in init_response
and "protocolVersion" in init_response["result"]
)
test_results.append(
{"test": "initialize", "success": init_success, "response": init_response}
)
assert init_success, "Initialize step should succeed"
# Step 3: Send initialized notification
logger.info("Step 3: Sending initialized notification...")
client.send_notification("notifications/initialized")
time.sleep(0.5)
# Step 4: Discover tools
logger.info("Step 4: Discovering tools...")
tools_response = client.send_request_and_get_response("tools/list", timeout=10)
tools_success = (
tools_response is not None
and "result" in tools_response
and "tools" in tools_response["result"]
)
test_results.append(
{"test": "tools_list", "success": tools_success, "response": tools_response}
)
assert tools_success, "Tools list step should succeed"
# Step 5: Execute tool (if tools are available)
if tools_success and len(tools_response["result"]["tools"]) > 0:
logger.info("Step 5: Executing tool...")
tool_name = tools_response["result"]["tools"][0]["name"]
call_response = client.send_request_and_get_response(
"tools/call",
{
"name": tool_name,
"arguments": {"_args": ["Hello from MCP protocol test!"]},
},
timeout=10,
)
call_success = (
call_response is not None
and "result" in call_response
and "content" in call_response["result"]
)
test_results.append(
{
"test": "tool_call",
"success": call_success,
"response": call_response,
}
)
assert call_success, "Tool call step should succeed"
# Step 6: Verify middleware logging
logger.info("Step 6: Verifying middleware logging...")
time.sleep(1) # Give logs time to be written
messages_log = log_dir / "mcp_messages.jsonl"
server_log = log_dir / "mcp_server.log"
logging_success = messages_log.exists() and server_log.exists()
if logging_success:
with open(messages_log, "r") as f:
log_lines = f.readlines()
logged_requests = sum(
1 for line in log_lines if '"direction": "request"' in line
)
logged_responses = sum(
1 for line in log_lines if '"direction": "response"' in line
)
logging_success = logged_requests > 0 and logged_responses > 0
test_results.append({"test": "middleware_logging", "success": logging_success})
assert logging_success, "Middleware logging should work"
# All tests passed
passed_tests = sum(1 for r in test_results if r["success"])
total_tests = len(test_results)
logger.info(f"All {total_tests} integration tests passed!")
# Save results
results_file = (
log_dir
/ f"mcp_protocol_test_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
with open(results_file, "w") as f:
json.dump(
{
"test_summary": {
"timestamp": datetime.now().isoformat(),
"total_tests": total_tests,
"passed_tests": passed_tests,
"success_rate": f"{100.0:.1f}%",
"overall_success": True,
},
"test_details": test_results,
},
f,
indent=2,
)
logger.info(f"Results saved to: {results_file}")
finally:
# Cleanup: kill the server subprocess
if server_process:
logger.info("Terminating MCP server subprocess...")
try:
server_process.terminate()
server_process.wait(timeout=3)
logger.info("Server terminated cleanly")
except subprocess.TimeoutExpired:
logger.warning("Server didn't terminate, killing...")
server_process.kill()
server_process.wait()
logger.info("Server killed")
except Exception as e:
logger.error(f"Error terminating server: {e}")
def run_mcp_protocol_test():
"""Run the MCP protocol integration test."""
return test_mcp_protocol_complete_flow()
def main():
"""Main entry point - redirects to pytest."""
print("๐งช MCP Protocol Integration Tests")
print("=" * 50)
print("This file has been converted to pytest.")
print("To run the integration tests, use:")
print()
print(" # Run all integration tests:")
print(" pytest tests/test_mcp_protocol_integration.py -v -m integration")
print()
print(" # Run specific test:")
print(
" pytest tests/test_mcp_protocol_integration.py::test_mcp_protocol_complete_flow -v"
)
print()
print(" # Run with verbose output:")
print(" pytest tests/test_mcp_protocol_integration.py -v -s")
print()
print(f"๐ Logs will be in: {log_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())