"""Tests for sampling tools."""
import sys
import pytest
from fastmcp.client import Client
from fastmcp.client.sampling import RequestContext, SamplingMessage, SamplingParams
sys.path.insert(0, "src")
from mcpbin.app import mcp
async def mock_sampling_handler(
messages: list[SamplingMessage],
params: SamplingParams,
context: RequestContext,
) -> str:
"""Mock sampling handler that echoes back message content."""
# Extract text from messages
texts = []
for msg in messages:
if hasattr(msg.content, "text"):
texts.append(f"{msg.role}: {msg.content.text}")
else:
texts.append(f"{msg.role}: {msg.content}")
# Include system prompt if provided
prefix = ""
if params.systemPrompt:
prefix = f"[system: {params.systemPrompt}] "
return f"{prefix}Echo: {' | '.join(texts)}"
@pytest.fixture
async def sampling_client():
"""Client with sampling handler for testing sampling tools."""
async with Client(transport=mcp, sampling_handler=mock_sampling_handler) as c:
yield c
async def test_sample_echo(sampling_client: Client):
"""Test sample_echo tool sends message and returns response."""
result = await sampling_client.call_tool("sample_echo", {"message": "Hello, how are you?"})
assert "response" in result.data
assert "Hello, how are you?" in result.data["response"]
async def test_sample_echo_empty_message(sampling_client: Client):
"""Test sample_echo with empty message."""
result = await sampling_client.call_tool("sample_echo", {"message": ""})
assert "response" in result.data
async def test_sample_with_system(sampling_client: Client):
"""Test sample_with_system tool includes system prompt."""
result = await sampling_client.call_tool(
"sample_with_system",
{"message": "What is 2+2?", "system_prompt": "You are a math tutor."},
)
assert "response" in result.data
assert "system_prompt" in result.data
assert result.data["system_prompt"] == "You are a math tutor."
# Our mock handler includes system prompt in response
assert "math tutor" in result.data["response"]
async def test_sample_with_system_no_system(sampling_client: Client):
"""Test sample_with_system without system prompt."""
result = await sampling_client.call_tool("sample_with_system", {"message": "Hello"})
assert "response" in result.data
assert result.data["system_prompt"] is None
async def test_sample_with_params_temperature(sampling_client: Client):
"""Test sample_with_params with temperature."""
result = await sampling_client.call_tool(
"sample_with_params", {"message": "Tell me a joke", "temperature": 0.9}
)
assert "response" in result.data
assert result.data["params"]["temperature"] == 0.9
async def test_sample_with_params_max_tokens(sampling_client: Client):
"""Test sample_with_params with max_tokens."""
result = await sampling_client.call_tool(
"sample_with_params", {"message": "Write a story", "max_tokens": 50}
)
assert "response" in result.data
assert result.data["params"]["max_tokens"] == 50
async def test_sample_multi_turn(sampling_client: Client):
"""Test sample_multi_turn with conversation history."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
result = await sampling_client.call_tool("sample_multi_turn", {"messages": messages})
assert "response" in result.data
assert result.data["turn_count"] == 3
async def test_sample_multi_turn_single_message(sampling_client: Client):
"""Test sample_multi_turn with single message."""
messages = [{"role": "user", "content": "Hello"}]
result = await sampling_client.call_tool("sample_multi_turn", {"messages": messages})
assert "response" in result.data
assert result.data["turn_count"] == 1