"""ACEManager 统计、评分与导入导出相关逻辑测试."""
import json
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from src.ace_manager import ACEManager
from src.config_manager import ConfigManager
class TestACEManagerStatsAndIO:
"""覆盖 get_strategy_stats / rate_result / export_strategies / import_strategies."""
@pytest.fixture
def config_manager(self, temp_config_dir):
return ConfigManager(config_path=temp_config_dir / ".graphitiace" / "config.json")
def _create_manager(self, config_manager, graphiti_client):
"""创建禁用 ACE 初始化的 ACEManager 以便注入依赖/ace_agent."""
with patch.object(ACEManager, "_initialize_ace", lambda self: None):
manager = ACEManager(config_manager, graphiti_client)
manager.enabled = True
# 为 rate_result 注入一个假的 ace_agent
manager.ace_agent = MagicMock()
manager.ace_agent.ask.return_value = json.dumps(
{
"success": True,
"effective_strategies": [],
"improvements": [],
"lessons": [],
}
)
return manager
def _mock_session(self, graphiti_client):
session = MagicMock()
graphiti_client.driver.session.return_value.__enter__.return_value = session
return session
# ----- get_strategy_stats -----
def test_get_strategy_stats_global(self, config_manager):
"""不指定 tool_name 时返回全局统计,并包含 by_tool 结构."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.return_value.single.return_value = {
"total_strategies": 3,
"avg_success_rate": 0.75,
"total_usage": 30,
"total_success": 24,
"total_failure": 6,
"by_tool": [
{
"tool_name": "search_entities",
"count": 2,
"avg_success_rate": 0.8,
"usage": 20,
"success": 16,
"failure": 4,
},
{
"tool_name": "add_episode",
"count": 1,
"avg_success_rate": 0.6,
"usage": 10,
"success": 8,
"failure": 2,
},
],
}
manager = self._create_manager(config_manager, graphiti_client)
stats = manager.get_strategy_stats()
assert stats is not None
assert stats["total_strategies"] == 3
assert stats["avg_success_rate"] == 0.75
assert stats["total_usage"] == 30
assert stats["total_success"] == 24
assert stats["total_failure"] == 6
assert "by_tool" in stats
assert set(stats["by_tool"].keys()) == {"search_entities", "add_episode"}
def test_get_strategy_stats_single_tool_no_data(self, config_manager):
"""指定 tool_name 且无数据时返回 None."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.return_value.single.return_value = None
manager = self._create_manager(config_manager, graphiti_client)
stats = manager.get_strategy_stats(tool_name="search_entities")
assert stats is None
def test_get_strategy_stats_exception(self, config_manager):
"""查询异常时应返回 None。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.side_effect = Exception("stats error")
manager = self._create_manager(config_manager, graphiti_client)
stats = manager.get_strategy_stats()
assert stats is None
# ----- rate_result + _update_skillbook_with_feedback -----
def test_rate_result_success_high_rating(self, config_manager):
"""高分评分时应返回 True,并调用 _update_skillbook_with_feedback 和 _reload_strategies_async."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "_update_skillbook_with_feedback") as mock_update, patch.object(
manager, "_reload_strategies_async"
) as mock_reload:
ok = manager.rate_result(
tool_name="search_entities",
rating=5,
feedback="非常好",
context={"foo": "bar"},
)
assert ok is True
mock_update.assert_called_once()
mock_reload.assert_called_once()
args, kwargs = mock_update.call_args
assert kwargs["tool_name"] == "search_entities"
assert kwargs["success"] is True
assert 0.8 <= kwargs["weight"] <= 1.0 # rating=5 -> weight=1.0
def test_rate_result_ace_disabled(self, config_manager):
"""ACE 未启用时,rate_result 应直接返回 False."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
with patch.object(ACEManager, "_initialize_ace", lambda self: None):
manager = ACEManager(config_manager, graphiti_client)
manager.enabled = False
manager.ace_agent = None
ok = manager.rate_result(tool_name="search_entities", rating=4)
assert ok is False
def test_rate_result_exception(self, config_manager):
"""ace_agent.ask 异常时返回 False。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
manager.ace_agent.ask.side_effect = Exception("ask failed")
ok = manager.rate_result(tool_name="search_entities", rating=4)
assert ok is False
# ----- export_strategies -----
def test_export_strategies_no_data(self, config_manager, tmp_path):
"""无策略数据时返回 None 并发出 warning."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "query_strategies", return_value=[]):
result = manager.export_strategies(tool_name="search_entities", file_path=str(tmp_path / "out.json"))
assert result is None
def test_export_strategies_write_file(self, config_manager, tmp_path):
"""有策略数据时应写出 JSON 文件并返回文件信息."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
strategies = [
{
"tool_name": "search_entities",
"content": "s1",
"success_rate": 0.8,
"usage_count": 10,
"success_count": 8,
"failure_count": 2,
"arguments_hash": "hash1",
"created_at": "2025-01-01T00:00:00",
"updated_at": "2025-01-02T00:00:00",
}
]
with patch.object(manager, "query_strategies", return_value=strategies):
out_file = tmp_path / "export.json"
result = manager.export_strategies(tool_name="search_entities", file_path=str(out_file))
assert result is not None
assert result["file_path"] == str(out_file)
assert result["count"] == 1
assert result["size"] > 0
assert out_file.exists()
data = json.loads(out_file.read_text(encoding="utf-8"))
assert data["strategies"][0]["tool_name"] == "search_entities"
# ----- import_strategies -----
def test_import_strategies_file_not_exists(self, config_manager):
"""文件不存在时直接返回 None."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
result = manager.import_strategies(file_path="non_exist.json")
assert result is None
def test_import_strategies_invalid_format(self, config_manager, tmp_path):
"""文件格式不合法(无 strategies 字段)时返回 None."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
file_path = tmp_path / "invalid.json"
file_path.write_text(json.dumps({"foo": "bar"}), encoding="utf-8")
manager = self._create_manager(config_manager, graphiti_client)
result = manager.import_strategies(file_path=str(file_path))
assert result is None
def test_import_strategies_success_create_and_overwrite(self, config_manager, tmp_path):
"""正常导入时统计 created/overwritten 计数,并调用 _reload_strategies_async."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
# 第一次 single 返回 None(表示不存在),第二次返回非空(表示已存在)
session.run.return_value.single.side_effect = [None, {"s": {"tool_name": "search_entities"}}]
strategies_payload = {
"strategies": [
{
"tool_name": "search_entities",
"arguments_hash": "hash1",
"content": "s1",
"success_rate": 0.8,
"usage_count": 10,
"success_count": 8,
"failure_count": 2,
},
{
"tool_name": "search_entities",
"arguments_hash": "hash2",
"content": "s2",
"success_rate": 0.6,
"usage_count": 5,
"success_count": 3,
"failure_count": 2,
},
]
}
file_path = tmp_path / "import.json"
file_path.write_text(json.dumps(strategies_payload), encoding="utf-8")
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "_reload_strategies_async") as mock_reload:
result = manager.import_strategies(file_path=str(file_path), overwrite=True)
assert result is not None
assert result["count"] == 2
assert result["created"] >= 1
assert result["overwritten"] >= 1
assert isinstance(result["errors"], list)
mock_reload.assert_called_once()
# ----- toggle_strategy / validate_strategies -----
def test_toggle_strategy_with_hash(self, config_manager):
"""指定 arguments_hash 时仅影响匹配策略。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.return_value.single.return_value = {"count": 2}
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "_reload_strategies_async") as mock_reload:
result = manager.toggle_strategy("search_entities", arguments_hash="hash123", enabled=False)
assert result["count"] == 2
assert result["enabled"] is False
mock_reload.assert_called_once()
def test_toggle_strategy_without_hash(self, config_manager):
"""未提供 hash 时应更新该工具的所有策略。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.return_value.single.return_value = {"count": 5}
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "_reload_strategies_async") as mock_reload:
result = manager.toggle_strategy("search_entities", enabled=True)
assert result["count"] == 5
assert result["enabled"] is True
mock_reload.assert_called_once()
def test_toggle_strategy_not_connected(self, config_manager):
"""未连接 Neo4j 时应直接返回 None。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = False
manager = self._create_manager(config_manager, graphiti_client)
assert manager.toggle_strategy("search_entities") is None
def test_validate_strategies_success_with_issues(self, config_manager):
"""validate_strategies 应返回详细统计。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.return_value.single.return_value = {
"total": 3,
"enabled_count": 2,
"disabled_count": 1,
"missing_success_rate": 1,
"missing_usage_count": 0,
"avg_success_rate": 0.8,
}
manager = self._create_manager(config_manager, graphiti_client)
result = manager.validate_strategies()
assert result["valid"] is False
assert result["total"] == 3
assert len(result["issues"]) == 1
def test_validate_strategies_exception(self, config_manager):
"""异常时返回 error 字段。"""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
session = self._mock_session(graphiti_client)
session.run.side_effect = Exception("validate error")
manager = self._create_manager(config_manager, graphiti_client)
result = manager.validate_strategies()
assert result["valid"] is False
assert "validate error" in result["error"]