"""Tests for MCP tool definitions and serialization."""
import pytest
from schwab_mcp.server import list_tools
class TestToolDefinitions:
"""Tests for tool definitions returned by list_tools()."""
@pytest.mark.asyncio
async def test_all_tools_have_required_fields(self):
"""Verify all tools have required MCP fields."""
tools = await list_tools()
for tool in tools:
assert tool.name, f"Tool missing name"
assert tool.description, f"Tool {tool.name} missing description"
assert tool.inputSchema, f"Tool {tool.name} missing inputSchema"
# Verify inputSchema structure
schema = tool.inputSchema
assert schema.get("type") == "object", f"Tool {tool.name} inputSchema type must be 'object'"
assert "properties" in schema, f"Tool {tool.name} inputSchema missing 'properties'"
@pytest.mark.asyncio
async def test_expected_tools_exist(self):
"""Verify all expected tools are registered."""
tools = await list_tools()
tool_names = {t.name for t in tools}
expected_tools = {
"get_positions",
# NOTE: get_account is NOT exposed for now
"get_quote",
"get_quotes",
# NOTE: get_option_chain and get_price_history are NOT exposed (too large for context)
"get_instruments",
"get_movers",
# SQL query tools (use these instead of get_option_chain/get_price_history)
"load_price_history",
"load_option_chain",
"query_market_data",
"get_data_schema",
}
assert tool_names == expected_tools, f"Missing tools: {expected_tools - tool_names}, Extra tools: {tool_names - expected_tools}"
@pytest.mark.asyncio
async def test_no_tools_have_allowed_callers(self):
"""Verify no tools have allowed_callers (large data tools removed)."""
tools = await list_tools()
for tool in tools:
serialized = tool.model_dump(exclude_none=True)
# No tools should have allowed_callers since we removed the large-output tools
assert "allowed_callers" not in serialized, (
f"Tool {tool.name} should not have allowed_callers"
)
@pytest.mark.asyncio
async def test_tool_serialization_format(self):
"""Verify tool serialization matches expected JSON structure."""
tools = await list_tools()
for tool in tools:
serialized = tool.model_dump(exclude_none=True)
# Check top-level keys
assert "name" in serialized
assert "description" in serialized
assert "inputSchema" in serialized
# Verify name matches
assert serialized["name"] == tool.name
# Verify inputSchema is properly serialized
assert isinstance(serialized["inputSchema"], dict)
assert serialized["inputSchema"]["type"] == "object"
@pytest.mark.asyncio
async def test_load_option_chain_schema(self):
"""Verify load_option_chain has correct input schema."""
tools = await list_tools()
tool = next(t for t in tools if t.name == "load_option_chain")
schema = tool.inputSchema
props = schema["properties"]
# Required fields
assert "symbol" in props
assert schema["required"] == ["symbol"]
# Optional fields
assert "contract_type" in props
assert props["contract_type"]["enum"] == ["CALL", "PUT", "ALL"]
assert "strike_count" in props
assert "from_date" in props
assert "to_date" in props
@pytest.mark.asyncio
async def test_load_price_history_schema(self):
"""Verify load_price_history has correct input schema."""
tools = await list_tools()
tool = next(t for t in tools if t.name == "load_price_history")
schema = tool.inputSchema
props = schema["properties"]
# Required fields
assert "symbol" in props
assert schema["required"] == ["symbol"]
# Optional fields with enums
assert "period_type" in props
assert set(props["period_type"]["enum"]) == {"day", "month", "year", "ytd"}
assert "frequency_type" in props
assert set(props["frequency_type"]["enum"]) == {"minute", "daily", "weekly", "monthly"}
@pytest.mark.asyncio
async def test_get_movers_schema(self):
"""Verify get_movers has correct input schema."""
tools = await list_tools()
tool = next(t for t in tools if t.name == "get_movers")
schema = tool.inputSchema
props = schema["properties"]
# All fields optional
assert schema.get("required", []) == []
# Index options
assert "index" in props
expected_indices = {
"$DJI", "$COMPX", "$SPX", "NYSE", "NASDAQ", "OTCBB",
"INDEX_ALL", "EQUITY_ALL", "OPTION_ALL", "OPTION_PUT", "OPTION_CALL"
}
assert set(props["index"]["enum"]) == expected_indices
# Direction/change options
assert "direction" in props
assert set(props["direction"]["enum"]) == {"up", "down"}
assert "change" in props
assert set(props["change"]["enum"]) == {"percent", "value"}
@pytest.mark.asyncio
async def test_get_instruments_schema(self):
"""Verify get_instruments has correct input schema."""
tools = await list_tools()
tool = next(t for t in tools if t.name == "get_instruments")
schema = tool.inputSchema
props = schema["properties"]
# Required fields
assert "symbol" in props
assert schema["required"] == ["symbol"]
# Projection options
assert "projection" in props
expected_projections = {"symbol-search", "symbol-regex", "desc-search", "desc-regex", "search", "fundamental"}
assert set(props["projection"]["enum"]) == expected_projections
@pytest.mark.asyncio
async def test_account_tools_schema(self):
"""Verify account tools have correct input schema."""
tools = await list_tools()
# get_positions
positions_tool = next(t for t in tools if t.name == "get_positions")
assert "account_id" in positions_tool.inputSchema["properties"]
assert positions_tool.inputSchema.get("required", []) == [] # Optional
# NOTE: get_account is not exposed for now
@pytest.mark.asyncio
async def test_quote_tools_schema(self):
"""Verify quote tools have correct input schema."""
tools = await list_tools()
# get_quote - single symbol
quote_tool = next(t for t in tools if t.name == "get_quote")
assert "symbol" in quote_tool.inputSchema["properties"]
assert quote_tool.inputSchema["required"] == ["symbol"]
# get_quotes - multiple symbols
quotes_tool = next(t for t in tools if t.name == "get_quotes")
assert "symbols" in quotes_tool.inputSchema["properties"]
assert quotes_tool.inputSchema["properties"]["symbols"]["type"] == "array"
assert quotes_tool.inputSchema["required"] == ["symbols"]