test_server.py•11.5 kB
"""Tests for MCP server implementation."""
import json
from unittest.mock import AsyncMock, patch
import pytest
from mcp.types import TextContent
from ala_mcp.server import call_tool, list_tools
@pytest.mark.asyncio
async def test_list_tools():
"""Test listing available tools."""
tools = await list_tools()
assert len(tools) == 7
tool_names = [tool.name for tool in tools]
# Check all expected tools are present
assert "search_occurrences" in tool_names
assert "get_occurrence" in tool_names
assert "download_occurrences" in tool_names
assert "count_taxa" in tool_names
assert "create_query_id" in tool_names
assert "get_query" in tool_names
assert "api_request" in tool_names
# Check search_occurrences schema
search_tool = next(t for t in tools if t.name == "search_occurrences")
assert "152+ million" in search_tool.description
schema = search_tool.inputSchema
assert schema["type"] == "object"
assert "query" in schema["properties"]
assert schema["required"] == ["query"]
@pytest.mark.asyncio
async def test_call_tool_unknown():
"""Test calling an unknown tool raises ValueError."""
with pytest.raises(ValueError, match="Unknown tool"):
await call_tool("unknown_tool", {})
@pytest.mark.asyncio
async def test_call_tool_success(sample_api_response):
"""Test successful tool call."""
mock_result = {
"status_code": 200,
"headers": {"content-type": "application/json"},
"url": "https://biocache-ws.ala.org.au/ws/occurrences",
"data": sample_api_response,
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"method": "GET", "endpoint": "/occurrences", "params": {"limit": 10}}
result = await call_tool("api_request", arguments)
assert len(result) == 1
assert isinstance(result[0], TextContent)
# Parse the JSON response
response_data = json.loads(result[0].text)
assert response_data["status_code"] == 200
assert response_data["data"] == sample_api_response
@pytest.mark.asyncio
async def test_call_tool_with_all_parameters(sample_api_response):
"""Test tool call with all parameters."""
mock_result = {
"status_code": 201,
"headers": {"content-type": "application/json"},
"url": "https://biocache-ws.ala.org.au/ws/data",
"data": {"success": True},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {
"method": "POST",
"endpoint": "/data",
"params": {"debug": "true"},
"data": {"key": "value"},
"headers": {"X-Custom": "header"},
}
result = await call_tool("api_request", arguments)
# Verify client was called with correct parameters
mock_client.request.assert_called_once_with(
method="POST",
endpoint="/data",
params={"debug": "true"},
data={"key": "value"},
headers={"X-Custom": "header"},
)
assert len(result) == 1
response_data = json.loads(result[0].text)
assert response_data["status_code"] == 201
@pytest.mark.asyncio
async def test_call_tool_error_handling():
"""Test tool call error handling."""
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(side_effect=Exception("Network error"))
arguments = {"method": "GET", "endpoint": "/occurrences"}
result = await call_tool("api_request", arguments)
assert len(result) == 1
assert isinstance(result[0], TextContent)
# Parse the error response
response_data = json.loads(result[0].text)
assert "error" in response_data
assert "Network error" in response_data["error"]
assert response_data["type"] == "Exception"
@pytest.mark.asyncio
async def test_call_tool_minimal_parameters():
"""Test tool call with only required parameters."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/test",
"data": {},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"method": "GET", "endpoint": "/test"}
result = await call_tool("api_request", arguments)
# Verify defaults were used
mock_client.request.assert_called_once_with(
method="GET", endpoint="/test", params={}, data=None, headers={}
)
assert len(result) == 1
@pytest.mark.asyncio
async def test_search_occurrences_basic():
"""Test search_occurrences with basic query."""
mock_result = {
"status_code": 200,
"headers": {"content-type": "application/json"},
"url": "https://biocache-ws.ala.org.au/ws/occurrences/search",
"data": {"totalRecords": 100, "occurrences": []},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"query": "Eucalyptus", "pageSize": 20}
result = await call_tool("search_occurrences", arguments)
# Verify correct API call
mock_client.request.assert_called_once_with(
method="GET",
endpoint="/occurrences/search",
params={"q": "Eucalyptus", "pageSize": 20, "start": 0},
data=None,
headers={},
)
assert len(result) == 1
response_data = json.loads(result[0].text)
assert response_data["status_code"] == 200
@pytest.mark.asyncio
async def test_search_occurrences_with_filters():
"""Test search_occurrences with filters and spatial search."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/occurrences/search",
"data": {"totalRecords": 50},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {
"query": "*:*",
"filters": ["state:Victoria", "year:[2020 TO 2024]"],
"lat": -37.8136,
"lon": 144.9631,
"radius": 10,
"facets": "basis_of_record",
}
await call_tool("search_occurrences", arguments)
# Verify all parameters were passed
call_args = mock_client.request.call_args
assert call_args[1]["params"]["q"] == "*:*"
assert call_args[1]["params"]["fq"] == ["state:Victoria", "year:[2020 TO 2024]"]
assert call_args[1]["params"]["lat"] == -37.8136
assert call_args[1]["params"]["lon"] == 144.9631
assert call_args[1]["params"]["radius"] == 10
assert call_args[1]["params"]["facets"] == "basis_of_record"
@pytest.mark.asyncio
async def test_get_occurrence():
"""Test get_occurrence tool."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/occurrence/test-uuid",
"data": {"uuid": "test-uuid", "scientificName": "Eucalyptus"},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"uuid": "test-uuid"}
result = await call_tool("get_occurrence", arguments)
mock_client.request.assert_called_once_with(
method="GET", endpoint="/occurrence/test-uuid", params={}, data=None, headers={}
)
assert len(result) == 1
response_data = json.loads(result[0].text)
assert response_data["data"]["uuid"] == "test-uuid"
@pytest.mark.asyncio
async def test_download_occurrences():
"""Test download_occurrences tool."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/occurrences/offline/download",
"data": {"statusUrl": "https://example.com/status/123"},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {
"query": "taxon_name:Eucalyptus",
"email": "test@example.com",
"reasonTypeId": "10",
"fields": "uuid,scientificName",
"filters": ["state:Victoria"],
"mintDoi": True,
}
await call_tool("download_occurrences", arguments)
call_args = mock_client.request.call_args
assert call_args[1]["method"] == "POST"
assert call_args[1]["endpoint"] == "/occurrences/offline/download"
assert call_args[1]["params"]["q"] == "taxon_name:Eucalyptus"
assert call_args[1]["params"]["email"] == "test@example.com"
assert call_args[1]["params"]["reasonTypeId"] == "10"
assert call_args[1]["params"]["fields"] == "uuid,scientificName"
assert call_args[1]["params"]["fq"] == ["state:Victoria"]
assert call_args[1]["params"]["mintDoi"] is True
@pytest.mark.asyncio
async def test_count_taxa():
"""Test count_taxa tool."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/occurrences/taxaCount",
"data": {"count": 42},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"query": "Eucalyptus", "filters": ["state:NSW"]}
await call_tool("count_taxa", arguments)
mock_client.request.assert_called_once_with(
method="GET",
endpoint="/occurrences/taxaCount",
params={"q": "Eucalyptus", "fq": ["state:NSW"]},
data=None,
headers={},
)
@pytest.mark.asyncio
async def test_create_query_id():
"""Test create_query_id tool."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/qid",
"data": {"qid": "12345"},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"query": "*:*", "filters": ["state:Victoria"], "params": {"facets": "year"}}
await call_tool("create_query_id", arguments)
call_args = mock_client.request.call_args
assert call_args[1]["method"] == "POST"
assert call_args[1]["params"]["q"] == "*:*"
assert call_args[1]["params"]["fq"] == ["state:Victoria"]
assert call_args[1]["params"]["facets"] == "year"
@pytest.mark.asyncio
async def test_get_query():
"""Test get_query tool."""
mock_result = {
"status_code": 200,
"headers": {},
"url": "https://biocache-ws.ala.org.au/ws/qid/12345",
"data": {"q": "*:*", "fq": ["state:Victoria"]},
}
with patch("ala_mcp.server.ala_client") as mock_client:
mock_client.request = AsyncMock(return_value=mock_result)
arguments = {"queryId": "12345"}
await call_tool("get_query", arguments)
mock_client.request.assert_called_once_with(
method="GET", endpoint="/qid/12345", params={}, data=None, headers={}
)