"""ACEManager 内部辅助方法测试."""
from unittest.mock import MagicMock, patch
import pytest
from src.ace_manager import ACEManager
from src.config_manager import ConfigManager
class TestACEManagerHelpers:
"""覆盖 _build_strategy_filter 与内部 Skillbook 更新逻辑."""
@pytest.fixture
def config_manager(self, temp_config_dir):
return ConfigManager(config_path=temp_config_dir / ".graphitiace" / "config.json")
def _create_manager(self, config_manager, graphiti_client=None):
"""创建禁用初始化的 ACEManager,便于直接测试内部方法."""
if graphiti_client is None:
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
with patch.object(ACEManager, "_initialize_ace", lambda self: None):
manager = ACEManager(config_manager, graphiti_client)
manager.enabled = True
manager.ace_agent = MagicMock()
return manager
# ----- _build_strategy_filter -----
def test_build_strategy_filter_all_conditions(self, config_manager):
"""传入全部过滤条件时应拼接所有 where 片段."""
manager = self._create_manager(config_manager)
where_clause, params = manager._build_strategy_filter(
group_id="group-1",
tool_name="search_entities",
success_rate_min=0.5,
success_rate_max=0.9,
usage_min=10,
usage_max=200,
tags=["ace", "neo4j"],
enabled=False,
latest_only=True,
)
assert "s.group_id = $group_id" in where_clause
assert "coalesce(s.is_latest, true)" in where_clause
assert "s.tool_name = $tool_name" in where_clause
assert "coalesce(s.success_rate, 0.0) >= $success_rate_min" in where_clause
assert "coalesce(s.success_rate, 0.0) <= $success_rate_max" in where_clause
assert "coalesce(s.usage_count, 0) >= $usage_min" in where_clause
assert "coalesce(s.usage_count, 0) <= $usage_max" in where_clause
assert "ANY(tag IN $tags" in where_clause
assert "s.enabled = $enabled_filter" in where_clause
assert params == {
"group_id": "group-1",
"tool_name": "search_entities",
"success_rate_min": pytest.approx(0.5),
"success_rate_max": pytest.approx(0.9),
"usage_min": pytest.approx(10.0),
"usage_max": pytest.approx(200.0),
"tags": ["ace", "neo4j"],
"enabled_filter": False,
}
def test_build_strategy_filter_without_latest_only(self, config_manager):
"""latest_only=False 时不应包含 is_latest 片段,同时仅返回 group_id 参数."""
manager = self._create_manager(config_manager)
where_clause, params = manager._build_strategy_filter(group_id="group-2", latest_only=False)
assert "coalesce(s.is_latest" not in where_clause
assert where_clause == "s.group_id = $group_id"
assert params == {"group_id": "group-2"}
# ----- _update_skillbook / _update_skillbook_with_feedback -----
def test_update_skillbook_success_branch(self, config_manager):
"""成功分支应将 success_increment=1 且 failure_increment=0."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "_create_strategy_version") as mock_create, patch(
"src.ace_manager.hash", return_value=42
):
manager._update_skillbook(
tool_name="search_entities",
arguments={"query": "python"},
result={"success": True},
reflection="ok",
success=True,
)
mock_create.assert_called_once()
kwargs = mock_create.call_args.kwargs
assert kwargs["arguments_hash"] == "42"
assert kwargs["success_increment"] == pytest.approx(1.0)
assert kwargs["failure_increment"] == pytest.approx(0.0)
assert kwargs["usage_increment"] == pytest.approx(1.0)
def test_update_skillbook_failure_branch(self, config_manager):
"""失败分支应将 failure_increment=1 且 success_increment=0."""
graphiti_client = MagicMock()
graphiti_client.is_connected.return_value = True
manager = self._create_manager(config_manager, graphiti_client)
with patch.object(manager, "_create_strategy_version") as mock_create, patch(
"src.ace_manager.hash", return_value=100
):
manager._update_skillbook(
tool_name="search_entities",
arguments={"query": "bad"},
result={"success": False},
reflection="oops",
success=False,
)
kwargs = mock_create.call_args.kwargs
assert kwargs["arguments_hash"] == "100"
assert kwargs["success_increment"] == pytest.approx(0.0)
assert kwargs["failure_increment"] == pytest.approx(1.0)
assert kwargs["usage_increment"] == pytest.approx(1.0)
def test_update_skillbook_with_feedback_success(self, config_manager):
"""加权反馈成功时应按照 weight 更新成功计数."""
manager = self._create_manager(config_manager)
with patch.object(manager, "_create_strategy_version") as mock_create, patch(
"src.ace_manager.hash", return_value=7
):
manager._update_skillbook_with_feedback(
tool_name="search_entities",
arguments={"query": "graph"},
result={"rating": 5},
reflection="great",
success=True,
weight=0.8,
)
kwargs = mock_create.call_args.kwargs
assert kwargs["arguments_hash"] == "7"
assert kwargs["success_increment"] == pytest.approx(0.8)
assert kwargs["failure_increment"] == pytest.approx(0.0)
assert kwargs["usage_increment"] == pytest.approx(0.8)
assert kwargs["rating"] == 5
assert kwargs["feedback_increment"] == pytest.approx(1.0)
def test_update_skillbook_with_feedback_failure(self, config_manager):
"""加权反馈失败时应累计 failure_increment."""
manager = self._create_manager(config_manager)
with patch.object(manager, "_create_strategy_version") as mock_create, patch(
"src.ace_manager.hash", return_value=9
):
manager._update_skillbook_with_feedback(
tool_name="search_entities",
arguments={"query": "graph"},
result={"rating": 2},
reflection="bad",
success=False,
weight=0.4,
)
kwargs = mock_create.call_args.kwargs
assert kwargs["arguments_hash"] == "9"
assert kwargs["success_increment"] == pytest.approx(0.0)
assert kwargs["failure_increment"] == pytest.approx(0.4)
assert kwargs["usage_increment"] == pytest.approx(0.4)
assert kwargs["rating"] == 2
assert kwargs["feedback_increment"] == pytest.approx(1.0)