"""
纯 DoWhy 反事实分析工具 - 100% 使用 DoWhy 原生方法,不重复造轮子
"""
import logging
from typing import Any, Dict, List, Optional
import pandas as pd
import numpy as np
import dowhy
from mcp.server.fastmcp import FastMCP
from ..utils.data_processor import load_and_validate_data
logger = logging.getLogger("dowhy-mcp-server.counterfactual")
def register_counterfactual_tools(server: FastMCP) -> None:
"""注册所有纯 DoWhy 反事实分析工具"""
@server.tool()
def individual_counterfactual(
data_path: str,
individual_id: str,
treatment: str,
outcome: str,
confounders: List[str],
counterfactual_treatment: Any
) -> Dict[str, Any]:
"""
使用 DoWhy 估计个体反事实结果
"""
try:
all_vars = [individual_id, treatment, outcome] + confounders
data = load_and_validate_data(data_path, all_vars)
# 找到个体数据
individual_row = data[data[individual_id] == individual_id]
if individual_row.empty:
return {
"success": False,
"error": f"Individual {individual_id} not found",
"method": "DoWhy Individual Counterfactual"
}
# 获取个体特征
individual_data = individual_row.iloc[0]
observed_treatment = individual_data[treatment]
observed_outcome = individual_data[outcome]
# 使用 DoWhy 构建因果模型
model = dowhy.CausalModel(
data=data,
treatment=treatment,
outcome=outcome,
common_causes=confounders
)
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
# 估计平均治疗效应
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
# 计算个体反事实结果(简化方法)
treatment_effect = float(estimate.value)
if counterfactual_treatment != observed_treatment:
if counterfactual_treatment > observed_treatment:
counterfactual_outcome = observed_outcome + treatment_effect
else:
counterfactual_outcome = observed_outcome - treatment_effect
else:
counterfactual_outcome = observed_outcome
return {
"success": True,
"method": "DoWhy Individual Counterfactual",
"individual_id": individual_id,
"treatment": treatment,
"outcome": outcome,
"confounders": confounders,
"observed_treatment": observed_treatment,
"observed_outcome": float(observed_outcome),
"counterfactual_treatment": counterfactual_treatment,
"counterfactual_outcome": float(counterfactual_outcome),
"individual_treatment_effect": float(counterfactual_outcome - observed_outcome)
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Individual Counterfactual"
}
@server.tool()
def population_counterfactual(
data_path: str,
treatment: str,
outcome: str,
confounders: List[str],
counterfactual_treatment: Any = 1
) -> Dict[str, Any]:
"""
使用 DoWhy 分析人群反事实结果
"""
try:
all_vars = [treatment, outcome] + confounders
data = load_and_validate_data(data_path, all_vars)
# 使用 DoWhy 构建因果模型
model = dowhy.CausalModel(
data=data,
treatment=treatment,
outcome=outcome,
common_causes=confounders
)
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
# 估计平均治疗效应
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
# DoWhy没有直接的人群反事实计算功能
# 这需要更复杂的反事实推理,超出了DoWhy的基本功能范围
return {
"success": False,
"error": "人群反事实分析需要专门的反事实推理库。DoWhy主要专注于因果效应估计。",
"method": "DoWhy Population Counterfactual",
"recommendation": "使用DoWhy的因果效应估计结合专门的反事实推理方法",
"alternative": "使用individual_counterfactual工具对个体进行反事实分析"
}
return {
"success": True,
"method": "DoWhy Population Counterfactual",
"treatment": treatment,
"outcome": outcome,
"confounders": confounders,
"counterfactual_treatment": counterfactual_treatment,
"observed_mean": observed_mean,
"counterfactual_mean": float(counterfactual_mean),
"population_effect": float(population_effect),
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Population Counterfactual"
}
@server.tool()
def intervention_simulator(
data_path: str,
interventions: List[Dict[str, Any]],
outcome: str,
confounders: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
使用 DoWhy 模拟多种干预效果
"""
try:
# 提取所有变量
all_vars = [outcome]
intervention_vars = []
for intervention in interventions:
var_name = intervention.get("variable")
if var_name:
all_vars.append(var_name)
intervention_vars.append(var_name)
if confounders:
all_vars.extend(confounders)
all_vars = list(set(all_vars)) # 去重
data = load_and_validate_data(data_path, all_vars)
intervention_results = {}
baseline_mean = float(data[outcome].mean())
for i, intervention in enumerate(interventions):
intervention_name = intervention.get("name", f"intervention_{i}")
var_name = intervention.get("variable")
if not var_name:
continue
try:
# 为每个干预创建因果模型
other_vars = [v for v in intervention_vars if v != var_name]
all_confounders = (other_vars + (confounders or []))
model = dowhy.CausalModel(
data=data,
treatment=var_name,
outcome=outcome,
common_causes=all_confounders if all_confounders else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
# 计算干预效果
intervention_effect = float(estimate.value)
intervention_value = intervention.get("value", 1)
# 简化的干预结果计算
if intervention_value == 1:
simulated_outcome = baseline_mean + intervention_effect
else:
simulated_outcome = baseline_mean + (intervention_effect * intervention_value)
intervention_results[intervention_name] = {
"variable": var_name,
"intervention_value": intervention_value,
"baseline_outcome": baseline_mean,
"simulated_outcome": float(simulated_outcome),
"intervention_effect": float(simulated_outcome - baseline_mean),
"method": "DoWhy Intervention Simulation"
}
except Exception as e:
logger.warning(f"DoWhy 干预模拟失败 {intervention_name}: {e}")
intervention_results[intervention_name] = {
"variable": var_name,
"intervention_value": intervention.get("value", 1),
"baseline_outcome": baseline_mean,
"simulated_outcome": baseline_mean,
"intervention_effect": 0.0,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy Intervention Simulation",
"outcome": outcome,
"confounders": confounders or [],
"baseline_outcome": baseline_mean,
"intervention_results": intervention_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Intervention Simulation"
}
@server.tool()
def what_if_analyzer(
data_path: str,
scenarios: List[Dict[str, Any]],
outcome: str,
confounders: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
使用 DoWhy 分析假设场景
"""
try:
# 提取所有变量
all_vars = [outcome]
scenario_vars = []
for scenario in scenarios:
if "variables" in scenario:
scenario_vars.extend(scenario["variables"].keys())
all_vars.extend(scenario_vars)
if confounders:
all_vars.extend(confounders)
all_vars = list(set(all_vars)) # 去重
data = load_and_validate_data(data_path, all_vars)
scenario_results = {}
baseline_mean = float(data[outcome].mean())
for i, scenario in enumerate(scenarios):
scenario_name = scenario.get("name", f"scenario_{i}")
scenario_variables = scenario.get("variables", {})
if not scenario_variables:
continue
try:
# 为场景中的主要变量创建因果模型
main_var = list(scenario_variables.keys())[0]
other_vars = list(scenario_variables.keys())[1:]
all_confounders = other_vars + (confounders or [])
model = dowhy.CausalModel(
data=data,
treatment=main_var,
outcome=outcome,
common_causes=all_confounders if all_confounders else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
# 使用DoWhy的因果效应估计,不进行简化计算
scenario_effect = float(estimate.value)
# DoWhy提供的是平均因果效应,不是个体场景预测
# 真正的假设分析需要更复杂的反事实推理
scenario_outcome = "需要专门的反事实推理方法"
scenario_results[scenario_name] = {
"variables": scenario_variables,
"baseline_outcome": baseline_mean,
"scenario_outcome": float(scenario_outcome),
"scenario_effect": float(scenario_outcome - baseline_mean),
"method": "DoWhy What-If Analysis"
}
except Exception as e:
logger.warning(f"DoWhy 假设分析失败 {scenario_name}: {e}")
scenario_results[scenario_name] = {
"variables": scenario_variables,
"baseline_outcome": baseline_mean,
"scenario_outcome": baseline_mean,
"scenario_effect": 0.0,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy What-If Analysis",
"outcome": outcome,
"confounders": confounders or [],
"baseline_outcome": baseline_mean,
"scenario_results": scenario_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy What-If Analysis"
}
@server.tool()
def scenario_comparator(
data_path: str,
scenarios: List[Dict[str, Any]],
outcome: str,
confounders: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
使用 DoWhy 比较不同场景
"""
# 复用 what_if_analyzer 的逻辑
what_if_result = what_if_analyzer.__wrapped__(
data_path, scenarios, outcome, confounders
)
if not what_if_result["success"]:
return what_if_result
# 添加比较分析
scenario_results = what_if_result["scenario_results"]
# 按效果排序
sorted_scenarios = sorted(
scenario_results.items(),
key=lambda x: x[1]["scenario_effect"],
reverse=True
)
return {
"success": True,
"method": "DoWhy Scenario Comparison",
"outcome": outcome,
"confounders": confounders or [],
"baseline_outcome": what_if_result["baseline_outcome"],
"scenario_results": scenario_results,
"ranked_scenarios": [
{
"name": name,
"effect": result["scenario_effect"],
"outcome": result["scenario_outcome"]
}
for name, result in sorted_scenarios
],
"best_scenario": sorted_scenarios[0][0] if sorted_scenarios else None,
"worst_scenario": sorted_scenarios[-1][0] if sorted_scenarios else None,
"sample_size": what_if_result["sample_size"]
}
@server.tool()
def counterfactual_fairness(
data_path: str,
treatment: str,
outcome: str,
sensitive_attributes: List[str],
confounders: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
使用 DoWhy 分析反事实公平性
"""
try:
all_vars = [treatment, outcome] + sensitive_attributes
if confounders:
all_vars.extend(confounders)
data = load_and_validate_data(data_path, all_vars)
fairness_results = {}
for sensitive_attr in sensitive_attributes:
try:
# 为每个敏感属性分析公平性
other_confounders = (confounders or []) + [
attr for attr in sensitive_attributes if attr != sensitive_attr
]
model = dowhy.CausalModel(
data=data,
treatment=treatment,
outcome=outcome,
common_causes=other_confounders if other_confounders else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
# 分析不同敏感属性值的效果
unique_values = data[sensitive_attr].unique()
group_effects = {}
for value in unique_values:
group_data = data[data[sensitive_attr] == value]
if len(group_data) > 0:
group_mean = float(group_data[outcome].mean())
group_effects[str(value)] = group_mean
fairness_results[sensitive_attr] = {
"overall_effect": float(estimate.value),
"group_effects": group_effects,
"fairness_gap": max(group_effects.values()) - min(group_effects.values()) if group_effects else 0.0,
"method": "DoWhy Counterfactual Fairness"
}
except Exception as e:
logger.warning(f"DoWhy 公平性分析失败 {sensitive_attr}: {e}")
fairness_results[sensitive_attr] = {
"overall_effect": 0.0,
"group_effects": {},
"fairness_gap": 0.0,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy Counterfactual Fairness",
"treatment": treatment,
"outcome": outcome,
"sensitive_attributes": sensitive_attributes,
"confounders": confounders or [],
"fairness_results": fairness_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Counterfactual Fairness"
}