"""Tests for the tool catalog."""
import pytest
from src.catalog import ToolCatalog, ToolDefinition, InputSchema, ToolReference
@pytest.fixture
def empty_catalog():
"""Create an empty catalog."""
return ToolCatalog()
@pytest.fixture
def sample_tool():
"""Create a sample tool definition."""
return ToolDefinition(
name="test_tool",
description="A test tool for testing",
input_schema=InputSchema(
properties={"param1": {"type": "string"}},
required=["param1"]
),
tags=["test"]
)
class TestToolCatalog:
"""Tests for ToolCatalog class."""
def test_register_tool(self, empty_catalog, sample_tool):
"""Test registering a tool."""
empty_catalog.register_tool(sample_tool)
assert empty_catalog.count() == 1
assert empty_catalog.get_tool("test_tool") is not None
def test_register_tools_bulk(self, empty_catalog):
"""Test bulk tool registration."""
tools = [
ToolDefinition(
name=f"tool_{i}",
description=f"Test tool {i}",
input_schema=InputSchema()
)
for i in range(5)
]
empty_catalog.register_tools(tools)
assert empty_catalog.count() == 5
def test_remove_tool(self, empty_catalog, sample_tool):
"""Test removing a tool."""
empty_catalog.register_tool(sample_tool)
result = empty_catalog.remove_tool("test_tool")
assert result is True
assert empty_catalog.count() == 0
def test_remove_nonexistent_tool(self, empty_catalog):
"""Test removing a tool that doesn't exist."""
result = empty_catalog.remove_tool("nonexistent")
assert result is False
def test_get_tool(self, empty_catalog, sample_tool):
"""Test getting a tool by name."""
empty_catalog.register_tool(sample_tool)
tool = empty_catalog.get_tool("test_tool")
assert tool is not None
assert tool.name == "test_tool"
assert tool.description == "A test tool for testing"
def test_get_nonexistent_tool(self, empty_catalog):
"""Test getting a tool that doesn't exist."""
tool = empty_catalog.get_tool("nonexistent")
assert tool is None
def test_list_tools(self, empty_catalog):
"""Test listing all tools."""
tools = [
ToolDefinition(name=f"tool_{i}", description=f"Tool {i}", input_schema=InputSchema())
for i in range(3)
]
empty_catalog.register_tools(tools)
listed = empty_catalog.list_tools()
assert len(listed) == 3
def test_get_tool_names(self, empty_catalog):
"""Test getting all tool names."""
tools = [
ToolDefinition(name=f"tool_{i}", description=f"Tool {i}", input_schema=InputSchema())
for i in range(3)
]
empty_catalog.register_tools(tools)
names = empty_catalog.get_tool_names()
assert set(names) == {"tool_0", "tool_1", "tool_2"}
def test_clear(self, empty_catalog):
"""Test clearing all tools."""
tools = [
ToolDefinition(name=f"tool_{i}", description=f"Tool {i}", input_schema=InputSchema())
for i in range(3)
]
empty_catalog.register_tools(tools)
empty_catalog.clear()
assert empty_catalog.count() == 0
def test_update_callback(self, empty_catalog, sample_tool):
"""Test that update callbacks are called."""
callback_called = []
def callback(catalog):
callback_called.append(True)
empty_catalog.on_update(callback)
empty_catalog.register_tool(sample_tool)
assert len(callback_called) == 1
class TestToolDefinition:
"""Tests for ToolDefinition class."""
def test_to_searchable_text(self, sample_tool):
"""Test converting tool to searchable text."""
text = sample_tool.to_searchable_text()
assert "test_tool" in text
assert "test tool for testing" in text
assert "param1" in text
def test_to_api_format(self, sample_tool):
"""Test converting tool to API format."""
api_format = sample_tool.to_api_format()
assert api_format["name"] == "test_tool"
assert api_format["description"] == "A test tool for testing"
assert "input_schema" in api_format
assert api_format["defer_loading"] is True
class TestToolReference:
"""Tests for ToolReference class."""
def test_to_dict(self):
"""Test converting reference to dict."""
ref = ToolReference(tool_name="test_tool")
d = ref.to_dict()
assert d == {"type": "tool_reference", "tool_name": "test_tool"}