"""工具处理函数测试."""
import pytest
import asyncio
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from src.tools import handle_tool_call
from src.config_manager import ConfigManager
from src.graphiti_client import GraphitiClient
class TestToolsHandlers:
"""工具处理函数测试类."""
@pytest.fixture
def config_manager(self, temp_config_dir):
"""创建配置管理器."""
return ConfigManager(config_path=temp_config_dir / ".graphitiace" / "config.json")
@pytest.fixture
def mock_client(self):
"""创建模拟客户端."""
client = MagicMock(spec=GraphitiClient)
client.is_connected.return_value = True
return client
@pytest.mark.asyncio
async def test_handle_configure_neo4j(self, config_manager):
"""测试处理配置 Neo4j."""
result = await handle_tool_call(
tool_name="configure_neo4j",
arguments={
"uri": "bolt://localhost:7687",
"username": "neo4j",
"password": "test"
},
config_manager=config_manager
)
assert len(result) > 0
assert result[0].text is not None
@pytest.mark.asyncio
async def test_handle_configure_api(self, config_manager):
"""测试处理配置 API."""
result = await handle_tool_call(
tool_name="configure_api",
arguments={
"provider": "openai",
"api_key": "test_key"
},
config_manager=config_manager
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_configure_api_no_key(self, config_manager):
"""测试处理配置 API(无 key)."""
result = await handle_tool_call(
tool_name="configure_api",
arguments={
"provider": "openai"
},
config_manager=config_manager
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_add_episode(self, config_manager, mock_client):
"""测试处理添加 Episode."""
mock_client.add_episode = AsyncMock(return_value={
"success": True,
"message": "Episode 已添加",
"episode_id": 1
})
result = await handle_tool_call(
tool_name="add_episode",
arguments={
"content": "Test episode",
"metadata": {"type": "test"}
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
mock_client.add_episode.assert_called_once()
@pytest.mark.asyncio
async def test_handle_add_episode_no_client(self, config_manager):
"""测试处理添加 Episode(无客户端)."""
result = await handle_tool_call(
tool_name="add_episode",
arguments={
"content": "Test episode"
},
config_manager=config_manager
)
assert len(result) > 0
assert "未初始化" in result[0].text or "错误" in result[0].text
@pytest.mark.asyncio
async def test_handle_search_entities(self, config_manager, mock_client):
"""测试处理搜索实体."""
mock_client.search_entities.return_value = {
"success": True,
"message": "找到 1 个结果",
"results": [{"name": "Test"}]
}
result = await handle_tool_call(
tool_name="search_entities",
arguments={
"query": "test",
"limit": 10
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
mock_client.search_entities.assert_called_once()
@pytest.mark.asyncio
async def test_handle_search_relationships(self, config_manager, mock_client):
"""测试处理搜索关系."""
mock_client.search_relationships.return_value = {
"success": True,
"message": "找到 1 个关系",
"results": []
}
result = await handle_tool_call(
tool_name="search_relationships",
arguments={
"query": "test"
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_query_knowledge_graph(self, config_manager, mock_client):
"""测试处理查询知识图谱."""
mock_client.query_knowledge_graph.return_value = {
"success": True,
"message": "查询成功",
"results": []
}
result = await handle_tool_call(
tool_name="query_knowledge_graph",
arguments={
"query": "MATCH (n) RETURN n"
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_delete_episode(self, config_manager, mock_client):
"""测试处理删除 Episode."""
mock_client.delete_episode.return_value = {
"success": True,
"message": "已删除"
}
result = await handle_tool_call(
tool_name="delete_episode",
arguments={
"episode_id": 1
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_clear_graph(self, config_manager, mock_client):
"""测试处理清空图谱."""
mock_client.clear_graph.return_value = {
"success": True,
"message": "已清空"
}
result = await handle_tool_call(
tool_name="clear_graph",
arguments={
"confirm": True
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_query_by_time_range(self, config_manager, mock_client):
"""测试处理时间范围查询."""
mock_client.query_by_time_range.return_value = {
"success": True,
"message": "找到结果",
"results": []
}
result = await handle_tool_call(
tool_name="query_by_time_range",
arguments={
"days": 7
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_semantic_search(self, config_manager, mock_client):
"""测试处理语义搜索."""
mock_client.semantic_search = AsyncMock(return_value={
"success": True,
"message": "找到结果",
"results": [],
"search_type": "semantic"
})
result = await handle_tool_call(
tool_name="semantic_search",
arguments={
"query": "test"
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_add_episodes_bulk(self, config_manager, mock_client):
"""测试处理批量添加 Episode."""
mock_client.add_episodes_bulk = AsyncMock(return_value={
"success": True,
"message": "批量添加完成",
"total": 2,
"success_count": 2,
"fail_count": 0,
"results": []
})
result = await handle_tool_call(
tool_name="add_episodes_bulk",
arguments={
"episodes": [
{"content": "Test 1"},
{"content": "Test 2"}
]
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_export_graph_data(self, config_manager, mock_client):
"""测试处理导出图谱数据."""
mock_client.export_graph_data.return_value = {
"success": True,
"message": "导出成功",
"data": {"nodes": [], "relationships": []}
}
result = await handle_tool_call(
tool_name="export_graph_data",
arguments={
"format": "json"
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_get_statistics(self, config_manager, mock_client):
"""测试处理获取统计信息."""
mock_client.get_statistics.return_value = {
"success": True,
"message": "统计信息获取成功",
"statistics": {
"nodes": {"total": 10},
"relationships": {"total": 5},
"episodes": {"total": 3}
}
}
result = await handle_tool_call(
tool_name="get_statistics",
arguments={},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_import_graph_data(self, config_manager, mock_client):
"""测试处理导入图谱数据."""
mock_client.import_graph_data.return_value = {
"success": True,
"message": "导入成功",
"imported_nodes": 5,
"imported_relationships": 3
}
result = await handle_tool_call(
tool_name="import_graph_data",
arguments={
"data": {"nodes": [], "relationships": []},
"format": "json"
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_validate_data(self, config_manager, mock_client):
"""测试处理数据验证."""
mock_client.validate_data.return_value = {
"success": True,
"message": "验证完成",
"issues": [],
"issue_count": 0
}
result = await handle_tool_call(
tool_name="validate_data",
arguments={
"check_orphaned": True,
"check_duplicates": True,
"check_integrity": True
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_clean_orphaned_nodes(self, config_manager, mock_client):
"""测试处理清理孤立节点."""
mock_client.clean_orphaned_nodes.return_value = {
"success": True,
"message": "已清理",
"deleted_count": 5
}
result = await handle_tool_call(
tool_name="clean_orphaned_nodes",
arguments={
"confirm": True
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_clean_orphaned_nodes_no_confirm(self, config_manager, mock_client):
"""测试处理清理孤立节点(未确认)."""
result = await handle_tool_call(
tool_name="clean_orphaned_nodes",
arguments={
"confirm": False
},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
assert "确认" in result[0].text or "confirm" in result[0].text.lower()
@pytest.mark.asyncio
async def test_handle_rebuild_indexes(self, config_manager, mock_client):
"""测试处理重建索引."""
mock_client.rebuild_indexes.return_value = {
"success": True,
"message": "索引检查完成",
"indexes": ["index1", "index2"],
"index_count": 2
}
result = await handle_tool_call(
tool_name="rebuild_indexes",
arguments={},
config_manager=config_manager,
graphiti_client=mock_client
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_health_check(self, config_manager):
"""测试处理健康检查."""
from src.health_check import health_check
with patch('src.health_check.health_check') as mock_health_check:
mock_health_check.return_value = {
"status": "healthy",
"checks": {}
}
result = await handle_tool_call(
tool_name="health_check",
arguments={},
config_manager=config_manager
)
assert len(result) > 0
@pytest.mark.asyncio
async def test_handle_invalid_tool(self, config_manager):
"""测试处理无效工具."""
result = await handle_tool_call(
tool_name="invalid_tool",
arguments={},
config_manager=config_manager
)
assert len(result) > 0
assert "未知工具" in result[0].text or "unknown" in result[0].text.lower()
@pytest.mark.asyncio
async def test_handle_tool_with_invalid_args(self, config_manager):
"""测试处理工具时参数错误."""
try:
result = await handle_tool_call(
tool_name="configure_neo4j",
arguments={}, # 缺少必需参数
config_manager=config_manager
)
# 应该返回错误信息
assert len(result) > 0
except Exception:
# 或者抛出异常
pass