"""Tests for the agent tool implementation."""
import json
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mcp_claude_code.tools.agent.agent_tool import AgentTool
from mcp_claude_code.tools.common.base import BaseTool
from mcp_claude_code.tools.common.context import DocumentContext
from mcp_claude_code.tools.common.permissions import PermissionManager
class TestAgentTool:
"""Test cases for the AgentTool."""
@pytest.fixture
def document_context(self):
"""Create a test document context."""
return MagicMock(spec=DocumentContext)
@pytest.fixture
def permission_manager(self):
"""Create a test permission manager."""
return MagicMock(spec=PermissionManager)
@pytest.fixture
def mcp_context(self):
"""Create a test MCP context."""
return MagicMock()
@pytest.fixture
def command_executor(self):
"""Create a test command executor."""
return MagicMock()
@pytest.fixture
def agent_tool(self, document_context, permission_manager, command_executor):
"""Create a test agent tool."""
with patch("mcp_claude_code.tools.agent.agent_tool.litellm"):
# Set environment variable for test
os.environ["OPENAI_API_KEY"] = "test_key"
return AgentTool(document_context, permission_manager, command_executor)
@pytest.fixture
def agent_tool_with_params(self, document_context, permission_manager, command_executor):
"""Create a test agent tool with custom parameters."""
with patch("mcp_claude_code.tools.agent.agent_tool.litellm"):
return AgentTool(
document_context=document_context,
permission_manager=permission_manager,
command_executor=command_executor,
model="anthropic/claude-3-sonnet",
api_key="test_anthropic_key",
max_tokens=2000,
max_iterations=40,
max_tool_uses=150
)
@pytest.fixture
def mock_tools(self):
"""Create a list of mock tools."""
tools = []
for name in ["read_files", "search_content", "directory_tree"]:
tool = MagicMock(spec=BaseTool)
tool.name = name
tool.description = f"Description for {name}"
tool.parameters = {"properties": {}, "type": "object"}
tool.required = []
tools.append(tool)
return tools
def test_initialization(self, agent_tool):
"""Test agent tool initialization."""
assert agent_tool.name == "dispatch_agent"
assert "Launch an agent" in agent_tool.description
assert agent_tool.required == ["prompt"]
assert agent_tool.model_override is None
assert agent_tool.api_key_override is None
assert agent_tool.max_tokens_override is None
assert agent_tool.max_iterations == 10
assert agent_tool.max_tool_uses == 30
def test_initialization_with_params(self, agent_tool_with_params):
"""Test agent tool initialization with custom parameters."""
assert agent_tool_with_params.name == "dispatch_agent"
assert agent_tool_with_params.model_override == "anthropic/claude-3-sonnet"
assert agent_tool_with_params.api_key_override == "test_anthropic_key"
assert agent_tool_with_params.max_tokens_override == 2000
assert agent_tool_with_params.max_iterations == 40
assert agent_tool_with_params.max_tool_uses == 150
def test_parameters(self, agent_tool):
"""Test agent tool parameters."""
params = agent_tool.parameters
assert "prompt" in params["properties"]
assert params["properties"]["prompt"]["type"] == "string"
assert "Task for the agent to perform" in params["properties"]["prompt"]["description"]
assert params["required"] == ["prompt"]
def test_model_and_api_key_override(self, document_context, permission_manager, command_executor):
"""Test API key and model override functionality."""
# Test with antropic model and API key
agent_tool = AgentTool(
document_context=document_context,
permission_manager=permission_manager,
command_executor=command_executor,
model="anthropic/claude-3-sonnet",
api_key="test_anthropic_key"
)
assert agent_tool.model_override == "anthropic/claude-3-sonnet"
assert agent_tool.api_key_override == "test_anthropic_key"
# Test with openai model and API key
agent_tool = AgentTool(
document_context=document_context,
permission_manager=permission_manager,
command_executor=command_executor,
model="openai/gpt-4o",
api_key="test_openai_key"
)
assert agent_tool.model_override == "openai/gpt-4o"
assert agent_tool.api_key_override == "test_openai_key"
# Test with no model or API key
agent_tool = AgentTool(
document_context=document_context,
permission_manager=permission_manager,
command_executor=command_executor
)
assert agent_tool.model_override is None
assert agent_tool.api_key_override is None
@pytest.mark.asyncio
async def test_call_no_prompt(self, agent_tool, mcp_context):
"""Test agent tool call with no prompt."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.error = AsyncMock()
tool_ctx.info = AsyncMock()
tool_ctx.set_tool_info = AsyncMock()
with patch("mcp_claude_code.tools.agent.agent_tool.create_tool_context", return_value=tool_ctx):
result = await agent_tool.call(ctx=mcp_context)
assert "Error" in result
assert "prompt" in result
tool_ctx.error.assert_called_once()
@pytest.mark.asyncio
async def test_call_with_litellm_error(self, agent_tool, mcp_context):
"""Test agent tool call when litellm raises an error."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.error = AsyncMock()
tool_ctx.info = AsyncMock()
tool_ctx.set_tool_info = AsyncMock()
tool_ctx.get_tools = AsyncMock(return_value=[])
# Mock litellm to raise an error
with patch("mcp_claude_code.tools.agent.agent_tool.create_tool_context", return_value=tool_ctx):
with patch("mcp_claude_code.tools.agent.agent_tool.litellm.completion", side_effect=RuntimeError("API key error")):
with patch("mcp_claude_code.tools.agent.agent_tool.get_allowed_agent_tools", return_value=[]):
with patch("mcp_claude_code.tools.agent.agent_tool.convert_tools_to_openai_functions", return_value=[]):
with patch("mcp_claude_code.tools.agent.agent_tool.get_system_prompt", return_value="System prompt"):
result = await agent_tool.call(ctx=mcp_context, prompt="Test prompt")
# We're just making sure an error is returned, the actual error message may vary in tests
assert "Error" in result
tool_ctx.error.assert_called()
@pytest.mark.asyncio
async def test_call_with_valid_prompt(self, agent_tool, mcp_context, mock_tools):
"""Test agent tool call with valid prompt string."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.set_tool_info = AsyncMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
tool_ctx.mcp_context = mcp_context
# Mock the _execute_agent method
with patch.object(agent_tool, "_execute_agent", AsyncMock(return_value="Agent result")):
with patch("mcp_claude_code.tools.agent.agent_tool.create_tool_context", return_value=tool_ctx):
result = await agent_tool.call(ctx=mcp_context, prompt="Test prompt")
assert "Agent execution completed" in result
assert "Agent result" in result
tool_ctx.info.assert_called()
@pytest.mark.asyncio
async def test_call_with_empty_prompt(self, agent_tool, mcp_context):
"""Test agent tool call with an empty prompt string."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.set_tool_info = AsyncMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
with patch("mcp_claude_code.tools.agent.agent_tool.create_tool_context", return_value=tool_ctx):
# Test with empty string
result = await agent_tool.call(ctx=mcp_context, prompt="")
assert "Error" in result
assert "must be a non-empty string" in result
tool_ctx.error.assert_called()
@pytest.mark.asyncio
async def test_call_with_invalid_type(self, agent_tool, mcp_context):
"""Test agent tool call with an invalid parameter type."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.set_tool_info = AsyncMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
with patch("mcp_claude_code.tools.agent.agent_tool.create_tool_context", return_value=tool_ctx):
# Test with invalid type (number)
result = await agent_tool.call(ctx=mcp_context, prompt=123)
assert "Error" in result
assert "Parameter 'prompt' must be a string" in result
tool_ctx.error.assert_called()
@pytest.mark.asyncio
async def test_execute_agent_with_tools_simple(self, agent_tool, mcp_context, mock_tools):
"""Test _execute_agent_with_tools with a simple response."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
tool_ctx.mcp_context = mcp_context
# Mock the OpenAI response
mock_message = MagicMock()
mock_message.content = "Simple result"
mock_message.tool_calls = None
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
# Set test mode and mock litellm
os.environ["TEST_MODE"] = "1"
with patch("mcp_claude_code.tools.agent.agent_tool.litellm.completion", return_value=mock_response):
# Execute the method
result = await agent_tool._execute_agent_with_tools(
"System prompt",
"User prompt",
mock_tools,
[], # openai_tools
tool_ctx
)
assert result == "Simple result"
@pytest.mark.asyncio
async def test_execute_agent_with_tools_tool_calls(self, agent_tool, mcp_context, mock_tools):
"""Test _execute_agent_with_tools with tool calls."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
tool_ctx.mcp_context = mcp_context
# Set up one of the mock tools
mock_tool = mock_tools[0]
mock_tool.call = AsyncMock(return_value="Tool result")
# Create a tool call
mock_tool_call = MagicMock()
mock_tool_call.id = "call_123"
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = mock_tool.name
mock_tool_call.function.arguments = json.dumps({"param": "value"})
# First response with tool call
first_message = MagicMock()
first_message.content = None
first_message.tool_calls = [mock_tool_call]
first_choice = MagicMock()
first_choice.message = first_message
first_response = MagicMock()
first_response.choices = [first_choice]
# Second response with final result
second_message = MagicMock()
second_message.content = "Final result"
second_message.tool_calls = None
second_choice = MagicMock()
second_choice.message = second_message
second_response = MagicMock()
second_response.choices = [second_choice]
# Set test mode and mock litellm
os.environ["TEST_MODE"] = "1"
with patch("mcp_claude_code.tools.agent.agent_tool.litellm.completion", side_effect=[first_response, second_response]):
# Mock any complex dictionary or string processing by directly using the expected values in the test
with patch.object(json, "loads", return_value={"param": "value"}):
result = await agent_tool._execute_agent_with_tools(
"System prompt",
"User prompt",
mock_tools,
[{"type": "function", "function": {"name": mock_tool.name}}], # openai_tools
tool_ctx
)
assert result == "Final result"
mock_tool.call.assert_called_once()
@pytest.mark.asyncio
async def test_execute_agent(self, agent_tool, mcp_context, mock_tools):
"""Test the _execute_agent method."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
tool_ctx.mcp_context = mcp_context
# Create test prompt
test_prompt = "Task 1"
# Mock the necessary dependencies
with patch("mcp_claude_code.tools.agent.agent_tool.get_allowed_agent_tools", return_value=mock_tools):
with patch("mcp_claude_code.tools.agent.agent_tool.get_system_prompt", return_value="System prompt"):
with patch("mcp_claude_code.tools.agent.agent_tool.convert_tools_to_openai_functions", return_value=[]):
with patch.object(agent_tool, "_execute_agent_with_tools", AsyncMock(return_value="Result 1")):
result = await agent_tool._execute_agent(test_prompt, tool_ctx)
# Check the result
assert result == "Result 1"
@pytest.mark.asyncio
async def test_execute_agent_with_exception(self, agent_tool, mcp_context, mock_tools):
"""Test the _execute_agent method with an exception."""
# Mock the tool context
tool_ctx = MagicMock()
tool_ctx.info = AsyncMock()
tool_ctx.error = AsyncMock()
tool_ctx.mcp_context = mcp_context
# Create test prompt
test_prompt = "Task that fails"
# Mock the necessary dependencies
with patch("mcp_claude_code.tools.agent.agent_tool.get_allowed_agent_tools", return_value=mock_tools):
with patch("mcp_claude_code.tools.agent.agent_tool.get_system_prompt", return_value="System prompt"):
with patch("mcp_claude_code.tools.agent.agent_tool.convert_tools_to_openai_functions", return_value=[]):
with patch.object(agent_tool, "_execute_agent_with_tools", side_effect=Exception("Task failed")):
result = await agent_tool._execute_agent(test_prompt, tool_ctx)
# Check the result format
assert "Error" in result
assert "Task failed" in result