Skip to content

Commit 17a97cb

Browse files
Tests for causal prediction
Signed-off-by: maartenvanhooft <[email protected]>
1 parent 28ceb19 commit 17a97cb

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

tests/causal_prediction/__init__.py

Whitespace-only changes.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
3+
from dowhy.causal_prediction.algorithms.utils import gaussian_kernel, mmd_compute, my_cdist
4+
5+
6+
class TestAlgorithmUtils:
7+
def test_my_cdist(self):
8+
# Squared Euclidean distances between x1 and x2
9+
x1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
10+
x2 = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
11+
distances = my_cdist(x1, x2)
12+
expected = torch.tensor([[1.0, 1.0], [13.0, 5.0]])
13+
assert torch.allclose(distances, expected, rtol=1e-5)
14+
15+
# Single vector case
16+
x1 = torch.tensor([[1.0, 2.0]])
17+
x2 = torch.tensor([[1.0, 1.0]])
18+
distances = my_cdist(x1, x2)
19+
expected = torch.tensor([[1.0]])
20+
assert torch.allclose(distances, expected, rtol=1e-5)
21+
22+
def test_gaussian_kernel(self):
23+
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
24+
y = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
25+
gamma = 1.0
26+
kernel = gaussian_kernel(x, y, gamma)
27+
28+
# Kernel values are exp(-gamma * squared distance)
29+
assert kernel.shape == (2, 2)
30+
assert torch.all(kernel >= 0) and torch.all(kernel <= 1)
31+
32+
# Symmetry for same input
33+
kernel_xx = gaussian_kernel(x, x, gamma)
34+
assert torch.allclose(kernel_xx, kernel_xx.t(), rtol=1e-5)
35+
36+
def test_mmd_compute(self):
37+
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
38+
y = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
39+
gamma = 1.0
40+
41+
# MMD^2 = mean(K(x, x)) + mean(K(y, y)) - 2 * mean(K(x, y))
42+
mmd_gaussian = mmd_compute(x, y, "gaussian", gamma)
43+
assert mmd_gaussian >= 0
44+
45+
# MMD for identical distributions should be zero
46+
mmd_same = mmd_compute(x, x, "gaussian", gamma)
47+
assert torch.allclose(mmd_same, torch.tensor(0.0), rtol=1e-5)
48+
49+
# 'other' kernel: sum of mean squared difference of means and covariances
50+
mmd_other = mmd_compute(x, y, "other", gamma)
51+
assert mmd_other >= 0
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
import pytorch_lightning as pl
3+
import torch
4+
from torch.utils.data import TensorDataset
5+
6+
from dowhy.causal_prediction.algorithms.cacm import CACM
7+
from dowhy.causal_prediction.algorithms.erm import ERM
8+
from dowhy.causal_prediction.dataloaders.get_data_loader import get_loaders
9+
from dowhy.causal_prediction.models.networks import MLP, Classifier
10+
from dowhy.datasets import linear_dataset
11+
12+
13+
class LinearTensorDataset:
14+
N_WORKERS = 0
15+
16+
def __init__(self, n_envs, n_samples, input_shape, num_classes):
17+
self.input_shape = input_shape
18+
self.num_classes = num_classes
19+
self.datasets = []
20+
21+
for env in range(n_envs):
22+
data = linear_dataset(
23+
beta=10,
24+
num_common_causes=2,
25+
num_instruments=0,
26+
num_samples=n_samples,
27+
treatment_is_binary=True,
28+
outcome_is_binary=True,
29+
)
30+
df = data["df"]
31+
32+
# Use treatment as input, outcome as label
33+
x = torch.tensor(df[data["treatment_name"]].values, dtype=torch.float32).reshape(-1, 1)
34+
y = torch.tensor(df[data["outcome_name"]].values, dtype=torch.long)
35+
36+
# Use common causes as attributes
37+
cc_names = data["common_causes_names"]
38+
a = torch.tensor(df[cc_names].values, dtype=torch.float32)
39+
self.datasets.append(TensorDataset(x, y, a))
40+
41+
def __getitem__(self, index):
42+
return self.datasets[index]
43+
44+
def __len__(self):
45+
return len(self.datasets)
46+
47+
48+
@pytest.mark.usefixtures("fixed_seed")
49+
@pytest.mark.parametrize(
50+
"algorithm_cls, algorithm_kwargs",
51+
[
52+
(ERM, {}),
53+
(CACM, {"gamma": 1e-2, "attr_types": ["causal"], "lambda_causal": 10.0}),
54+
],
55+
)
56+
def test_causal_prediction_training_and_eval(algorithm_cls, algorithm_kwargs, fixed_seed):
57+
# Use the new linear dataset-based class
58+
dataset = LinearTensorDataset(n_envs=4, n_samples=1000, input_shape=(1,), num_classes=2)
59+
loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64, val_envs=[2], test_envs=[3])
60+
61+
# Model
62+
n_inputs = dataset.input_shape[0]
63+
mlp_width = 128
64+
mlp_depth = 4
65+
mlp_dropout = 0.1
66+
n_outputs = mlp_width
67+
featurizer = MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout)
68+
classifier = Classifier(featurizer.n_outputs, dataset.num_classes)
69+
model = torch.nn.Sequential(featurizer, classifier)
70+
71+
# Train
72+
algorithm = algorithm_cls(model, lr=1e-3, **algorithm_kwargs)
73+
trainer = pl.Trainer(devices=1, max_epochs=5, accelerator="cpu", logger=False, enable_checkpointing=False)
74+
75+
# Fit
76+
trainer.fit(algorithm, loaders["train_loaders"], loaders["val_loaders"])
77+
78+
# Check results
79+
results = trainer.test(algorithm, dataloaders=loaders["test_loaders"])
80+
assert isinstance(results, list)
81+
assert len(results) > 0
82+
for r in results:
83+
if "test_acc" in r:
84+
assert r["test_acc"] > 0.7, f"Test accuracy too low: {r['test_acc']}"
85+
if "test_loss" in r:
86+
assert r["test_loss"] < 1.0, f"Test loss too high: {r['test_loss']}"

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
import numpy
44
import pytest
5+
import torch
56

67

78
@pytest.fixture
89
def fixed_seed():
910
rand.seed(0)
1011
numpy.random.seed(0)
12+
torch.manual_seed(0)
13+
if hasattr(torch, "cuda"):
14+
torch.cuda.manual_seed_all(0)

0 commit comments

Comments
 (0)