"""工具综合测试."""
import pytest
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from src.tools import get_tools, handle_tool_call
from src.config_manager import ConfigManager
from src.graphiti_client import GraphitiClient
class TestToolsComprehensive:
"""工具综合测试类."""
@pytest.fixture
def config_manager(self, temp_config_dir):
"""创建配置管理器."""
return ConfigManager(config_path=temp_config_dir / ".graphitiace" / "config.json")
@pytest.fixture
def mock_client(self):
"""创建模拟客户端."""
client = MagicMock(spec=GraphitiClient)
client.is_connected.return_value = True
return client
def test_get_tools_returns_list(self, config_manager):
"""测试 get_tools 返回工具列表."""
tools = get_tools(config_manager)
assert isinstance(tools, list)
assert len(tools) > 0
def test_get_tools_contains_configure_tools(self, config_manager):
"""测试工具列表包含配置工具."""
tools = get_tools(config_manager)
tool_names = [tool.name for tool in tools]
assert "configure_neo4j" in tool_names
assert "configure_api" in tool_names
assert "check_configuration" in tool_names
assert "reset_configuration" in tool_names
def test_get_tools_contains_episode_tools(self, config_manager):
"""测试工具列表包含 Episode 工具."""
tools = get_tools(config_manager)
tool_names = [tool.name for tool in tools]
assert "add_episode" in tool_names
assert "add_episodes_bulk" in tool_names
assert "delete_episode" in tool_names
def test_get_tools_contains_search_tools(self, config_manager):
"""测试工具列表包含搜索工具."""
tools = get_tools(config_manager)
tool_names = [tool.name for tool in tools]
assert "search_entities" in tool_names
assert "search_relationships" in tool_names
assert "semantic_search" in tool_names
assert "query_knowledge_graph" in tool_names
def test_get_tools_contains_utility_tools(self, config_manager):
"""测试工具列表包含实用工具."""
tools = get_tools(config_manager)
tool_names = [tool.name for tool in tools]
assert "export_graph_data" in tool_names
assert "import_graph_data" in tool_names
assert "get_statistics" in tool_names
assert "validate_data" in tool_names
assert "clean_orphaned_nodes" in tool_names
assert "rebuild_indexes" in tool_names
assert "get_cache_stats" in tool_names
@pytest.mark.asyncio
async def test_handle_tool_call_with_missing_client(self, config_manager):
"""测试缺少客户端时的工具调用."""
result = await handle_tool_call(
tool_name="add_episode",
arguments={"content": "Test"},
config_manager=config_manager
)
assert result is not None
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_tool_call_with_client(self, config_manager, mock_client):
"""测试有客户端时的工具调用."""
mock_client.add_episode = AsyncMock(return_value={
"success": True,
"message": "Episode 已添加",
"episode_id": 1
})
result = await handle_tool_call(
tool_name="add_episode",
arguments={"content": "Test"},
config_manager=config_manager,
graphiti_client=mock_client
)
assert result is not None
assert len(result) > 0
mock_client.add_episode.assert_called_once()
@pytest.mark.asyncio
async def test_handle_tool_call_error_handling(self, config_manager, mock_client):
"""测试工具调用错误处理."""
mock_client.add_episode = AsyncMock(side_effect=Exception("Test error"))
result = await handle_tool_call(
tool_name="add_episode",
arguments={"content": "Test"},
config_manager=config_manager,
graphiti_client=mock_client
)
assert result is not None
assert len(result) > 0
@pytest.mark.asyncio
async def test_all_configuration_tools(self, config_manager):
"""测试所有配置工具."""
config_tools = [
"configure_neo4j",
"configure_api",
"check_configuration",
"reset_configuration",
"set_group_id",
]
for tool_name in config_tools:
result = await handle_tool_call(
tool_name=tool_name,
arguments={},
config_manager=config_manager
)
assert result is not None
@pytest.mark.asyncio
async def test_all_episode_tools(self, config_manager, mock_client):
"""测试所有 Episode 工具."""
mock_client.add_episode = AsyncMock(return_value={"success": True})
mock_client.add_episodes_bulk = AsyncMock(return_value={"success": True, "total": 1})
mock_client.delete_episode.return_value = {"success": True}
episode_tools = [
("add_episode", {"content": "Test"}),
("add_episodes_bulk", {"episodes": [{"content": "Test"}]}),
("delete_episode", {"episode_id": 1}),
]
for tool_name, args in episode_tools:
result = await handle_tool_call(
tool_name=tool_name,
arguments=args,
config_manager=config_manager,
graphiti_client=mock_client
)
assert result is not None
@pytest.mark.asyncio
async def test_all_search_tools(self, config_manager, mock_client):
"""测试所有搜索工具."""
mock_client.search_entities.return_value = {"success": True, "results": []}
mock_client.search_relationships.return_value = {"success": True, "results": []}
mock_client.semantic_search = AsyncMock(return_value={"success": True, "results": []})
mock_client.query_knowledge_graph.return_value = {"success": True, "results": []}
search_tools = [
("search_entities", {"query": "test"}),
("search_relationships", {"query": "test"}),
("semantic_search", {"query": "test"}),
("query_knowledge_graph", {"query": "MATCH (n) RETURN n"}),
]
for tool_name, args in search_tools:
result = await handle_tool_call(
tool_name=tool_name,
arguments=args,
config_manager=config_manager,
graphiti_client=mock_client
)
assert result is not None