"""
Causal Graph Utilities
Utilities for building and validating causal graphs.
"""
import logging
import networkx as nx
import pandas as pd
from typing import Any, Dict, List, Optional
logger = logging.getLogger("dowhy-mcp-server.utils.graph")
def build_causal_graph(
data: pd.DataFrame,
treatment: str,
outcome: str,
confounders: List[str],
graph_type: str = "dag"
) -> Dict[str, Any]:
"""
Build a causal graph from data and domain knowledge.
Args:
data: Input DataFrame
treatment: Treatment variable name
outcome: Outcome variable name
confounders: List of confounder variables
graph_type: Type of graph to build ('dag', 'cpdag', 'pdag')
Returns:
Dictionary with graph information
"""
# Create a directed graph
G = nx.DiGraph()
# Add nodes
all_nodes = [treatment, outcome] + confounders
G.add_nodes_from(all_nodes)
# Add edges based on domain knowledge
# 1. Confounders -> Treatment
for conf in confounders:
G.add_edge(conf, treatment)
# 2. Confounders -> Outcome
for conf in confounders:
G.add_edge(conf, outcome)
# 3. Treatment -> Outcome
G.add_edge(treatment, outcome)
# Convert to specified graph type
if graph_type == "cpdag":
# Convert to CPDAG (completed partially directed acyclic graph)
# This is a simplified version - in practice, use a proper CPDAG algorithm
cpdag = G.copy()
# For demonstration, we'll just make some edges bidirectional
for conf in confounders:
if (conf, outcome) in cpdag.edges:
cpdag.add_edge(outcome, conf)
graph = cpdag
elif graph_type == "pdag":
# Convert to PDAG (partially directed acyclic graph)
# This is a simplified version
pdag = G.copy()
# For demonstration, we'll remove some edges
if len(confounders) > 1:
pdag.remove_edge(confounders[0], outcome)
graph = pdag
else:
# Default to DAG
graph = G
# Calculate graph properties
is_dag = nx.is_directed_acyclic_graph(graph)
# Create adjacency matrix
adj_matrix = nx.adjacency_matrix(graph).toarray().tolist()
# Create edge list
edge_list = list(graph.edges())
return {
"type": graph_type,
"is_dag": is_dag,
"nodes": all_nodes,
"edges": [{"from": u, "to": v} for u, v in edge_list],
"adjacency_matrix": adj_matrix,
"treatment": treatment,
"outcome": outcome,
"confounders": confounders
}
def validate_graph_structure(G: nx.DiGraph, data: pd.DataFrame) -> Dict[str, Any]:
"""
Validate a causal graph structure against data.
Args:
G: NetworkX DiGraph object
data: DataFrame with the data
Returns:
Dictionary with validation results
"""
# Check if graph is a DAG
is_dag = nx.is_directed_acyclic_graph(G)
# Check for isolated nodes
isolated_nodes = list(nx.isolates(G))
# Check for cycles
try:
cycles = list(nx.simple_cycles(G))
except:
cycles = []
# Check d-separation properties
# This is a simplified version - in practice, use proper d-separation tests
d_separation_tests = []
# For demonstration, we'll check some basic conditional independence
# between pairs of nodes that should be d-separated
nodes = list(G.nodes())
for i in range(len(nodes)):
for j in range(i+1, len(nodes)):
node_i = nodes[i]
node_j = nodes[j]
# Skip if directly connected
if G.has_edge(node_i, node_j) or G.has_edge(node_j, node_i):
continue
# Find a potential separator set
# For simplicity, we'll use all other nodes as separators
separators = [n for n in nodes if n != node_i and n != node_j]
# In practice, you would perform a conditional independence test here
# For now, we'll just record the test
d_separation_tests.append({
"node_i": node_i,
"node_j": node_j,
"separators": separators,
"should_be_independent": True
})
return {
"is_dag": is_dag,
"isolated_nodes": isolated_nodes,
"cycles": [list(c) for c in cycles],
"d_separation_tests": d_separation_tests,
"validation_passed": is_dag and len(isolated_nodes) == 0 and len(cycles) == 0
}