Skip to content

Learned kernel MMD with KeOps backend #602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 341 additions & 0 deletions alibi_detect/cd/keops/learned_kernel.py

Large diffs are not rendered by default.

130 changes: 130 additions & 0 deletions alibi_detect/cd/keops/tests/test_learned_kernel_keops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from itertools import product
import numpy as np
import pytest
import torch
import torch.nn as nn
from typing import Callable, Optional, Union
from alibi_detect.utils.frameworks import has_keops
from alibi_detect.utils.pytorch import GaussianRBF as GaussianRBFTorch
from alibi_detect.utils.pytorch import mmd2_from_kernel_matrix
if has_keops:
from alibi_detect.cd.keops.learned_kernel import LearnedKernelDriftKeops
from alibi_detect.utils.keops import GaussianRBF
from pykeops.torch import LazyTensor

n = 50 # number of instances used for the reference and test data samples in the tests


if has_keops:
class MyKernel(nn.Module):
def __init__(self, n_features: int, proj: bool):
super().__init__()
sigma = .1
self.kernel = GaussianRBF(trainable=True, sigma=torch.Tensor([sigma]))
self.has_proj = proj
if proj:
self.proj = nn.Linear(n_features, 2)
self.kernel_b = GaussianRBF(trainable=True, sigma=torch.Tensor([sigma]))

def forward(self, x_proj: LazyTensor, y_proj: LazyTensor, x: Optional[LazyTensor] = None,
y: Optional[LazyTensor] = None) -> LazyTensor:
similarity = self.kernel(x_proj, y_proj)
if self.has_proj:
similarity = similarity + self.kernel_b(x, y)
return similarity


# test List[Any] inputs to the detector
def identity_fn(x: Union[torch.Tensor, list]) -> torch.Tensor:
if isinstance(x, list):
return torch.from_numpy(np.array(x))
else:
return x


p_val = [.05]
n_features = [4]
preprocess_at_init = [True, False]
update_x_ref = [None, {'reservoir_sampling': 1000}]
preprocess_fn = [None, identity_fn]
n_permutations = [10]
batch_size_permutations = [5, 1000000]
train_size = [.5]
retrain_from_scratch = [True]
batch_size_predict = [1000000]
preprocess_batch = [None, identity_fn]
has_proj = [True, False]
tests_lkdrift = list(product(p_val, n_features, preprocess_at_init, update_x_ref, preprocess_fn,
n_permutations, batch_size_permutations, train_size, retrain_from_scratch,
batch_size_predict, preprocess_batch, has_proj))
n_tests = len(tests_lkdrift)


@pytest.fixture
def lkdrift_params(request):
return tests_lkdrift[request.param]


@pytest.mark.skipif(not has_keops, reason='Skipping since pykeops is not installed.')
@pytest.mark.parametrize('lkdrift_params', list(range(n_tests)), indirect=True)
def test_lkdrift(lkdrift_params):
p_val, n_features, preprocess_at_init, update_x_ref, preprocess_fn, \
n_permutations, batch_size_permutations, train_size, retrain_from_scratch, \
batch_size_predict, preprocess_batch, has_proj = lkdrift_params

np.random.seed(0)
torch.manual_seed(0)

kernel = MyKernel(n_features, has_proj)
x_ref = np.random.randn(*(n, n_features)).astype(np.float32)
x_test1 = np.ones_like(x_ref)
to_list = False
if preprocess_batch is not None and preprocess_fn is None:
to_list = True
x_ref = [_ for _ in x_ref]
update_x_ref = None

cd = LearnedKernelDriftKeops(
x_ref=x_ref,
kernel=kernel,
p_val=p_val,
preprocess_at_init=preprocess_at_init,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
n_permutations=n_permutations,
batch_size_permutations=batch_size_permutations,
train_size=train_size,
retrain_from_scratch=retrain_from_scratch,
batch_size_predict=batch_size_predict,
preprocess_batch_fn=preprocess_batch,
batch_size=32,
epochs=1
)

x_test0 = x_ref.copy()
preds_0 = cd.predict(x_test0)
assert cd.n == len(x_test0) + len(x_ref)
assert preds_0['data']['is_drift'] == 0

if to_list:
x_test1 = [_ for _ in x_test1]
preds_1 = cd.predict(x_test1)
assert cd.n == len(x_test1) + len(x_test0) + len(x_ref)
assert preds_1['data']['is_drift'] == 1
assert preds_0['data']['distance'] < preds_1['data']['distance']

# ensure the keops MMD^2 estimate matches the pytorch implementation for the same kernel
if not isinstance(x_ref, list) and update_x_ref is None and not has_proj:
if isinstance(preprocess_fn, Callable):
x_ref, x_test1 = cd.preprocess(x_test1)
n_ref, n_test = x_ref.shape[0], x_test1.shape[0]
x_all = torch.from_numpy(np.concatenate([x_ref, x_test1], axis=0)).float()
perms = [torch.randperm(n_ref + n_test) for _ in range(n_permutations)]
mmd2 = cd._mmd2(x_all, perms, n_ref, n_test)[0]

if isinstance(preprocess_batch, Callable):
x_all = preprocess_batch(x_all)
kernel = GaussianRBFTorch(sigma=cd.kernel.kernel.sigma)
kernel_mat = kernel(x_all, x_all)
mmd2_torch = mmd2_from_kernel_matrix(kernel_mat, n_test)
np.testing.assert_almost_equal(mmd2, mmd2_torch, decimal=6)
35 changes: 26 additions & 9 deletions alibi_detect/cd/learned_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from typing import Callable, Dict, Optional, Union
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, has_keops, BackendValidator
from alibi_detect.utils.warnings import deprecated_alias
from alibi_detect.base import DriftConfigMixin

Expand All @@ -13,6 +13,9 @@
from alibi_detect.cd.tensorflow.learned_kernel import LearnedKernelDriftTF
from alibi_detect.utils.tensorflow.data import TFDataset

if has_keops:
from alibi_detect.cd.keops.learned_kernel import LearnedKernelDriftKeops


class LearnedKernelDrift(DriftConfigMixin):
@deprecated_alias(preprocess_x_ref='preprocess_at_init')
Expand All @@ -27,13 +30,15 @@ def __init__(
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
n_permutations: int = 100,
batch_size_permutations: int = 1000000,
var_reg: float = 1e-5,
reg_loss_fn: Callable = (lambda kernel: 0),
train_size: Optional[float] = .75,
retrain_from_scratch: bool = True,
optimizer: Optional[Callable] = None,
learning_rate: float = 1e-3,
batch_size: int = 32,
batch_size_predict: int = 1000000,
preprocess_batch_fn: Optional[Callable] = None,
epochs: int = 3,
verbose: int = 0,
Expand All @@ -52,7 +57,6 @@ def __init__(
For details see Liu et al (2020): Learning Deep Kernels for Non-Parametric Two-Sample Tests
(https://arxiv.org/abs/2002.09116)


Parameters
----------
x_ref
Expand All @@ -78,6 +82,9 @@ def __init__(
Function to preprocess the data before applying the kernel.
n_permutations
The number of permutations to use in the permutation test once the MMD has been computed.
batch_size_permutations
KeOps computes the n_permutations of the MMD^2 statistics in chunks of batch_size_permutations.
Only relevant for 'keops' backend.
var_reg
Constant added to the estimated variance of the MMD for stability.
reg_loss_fn
Expand All @@ -94,6 +101,8 @@ def __init__(
Learning rate used by optimizer.
batch_size
Batch size used during training of the kernel.
batch_size_predict
Batch size used for the trained drift detector predictions. Only relevant for 'keops' backend.
preprocess_batch_fn
Optional batch preprocessing function. For example to convert a list of objects to a batch which can be
processed by the kernel.
Expand All @@ -105,11 +114,11 @@ def __init__(
Optional additional kwargs when training the kernel.
device
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Relevant for 'pytorch' and 'keops' backends.
dataset
Dataset object used during training.
dataloader
Dataloader object used during training. Only relevant for 'pytorch' backend.
Dataloader object used during training. Relevant for 'pytorch' and 'keops' backends.
input_shape
Shape of input data.
data_type
Expand All @@ -123,7 +132,8 @@ def __init__(
backend = backend.lower()
BackendValidator(
backend_options={'tensorflow': ['tensorflow'],
'pytorch': ['pytorch']},
'pytorch': ['pytorch'],
'keops': ['keops']},
construct_name=self.__class__.__name__
).verify_backend(backend)

Expand All @@ -134,18 +144,25 @@ def __init__(
pop_kwargs += ['optimizer']
[kwargs.pop(k, None) for k in pop_kwargs]

if backend == 'tensorflow' and has_tensorflow:
pop_kwargs = ['device', 'dataloader']
if backend == 'tensorflow':
pop_kwargs = ['device', 'dataloader', 'batch_size_permutations', 'batch_size_predict']
[kwargs.pop(k, None) for k in pop_kwargs]
if dataset is None:
kwargs.update({'dataset': TFDataset})
self._detector = LearnedKernelDriftTF(*args, **kwargs) # type: ignore
detector = LearnedKernelDriftTF
else:
if dataset is None:
kwargs.update({'dataset': TorchDataset})
if dataloader is None:
kwargs.update({'dataloader': DataLoader})
self._detector = LearnedKernelDriftTorch(*args, **kwargs) # type: ignore
if backend == 'pytorch':
pop_kwargs = ['batch_size_permutations', 'batch_size_predict']
[kwargs.pop(k, None) for k in pop_kwargs]
detector = LearnedKernelDriftTorch
else:
detector = LearnedKernelDriftKeops

self._detector = detector(*args, **kwargs) # type: ignore
self.meta = self._detector.meta

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
Expand Down
22 changes: 21 additions & 1 deletion alibi_detect/cd/tests/test_learned_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from alibi_detect.cd import LearnedKernelDrift
from alibi_detect.cd.pytorch.learned_kernel import LearnedKernelDriftTorch
from alibi_detect.cd.tensorflow.learned_kernel import LearnedKernelDriftTF
from alibi_detect.utils.frameworks import has_keops
if has_keops:
from alibi_detect.cd.keops.learned_kernel import LearnedKernelDriftKeops
from pykeops.torch import LazyTensor

n, n_features = 100, 5

Expand Down Expand Up @@ -37,7 +41,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.einsum('ji,ki->jk', self.dense(x), self.dense(y))


tests_lkdrift = ['tensorflow', 'pytorch', 'PyToRcH', 'mxnet']
if has_keops:
class MyKernelKeops(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as this one?

def __init__(self):
super().__init__()

def forward(self, x: LazyTensor, y: LazyTensor) -> LazyTensor:
return (- ((x - y) ** 2).sum(-1)).exp()


tests_lkdrift = ['tensorflow', 'pytorch', 'keops', 'PyToRcH', 'mxnet']
n_tests = len(tests_lkdrift)


Expand All @@ -53,6 +66,8 @@ def test_lkdrift(lkdrift_params):
kernel = MyKernelTorch(n_features)
elif backend.lower() == 'tensorflow':
kernel = MyKernelTF(n_features)
elif has_keops and backend.lower() == 'keops':
kernel = MyKernelKeops()
else:
kernel = None
x_ref = np.random.randn(*(n, n_features))
Expand All @@ -61,10 +76,15 @@ def test_lkdrift(lkdrift_params):
cd = LearnedKernelDrift(x_ref=x_ref, kernel=kernel, backend=backend)
except NotImplementedError:
cd = None
except ImportError:
assert not has_keops
cd = None

if backend.lower() == 'pytorch':
assert isinstance(cd._detector, LearnedKernelDriftTorch)
elif backend.lower() == 'tensorflow':
assert isinstance(cd._detector, LearnedKernelDriftTF)
elif has_keops and backend.lower() == 'keops':
assert isinstance(cd._detector, LearnedKernelDriftKeops)
else:
assert cd is None
4 changes: 4 additions & 0 deletions alibi_detect/saving/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,13 +842,15 @@ class LearnedKernelDriftConfig(DriftDetectorConfig):
preprocess_at_init: bool = True
update_x_ref: Optional[Dict[str, int]] = None
n_permutations: int = 100
batch_size_permutations: int = 1000000
var_reg: float = 1e-5
reg_loss_fn: Optional[str] = None
train_size: Optional[float] = .75
retrain_from_scratch: bool = True
optimizer: Optional[Union[str, OptimizerConfig]] = None
learning_rate: float = 1e-3
batch_size: int = 32
batch_size_predict: int = 1000000
preprocess_batch_fn: Optional[str] = None
epochs: int = 3
verbose: int = 0
Expand All @@ -872,13 +874,15 @@ class LearnedKernelDriftConfigResolved(DriftDetectorConfigResolved):
preprocess_at_init: bool = True
update_x_ref: Optional[Dict[str, int]] = None
n_permutations: int = 100
batch_size_permutations: int = 1000000
var_reg: float = 1e-5
reg_loss_fn: Optional[Callable] = None
train_size: Optional[float] = .75
retrain_from_scratch: bool = True
optimizer: Optional['tf.keras.optimizers.Optimizer'] = None
learning_rate: float = 1e-3
batch_size: int = 32
batch_size_predict: int = 1000000
preprocess_batch_fn: Optional[Callable] = None
epochs: int = 3
verbose: int = 0
Expand Down
1 change: 1 addition & 0 deletions alibi_detect/tests/test_dep_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def test_keops_utils_dependencies(opt_dep):
dependency_map = defaultdict(lambda: ['default'])
for dependency, relations in [
("GaussianRBF", ['keops']),
("DeepKernel", ['keops']),
]:
dependency_map[dependency] = relations
from alibi_detect.utils import keops as keops_utils
Expand Down
8 changes: 6 additions & 2 deletions alibi_detect/utils/keops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from alibi_detect.utils.missing_optional_dependency import import_optional


GaussianRBF = import_optional('alibi_detect.utils.keops.kernels', names=['GaussianRBF'])
GaussianRBF, DeepKernel = import_optional(
'alibi_detect.utils.keops.kernels',
names=['GaussianRBF', 'DeepKernel']
)

__all__ = [
"GaussianRBF"
"GaussianRBF",
"DeepKernel"
]
Loading