FastMCP

import json import logging from typing import Optional import pytest from pydantic import BaseModel from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import ToolManager class TestAddTools: def test_basic_function(self): """Test registering and running a basic function.""" def add(a: int, b: int) -> int: """Add two numbers.""" return a + b manager = ToolManager() manager.add_tool(add) tool = manager.get_tool("add") assert tool is not None assert tool.name == "add" assert tool.description == "Add two numbers." assert tool.is_async is False assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" @pytest.mark.anyio async def test_async_function(self): """Test registering and running an async function.""" async def fetch_data(url: str) -> str: """Fetch data from URL.""" return f"Data from {url}" manager = ToolManager() manager.add_tool(fetch_data) tool = manager.get_tool("fetch_data") assert tool is not None assert tool.name == "fetch_data" assert tool.description == "Fetch data from URL." assert tool.is_async is True assert tool.parameters["properties"]["url"]["type"] == "string" def test_pydantic_model_function(self): """Test registering a function that takes a Pydantic model.""" class UserInput(BaseModel): name: str age: int def create_user(user: UserInput, flag: bool) -> dict: """Create a new user.""" return {"id": 1, **user.model_dump()} manager = ToolManager() manager.add_tool(create_user) tool = manager.get_tool("create_user") assert tool is not None assert tool.name == "create_user" assert tool.description == "Create a new user." assert tool.is_async is False assert "name" in tool.parameters["$defs"]["UserInput"]["properties"] assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] assert "flag" in tool.parameters["properties"] def test_add_invalid_tool(self): manager = ToolManager() with pytest.raises(AttributeError): manager.add_tool(1) # type: ignore def test_add_lambda(self): manager = ToolManager() tool = manager.add_tool(lambda x: x, name="my_tool") assert tool.name == "my_tool" def test_add_lambda_with_no_name(self): manager = ToolManager() with pytest.raises( ValueError, match="You must provide a name for lambda functions" ): manager.add_tool(lambda x: x) def test_warn_on_duplicate_tools(self, caplog): """Test warning on duplicate tools.""" def f(x: int) -> int: return x manager = ToolManager() manager.add_tool(f) with caplog.at_level(logging.WARNING): manager.add_tool(f) assert "Tool already exists: f" in caplog.text def test_disable_warn_on_duplicate_tools(self, caplog): """Test disabling warning on duplicate tools.""" def f(x: int) -> int: return x manager = ToolManager() manager.add_tool(f) manager.warn_on_duplicate_tools = False with caplog.at_level(logging.WARNING): manager.add_tool(f) assert "Tool already exists: f" not in caplog.text class TestCallTools: @pytest.mark.anyio async def test_call_tool(self): def add(a: int, b: int) -> int: """Add two numbers.""" return a + b manager = ToolManager() manager.add_tool(add) result = await manager.call_tool("add", {"a": 1, "b": 2}) assert result == 3 @pytest.mark.anyio async def test_call_async_tool(self): async def double(n: int) -> int: """Double a number.""" return n * 2 manager = ToolManager() manager.add_tool(double) result = await manager.call_tool("double", {"n": 5}) assert result == 10 @pytest.mark.anyio async def test_call_tool_with_default_args(self): def add(a: int, b: int = 1) -> int: """Add two numbers.""" return a + b manager = ToolManager() manager.add_tool(add) result = await manager.call_tool("add", {"a": 1}) assert result == 2 @pytest.mark.anyio async def test_call_tool_with_missing_args(self): def add(a: int, b: int) -> int: """Add two numbers.""" return a + b manager = ToolManager() manager.add_tool(add) with pytest.raises(ToolError): await manager.call_tool("add", {"a": 1}) @pytest.mark.anyio async def test_call_unknown_tool(self): manager = ToolManager() with pytest.raises(ToolError): await manager.call_tool("unknown", {"a": 1}) @pytest.mark.anyio async def test_call_tool_with_list_int_input(self): def sum_vals(vals: list[int]) -> int: return sum(vals) manager = ToolManager() manager.add_tool(sum_vals) # Try both with plain list and with JSON list result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) assert result == 6 result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) assert result == 6 @pytest.mark.anyio async def test_call_tool_with_list_str_or_str_input(self): def concat_strs(vals: list[str] | str) -> str: return vals if isinstance(vals, str) else "".join(vals) manager = ToolManager() manager.add_tool(concat_strs) # Try both with plain python object and with JSON list result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) assert result == "abc" result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) assert result == "abc" result = await manager.call_tool("concat_strs", {"vals": "a"}) assert result == "a" result = await manager.call_tool("concat_strs", {"vals": '"a"'}) assert result == '"a"' @pytest.mark.anyio async def test_call_tool_with_complex_model(self): from mcp.server.fastmcp import Context class MyShrimpTank(BaseModel): class Shrimp(BaseModel): name: str shrimp: list[Shrimp] x: None def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: return [x.name for x in tank.shrimp] manager = ToolManager() manager.add_tool(name_shrimp) result = await manager.call_tool( "name_shrimp", {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, ) assert result == ["rex", "gertrude"] result = await manager.call_tool( "name_shrimp", {"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, ) assert result == ["rex", "gertrude"] class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): from mcp.server.fastmcp import Context def something(a: int, ctx: Context) -> int: return a manager = ToolManager() tool = manager.add_tool(something) assert "ctx" not in json.dumps(tool.parameters) assert "Context" not in json.dumps(tool.parameters) assert "ctx" not in tool.fn_metadata.arg_model.model_fields class TestContextHandling: """Test context handling in the tool manager.""" def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Context) -> str: return str(x) manager = ToolManager() tool = manager.add_tool(tool_with_context) assert tool.context_kwarg == "ctx" def tool_without_context(x: int) -> str: return str(x) tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None @pytest.mark.anyio async def test_context_injection(self): """Test that context is properly injected during tool execution.""" from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) manager = ToolManager() manager.add_tool(tool_with_context) mcp = FastMCP() ctx = mcp.get_context() result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) assert result == "42" @pytest.mark.anyio async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" from mcp.server.fastmcp import Context, FastMCP async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) manager = ToolManager() manager.add_tool(async_tool) mcp = FastMCP() ctx = mcp.get_context() result = await manager.call_tool("async_tool", {"x": 42}, context=ctx) assert result == "42" @pytest.mark.anyio async def test_context_optional(self): """Test that context is optional when calling tools.""" from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Optional[Context] = None) -> str: return str(x) manager = ToolManager() manager.add_tool(tool_with_context) # Should not raise an error when context is not provided result = await manager.call_tool("tool_with_context", {"x": 42}) assert result == "42" @pytest.mark.anyio async def test_context_error_handling(self): """Test error handling when context injection fails.""" from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error") manager = ToolManager() manager.add_tool(tool_with_context) mcp = FastMCP() ctx = mcp.get_context() with pytest.raises(ToolError, match="Error executing tool tool_with_context"): await manager.call_tool("tool_with_context", {"x": 42}, context=ctx)