Skip to content

Tests for causal prediction #1321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
51 changes: 51 additions & 0 deletions tests/causal_prediction/test_algorithm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch

from dowhy.causal_prediction.algorithms.utils import gaussian_kernel, mmd_compute, my_cdist


class TestAlgorithmUtils:
def test_my_cdist(self):
# Squared Euclidean distances between x1 and x2
x1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
x2 = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
distances = my_cdist(x1, x2)
expected = torch.tensor([[1.0, 1.0], [13.0, 5.0]])
assert torch.allclose(distances, expected, rtol=1e-5)

# Single vector case
x1 = torch.tensor([[1.0, 2.0]])
x2 = torch.tensor([[1.0, 1.0]])
distances = my_cdist(x1, x2)
expected = torch.tensor([[1.0]])
assert torch.allclose(distances, expected, rtol=1e-5)

def test_gaussian_kernel(self):
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
gamma = 1.0
kernel = gaussian_kernel(x, y, gamma)

# Kernel values are exp(-gamma * squared distance)
assert kernel.shape == (2, 2)
assert torch.all(kernel >= 0) and torch.all(kernel <= 1)

# Symmetry for same input
kernel_xx = gaussian_kernel(x, x, gamma)
assert torch.allclose(kernel_xx, kernel_xx.t(), rtol=1e-5)

def test_mmd_compute(self):
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
gamma = 1.0

# MMD^2 = mean(K(x, x)) + mean(K(y, y)) - 2 * mean(K(x, y))
mmd_gaussian = mmd_compute(x, y, "gaussian", gamma)
assert mmd_gaussian >= 0

# MMD for identical distributions should be zero
mmd_same = mmd_compute(x, x, "gaussian", gamma)
assert torch.allclose(mmd_same, torch.tensor(0.0), rtol=1e-5)

# 'other' kernel: sum of mean squared difference of means and covariances
mmd_other = mmd_compute(x, y, "other", gamma)
assert mmd_other >= 0
115 changes: 115 additions & 0 deletions tests/causal_prediction/test_causal_prediction_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import pytest
import torch
import pytorch_lightning as pl

from dowhy.causal_prediction.algorithms.cacm import CACM
from dowhy.causal_prediction.algorithms.erm import ERM
from dowhy.causal_prediction.dataloaders.get_data_loader import get_loaders
from dowhy.causal_prediction.models.networks import MLP, Classifier
from dowhy.datasets import linear_dataset


class LinearTensorDataset:
N_WORKERS = 0

def __init__(self, env_specs, n_samples, input_shape, num_classes):
self.input_shape = input_shape
self.num_classes = num_classes
self.datasets = []
self.env_names = []

for env_name, beta in env_specs:
data = linear_dataset(
beta=beta,
num_common_causes=2,
num_instruments=0,
num_samples=n_samples,
treatment_is_binary=True,
outcome_is_binary=True,
)
df = data["df"]

# Use treatment as input, outcome as label
x = torch.tensor(df[data["treatment_name"]].values, dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(df[data["outcome_name"]].values, dtype=torch.long)

# Use common causes as attributes
cc_names = data["common_causes_names"]
a = torch.tensor(df[cc_names].values, dtype=torch.float32)
self.datasets.append(torch.utils.data.TensorDataset(x, y, a))
self.env_names.append(env_name)

self.env_by_name = {name: idx for idx, name in enumerate(self.env_names)}

def __getitem__(self, index):
return self.datasets[index]

def __len__(self):
return len(self.datasets)


def _evaluate_algorithm(algorithm_cls, algorithm_kwargs):
# Setup dataset
env_specs = [
("train", 10),
("val_same_distribution", 10),
("test_new_distribution", 5),
]
dataset = LinearTensorDataset(env_specs=env_specs, n_samples=1000, input_shape=(1,), num_classes=2)
envs = dataset.env_by_name

loaders = get_loaders(
dataset,
train_envs=[envs["train"]],
val_envs=[envs["val_same_distribution"]],
test_envs=[envs["test_new_distribution"]],
batch_size=64,
)

# Model
n_inputs = dataset.input_shape[0]
mlp_width = 128
mlp_depth = 4
mlp_dropout = 0.1
n_outputs = mlp_width
featurizer = MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout)
classifier = Classifier(featurizer.n_outputs, dataset.num_classes)
model = torch.nn.Sequential(featurizer, classifier)

# Train
algorithm = algorithm_cls(model, lr=1e-3, **algorithm_kwargs)
trainer = pl.Trainer(devices=1, max_epochs=5, accelerator="cpu", logger=False, enable_checkpointing=False)
trainer.fit(algorithm, loaders["train_loaders"], loaders["val_loaders"])

# Test
val_same_distribution = trainer.test(algorithm, dataloaders=loaders["val_loaders"], verbose=False)[0]["test_acc"]
test_new_distribution = trainer.test(algorithm, dataloaders=loaders["test_loaders"], verbose=False)[0]["test_acc"]

return val_same_distribution, test_new_distribution


@pytest.mark.usefixtures("fixed_seed")
def test_cacm_vs_erm_accuracy_and_gap():
# CACM
val_same_distribution_cacm, test_new_distribution_cacm = _evaluate_algorithm(
CACM,
{
"gamma": 1e-4,
"attr_types": ["causal"],
"lambda_causal": 1.0,
},
)
gap_cacm = val_same_distribution_cacm - test_new_distribution_cacm

# ERM
val_same_distribution_erm, test_new_distribution_erm = _evaluate_algorithm(ERM, {})
gap_erm = val_same_distribution_erm - test_new_distribution_erm

# Accuracy checks
assert val_same_distribution_erm > 0.7
assert test_new_distribution_erm > 0.7
assert val_same_distribution_cacm > 0.7
assert test_new_distribution_cacm > 0.7

# Generalization gap check
assert gap_erm > gap_cacm, f"Expected ERM to degrade more. ERM gap={gap_erm:.4f}, CACM gap={gap_cacm:.4f}"
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import numpy
import pytest
import torch


@pytest.fixture
def fixed_seed():
rand.seed(0)
numpy.random.seed(0)
torch.manual_seed(0)
if hasattr(torch, "cuda"):
torch.cuda.manual_seed_all(0)