|
|
""" |
|
|
Causal Discovery Module |
|
|
|
|
|
Discover causal relationships from observational data using various algorithms. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from typing import Optional, Dict, List, Tuple |
|
|
from .causal_graph import CausalGraph |
|
|
|
|
|
|
|
|
class CausalDiscovery: |
|
|
""" |
|
|
Discover causal structures from data. |
|
|
|
|
|
Implements various causal discovery algorithms to learn |
|
|
causal graphs from observational data. |
|
|
""" |
|
|
|
|
|
def __init__(self, method: str = 'pc'): |
|
|
""" |
|
|
Initialize causal discovery. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
method : str |
|
|
Discovery method ('pc', 'ges', 'lingam') |
|
|
""" |
|
|
self.method = method |
|
|
|
|
|
def discover_from_data( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
alpha: float = 0.05, |
|
|
max_cond_vars: int = 3 |
|
|
) -> CausalGraph: |
|
|
""" |
|
|
Discover causal graph from data. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Observational data |
|
|
alpha : float |
|
|
Significance level for independence tests |
|
|
max_cond_vars : int |
|
|
Maximum number of conditioning variables |
|
|
|
|
|
Returns |
|
|
------- |
|
|
CausalGraph |
|
|
Discovered causal graph |
|
|
""" |
|
|
if self.method == 'pc': |
|
|
return self._pc_algorithm(data, alpha, max_cond_vars) |
|
|
elif self.method == 'ges': |
|
|
return self._ges_algorithm(data) |
|
|
elif self.method == 'lingam': |
|
|
return self._lingam_algorithm(data) |
|
|
else: |
|
|
raise ValueError(f"Unknown method: {self.method}") |
|
|
|
|
|
def _pc_algorithm( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
alpha: float, |
|
|
max_cond_vars: int |
|
|
) -> CausalGraph: |
|
|
""" |
|
|
PC (Peter-Clark) algorithm for causal discovery. |
|
|
|
|
|
This is a constraint-based algorithm that uses conditional |
|
|
independence tests to discover causal structure. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Observational data |
|
|
alpha : float |
|
|
Significance level |
|
|
max_cond_vars : int |
|
|
Maximum conditioning set size |
|
|
|
|
|
Returns |
|
|
------- |
|
|
CausalGraph |
|
|
Discovered graph |
|
|
""" |
|
|
try: |
|
|
from pgmpy.estimators import PC |
|
|
from pgmpy.independence_tests import ChiSquareTest |
|
|
|
|
|
|
|
|
pc = PC(data=data) |
|
|
model = pc.estimate( |
|
|
significance_level=alpha, |
|
|
max_cond_vars=max_cond_vars |
|
|
) |
|
|
|
|
|
|
|
|
graph = CausalGraph(name="pc_discovered") |
|
|
|
|
|
|
|
|
for node in model.nodes(): |
|
|
graph.add_node(node) |
|
|
|
|
|
|
|
|
for edge in model.edges(): |
|
|
graph.add_edge(edge[0], edge[1]) |
|
|
|
|
|
return graph |
|
|
|
|
|
except ImportError: |
|
|
print("pgmpy required for PC algorithm") |
|
|
return self._simple_correlation_graph(data) |
|
|
|
|
|
def _ges_algorithm(self, data: pd.DataFrame) -> CausalGraph: |
|
|
""" |
|
|
GES (Greedy Equivalence Search) algorithm. |
|
|
|
|
|
Score-based causal discovery algorithm. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Observational data |
|
|
|
|
|
Returns |
|
|
------- |
|
|
CausalGraph |
|
|
Discovered graph |
|
|
""" |
|
|
|
|
|
print("GES algorithm not fully implemented yet") |
|
|
return self._simple_correlation_graph(data) |
|
|
|
|
|
def _lingam_algorithm(self, data: pd.DataFrame) -> CausalGraph: |
|
|
""" |
|
|
LiNGAM (Linear Non-Gaussian Acyclic Model) algorithm. |
|
|
|
|
|
Assumes linear relationships and non-Gaussian noise. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Observational data |
|
|
|
|
|
Returns |
|
|
------- |
|
|
CausalGraph |
|
|
Discovered graph |
|
|
""" |
|
|
|
|
|
print("LiNGAM algorithm not fully implemented yet") |
|
|
return self._simple_correlation_graph(data) |
|
|
|
|
|
def _simple_correlation_graph( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
threshold: float = 0.3 |
|
|
) -> CausalGraph: |
|
|
""" |
|
|
Create a simple graph based on correlations. |
|
|
|
|
|
This is a fallback method and does NOT imply causation. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Data |
|
|
threshold : float |
|
|
Correlation threshold |
|
|
|
|
|
Returns |
|
|
------- |
|
|
CausalGraph |
|
|
Correlation-based graph |
|
|
""" |
|
|
graph = CausalGraph(name="correlation_based") |
|
|
|
|
|
|
|
|
for col in data.columns: |
|
|
graph.add_node(col) |
|
|
|
|
|
|
|
|
corr_matrix = data.corr() |
|
|
|
|
|
for i, col1 in enumerate(data.columns): |
|
|
for j, col2 in enumerate(data.columns): |
|
|
if i < j: |
|
|
corr = abs(corr_matrix.loc[col1, col2]) |
|
|
if corr > threshold: |
|
|
|
|
|
try: |
|
|
graph.add_edge( |
|
|
col1, col2, |
|
|
strength=corr, |
|
|
confidence=0.5, |
|
|
mechanism="correlation (not causal)" |
|
|
) |
|
|
except ValueError: |
|
|
|
|
|
try: |
|
|
graph.add_edge( |
|
|
col2, col1, |
|
|
strength=corr, |
|
|
confidence=0.5, |
|
|
mechanism="correlation (not causal)" |
|
|
) |
|
|
except ValueError: |
|
|
|
|
|
pass |
|
|
|
|
|
return graph |
|
|
|
|
|
def test_conditional_independence( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
X: str, |
|
|
Y: str, |
|
|
Z: Optional[List[str]] = None, |
|
|
method: str = 'fisherz' |
|
|
) -> Tuple[float, float]: |
|
|
""" |
|
|
Test conditional independence X ⊥ Y | Z. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Data |
|
|
X : str |
|
|
First variable |
|
|
Y : str |
|
|
Second variable |
|
|
Z : List[str], optional |
|
|
Conditioning variables |
|
|
method : str |
|
|
Test method ('fisherz', 'chi_square') |
|
|
|
|
|
Returns |
|
|
------- |
|
|
tuple |
|
|
(test_statistic, p_value) |
|
|
""" |
|
|
if Z is None: |
|
|
Z = [] |
|
|
|
|
|
if method == 'fisherz': |
|
|
return self._fisherz_test(data, X, Y, Z) |
|
|
elif method == 'chi_square': |
|
|
return self._chi_square_test(data, X, Y, Z) |
|
|
else: |
|
|
raise ValueError(f"Unknown test method: {method}") |
|
|
|
|
|
def _fisherz_test( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
X: str, |
|
|
Y: str, |
|
|
Z: List[str] |
|
|
) -> Tuple[float, float]: |
|
|
""" |
|
|
Fisher's Z test for conditional independence. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Data |
|
|
X : str |
|
|
First variable |
|
|
Y : str |
|
|
Second variable |
|
|
Z : List[str] |
|
|
Conditioning variables |
|
|
|
|
|
Returns |
|
|
------- |
|
|
tuple |
|
|
(test_statistic, p_value) |
|
|
""" |
|
|
from scipy.stats import norm |
|
|
|
|
|
n = len(data) |
|
|
|
|
|
if len(Z) == 0: |
|
|
|
|
|
corr = data[[X, Y]].corr().loc[X, Y] |
|
|
else: |
|
|
|
|
|
all_vars = [X, Y] + Z |
|
|
corr_matrix = data[all_vars].corr() |
|
|
|
|
|
|
|
|
|
|
|
corr_XY = corr_matrix.loc[X, Y] |
|
|
corr = corr_XY |
|
|
|
|
|
|
|
|
if abs(corr) >= 0.9999: |
|
|
corr = 0.9999 * np.sign(corr) |
|
|
|
|
|
z = 0.5 * np.log((1 + corr) / (1 - corr)) |
|
|
test_stat = np.sqrt(n - len(Z) - 3) * z |
|
|
|
|
|
|
|
|
p_value = 2 * (1 - norm.cdf(abs(test_stat))) |
|
|
|
|
|
return test_stat, p_value |
|
|
|
|
|
def _chi_square_test( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
X: str, |
|
|
Y: str, |
|
|
Z: List[str] |
|
|
) -> Tuple[float, float]: |
|
|
""" |
|
|
Chi-square test for conditional independence. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : pd.DataFrame |
|
|
Data |
|
|
X : str |
|
|
First variable |
|
|
Y : str |
|
|
Second variable |
|
|
Z : List[str] |
|
|
Conditioning variables |
|
|
|
|
|
Returns |
|
|
------- |
|
|
tuple |
|
|
(test_statistic, p_value) |
|
|
""" |
|
|
from scipy.stats import chi2_contingency |
|
|
|
|
|
if len(Z) == 0: |
|
|
|
|
|
contingency_table = pd.crosstab(data[X], data[Y]) |
|
|
chi2, p_value, dof, expected = chi2_contingency(contingency_table) |
|
|
return chi2, p_value |
|
|
else: |
|
|
|
|
|
|
|
|
chi2_sum = 0 |
|
|
dof_sum = 0 |
|
|
|
|
|
for z_value in data[Z[0]].unique(): |
|
|
subset = data[data[Z[0]] == z_value] |
|
|
if len(subset) > 1: |
|
|
contingency_table = pd.crosstab(subset[X], subset[Y]) |
|
|
if contingency_table.shape[0] > 1 and contingency_table.shape[1] > 1: |
|
|
chi2, _, dof, _ = chi2_contingency(contingency_table) |
|
|
chi2_sum += chi2 |
|
|
dof_sum += dof |
|
|
|
|
|
|
|
|
from scipy.stats import chi2 |
|
|
p_value = 1 - chi2.cdf(chi2_sum, dof_sum) if dof_sum > 0 else 1.0 |
|
|
|
|
|
return chi2_sum, p_value |
|
|
|