"""ACE (Agentic Context Engineering) 管理器 - 集成到 MCP Server."""
from typing import Dict, Any, Optional, List
import json
import uuid
from datetime import datetime
from .config_manager import ConfigManager
from .logger import default_logger as logger
class ACEManager:
"""ACE 管理器,集成到 MCP Server 中学习工具调用策略."""
def __init__(self, config_manager: ConfigManager, graphiti_client: Optional[Any] = None):
"""
初始化 ACE 管理器.
Args:
config_manager: 配置管理器
graphiti_client: Graphiti 客户端(可选,用于存储 Skillbook)
"""
self.config_manager = config_manager
self.graphiti_client = graphiti_client
self.ace_agent = None
self.enabled = False
self._cached_strategies = [] # 缓存的策略列表
self._last_api_config: Optional[Dict[str, Any]] = None
self._initialize_ace()
def _initialize_ace(self):
"""初始化 ACE Agent."""
try:
import sys
import os
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
# 记录当前 API 配置指纹
api_config = self.config_manager.get_api_config()
self._last_api_config = api_config.model_dump() if api_config else None
# 在 MCP 模式下,需要捕获 ACE 框架的日志输出
# 因为 MCP 协议要求所有输出都必须是有效的 JSON-RPC 格式
_is_mcp_mode = not sys.stdin.isatty() if hasattr(sys.stdin, 'isatty') else True
if _is_mcp_mode:
# MCP 模式:重定向 ACE 的日志输出,避免干扰 JSON-RPC 协议
stderr_capture = StringIO()
stdout_capture = StringIO()
with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture):
from ace import ACELiteLLM
api_config = self.config_manager.get_api_config()
# 如果配置了 API key,使用它;否则使用默认模型
if api_config and api_config.api_key:
# 设置环境变量供 ACE 使用
if api_config.provider == "openai":
os.environ["OPENAI_API_KEY"] = api_config.api_key
elif api_config.provider == "anthropic":
os.environ["ANTHROPIC_API_KEY"] = api_config.api_key
model = api_config.model or "gpt-4o-mini"
else:
# 使用默认模型(ACE 会使用 LiteLLM 的默认配置)
model = "gpt-4o-mini"
# ACE 0.6.0+: 支持去重配置
try:
from ace import DeduplicationConfig
dedup_config = DeduplicationConfig(similarity_threshold=0.85)
self.ace_agent = ACELiteLLM(model=model, dedup_config=dedup_config)
except ImportError:
# 旧版本 ACE 不支持 DeduplicationConfig
self.ace_agent = ACELiteLLM(model=model)
else:
# 非 MCP 模式:正常初始化
from ace import ACELiteLLM
api_config = self.config_manager.get_api_config()
# 如果配置了 API key,使用它;否则使用默认模型
if api_config and api_config.api_key:
# 设置环境变量供 ACE 使用
if api_config.provider == "openai":
os.environ["OPENAI_API_KEY"] = api_config.api_key
elif api_config.provider == "anthropic":
os.environ["ANTHROPIC_API_KEY"] = api_config.api_key
model = api_config.model or "gpt-4o-mini"
else:
# 使用默认模型(ACE 会使用 LiteLLM 的默认配置)
model = "gpt-4o-mini"
self.ace_agent = ACELiteLLM(model=model)
self.enabled = True
# 从 Neo4j 加载 Skillbook
self._load_skillbook()
logger.info("ACE Manager 初始化成功")
except ImportError:
logger.warning("ace-framework 未安装,ACE 功能将被禁用")
self.enabled = False
except Exception as e:
logger.error(f"初始化 ACE Manager 失败: {e}", exc_info=True)
self.enabled = False
def check_reload(self):
"""检查配置是否变更,如果变更则重新初始化."""
api_config = self.config_manager.get_api_config()
current_config_dict = api_config.model_dump() if api_config else None
if current_config_dict != self._last_api_config:
logger.info("检测到 API 配置变更,重新初始化 ACE Manager...")
self._initialize_ace()
def is_enabled(self) -> bool:
"""检查 ACE 是否已启用."""
return self.enabled and self.ace_agent is not None
def generate_tool_strategy(
self,
user_query: str,
tool_name: str,
context: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
使用 ACE Generator 生成工具调用策略.
Args:
user_query: 用户查询
tool_name: 工具名称
context: 上下文信息
Returns:
优化后的策略,如果失败返回 None
"""
if not self.is_enabled():
return None
try:
# 获取相关策略
relevant_strategies = self._get_relevant_strategies(tool_name)
# 构建提示,包含上下文、可用工具和历史策略
strategies_text = ""
if relevant_strategies:
strategies_text = "\n\n历史成功策略:\n"
for i, strategy in enumerate(relevant_strategies[:5], 1): # 只取前5个
strategies_text += f"{i}. 工具: {strategy.get('tool_name', 'N/A')}\n"
strategies_text += f" 成功率: {strategy.get('success_rate', 0):.2%}\n"
strategies_text += f" 使用次数: {strategy.get('usage_count', 0)}\n"
strategies_text += f" 策略内容: {strategy.get('content', '')[:200]}...\n\n"
else:
strategies_text = "\n\n注意:暂无相关历史策略,将基于通用经验进行优化。\n"
prompt = f"""
根据 Skillbook 中的策略,优化以下工具调用:
用户查询:{user_query}
工具名称:{tool_name}
当前参数:{json.dumps(context, ensure_ascii=False, indent=2)}
{strategies_text}
请基于历史经验,提供优化建议:
1. 是否需要调整参数?
2. 是否需要先调用其他工具?
3. 是否有更好的工具选择?
只返回 JSON 格式的建议,格式:
{{
"optimized_arguments": {{...}},
"suggested_tools": ["tool1", "tool2"],
"reasoning": "优化原因"
}}
"""
# 在 MCP 模式下重定向日志
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
_is_mcp_mode = not sys.stdin.isatty() if hasattr(sys.stdin, 'isatty') else True
if _is_mcp_mode:
stderr_capture = StringIO()
stdout_capture = StringIO()
with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture):
response = self.ace_agent.ask(prompt)
else:
response = self.ace_agent.ask(prompt)
# 尝试解析 JSON 响应
try:
strategy = json.loads(response)
return strategy
except json.JSONDecodeError:
# 如果不是 JSON,尝试提取 JSON 部分
import re
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
strategy = json.loads(json_match.group())
return strategy
else:
logger.warning(f"无法解析 ACE 响应为 JSON: {response}")
return None
except Exception as e:
logger.error(f"生成工具策略失败: {e}", exc_info=True)
return None
def _get_relevant_strategies(self, tool_name: str, limit: int = 5) -> List[Dict[str, Any]]:
"""
获取与指定工具相关的策略.
Args:
tool_name: 工具名称
limit: 返回策略数量限制
Returns:
相关策略列表
"""
if not hasattr(self, '_cached_strategies') or not self._cached_strategies:
return []
# 筛选相关策略(匹配工具名称)
relevant = [s for s in self._cached_strategies if s.get('tool_name') == tool_name]
# 按成功率和使用次数排序
relevant.sort(key=lambda x: (x.get('success_rate', 0), x.get('usage_count', 0)), reverse=True)
return relevant[:limit]
def reflect_on_result(
self,
tool_name: str,
arguments: Dict[str, Any],
result: Dict[str, Any],
user_feedback: Optional[bool] = None
):
"""
使用 ACE Reflector 分析工具执行结果.
Args:
tool_name: 工具名称
arguments: 工具参数
result: 执行结果
user_feedback: 用户反馈(可选)
"""
if not self.is_enabled():
return
try:
# 判断成功/失败
success = result.get('success', False) or '✅' in str(result.get('message', ''))
# 构建反思提示
prompt = f"""
分析以下工具调用的执行结果:
工具:{tool_name}
参数:{json.dumps(arguments, ensure_ascii=False, indent=2)}
结果:{json.dumps(result, ensure_ascii=False, indent=2)}
成功:{success}
用户反馈:{user_feedback if user_feedback is not None else '未知'}
请分析:
1. 这次执行是否成功?
2. 哪些策略有效?
3. 哪些策略需要改进?
4. 可以提取什么经验教训?
返回 JSON 格式的反思结果,包含:
- "success": 是否成功
- "effective_strategies": 有效的策略
- "improvements": 需要改进的地方
- "lessons": 经验教训
"""
# 在 MCP 模式下重定向日志
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
_is_mcp_mode = not sys.stdin.isatty() if hasattr(sys.stdin, 'isatty') else True
if _is_mcp_mode:
stderr_capture = StringIO()
stdout_capture = StringIO()
with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture):
reflection = self.ace_agent.ask(prompt)
else:
reflection = self.ace_agent.ask(prompt)
# 更新 Skillbook
self._update_skillbook(tool_name, arguments, result, reflection, success)
# 重新加载策略缓存(异步,不阻塞)
self._reload_strategies_async()
except Exception as e:
error_msg = f"反思工具执行结果失败: {e}"
logger.error(error_msg, exc_info=True)
# 记录详细的错误信息以便调试
logger.debug(f"工具: {tool_name}, 成功: {result.get('success', False)}")
# 不抛出异常,避免影响主流程
def _reload_strategies_async(self):
"""异步重新加载策略缓存."""
try:
# 简单重新加载,不阻塞
self._load_skillbook()
except Exception as e:
logger.debug(f"异步重新加载策略失败: {e}")
# ===== 批量策略管理 =====
def _build_strategy_filter(
self,
group_id: str,
tool_name: Optional[str] = None,
success_rate_min: Optional[float] = None,
success_rate_max: Optional[float] = None,
usage_min: Optional[float] = None,
usage_max: Optional[float] = None,
tags: Optional[List[str]] = None,
enabled: Optional[bool] = None,
latest_only: bool = True,
):
"""
构建批量操作的过滤条件片段和参数.
"""
conditions = ["s.group_id = $group_id"]
params: Dict[str, Any] = {"group_id": group_id}
if latest_only:
conditions.append("coalesce(s.is_latest, true)")
if tool_name:
conditions.append("s.tool_name = $tool_name")
params["tool_name"] = tool_name
if success_rate_min is not None:
conditions.append("coalesce(s.success_rate, 0.0) >= $success_rate_min")
params["success_rate_min"] = float(success_rate_min)
if success_rate_max is not None:
conditions.append("coalesce(s.success_rate, 0.0) <= $success_rate_max")
params["success_rate_max"] = float(success_rate_max)
if usage_min is not None:
conditions.append("coalesce(s.usage_count, 0) >= $usage_min")
params["usage_min"] = float(usage_min)
if usage_max is not None:
conditions.append("coalesce(s.usage_count, 0) <= $usage_max")
params["usage_max"] = float(usage_max)
if tags:
# 只要有一个标签命中即可
conditions.append("ANY(tag IN $tags WHERE tag IN coalesce(s.tags, []))")
params["tags"] = tags
if enabled is not None:
conditions.append("s.enabled = $enabled_filter")
params["enabled_filter"] = enabled
return " AND ".join(conditions), params
def _create_strategy_version(
self,
tool_name: str,
arguments_hash: str,
reflection: str,
success_increment: float,
failure_increment: float,
usage_increment: float,
rating: Optional[float] = None,
feedback_increment: float = 0.0
):
"""
为指定策略创建一个新的版本节点,保留历史记录.
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return
try:
group_id = self.config_manager.get_group_id()
now = datetime.utcnow()
success_increment = float(success_increment or 0.0)
failure_increment = float(failure_increment or 0.0)
usage_increment = float(
usage_increment if usage_increment and usage_increment > 0 else success_increment + failure_increment
)
feedback_increment = float(feedback_increment or 0.0)
rating_sum_increment = (rating or 0.0) * feedback_increment if rating is not None else 0.0
with self.graphiti_client.driver.session() as session:
prev_record = session.run(
"""
OPTIONAL MATCH (s:Strategy {group_id: $group_id, tool_name: $tool_name, arguments_hash: $arguments_hash})
WHERE coalesce(s.is_latest, true)
RETURN s
ORDER BY coalesce(s.version, 0) DESC, s.updated_at DESC
LIMIT 1
""",
group_id=group_id,
tool_name=tool_name,
arguments_hash=arguments_hash
).single()
prev_node = prev_record["s"] if prev_record else None
prev_data: Dict[str, Any] = {}
prev_id = None
if prev_node:
prev_id = prev_node.id
prev_data = dict(prev_node)
if not prev_data.get("strategy_uuid"):
prev_uuid = str(uuid.uuid4())
session.run(
"MATCH (s:Strategy) WHERE id(s)=$id SET s.strategy_uuid=$uuid",
id=prev_id,
uuid=prev_uuid
)
prev_data["strategy_uuid"] = prev_uuid
prev_success = float(prev_data.get("success_count", 0) or 0.0)
prev_failure = float(prev_data.get("failure_count", 0) or 0.0)
prev_usage = float(prev_data.get("usage_count", 0) or 0.0)
prev_feedback_count = float(prev_data.get("user_feedback_count", 0) or 0.0)
new_success = prev_success + success_increment
new_failure = prev_failure + failure_increment
new_usage = prev_usage + usage_increment
success_rate = (new_success / new_usage) if new_usage > 0 else 0.0
new_feedback_count = prev_feedback_count + feedback_increment
if feedback_increment > 0 and rating is not None:
total_rating = (prev_data.get("avg_rating", 0) or 0.0) * prev_feedback_count + rating * feedback_increment
avg_rating = (total_rating / new_feedback_count) if new_feedback_count > 0 else rating
else:
avg_rating = prev_data.get("avg_rating")
params = {
"strategy_uuid": str(uuid.uuid4()),
"parent_id": prev_data.get("strategy_uuid"),
"version": int((prev_data.get("version") or 0) + 1),
"group_id": group_id,
"tool_name": tool_name,
"arguments_hash": arguments_hash,
"content": reflection,
"success_count": new_success,
"failure_count": new_failure,
"usage_count": new_usage,
"success_rate": success_rate,
"created_at": now,
"updated_at": now,
"enabled": prev_data.get("enabled", True),
"is_latest": True,
"user_feedback_count": new_feedback_count if new_feedback_count > 0 else 0.0,
"avg_rating": avg_rating,
}
session.run(
"""
CREATE (s:Strategy {
strategy_uuid: $strategy_uuid,
parent_id: $parent_id,
version: $version,
group_id: $group_id,
tool_name: $tool_name,
arguments_hash: $arguments_hash,
content: $content,
success_count: $success_count,
failure_count: $failure_count,
usage_count: $usage_count,
success_rate: $success_rate,
created_at: $created_at,
updated_at: $updated_at,
enabled: $enabled,
is_latest: $is_latest,
user_feedback_count: $user_feedback_count,
avg_rating: $avg_rating
})
""",
params
)
if prev_id is not None:
session.run(
"MATCH (s:Strategy) WHERE id(s)=$id SET s.is_latest = false",
id=prev_id
)
self._record_strategy_trend(
tool_name=tool_name,
arguments_hash=arguments_hash,
success_increment=success_increment,
failure_increment=failure_increment,
usage_increment=usage_increment,
feedback_increment=feedback_increment,
rating_sum_increment=rating_sum_increment,
timestamp=now
)
except Exception as e:
logger.error(f"创建策略版本失败: {e}", exc_info=True)
def _record_strategy_trend(
self,
tool_name: str,
arguments_hash: str,
success_increment: float,
failure_increment: float,
usage_increment: float,
feedback_increment: float,
rating_sum_increment: float,
timestamp: datetime
):
"""记录策略的学习趋势."""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return
try:
group_id = self.config_manager.get_group_id()
trend_date = timestamp.date()
params = {
"group_id": group_id,
"tool_name": tool_name,
"arguments_hash": arguments_hash,
"date": trend_date,
"success_increment": success_increment,
"failure_increment": failure_increment,
"usage_increment": usage_increment,
"feedback_increment": feedback_increment,
"rating_sum_increment": rating_sum_increment,
}
query = """
MERGE (t:StrategyTrend {
group_id: $group_id,
tool_name: $tool_name,
arguments_hash: $arguments_hash,
date: $date
})
ON CREATE SET
t.success_count = $success_increment,
t.failure_count = $failure_increment,
t.usage_count = $usage_increment,
t.feedback_count = $feedback_increment,
t.rating_sum = $rating_sum_increment,
t.sample_count = 1,
t.created_at = datetime(),
t.updated_at = datetime()
ON MATCH SET
t.success_count = coalesce(t.success_count, 0) + $success_increment,
t.failure_count = coalesce(t.failure_count, 0) + $failure_increment,
t.usage_count = coalesce(t.usage_count, 0) + $usage_increment,
t.feedback_count = coalesce(t.feedback_count, 0) + $feedback_increment,
t.rating_sum = coalesce(t.rating_sum, 0) + $rating_sum_increment,
t.sample_count = coalesce(t.sample_count, 0) + 1,
t.updated_at = datetime()
"""
with self.graphiti_client.driver.session() as session:
session.run(query, params)
except Exception as e:
logger.debug(f"记录策略趋势失败: {e}")
def query_strategies(self, tool_name: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
"""
查询策略.
Args:
tool_name: 工具名称(可选,用于筛选)
limit: 返回数量限制
Returns:
策略列表
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return []
try:
group_id = self.config_manager.get_group_id()
# 构建查询
if tool_name:
query = """
MATCH (s:Strategy {group_id: $group_id, tool_name: $tool_name})
WHERE s.enabled = true AND coalesce(s.is_latest, true)
RETURN s
ORDER BY s.success_rate DESC, s.usage_count DESC
LIMIT $limit
"""
params = {"group_id": group_id, "tool_name": tool_name, "limit": limit}
else:
query = """
MATCH (s:Strategy {group_id: $group_id})
WHERE s.enabled = true AND coalesce(s.is_latest, true)
RETURN s
ORDER BY s.success_rate DESC, s.usage_count DESC
LIMIT $limit
"""
params = {"group_id": group_id, "limit": limit}
strategies = []
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
for record in result:
strategy_node = record["s"]
strategy = {
"tool_name": strategy_node.get("tool_name", ""),
"content": strategy_node.get("content", ""),
"success_rate": strategy_node.get("success_rate", 0.0),
"usage_count": strategy_node.get("usage_count", 0),
"success_count": strategy_node.get("success_count", 0),
"failure_count": strategy_node.get("failure_count", 0),
"arguments_hash": strategy_node.get("arguments_hash", ""),
"created_at": str(strategy_node.get("created_at", "")),
"updated_at": str(strategy_node.get("updated_at", "")),
}
strategies.append(strategy)
return strategies
except Exception as e:
logger.error(f"查询策略失败: {e}", exc_info=True)
return []
def list_strategy_versions(
self,
tool_name: str,
arguments_hash: Optional[str] = None,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
列出指定策略的历史版本.
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return []
try:
group_id = self.config_manager.get_group_id()
query = """
MATCH (s:Strategy {group_id: $group_id, tool_name: $tool_name})
WHERE ($arguments_hash IS NULL OR s.arguments_hash = $arguments_hash)
RETURN s
ORDER BY coalesce(s.version, 1) DESC, s.updated_at DESC
LIMIT $limit
"""
params = {
"group_id": group_id,
"tool_name": tool_name,
"arguments_hash": arguments_hash,
"limit": limit
}
versions: List[Dict[str, Any]] = []
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
for record in result:
node = record["s"]
versions.append({
"version": node.get("version", 1),
"strategy_uuid": node.get("strategy_uuid"),
"parent_id": node.get("parent_id"),
"is_latest": node.get("is_latest", True),
"content": node.get("content", ""),
"success_rate": node.get("success_rate", 0.0),
"usage_count": node.get("usage_count", 0),
"success_count": node.get("success_count", 0),
"failure_count": node.get("failure_count", 0),
"enabled": node.get("enabled", True),
"updated_at": str(node.get("updated_at", "")),
"created_at": str(node.get("created_at", "")),
"arguments_hash": node.get("arguments_hash", ""),
})
return versions
except Exception as e:
logger.error(f"查询策略版本失败: {e}", exc_info=True)
return []
def get_learning_trends(
self,
tool_name: Optional[str] = None,
arguments_hash: Optional[str] = None,
days: int = 30
) -> List[Dict[str, Any]]:
"""获取策略学习趋势."""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return []
try:
group_id = self.config_manager.get_group_id()
days = max(1, min(days, 180))
query = """
MATCH (t:StrategyTrend {group_id: $group_id})
WHERE ($tool_name IS NULL OR t.tool_name = $tool_name)
AND ($arguments_hash IS NULL OR t.arguments_hash = $arguments_hash)
AND t.date >= date() - duration({days: $days})
RETURN t
ORDER BY t.date ASC
"""
params = {
"group_id": group_id,
"tool_name": tool_name,
"arguments_hash": arguments_hash,
"days": days
}
trends: List[Dict[str, Any]] = []
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
for record in result:
node = record["t"]
success = float(node.get("success_count", 0) or 0.0)
usage = float(node.get("usage_count", 0) or 0.0)
success_rate = success / usage if usage > 0 else 0.0
feedback_count = float(node.get("feedback_count", 0) or 0.0)
rating_sum = float(node.get("rating_sum", 0) or 0.0)
avg_rating = rating_sum / feedback_count if feedback_count > 0 else None
trends.append({
"date": str(node.get("date")),
"tool_name": node.get("tool_name", ""),
"arguments_hash": node.get("arguments_hash", ""),
"usage_count": usage,
"success_count": success,
"failure_count": float(node.get("failure_count", 0) or 0.0),
"success_rate": success_rate,
"feedback_count": feedback_count,
"avg_rating": avg_rating,
"sample_count": int(node.get("sample_count", 0) or 0),
})
return trends
except Exception as e:
logger.error(f"获取学习趋势失败: {e}", exc_info=True)
return []
def get_strategy_alerts(
self,
failure_threshold: int = 5,
days: int = 1,
tool_name: Optional[str] = None,
arguments_hash: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
基于 StrategyTrend 生成简单的策略告警列表.
说明:
- 当前实现按“最近 N 天累计失败次数”触发告警,后续可接入 Opik 做更精细的时序监控。
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return []
try:
group_id = self.config_manager.get_group_id()
days = max(1, min(days, 30))
failure_threshold = max(1, failure_threshold)
query = """
MATCH (t:StrategyTrend {group_id: $group_id})
WHERE ($tool_name IS NULL OR t.tool_name = $tool_name)
AND ($arguments_hash IS NULL OR t.arguments_hash = $arguments_hash)
AND t.date >= date() - duration({days: $days})
WITH t.tool_name as tool_name,
t.arguments_hash as arguments_hash,
sum(coalesce(t.failure_count, 0)) as failures,
sum(coalesce(t.usage_count, 0)) as usage,
max(t.date) as last_date
WHERE failures >= $failure_threshold
RETURN tool_name, arguments_hash, failures, usage, last_date
ORDER BY failures DESC, usage DESC
"""
params = {
"group_id": group_id,
"tool_name": tool_name,
"arguments_hash": arguments_hash,
"days": days,
"failure_threshold": failure_threshold,
}
alerts: List[Dict[str, Any]] = []
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
for record in result:
failures = float(record.get("failures", 0) or 0.0)
usage = float(record.get("usage", 0) or 0.0)
failure_rate = failures / usage if usage > 0 else 0.0
severity = "info"
if failures >= failure_threshold * 4:
severity = "critical"
elif failures >= failure_threshold * 2:
severity = "warning"
alerts.append(
{
"tool_name": record.get("tool_name", ""),
"arguments_hash": record.get("arguments_hash", ""),
"failures": failures,
"usage": usage,
"failure_rate": failure_rate,
"last_date": str(record.get("last_date")),
"severity": severity,
}
)
return alerts
except Exception as e:
logger.error(f"获取策略告警失败: {e}", exc_info=True)
return []
def bulk_update_strategies(
self,
action: str,
tool_name: Optional[str] = None,
success_rate_min: Optional[float] = None,
success_rate_max: Optional[float] = None,
usage_min: Optional[float] = None,
usage_max: Optional[float] = None,
tags: Optional[List[str]] = None,
enabled_filter: Optional[bool] = None,
limit: int = 100,
) -> Optional[Dict[str, Any]]:
"""
批量启用/禁用/删除策略.
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return None
action = (action or "").lower()
if action not in {"enable", "disable", "delete"}:
logger.warning(f"无效的批量操作类型: {action}")
return None
try:
group_id = self.config_manager.get_group_id()
where_clause, params = self._build_strategy_filter(
group_id=group_id,
tool_name=tool_name,
success_rate_min=success_rate_min,
success_rate_max=success_rate_max,
usage_min=usage_min,
usage_max=usage_max,
tags=tags,
enabled=enabled_filter,
latest_only=True,
)
params["limit"] = max(1, min(limit, 1000))
if action == "delete":
query = f"""
MATCH (s:Strategy)
WHERE {where_clause}
WITH s LIMIT $limit
WITH collect(s) as targets
FOREACH (node IN targets | DETACH DELETE node)
RETURN size(targets) as affected
"""
else:
params["target_enabled"] = (action == "enable")
query = f"""
MATCH (s:Strategy)
WHERE {where_clause}
WITH s LIMIT $limit
SET s.enabled = $target_enabled,
s.updated_at = datetime()
RETURN count(s) as affected
"""
with self.graphiti_client.driver.session() as session:
record = session.run(query, params).single()
affected = record.get("affected", 0) if record else 0
# 批量更新后刷新缓存
self._reload_strategies_async()
return {
"action": action,
"affected": affected,
"limit": params["limit"],
}
except Exception as e:
logger.error(f"批量更新策略失败: {e}", exc_info=True)
return None
def bulk_export_strategies(
self,
tool_name: Optional[str] = None,
success_rate_min: Optional[float] = None,
success_rate_max: Optional[float] = None,
usage_min: Optional[float] = None,
usage_max: Optional[float] = None,
tags: Optional[List[str]] = None,
enabled_filter: Optional[bool] = None,
limit: int = 500,
file_path: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""
按过滤条件批量导出策略.
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return None
try:
group_id = self.config_manager.get_group_id()
where_clause, params = self._build_strategy_filter(
group_id=group_id,
tool_name=tool_name,
success_rate_min=success_rate_min,
success_rate_max=success_rate_max,
usage_min=usage_min,
usage_max=usage_max,
tags=tags,
enabled=enabled_filter,
latest_only=False,
)
params["limit"] = max(1, min(limit, 5000))
query = f"""
MATCH (s:Strategy)
WHERE {where_clause}
RETURN s
ORDER BY coalesce(s.updated_at, s.created_at, datetime()) DESC
LIMIT $limit
"""
strategies: List[Dict[str, Any]] = []
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
for record in result:
node = record["s"]
strategies.append(dict(node))
if not strategies:
return {
"count": 0,
"file_path": file_path,
"strategies": [],
}
export_payload = {
"version": "bulk-1.0",
"group_id": group_id,
"strategy_count": len(strategies),
"filters": {
"tool_name": tool_name,
"success_rate_min": success_rate_min,
"success_rate_max": success_rate_max,
"usage_min": usage_min,
"usage_max": usage_max,
"tags": tags,
"enabled": enabled_filter,
},
"strategies": strategies,
}
if file_path:
import os
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(export_payload, f, ensure_ascii=False, indent=2, default=str)
return {
"count": len(strategies),
"file_path": file_path,
"strategies": [],
}
return {
"count": len(strategies),
"file_path": None,
"strategies": strategies,
}
except Exception as e:
logger.error(f"批量导出策略失败: {e}", exc_info=True)
return None
def get_strategy_heatmap(
self,
limit: int = 20,
group_by: str = "tool"
) -> Dict[str, Any]:
"""
获取策略热力图数据,按成功率区间聚合.
Args:
limit: 返回的分组数量限制
group_by: 汇总方式(tool 或 bucket)
Returns:
包含聚合数据的字典
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return {
"success": False,
"message": "Graphiti 客户端未连接,无法生成策略热力图。"
}
try:
group_id = self.config_manager.get_group_id()
limit_value = max(1, min(limit, 100))
normalized_group_by = (group_by or "tool").lower()
if normalized_group_by not in {"tool", "bucket"}:
normalized_group_by = "tool"
query = """
MATCH (s:Strategy {group_id: $group_id})
WHERE s.enabled = true AND coalesce(s.is_latest, true)
WITH
s.tool_name as tool_name,
coalesce(s.success_rate, 0.0) as success_rate,
coalesce(s.usage_count, 0) as usage_count,
coalesce(s.updated_at, s.created_at, datetime()) as last_updated
WITH
tool_name,
CASE
WHEN success_rate < 0.4 THEN '0-40%'
WHEN success_rate < 0.7 THEN '40-70%'
ELSE '70-100%'
END as bucket,
success_rate,
usage_count,
last_updated
RETURN
tool_name,
bucket,
count(*) as strategy_count,
sum(usage_count) as total_usage,
avg(success_rate) as avg_success_rate,
toString(max(last_updated)) as last_updated
ORDER BY total_usage DESC
LIMIT $limit
"""
params = {"group_id": group_id, "limit": limit_value}
entries: List[Dict[str, Any]] = []
bucket_order = ["0-40%", "40-70%", "70-100%"]
total_usage = 0
total_strategies = 0
summary_map: Dict[str, Dict[str, Any]] = {}
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
for record in result:
tool_name = record.get("tool_name", "未知工具")
bucket = record.get("bucket", "0-40%")
strategy_count = int(record.get("strategy_count", 0) or 0)
usage_count = int(record.get("total_usage", 0) or 0)
avg_success = float(record.get("avg_success_rate", 0.0) or 0.0)
entry = {
"tool_name": tool_name,
"bucket": bucket,
"strategy_count": strategy_count,
"total_usage": usage_count,
"avg_success_rate": avg_success,
"last_updated": record.get("last_updated")
}
entries.append(entry)
total_usage += usage_count
total_strategies += strategy_count
summary_key = tool_name if normalized_group_by == "tool" else bucket
summary_entry = summary_map.setdefault(summary_key, {
"usage": 0,
"strategies": 0,
"success_weight": 0.0
})
summary_entry["usage"] += usage_count
summary_entry["strategies"] += strategy_count
summary_entry["success_weight"] += avg_success * max(strategy_count, 1)
groups: List[Dict[str, Any]] = []
for key, data in summary_map.items():
strategies = data["strategies"]
avg_rate = data["success_weight"] / strategies if strategies else 0.0
groups.append({
"name": key,
"total_usage": data["usage"],
"strategies": strategies,
"avg_success_rate": avg_rate
})
groups.sort(key=lambda item: item["total_usage"], reverse=True)
return {
"success": True,
"generated_at": datetime.utcnow().isoformat(),
"entries": entries,
"summary": {
"total_entries": len(entries),
"total_usage": total_usage,
"total_strategies": total_strategies,
"unique_tools": len({entry["tool_name"] for entry in entries}),
"group_by": normalized_group_by,
"groups": groups,
"bucket_order": bucket_order
}
}
except Exception as e:
logger.error(f"生成策略热力图失败: {e}", exc_info=True)
return {
"success": False,
"message": f"生成策略热力图失败: {e}"
}
def get_strategy_stats(self, tool_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
获取策略统计信息.
Args:
tool_name: 工具名称(可选,用于筛选)
Returns:
统计信息字典
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return None
try:
group_id = self.config_manager.get_group_id()
# 构建查询
if tool_name:
query = """
MATCH (s:Strategy {group_id: $group_id, tool_name: $tool_name})
WHERE s.enabled = true AND coalesce(s.is_latest, true)
RETURN
count(s) as total_strategies,
avg(s.success_rate) as avg_success_rate,
sum(s.usage_count) as total_usage,
sum(s.success_count) as total_success,
sum(s.failure_count) as total_failure
"""
params = {"group_id": group_id, "tool_name": tool_name}
else:
query = """
MATCH (s:Strategy {group_id: $group_id})
WHERE s.enabled = true AND coalesce(s.is_latest, true)
WITH s.tool_name as tool_name,
count(s) as count,
avg(s.success_rate) as avg_success_rate,
sum(s.usage_count) as usage,
sum(s.success_count) as success,
sum(s.failure_count) as failure
RETURN
sum(count) as total_strategies,
avg(avg_success_rate) as avg_success_rate,
sum(usage) as total_usage,
sum(success) as total_success,
sum(failure) as total_failure,
collect({
tool_name: tool_name,
count: count,
avg_success_rate: avg_success_rate,
usage: usage,
success: success,
failure: failure
}) as by_tool
"""
params = {"group_id": group_id}
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
record = result.single()
if record:
stats = {
"total_strategies": record.get("total_strategies", 0),
"avg_success_rate": record.get("avg_success_rate", 0.0) or 0.0,
"total_usage": record.get("total_usage", 0),
"total_success": record.get("total_success", 0),
"total_failure": record.get("total_failure", 0),
}
# 如果有按工具分类的统计
by_tool_raw = record.get("by_tool")
if by_tool_raw:
by_tool = {}
for item in by_tool_raw:
tool_name_key = item.get("tool_name")
if tool_name_key:
by_tool[tool_name_key] = {
"count": item.get("count", 0),
"avg_success_rate": item.get("avg_success_rate", 0.0) or 0.0,
"usage": item.get("usage", 0),
"success": item.get("success", 0),
"failure": item.get("failure", 0),
}
stats["by_tool"] = by_tool
return stats
return None
except Exception as e:
logger.error(f"获取策略统计失败: {e}", exc_info=True)
return None
def rate_result(
self,
tool_name: str,
rating: int,
feedback: Optional[str] = None,
context: Optional[Dict[str, Any]] = None
) -> bool:
"""
对工具执行结果进行评分和反馈.
Args:
tool_name: 工具名称
rating: 评分(1-5)
feedback: 反馈意见(可选)
context: 上下文信息(可选)
Returns:
是否成功保存反馈
"""
if not self.is_enabled():
return False
try:
# 将评分转换为成功/失败(3分以上为成功)
success = rating >= 3
# 构建反馈结果
result = {
"success": success,
"rating": rating,
"feedback": feedback or "",
"message": f"用户评分: {rating}/5"
}
# 如果有上下文,使用上下文的参数;否则使用空字典
arguments = context or {}
# 构建反思提示,包含用户反馈
prompt = f"""
分析以下工具调用的用户反馈:
工具:{tool_name}
参数:{json.dumps(arguments, ensure_ascii=False, indent=2)}
用户评分:{rating}/5
用户反馈:{feedback or '无'}
成功:{success}
请分析:
1. 用户为什么给出这个评分?
2. 哪些策略有效?
3. 哪些策略需要改进?
4. 可以提取什么经验教训?
返回 JSON 格式的反思结果,包含:
- "success": 是否成功
- "effective_strategies": 有效的策略
- "improvements": 需要改进的地方
- "lessons": 经验教训
"""
# 在 MCP 模式下重定向日志
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
_is_mcp_mode = not sys.stdin.isatty() if hasattr(sys.stdin, 'isatty') else True
if _is_mcp_mode:
stderr_capture = StringIO()
stdout_capture = StringIO()
with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture):
reflection = self.ace_agent.ask(prompt)
else:
reflection = self.ace_agent.ask(prompt)
# 更新 Skillbook(使用加权评分,高分反馈权重更大)
# 对于高分反馈,增加成功计数;对于低分反馈,增加失败计数
# 使用评分作为权重(1-5 分对应不同的权重)
weight = rating / 5.0 # 1分=0.2, 5分=1.0
self._update_skillbook_with_feedback(
tool_name=tool_name,
arguments=arguments,
result=result,
reflection=reflection,
success=success,
weight=weight
)
# 重新加载策略缓存
self._reload_strategies_async()
logger.info(f"用户反馈已保存: {tool_name} - {rating}/5")
return True
except Exception as e:
logger.error(f"保存用户反馈失败: {e}", exc_info=True)
return False
def _update_skillbook_with_feedback(
self,
tool_name: str,
arguments: Dict[str, Any],
result: Dict[str, Any],
reflection: str,
success: bool,
weight: float = 1.0
):
"""
使用加权反馈更新 Skillbook.
Args:
tool_name: 工具名称
arguments: 工具参数
result: 执行结果
reflection: 反思结果
success: 是否成功
weight: 权重(0.0-1.0),用于加权统计
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return
try:
arguments_hash = str(hash(json.dumps(arguments, sort_keys=True)))
rating = result.get('rating', 3) # 默认3分
self._create_strategy_version(
tool_name=tool_name,
arguments_hash=arguments_hash,
reflection=reflection,
success_increment=weight if success else 0.0,
failure_increment=0.0 if success else weight,
usage_increment=weight,
rating=rating,
feedback_increment=1.0
)
logger.debug(f"更新 Skillbook(加权): {tool_name} - {'成功' if success else '失败'} (权重: {weight:.2f})")
except Exception as e:
logger.error(f"更新 Skillbook(加权)失败: {e}", exc_info=True)
def _load_skillbook(self):
"""从 Neo4j 加载 Skillbook."""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return
try:
group_id = self.config_manager.get_group_id()
# 查询策略节点
query = """
MATCH (s:Strategy {group_id: $group_id})
WHERE s.enabled = true AND coalesce(s.is_latest, true)
RETURN s
ORDER BY s.success_rate DESC, s.usage_count DESC
LIMIT 100
"""
with self.graphiti_client.driver.session() as session:
result = session.run(query, group_id=group_id)
strategies = []
for record in result:
strategy_node = record["s"]
strategy = {
"tool_name": strategy_node.get("tool_name", ""),
"content": strategy_node.get("content", ""),
"success_rate": strategy_node.get("success_rate", 0.0),
"usage_count": strategy_node.get("usage_count", 0),
"arguments_hash": strategy_node.get("arguments_hash", ""),
}
strategies.append(strategy)
if strategies:
# 存储策略供 Generator 使用
self._cached_strategies = strategies
logger.info(f"从 Neo4j 加载了 {len(strategies)} 个策略")
else:
self._cached_strategies = []
except Exception as e:
error_msg = f"加载 Skillbook 失败: {e}"
logger.warning(error_msg, exc_info=True)
# 如果加载失败,使用空列表,不影响系统运行
self._cached_strategies = []
logger.debug("已重置策略缓存为空列表")
def _update_skillbook(
self,
tool_name: str,
arguments: Dict[str, Any],
result: Dict[str, Any],
reflection: str,
success: bool
):
"""更新 Skillbook 到 Neo4j。"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return
try:
arguments_hash = str(hash(json.dumps(arguments, sort_keys=True)))
self._create_strategy_version(
tool_name=tool_name,
arguments_hash=arguments_hash,
reflection=reflection,
success_increment=1.0 if success else 0.0,
failure_increment=0.0 if success else 1.0,
usage_increment=1.0
)
logger.debug(f"更新 Skillbook: {tool_name} - {'成功' if success else '失败'}")
except Exception as e:
error_msg = f"更新 Skillbook 失败: {e}"
logger.error(error_msg, exc_info=True)
logger.debug(f"工具: {tool_name}, 成功: {success}, 参数: {arguments}")
# 不抛出异常,避免影响主流程
def save_skillbook(self, file_path: Optional[str] = None):
"""保存 Skillbook 到文件."""
if not self.is_enabled():
return
try:
if file_path is None:
config_dir = self.config_manager.config_path.parent
file_path = str(config_dir / "ace_skillbook.json")
# ACE 框架的 save_skillbook 方法
if hasattr(self.ace_agent, 'save_skillbook'):
self.ace_agent.save_skillbook(file_path)
logger.info(f"Skillbook 已保存到: {file_path}")
else:
logger.warning("ACE Agent 不支持 save_skillbook 方法")
except Exception as e:
logger.error(f"保存 Skillbook 失败: {e}", exc_info=True)
def load_skillbook(self, file_path: str):
"""从文件加载 Skillbook."""
if not self.is_enabled():
return
try:
# ACE 框架 0.7.x 使用 skillbook_path 参数或 load_skillbook 方法
from ace import ACELiteLLM
api_config = self.config_manager.get_api_config()
model = api_config.model if api_config else "gpt-4o-mini"
# 尝试使用 load_skillbook 方法(新版 API)
if hasattr(self.ace_agent, 'load_skillbook'):
self.ace_agent.load_skillbook(file_path)
logger.info(f"Skillbook 已从文件加载: {file_path}")
else:
# 回退:重新创建 agent 并加载 skillbook
self.ace_agent = ACELiteLLM(model=model, skillbook_path=file_path)
logger.info(f"Skillbook 已从文件加载(重新初始化): {file_path}")
except Exception as e:
logger.error(f"加载 Skillbook 失败: {e}", exc_info=True)
def export_strategies(
self,
tool_name: Optional[str] = None,
file_path: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
导出策略到 JSON 文件.
Args:
tool_name: 工具名称(可选,用于筛选)
file_path: 导出文件路径(可选)
Returns:
导出结果字典,包含文件路径和导出数量
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return None
try:
# 查询策略
strategies = self.query_strategies(tool_name=tool_name, limit=1000)
if not strategies:
logger.warning("没有策略可导出")
return None
# 准备导出数据
export_data = {
"version": "1.0",
"export_time": datetime.now().isoformat(),
"group_id": self.config_manager.get_group_id(),
"tool_name": tool_name,
"strategies": strategies
}
# 确定文件路径
if file_path is None:
from pathlib import Path
config_dir = self.config_manager.config_path.parent
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"ace_strategies_{timestamp}.json"
if tool_name:
filename = f"ace_strategies_{tool_name}_{timestamp}.json"
file_path = str(config_dir / filename)
# 写入文件
import os
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else ".", exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, ensure_ascii=False, indent=2, default=str)
file_size = os.path.getsize(file_path)
logger.info(f"策略已导出到: {file_path} ({len(strategies)} 个策略)")
return {
"file_path": file_path,
"count": len(strategies),
"size": file_size
}
except Exception as e:
logger.error(f"导出策略失败: {e}", exc_info=True)
return None
def import_strategies(
self,
file_path: str,
overwrite: bool = False
) -> Optional[Dict[str, Any]]:
"""
从 JSON 文件导入策略.
Args:
file_path: 导入文件路径
overwrite: 是否覆盖现有策略
Returns:
导入结果字典,包含导入数量和统计信息
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return None
try:
import os
if not os.path.exists(file_path):
logger.error(f"文件不存在: {file_path}")
return None
# 读取文件
with open(file_path, 'r', encoding='utf-8') as f:
import_data = json.load(f)
# 验证数据格式
if not isinstance(import_data, dict) or 'strategies' not in import_data:
logger.error("无效的策略文件格式")
return None
strategies = import_data.get('strategies', [])
if not strategies:
logger.warning("文件中没有策略数据")
return None
group_id = self.config_manager.get_group_id()
imported_count = 0
overwritten_count = 0
created_count = 0
errors = []
# 导入每个策略
for strategy in strategies:
try:
tool_name = strategy.get('tool_name')
arguments_hash = strategy.get('arguments_hash', '')
content = strategy.get('content', '')
success_rate = strategy.get('success_rate', 0.0)
usage_count = strategy.get('usage_count', 0)
success_count = strategy.get('success_count', 0)
failure_count = strategy.get('failure_count', 0)
if not tool_name:
errors.append("策略缺少 tool_name")
continue
# 检查策略是否已存在
check_query = """
MATCH (s:Strategy {
tool_name: $tool_name,
arguments_hash: $arguments_hash,
group_id: $group_id
})
RETURN s
"""
with self.graphiti_client.driver.session() as session:
existing = session.run(
check_query,
tool_name=tool_name,
arguments_hash=arguments_hash,
group_id=group_id
).single()
if existing and not overwrite:
# 跳过已存在的策略
continue
# 导入或更新策略
import_query = """
MERGE (s:Strategy {
tool_name: $tool_name,
arguments_hash: $arguments_hash,
group_id: $group_id
})
ON CREATE SET
s.content = $content,
s.success_count = $success_count,
s.failure_count = $failure_count,
s.usage_count = $usage_count,
s.success_rate = $success_rate,
s.created_at = datetime(),
s.updated_at = datetime(),
s.enabled = true
ON MATCH SET
s.content = $content,
s.success_count = $success_count,
s.failure_count = $failure_count,
s.usage_count = $usage_count,
s.success_rate = $success_rate,
s.updated_at = datetime(),
s.enabled = true
RETURN s
"""
session.run(
import_query,
tool_name=tool_name,
arguments_hash=arguments_hash,
group_id=group_id,
content=content,
success_count=success_count,
failure_count=failure_count,
usage_count=usage_count,
success_rate=success_rate
)
if existing:
overwritten_count += 1
else:
created_count += 1
imported_count += 1
except Exception as e:
errors.append(f"导入策略失败: {str(e)}")
logger.warning(f"导入策略失败: {e}")
# 重新加载策略缓存
self._reload_strategies_async()
logger.info(f"策略导入完成: {imported_count} 个策略(创建: {created_count}, 覆盖: {overwritten_count})")
return {
"count": imported_count,
"created": created_count,
"overwritten": overwritten_count,
"errors": errors
}
except Exception as e:
logger.error(f"导入策略失败: {e}", exc_info=True)
return None
def toggle_strategy(
self,
tool_name: str,
arguments_hash: Optional[str] = None,
enabled: bool = True
) -> Optional[Dict[str, Any]]:
"""
启用或禁用策略.
Args:
tool_name: 工具名称
arguments_hash: 参数哈希(可选,如果不提供则操作该工具的所有策略)
enabled: 是否启用
Returns:
操作结果字典,包含影响的策略数量
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return None
try:
group_id = self.config_manager.get_group_id()
# 构建查询
if arguments_hash:
# 操作特定策略
query = """
MATCH (s:Strategy {
tool_name: $tool_name,
arguments_hash: $arguments_hash,
group_id: $group_id
})
WHERE coalesce(s.is_latest, true)
SET s.enabled = $enabled, s.updated_at = datetime()
RETURN count(s) as count
"""
params = {
"tool_name": tool_name,
"arguments_hash": arguments_hash,
"group_id": group_id,
"enabled": enabled
}
else:
# 操作该工具的所有策略
query = """
MATCH (s:Strategy {
tool_name: $tool_name,
group_id: $group_id
})
WHERE coalesce(s.is_latest, true)
SET s.enabled = $enabled, s.updated_at = datetime()
RETURN count(s) as count
"""
params = {
"tool_name": tool_name,
"group_id": group_id,
"enabled": enabled
}
with self.graphiti_client.driver.session() as session:
result = session.run(query, params)
record = result.single()
count = record.get("count", 0) if record else 0
# 重新加载策略缓存
self._reload_strategies_async()
logger.info(f"策略已{'启用' if enabled else '禁用'}: {tool_name} ({count} 个策略)")
return {
"count": count,
"tool_name": tool_name,
"enabled": enabled
}
except Exception as e:
logger.error(f"操作策略失败: {e}", exc_info=True)
return None
def validate_strategies(self) -> Dict[str, Any]:
"""
验证策略的健康状态.
Returns:
验证结果字典
"""
if not self.graphiti_client or not self.graphiti_client.is_connected():
return {
"valid": False,
"error": "Graphiti 客户端未连接"
}
try:
group_id = self.config_manager.get_group_id()
# 检查策略数据完整性
query = """
MATCH (s:Strategy {group_id: $group_id})
WHERE coalesce(s.is_latest, true)
RETURN
count(s) as total,
count(CASE WHEN s.enabled = true THEN 1 END) as enabled_count,
count(CASE WHEN s.enabled = false THEN 1 END) as disabled_count,
count(CASE WHEN s.success_rate IS NULL THEN 1 END) as missing_success_rate,
count(CASE WHEN s.usage_count IS NULL THEN 1 END) as missing_usage_count,
avg(s.success_rate) as avg_success_rate
"""
with self.graphiti_client.driver.session() as session:
result = session.run(query, group_id=group_id)
record = result.single()
if record:
total = record.get("total", 0)
enabled_count = record.get("enabled_count", 0)
disabled_count = record.get("disabled_count", 0)
missing_success_rate = record.get("missing_success_rate", 0)
missing_usage_count = record.get("missing_usage_count", 0)
avg_success_rate = record.get("avg_success_rate", 0.0) or 0.0
issues = []
if missing_success_rate > 0:
issues.append(f"{missing_success_rate} 个策略缺少成功率")
if missing_usage_count > 0:
issues.append(f"{missing_usage_count} 个策略缺少使用次数")
return {
"valid": len(issues) == 0,
"total": total,
"enabled": enabled_count,
"disabled": disabled_count,
"avg_success_rate": avg_success_rate,
"issues": issues
}
return {
"valid": True,
"total": 0,
"enabled": 0,
"disabled": 0,
"avg_success_rate": 0.0,
"issues": []
}
except Exception as e:
logger.error(f"验证策略失败: {e}", exc_info=True)
return {
"valid": False,
"error": str(e)
}