"""Tests for MCP server implementation."""
import pytest
from unittest.mock import AsyncMock, Mock, patch
from fastapi.testclient import TestClient
from src.server import MCPServer, ToolSchema, ToolResponse, MCPServerRequest, MCPServerResponse
from src.utils.errors import APIError, ErrorCode, MCPError
class TestMCPServer:
"""Test cases for MCPServer class."""
def test_mcp_server_initialization(self, mock_settings):
"""Test MCPServer initialization."""
server = MCPServer(mock_settings)
assert server.settings == mock_settings
assert server.app is not None
assert server.app.title == "API Aggregator MCP Server"
assert len(server._tools) == 0
assert len(server._handlers) == 0
def test_mcp_server_get_app(self, mcp_server):
"""Test getting FastAPI app instance."""
app = mcp_server.get_app()
assert app == mcp_server.app
def test_register_tool(self, mcp_server):
"""Test registering a tool."""
async def dummy_handler(params):
return {"result": "test"}
schema = {
"type": "object",
"properties": {
"param1": {"type": "string"}
}
}
mcp_server.register_tool(
name="test_tool",
description="Test tool description",
input_schema=schema,
handler=dummy_handler
)
assert "test_tool" in mcp_server._tools
assert "test_tool" in mcp_server._handlers
tool_schema = mcp_server._tools["test_tool"]
assert tool_schema.name == "test_tool"
assert tool_schema.description == "Test tool description"
assert tool_schema.input_schema == schema
assert mcp_server._handlers["test_tool"] == dummy_handler
def test_register_multiple_tools(self, mcp_server):
"""Test registering multiple tools."""
async def handler1(params):
return {"result": "1"}
async def handler2(params):
return {"result": "2"}
mcp_server.register_tool("tool1", "Description 1", {}, handler1)
mcp_server.register_tool("tool2", "Description 2", {}, handler2)
assert len(mcp_server._tools) == 2
assert len(mcp_server._handlers) == 2
assert "tool1" in mcp_server._tools
assert "tool2" in mcp_server._tools
class TestMCPServerRoutes:
"""Test cases for MCP server routes."""
@pytest.fixture
def client(self, mcp_server):
"""Create a test client for the MCP server."""
return TestClient(mcp_server.app)
def test_root_endpoint(self, client):
"""Test the root health check endpoint."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["server"] == "API Aggregator MCP"
def test_list_tools_empty(self, client):
"""Test listing tools when no tools are registered."""
response = client.get("/tools")
assert response.status_code == 200
data = response.json()
assert data == []
def test_list_tools_with_registered_tools(self, mcp_server):
"""Test listing tools when tools are registered."""
async def dummy_handler(params):
return {}
mcp_server.register_tool("tool1", "Description 1", {"type": "object"}, dummy_handler)
mcp_server.register_tool("tool2", "Description 2", {"type": "object"}, dummy_handler)
client = TestClient(mcp_server.app)
response = client.get("/tools")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
tool_names = [tool["name"] for tool in data]
assert "tool1" in tool_names
assert "tool2" in tool_names
@pytest.mark.asyncio
async def test_invoke_tool_success(self, mcp_server):
"""Test successful tool invocation."""
async def test_handler(params):
return {"message": f"Hello {params.get('name', 'World')}"}
mcp_server.register_tool("test_tool", "Test tool", {}, test_handler)
client = TestClient(mcp_server.app)
response = client.post("/tools/test_tool", json={"name": "Alice"})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["message"] == "Hello Alice"
assert data["error"] is None
def test_invoke_nonexistent_tool(self, client):
"""Test invoking a tool that doesn't exist."""
response = client.post("/tools/nonexistent_tool", json={})
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"]["code"] == ErrorCode.METHOD_NOT_FOUND.value
assert "Tool 'nonexistent_tool' not found" in data["error"]["message"]
@pytest.mark.asyncio
async def test_invoke_tool_handler_error(self, mcp_server):
"""Test tool invocation when handler raises an APIError."""
async def error_handler(params):
raise APIError("Handler error", ErrorCode.INVALID_PARAMS)
mcp_server.register_tool("error_tool", "Error tool", {}, error_handler)
client = TestClient(mcp_server.app)
response = client.post("/tools/error_tool", json={})
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"]["code"] == ErrorCode.INVALID_PARAMS.value
assert data["error"]["message"] == "Handler error"
@pytest.mark.asyncio
async def test_invoke_tool_unexpected_error(self, mcp_server):
"""Test tool invocation when handler raises an unexpected error."""
async def failing_handler(params):
raise ValueError("Unexpected error")
mcp_server.register_tool("failing_tool", "Failing tool", {}, failing_handler)
client = TestClient(mcp_server.app)
response = client.post("/tools/failing_tool", json={})
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"]["code"] == ErrorCode.INTERNAL_ERROR.value
assert data["error"]["message"] == "Internal server error"
def test_mcp_endpoint_tools_list(self, mcp_server):
"""Test MCP endpoint for listing tools."""
async def dummy_handler(params):
return {}
mcp_server.register_tool("test_tool", "Test tool", {"type": "object"}, dummy_handler)
client = TestClient(mcp_server.app)
request_data = {
"jsonrpc": "2.0",
"method": "tools/list",
"id": "test-request-1"
}
response = client.post("/mcp", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == "test-request-1"
assert "result" in data
assert "tools" in data["result"]
assert len(data["result"]["tools"]) == 1
assert data["result"]["tools"][0]["name"] == "test_tool"
@pytest.mark.asyncio
async def test_mcp_endpoint_tool_call_success(self, mcp_server):
"""Test MCP endpoint for successful tool call."""
async def test_handler(params):
return {"output": f"Processed {params.get('input', 'data')}"}
mcp_server.register_tool("process_tool", "Process tool", {}, test_handler)
client = TestClient(mcp_server.app)
request_data = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "process_tool",
"arguments": {"input": "test data"}
},
"id": "test-call-1"
}
response = client.post("/mcp", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == "test-call-1"
assert "result" in data
assert "content" in data["result"]
assert len(data["result"]["content"]) == 1
assert data["result"]["content"][0]["type"] == "text"
def test_mcp_endpoint_tool_call_missing_name(self, client):
"""Test MCP endpoint tool call without tool name."""
request_data = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"arguments": {}
},
"id": "test-call-error"
}
response = client.post("/mcp", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == "test-call-error"
assert "error" in data
assert data["error"]["code"] == ErrorCode.INVALID_PARAMS.value
assert "Tool name is required" in data["error"]["message"]
def test_mcp_endpoint_tool_call_nonexistent_tool(self, client):
"""Test MCP endpoint tool call for nonexistent tool."""
request_data = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "nonexistent_tool",
"arguments": {}
},
"id": "test-call-404"
}
response = client.post("/mcp", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == "test-call-404"
assert "error" in data
assert data["error"]["code"] == ErrorCode.METHOD_NOT_FOUND.value
def test_mcp_endpoint_unknown_method(self, client):
"""Test MCP endpoint with unknown method."""
request_data = {
"jsonrpc": "2.0",
"method": "unknown/method",
"id": "test-unknown"
}
response = client.post("/mcp", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == "test-unknown"
assert "error" in data
assert data["error"]["code"] == ErrorCode.METHOD_NOT_FOUND.value
assert "Unknown method: unknown/method" in data["error"]["message"]
class TestServerModels:
"""Test cases for server model classes."""
def test_tool_schema_model(self):
"""Test ToolSchema model."""
schema = ToolSchema(
name="test_tool",
description="Test description",
input_schema={"type": "object"}
)
assert schema.name == "test_tool"
assert schema.description == "Test description"
assert schema.input_schema == {"type": "object"}
def test_tool_response_success(self):
"""Test ToolResponse model for success case."""
response = ToolResponse(
success=True,
data={"result": "success"}
)
assert response.success is True
assert response.data == {"result": "success"}
assert response.error is None
def test_tool_response_error(self):
"""Test ToolResponse model for error case."""
error = MCPError(code=-32602, message="Invalid params")
response = ToolResponse(
success=False,
error=error
)
assert response.success is False
assert response.data is None
assert response.error == error
def test_mcp_server_request(self):
"""Test MCPServerRequest model."""
request = MCPServerRequest(
method="tools/list",
params={"test": "data"},
id="request-1"
)
assert request.jsonrpc == "2.0"
assert request.method == "tools/list"
assert request.params == {"test": "data"}
assert request.id == "request-1"
def test_mcp_server_request_defaults(self):
"""Test MCPServerRequest model with defaults."""
request = MCPServerRequest(method="tools/list")
assert request.jsonrpc == "2.0"
assert request.method == "tools/list"
assert request.params is None
assert request.id is None
def test_mcp_server_response_success(self):
"""Test MCPServerResponse model for success."""
response = MCPServerResponse(
result={"tools": []},
id="request-1"
)
assert response.jsonrpc == "2.0"
assert response.result == {"tools": []}
assert response.error is None
assert response.id == "request-1"
def test_mcp_server_response_error(self):
"""Test MCPServerResponse model for error."""
error = MCPError(code=-32601, message="Method not found")
response = MCPServerResponse(
error=error,
id="request-1"
)
assert response.jsonrpc == "2.0"
assert response.result is None
assert response.error == error
assert response.id == "request-1"