"""
纯 DoWhy 归因分析工具 - 100% 使用 DoWhy 原生方法,不重复造轮子
"""
import logging
from typing import Any, Dict, List, Optional
import pandas as pd
import numpy as np
import dowhy
from mcp.server.fastmcp import FastMCP
from ..utils.data_processor import load_and_validate_data
logger = logging.getLogger("dowhy-mcp-server.attribution")
def register_attribution_tools(server: FastMCP) -> None:
"""注册所有纯 DoWhy 归因分析工具"""
@server.tool()
def shapley_value_attribution(
data_path: str,
outcome: str,
features: List[str],
num_samples: int = 100
) -> Dict[str, Any]:
"""
使用 DoWhy 进行特征归因分析
"""
try:
# 加载数据
all_vars = [outcome] + features
data = load_and_validate_data(data_path, all_vars)
# 使用 DoWhy 计算每个特征的因果贡献
shapley_values = {}
for feature in features:
try:
# 为每个特征创建因果模型
other_features = [f for f in features if f != feature]
model = dowhy.CausalModel(
data=data,
treatment=feature,
outcome=outcome,
common_causes=other_features if other_features else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
shapley_values[feature] = float(estimate.value)
except Exception as e:
logger.warning(f"DoWhy 模型失败 {feature}: {e}")
# 如果因果模型失败,设为 0
shapley_values[feature] = 0.0
return {
"success": True,
"method": "DoWhy Causal Attribution",
"outcome": outcome,
"features": features,
"shapley_values": shapley_values,
"total_attribution": sum(shapley_values.values()),
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Causal Attribution"
}
@server.tool()
def direct_causal_influence(
data_path: str,
target: str,
variables: List[str],
method: str = "partial_correlation"
) -> Dict[str, Any]:
"""
使用 DoWhy 测量直接因果影响
"""
try:
all_vars = [target] + variables
data = load_and_validate_data(data_path, all_vars)
influence_results = {}
for variable in variables:
try:
# 其他变量作为混杂因子
other_vars = [v for v in variables if v != variable]
model = dowhy.CausalModel(
data=data,
treatment=variable,
outcome=target,
common_causes=other_vars if other_vars else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
influence_results[variable] = {
"direct_influence": float(estimate.value),
"method": "DoWhy Direct Causal Effect"
}
except Exception as e:
logger.warning(f"DoWhy 直接影响计算失败 {variable}: {e}")
influence_results[variable] = {
"direct_influence": 0.0,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy Direct Causal Influence",
"target": target,
"variables": variables,
"influence_results": influence_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Direct Causal Influence"
}
@server.tool()
def total_causal_influence(
data_path: str,
target: str,
variables: List[str],
include_indirect: bool = True
) -> Dict[str, Any]:
"""
使用 DoWhy 测量总因果影响(直接+间接)
"""
try:
all_vars = [target] + variables
data = load_and_validate_data(data_path, all_vars)
influence_results = {}
for variable in variables:
try:
# 不包含混杂因子,测量总效应
model = dowhy.CausalModel(
data=data,
treatment=variable,
outcome=target
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
influence_results[variable] = {
"total_influence": float(estimate.value),
"includes_indirect": include_indirect,
"method": "DoWhy Total Causal Effect"
}
except Exception as e:
logger.warning(f"DoWhy 总影响计算失败 {variable}: {e}")
influence_results[variable] = {
"total_influence": 0.0,
"includes_indirect": include_indirect,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy Total Causal Influence",
"target": target,
"variables": variables,
"influence_results": influence_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Total Causal Influence"
}
@server.tool()
def path_specific_effects(
data_path: str,
treatment: str,
outcome: str,
mediators: List[str],
confounders: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
使用 DoWhy 分析路径特定效应
"""
try:
all_vars = [treatment, outcome] + mediators
if confounders:
all_vars.extend(confounders)
data = load_and_validate_data(data_path, all_vars)
# 使用 DoWhy 的前门调整方法
model = dowhy.CausalModel(
data=data,
treatment=treatment,
outcome=outcome,
mediators=mediators,
common_causes=confounders if confounders else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
try:
# 尝试前门调整
estimate = model.estimate_effect(
identified_estimand,
method_name="frontdoor.two_stage_regression"
)
method_used = "DoWhy Front-door Adjustment"
except:
# 回退到后门调整
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
method_used = "DoWhy Backdoor Adjustment (fallback)"
return {
"success": True,
"method": method_used,
"treatment": treatment,
"outcome": outcome,
"mediators": mediators,
"confounders": confounders or [],
"path_specific_effect": float(estimate.value),
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Path-Specific Effects"
}
@server.tool()
def mechanism_attribution(
data_path: str,
outcome: str,
mechanisms: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
使用 DoWhy 进行机制归因
"""
try:
# 提取所有变量
all_vars = [outcome]
for mechanism in mechanisms:
if "variables" in mechanism:
all_vars.extend(mechanism["variables"])
all_vars = list(set(all_vars)) # 去重
data = load_and_validate_data(data_path, all_vars)
mechanism_results = {}
for i, mechanism in enumerate(mechanisms):
mechanism_name = mechanism.get("name", f"mechanism_{i}")
mechanism_vars = mechanism.get("variables", [])
if not mechanism_vars:
continue
try:
# 为每个机制创建因果模型
# 使用第一个变量作为治疗变量
treatment_var = mechanism_vars[0]
confounders = mechanism_vars[1:] if len(mechanism_vars) > 1 else None
model = dowhy.CausalModel(
data=data,
treatment=treatment_var,
outcome=outcome,
common_causes=confounders
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
mechanism_results[mechanism_name] = {
"attribution": float(estimate.value),
"variables": mechanism_vars,
"method": "DoWhy Mechanism Attribution"
}
except Exception as e:
logger.warning(f"DoWhy 机制归因失败 {mechanism_name}: {e}")
mechanism_results[mechanism_name] = {
"attribution": 0.0,
"variables": mechanism_vars,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy Mechanism Attribution",
"outcome": outcome,
"mechanisms": mechanisms,
"mechanism_results": mechanism_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Mechanism Attribution"
}
@server.tool()
def causal_contribution_analysis(
data_path: str,
outcome: str,
factors: List[str],
method: str = "variance_decomposition"
) -> Dict[str, Any]:
"""
使用 DoWhy 分析因果贡献
"""
try:
all_vars = [outcome] + factors
data = load_and_validate_data(data_path, all_vars)
contribution_results = {}
for factor in factors:
try:
# 其他因子作为混杂因子
other_factors = [f for f in factors if f != factor]
model = dowhy.CausalModel(
data=data,
treatment=factor,
outcome=outcome,
common_causes=other_factors if other_factors else None
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True
)
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.linear_regression"
)
contribution_results[factor] = {
"contribution": float(estimate.value),
"method": "DoWhy Causal Contribution"
}
except Exception as e:
logger.warning(f"DoWhy 贡献分析失败 {factor}: {e}")
contribution_results[factor] = {
"contribution": 0.0,
"method": "Failed"
}
return {
"success": True,
"method": "DoWhy Causal Contribution Analysis",
"outcome": outcome,
"factors": factors,
"contribution_results": contribution_results,
"sample_size": int(len(data))
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "DoWhy Causal Contribution Analysis"
}