FastMCP
import json
import logging
from typing import Optional
import pytest
from pydantic import BaseModel
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import ToolManager
class TestAddTools:
def test_basic_function(self):
"""Test registering and running a basic function."""
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
manager = ToolManager()
manager.add_tool(add)
tool = manager.get_tool("add")
assert tool is not None
assert tool.name == "add"
assert tool.description == "Add two numbers."
assert tool.is_async is False
assert tool.parameters["properties"]["a"]["type"] == "integer"
assert tool.parameters["properties"]["b"]["type"] == "integer"
@pytest.mark.anyio
async def test_async_function(self):
"""Test registering and running an async function."""
async def fetch_data(url: str) -> str:
"""Fetch data from URL."""
return f"Data from {url}"
manager = ToolManager()
manager.add_tool(fetch_data)
tool = manager.get_tool("fetch_data")
assert tool is not None
assert tool.name == "fetch_data"
assert tool.description == "Fetch data from URL."
assert tool.is_async is True
assert tool.parameters["properties"]["url"]["type"] == "string"
def test_pydantic_model_function(self):
"""Test registering a function that takes a Pydantic model."""
class UserInput(BaseModel):
name: str
age: int
def create_user(user: UserInput, flag: bool) -> dict:
"""Create a new user."""
return {"id": 1, **user.model_dump()}
manager = ToolManager()
manager.add_tool(create_user)
tool = manager.get_tool("create_user")
assert tool is not None
assert tool.name == "create_user"
assert tool.description == "Create a new user."
assert tool.is_async is False
assert "name" in tool.parameters["$defs"]["UserInput"]["properties"]
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
assert "flag" in tool.parameters["properties"]
def test_add_invalid_tool(self):
manager = ToolManager()
with pytest.raises(AttributeError):
manager.add_tool(1) # type: ignore
def test_add_lambda(self):
manager = ToolManager()
tool = manager.add_tool(lambda x: x, name="my_tool")
assert tool.name == "my_tool"
def test_add_lambda_with_no_name(self):
manager = ToolManager()
with pytest.raises(
ValueError, match="You must provide a name for lambda functions"
):
manager.add_tool(lambda x: x)
def test_warn_on_duplicate_tools(self, caplog):
"""Test warning on duplicate tools."""
def f(x: int) -> int:
return x
manager = ToolManager()
manager.add_tool(f)
with caplog.at_level(logging.WARNING):
manager.add_tool(f)
assert "Tool already exists: f" in caplog.text
def test_disable_warn_on_duplicate_tools(self, caplog):
"""Test disabling warning on duplicate tools."""
def f(x: int) -> int:
return x
manager = ToolManager()
manager.add_tool(f)
manager.warn_on_duplicate_tools = False
with caplog.at_level(logging.WARNING):
manager.add_tool(f)
assert "Tool already exists: f" not in caplog.text
class TestCallTools:
@pytest.mark.anyio
async def test_call_tool(self):
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
manager = ToolManager()
manager.add_tool(add)
result = await manager.call_tool("add", {"a": 1, "b": 2})
assert result == 3
@pytest.mark.anyio
async def test_call_async_tool(self):
async def double(n: int) -> int:
"""Double a number."""
return n * 2
manager = ToolManager()
manager.add_tool(double)
result = await manager.call_tool("double", {"n": 5})
assert result == 10
@pytest.mark.anyio
async def test_call_tool_with_default_args(self):
def add(a: int, b: int = 1) -> int:
"""Add two numbers."""
return a + b
manager = ToolManager()
manager.add_tool(add)
result = await manager.call_tool("add", {"a": 1})
assert result == 2
@pytest.mark.anyio
async def test_call_tool_with_missing_args(self):
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
manager = ToolManager()
manager.add_tool(add)
with pytest.raises(ToolError):
await manager.call_tool("add", {"a": 1})
@pytest.mark.anyio
async def test_call_unknown_tool(self):
manager = ToolManager()
with pytest.raises(ToolError):
await manager.call_tool("unknown", {"a": 1})
@pytest.mark.anyio
async def test_call_tool_with_list_int_input(self):
def sum_vals(vals: list[int]) -> int:
return sum(vals)
manager = ToolManager()
manager.add_tool(sum_vals)
# Try both with plain list and with JSON list
result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"})
assert result == 6
result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]})
assert result == 6
@pytest.mark.anyio
async def test_call_tool_with_list_str_or_str_input(self):
def concat_strs(vals: list[str] | str) -> str:
return vals if isinstance(vals, str) else "".join(vals)
manager = ToolManager()
manager.add_tool(concat_strs)
# Try both with plain python object and with JSON list
result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]})
assert result == "abc"
result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'})
assert result == "abc"
result = await manager.call_tool("concat_strs", {"vals": "a"})
assert result == "a"
result = await manager.call_tool("concat_strs", {"vals": '"a"'})
assert result == '"a"'
@pytest.mark.anyio
async def test_call_tool_with_complex_model(self):
from mcp.server.fastmcp import Context
class MyShrimpTank(BaseModel):
class Shrimp(BaseModel):
name: str
shrimp: list[Shrimp]
x: None
def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]:
return [x.name for x in tank.shrimp]
manager = ToolManager()
manager.add_tool(name_shrimp)
result = await manager.call_tool(
"name_shrimp",
{"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}},
)
assert result == ["rex", "gertrude"]
result = await manager.call_tool(
"name_shrimp",
{"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'},
)
assert result == ["rex", "gertrude"]
class TestToolSchema:
@pytest.mark.anyio
async def test_context_arg_excluded_from_schema(self):
from mcp.server.fastmcp import Context
def something(a: int, ctx: Context) -> int:
return a
manager = ToolManager()
tool = manager.add_tool(something)
assert "ctx" not in json.dumps(tool.parameters)
assert "Context" not in json.dumps(tool.parameters)
assert "ctx" not in tool.fn_metadata.arg_model.model_fields
class TestContextHandling:
"""Test context handling in the tool manager."""
def test_context_parameter_detection(self):
"""Test that context parameters are properly detected in
Tool.from_function()."""
from mcp.server.fastmcp import Context
def tool_with_context(x: int, ctx: Context) -> str:
return str(x)
manager = ToolManager()
tool = manager.add_tool(tool_with_context)
assert tool.context_kwarg == "ctx"
def tool_without_context(x: int) -> str:
return str(x)
tool = manager.add_tool(tool_without_context)
assert tool.context_kwarg is None
@pytest.mark.anyio
async def test_context_injection(self):
"""Test that context is properly injected during tool execution."""
from mcp.server.fastmcp import Context, FastMCP
def tool_with_context(x: int, ctx: Context) -> str:
assert isinstance(ctx, Context)
return str(x)
manager = ToolManager()
manager.add_tool(tool_with_context)
mcp = FastMCP()
ctx = mcp.get_context()
result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx)
assert result == "42"
@pytest.mark.anyio
async def test_context_injection_async(self):
"""Test that context is properly injected in async tools."""
from mcp.server.fastmcp import Context, FastMCP
async def async_tool(x: int, ctx: Context) -> str:
assert isinstance(ctx, Context)
return str(x)
manager = ToolManager()
manager.add_tool(async_tool)
mcp = FastMCP()
ctx = mcp.get_context()
result = await manager.call_tool("async_tool", {"x": 42}, context=ctx)
assert result == "42"
@pytest.mark.anyio
async def test_context_optional(self):
"""Test that context is optional when calling tools."""
from mcp.server.fastmcp import Context
def tool_with_context(x: int, ctx: Optional[Context] = None) -> str:
return str(x)
manager = ToolManager()
manager.add_tool(tool_with_context)
# Should not raise an error when context is not provided
result = await manager.call_tool("tool_with_context", {"x": 42})
assert result == "42"
@pytest.mark.anyio
async def test_context_error_handling(self):
"""Test error handling when context injection fails."""
from mcp.server.fastmcp import Context, FastMCP
def tool_with_context(x: int, ctx: Context) -> str:
raise ValueError("Test error")
manager = ToolManager()
manager.add_tool(tool_with_context)
mcp = FastMCP()
ctx = mcp.get_context()
with pytest.raises(ToolError, match="Error executing tool tool_with_context"):
await manager.call_tool("tool_with_context", {"x": 42}, context=ctx)