"""
纯 DoWhy 统计工具 - 只包含安全计算函数,不包含自定义统计方法
"""
import numpy as np
from typing import Union
import warnings
def safe_divide(
numerator: Union[float, np.ndarray],
denominator: Union[float, np.ndarray],
default_value: float = 0.0
) -> Union[float, np.ndarray]:
"""
安全除法,处理除零情况
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if isinstance(denominator, np.ndarray):
result = np.divide(numerator, denominator,
out=np.full_like(denominator, default_value, dtype=float),
where=(denominator != 0))
return result
else:
if denominator == 0:
return default_value
return numerator / denominator
def safe_log(
x: Union[float, np.ndarray],
epsilon: float = 1e-10
) -> Union[float, np.ndarray]:
"""
安全对数计算,处理非正值
"""
if isinstance(x, np.ndarray):
return np.log(np.maximum(x, epsilon))
else:
return np.log(max(x, epsilon))
def safe_sqrt(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""
安全平方根计算,处理负值
"""
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(x, 0))
else:
return np.sqrt(max(x, 0))