two_causal_analysis.py•13.8 kB
import logging
import numpy as np
import pandas as pd
from enum import Enum
from pydantic import BaseModel, Field
from typing import Optional, Dict, List, Any
from fastapi import HTTPException, APIRouter
from statsmodels.tsa.stattools import grangercausalitytests
from config.config import *
from routers.utils.openplant import OpenPlant
# 全局配置
router = APIRouter()
logger = logging.getLogger("two_causal_analysis")
opt = OpenPlant(host=config_host, port=config_port, timeout=config_timeout)
class CausalAnalysisError(Exception):
"""因果分析专用异常类"""
def __init__(
self, message: str, error_type: str = "分析错误", solution: str = None
):
self.message = message
self.error_type = error_type
self.solution = solution or "请检查输入参数或联系技术支持"
super().__init__(self.message)
def to_http_exception(self, status_code: int = 400) -> HTTPException:
"""转换为HTTPException"""
return HTTPException(
status_code=status_code,
detail={
"error_type": self.error_type,
"message": self.message,
"solution": self.solution,
},
)
def validate_analysis_inputs(
point1: str,
point2: str,
start_time: str,
end_time: str,
df_data,
processed_data1: np.ndarray,
processed_data2: np.ndarray,
method: str,
) -> None:
"""
统一验证分析输入参数和数据质量
Args:
point1, point2: 数据点名称
start_time, end_time: 时间范围
df_data: 原始数据
processed_data1, processed_data2: 处理后的数据
method: 分析方法
Raises:
CausalAnalysisError: 验证失败时抛出
"""
# 验证数据获取
if df_data is None or df_data.empty:
raise CausalAnalysisError(
f"无法获取数据点 {point1} 或 {point2} 的数据",
"数据获取失败",
"请检查数据点名称是否正确,时间范围是否有数据",
)
# 验证数据长度
if len(processed_data1) < 10:
raise CausalAnalysisError(
f"有效数据点只有{len(processed_data1)}个,因果分析至少需要10个",
"数据不足",
"建议扩大时间范围或调整采样间隔以获得更多数据",
)
# 验证数据变异性
if np.var(processed_data1) == 0 or np.var(processed_data2) == 0:
raise CausalAnalysisError(
"数据序列无变化,无法进行因果分析",
"数据无变异性",
"请提供具有变化的时间序列数据",
)
# 验证分析方法
supported_methods = ["granger", "cross_correlation", "transfer_entropy"]
if method not in supported_methods:
raise CausalAnalysisError(
f"不支持的因果分析方法: {method}",
"方法不支持",
f"请选择支持的分析方法:{', '.join(supported_methods)}",
)
class CausalAnalysisRequest(BaseModel):
"""因果分析请求模型"""
point1: str = Field(..., description="数据点1名称")
point2: str = Field(..., description="数据点2名称")
start_time: str = Field(..., description="开始时间 YYYY-MM-DD HH:MM:SS")
end_time: str = Field(..., description="结束时间 YYYY-MM-DD HH:MM:SS")
interval: str = Field(..., description="采样间隔,如'1m'")
method: str = Field(default="granger", description="分析方法")
max_lag: int = Field(default=5, ge=1, le=20, description="最大滞后期")
fill_method: str = Field(default="outer", description="缺失值填充方法")
class CausalAnalysisResponse(BaseModel):
"""因果分析响应模型"""
causal_strength: float = Field(..., description="因果关系强度 (0-1)")
direction: str = Field(..., description="因果方向")
method: str = Field(..., description="分析方法")
confidence: str = Field(..., description="可信度等级")
data_points: int = Field(..., description="有效数据点数")
summary: str = Field(..., description="结果摘要")
recommendations: List[str] = Field(default_factory=list, description="建议")
details: Optional[Dict[str, Any]] = Field(None, description="详细结果")
def assess_data_quality(data1: np.ndarray, data2: np.ndarray) -> Dict[str, Any]:
"""评估数据质量"""
return {
"length": len(data1),
"variance1": float(np.var(data1)),
"variance2": float(np.var(data2)),
"quality": (
"high"
if len(data1) >= 50 and np.var(data1) > 0 and np.var(data2) > 0
else "low"
),
}
def determine_confidence_level(strength: float, data_length: int, quality: str) -> str:
"""确定可信度等级"""
if strength >= 0.7 and data_length >= 50 and quality == "high":
return "high"
elif strength >= 0.4 and data_length >= 30:
return "medium"
else:
return "low"
def generate_recommendations(
strength: float, direction: str, confidence: str
) -> List[str]:
"""生成建议"""
recommendations = []
if confidence == "low":
recommendations.append("建议增加数据量或检查数据质量")
if strength < 0.3:
recommendations.append("因果关系较弱,建议验证数据或尝试其他方法")
elif strength > 0.7:
recommendations.append("发现强因果关系,可进一步分析机制")
if direction == "bidirectional":
recommendations.append("存在双向因果关系,需要更深入的分析")
return recommendations
def preprocess_data(
data1: np.ndarray, data2: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""数据预处理:去除NaN值"""
df = pd.DataFrame({"data1": data1, "data2": data2})
df_clean = df.dropna()
return df_clean["data1"].values, df_clean["data2"].values
def calculate_granger_causality(
data1: np.ndarray, data2: np.ndarray, max_lag: int = 5
) -> Dict[str, Any]:
"""计算格兰杰因果关系"""
try:
lag = min(max_lag, len(data1) // 4, 10)
# 双向检验
data_12 = np.column_stack([data2, data1])
granger_12 = grangercausalitytests(data_12, maxlag=lag, verbose=False)
pvalue_12 = granger_12[lag][0]["ssr_ftest"][1]
data_21 = np.column_stack([data1, data2])
granger_21 = grangercausalitytests(data_21, maxlag=lag, verbose=False)
pvalue_21 = granger_21[lag][0]["ssr_ftest"][1]
# 计算强度和方向
strength_12 = max(0, 1 - pvalue_12)
strength_21 = max(0, 1 - pvalue_21)
causal_strength = max(strength_12, strength_21)
if strength_12 > 0.5 and strength_21 > 0.5:
direction = "bidirectional"
elif strength_12 > strength_21 and strength_12 > 0.3:
direction = "data1->data2"
elif strength_21 > strength_12 and strength_21 > 0.3:
direction = "data2->data1"
else:
direction = "none"
return {
"causal_strength": causal_strength,
"direction": direction,
"details": {"p_values": [pvalue_12, pvalue_21], "lag": lag},
"success": True,
}
except Exception as e:
return {
"causal_strength": 0.0,
"direction": "none",
"success": False,
"error": str(e),
}
def calculate_cross_correlation(
data1: np.ndarray, data2: np.ndarray, max_lag: int = 5
) -> Dict[str, Any]:
"""计算互相关分析"""
try:
correlation = np.corrcoef(data1, data2)[0, 1]
max_correlation = abs(correlation)
best_lag = 0
for lag in range(1, min(max_lag + 1, len(data1) // 4)):
if lag >= len(data1):
break
corr_12 = np.corrcoef(data1[:-lag], data2[lag:])[0, 1]
corr_21 = np.corrcoef(data2[:-lag], data1[lag:])[0, 1]
if abs(corr_12) > max_correlation:
max_correlation = abs(corr_12)
best_lag = lag
if abs(corr_21) > max_correlation:
max_correlation = abs(corr_21)
best_lag = -lag
causal_strength = max_correlation
if best_lag > 0:
direction = "data1->data2"
elif best_lag < 0:
direction = "data2->data1"
else:
direction = "simultaneous"
return {
"causal_strength": causal_strength,
"direction": direction,
"details": {"correlation": correlation, "best_lag": best_lag},
"success": True,
}
except Exception as e:
return {
"causal_strength": 0.0,
"direction": "none",
"success": False,
"error": str(e),
}
def calculate_transfer_entropy(x: np.ndarray, y: np.ndarray, lag: int = 1) -> float:
"""计算传递熵"""
try:
if len(x) < lag + 1 or len(y) < lag + 1:
return 0.0
# 简化的传递熵计算
x_past = x[:-lag]
x_present = x[lag:]
y_past = y[:-lag]
# 使用条件互信息的简化估计
te = np.corrcoef(x_present, y_past)[0, 1] ** 2
return abs(te) if not np.isnan(te) else 0.0
except:
return 0.0
def calculate_transfer_entropy_analysis(
data1: np.ndarray, data2: np.ndarray, max_lag: int = 3
) -> Dict[str, Any]:
"""传递熵分析"""
try:
te_12 = calculate_transfer_entropy(data2, data1, max_lag)
te_21 = calculate_transfer_entropy(data1, data2, max_lag)
causal_strength = max(te_12, te_21)
if te_12 > 0.3 and te_21 > 0.3:
direction = "bidirectional"
elif te_12 > te_21 and te_12 > 0.1:
direction = "data1->data2"
elif te_21 > te_12 and te_21 > 0.1:
direction = "data2->data1"
else:
direction = "none"
return {
"causal_strength": causal_strength,
"direction": direction,
"details": {"te_12": te_12, "te_21": te_21},
"success": True,
}
except Exception as e:
return {
"causal_strength": 0.0,
"direction": "none",
"success": False,
"error": str(e),
}
def interpret_causal_strength(strength: float, method: str) -> str:
"""解释因果关系强度"""
if strength >= 0.7:
return f"强因果关系 (强度: {strength:.3f})"
elif strength >= 0.4:
return f"中等因果关系 (强度: {strength:.3f})"
elif strength >= 0.2:
return f"弱因果关系 (强度: {strength:.3f})"
else:
return f"无明显因果关系 (强度: {strength:.3f})"
@router.post(
"/two_causal_analysis",
response_model=CausalAnalysisResponse,
operation_id="two_points_causal_analysis",
tags=["双变量因果分析"],
)
async def two_causal_analysis(request: CausalAnalysisRequest):
"""双变量因果分析接口"""
try:
# 1. 获取数据
df_data = opt.api_select_to_frame(
point_list=[request.point1, request.point2],
start_time=request.start_time,
end_time=request.end_time,
interval=request.interval,
fill_method=request.fill_method,
)
# 2. 预处理数据
data1 = df_data[request.point1].values
data2 = df_data[request.point2].values
processed_data1, processed_data2 = preprocess_data(data1, data2)
# 3. 验证输入
validate_analysis_inputs(
request.point1,
request.point2,
request.start_time,
request.end_time,
df_data,
processed_data1,
processed_data2,
request.method,
)
# 4. 执行分析
analysis_methods = {
"granger": calculate_granger_causality,
"cross_correlation": calculate_cross_correlation,
"transfer_entropy": calculate_transfer_entropy_analysis,
}
analysis_func = analysis_methods.get(request.method)
if not analysis_func:
raise HTTPException(
status_code=400, detail=f"不支持的分析方法: {request.method}"
)
result = analysis_func(processed_data1, processed_data2, request.max_lag)
if not result.get("success", False):
raise HTTPException(
status_code=500,
detail=f"分析计算失败: {result.get('error', '未知错误')}",
)
# 5. 构造响应
data_quality = assess_data_quality(processed_data1, processed_data2)
confidence = determine_confidence_level(
result["causal_strength"], len(processed_data1), data_quality["quality"]
)
summary = interpret_causal_strength(result["causal_strength"], request.method)
recommendations = generate_recommendations(
result["causal_strength"], result["direction"], confidence
)
return CausalAnalysisResponse(
causal_strength=result["causal_strength"],
direction=result["direction"],
method=request.method,
confidence=confidence,
data_points=len(processed_data1),
summary=summary,
recommendations=recommendations,
details=result.get("details"),
)
except CausalAnalysisError as e:
raise e.to_http_exception()
except Exception as e:
logger.error(f"因果分析错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"系统错误: {str(e)}")