|
| 1 | +import pytest |
| 2 | +import pytorch_lightning as pl |
| 3 | +import torch |
| 4 | +from torch.utils.data import TensorDataset |
| 5 | + |
| 6 | +from dowhy.causal_prediction.algorithms.cacm import CACM |
| 7 | +from dowhy.causal_prediction.algorithms.erm import ERM |
| 8 | +from dowhy.causal_prediction.dataloaders.get_data_loader import get_loaders |
| 9 | +from dowhy.causal_prediction.models.networks import MLP, Classifier |
| 10 | +from dowhy.datasets import linear_dataset |
| 11 | + |
| 12 | + |
| 13 | +class LinearTensorDataset: |
| 14 | + N_WORKERS = 0 |
| 15 | + |
| 16 | + def __init__(self, n_envs, n_samples, input_shape, num_classes): |
| 17 | + self.input_shape = input_shape |
| 18 | + self.num_classes = num_classes |
| 19 | + self.datasets = [] |
| 20 | + |
| 21 | + for env in range(n_envs): |
| 22 | + data = linear_dataset( |
| 23 | + beta=10, |
| 24 | + num_common_causes=2, |
| 25 | + num_instruments=0, |
| 26 | + num_samples=n_samples, |
| 27 | + treatment_is_binary=True, |
| 28 | + outcome_is_binary=True, |
| 29 | + ) |
| 30 | + df = data["df"] |
| 31 | + |
| 32 | + # Use treatment as input, outcome as label |
| 33 | + x = torch.tensor(df[data["treatment_name"]].values, dtype=torch.float32).reshape(-1, 1) |
| 34 | + y = torch.tensor(df[data["outcome_name"]].values, dtype=torch.long) |
| 35 | + |
| 36 | + # Use common causes as attributes |
| 37 | + cc_names = data["common_causes_names"] |
| 38 | + a = torch.tensor(df[cc_names].values, dtype=torch.float32) |
| 39 | + self.datasets.append(TensorDataset(x, y, a)) |
| 40 | + |
| 41 | + def __getitem__(self, index): |
| 42 | + return self.datasets[index] |
| 43 | + |
| 44 | + def __len__(self): |
| 45 | + return len(self.datasets) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.usefixtures("fixed_seed") |
| 49 | +@pytest.mark.parametrize( |
| 50 | + "algorithm_cls, algorithm_kwargs", |
| 51 | + [ |
| 52 | + (ERM, {}), |
| 53 | + (CACM, {"gamma": 1e-2, "attr_types": ["causal"], "lambda_causal": 10.0}), |
| 54 | + ], |
| 55 | +) |
| 56 | +def test_causal_prediction_training_and_eval(algorithm_cls, algorithm_kwargs, fixed_seed): |
| 57 | + # Use the new linear dataset-based class |
| 58 | + dataset = LinearTensorDataset(n_envs=4, n_samples=1000, input_shape=(1,), num_classes=2) |
| 59 | + loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64, val_envs=[2], test_envs=[3]) |
| 60 | + |
| 61 | + # Model |
| 62 | + n_inputs = dataset.input_shape[0] |
| 63 | + mlp_width = 128 |
| 64 | + mlp_depth = 4 |
| 65 | + mlp_dropout = 0.1 |
| 66 | + n_outputs = mlp_width |
| 67 | + featurizer = MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout) |
| 68 | + classifier = Classifier(featurizer.n_outputs, dataset.num_classes) |
| 69 | + model = torch.nn.Sequential(featurizer, classifier) |
| 70 | + |
| 71 | + # Train |
| 72 | + algorithm = algorithm_cls(model, lr=1e-3, **algorithm_kwargs) |
| 73 | + trainer = pl.Trainer(devices=1, max_epochs=5, accelerator="cpu", logger=False, enable_checkpointing=False) |
| 74 | + |
| 75 | + # Fit |
| 76 | + trainer.fit(algorithm, loaders["train_loaders"], loaders["val_loaders"]) |
| 77 | + |
| 78 | + # Check results |
| 79 | + results = trainer.test(algorithm, dataloaders=loaders["test_loaders"]) |
| 80 | + assert isinstance(results, list) |
| 81 | + assert len(results) > 0 |
| 82 | + for r in results: |
| 83 | + if "test_acc" in r: |
| 84 | + assert r["test_acc"] > 0.7, f"Test accuracy too low: {r['test_acc']}" |
| 85 | + if "test_loss" in r: |
| 86 | + assert r["test_loss"] < 1.0, f"Test loss too high: {r['test_loss']}" |
0 commit comments