Skip to content

Fix two bugs w/ kwargs and preprocess_batch_fn in preprocess_fn serialisation #752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.990
rev: v1.0.1
hooks:
- id: mypy
additional_dependencies: [
Expand Down
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
# Change Log

## v0.12.0dev
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.0...master)
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.1...master)

## v0.11.1
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.0...v0.11.1)

### Fixed

- Fixed two bugs with the saving/loading of drift detector `preprocess_fn`'s [#752](https://github.com/SeldonIO/alibi-detect/pull/752)):
- When `preprocess_fn` was a custom Python function wrapped in a partial, included kwarg's were not serialized. This has now been fixed.
- When saving drift detector `preprocess_fn`'s, for kwargs saved to `.dill` files, the filenames are now prepended with the kwarg name, so that files aren't overwritten if multiple kwargs are saved to `.dill`.

## v0.11.0
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.10.5...v0.11.0)
Expand Down
5 changes: 4 additions & 1 deletion alibi_detect/saving/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,10 @@ def _load_preprocess_config(cfg: dict) -> Optional[Callable]:
logger.warning('Unable to process preprocess_fn. No preprocessing function is defined.')
return None

return partial(preprocess_fn, **kwargs)
if kwargs == {}:
return preprocess_fn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the preprocess function was wrapped in a partial even if there were no kwarg's. Now the bare function is returned in this case.

else:
return partial(preprocess_fn, **kwargs)


def _load_model_config(cfg: dict) -> Callable:
Expand Down
8 changes: 4 additions & 4 deletions alibi_detect/saving/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from functools import partial
from pathlib import Path
from typing import Callable, Optional, Tuple, Union, Any, TYPE_CHECKING
from typing import Callable, Optional, Tuple, Union, Any, Dict, TYPE_CHECKING
import dill
import numpy as np
import toml
Expand Down Expand Up @@ -264,7 +264,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
The config dictionary, containing references to the serialized artefacts. The format if this dict matches that
of the `preprocess` field in the drift detector specification.
"""
preprocess_cfg = {}
preprocess_cfg: Dict[str, Any] = {}
local_path = Path('preprocess_fn')

# Serialize function
Expand Down Expand Up @@ -292,7 +292,7 @@ def _save_preprocess_config(preprocess_fn: Callable,

# Arbitrary function
elif callable(v):
src, _ = _serialize_object(v, filepath, local_path)
src, _ = _serialize_object(v, filepath, local_path.joinpath(k))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bug fix 1. i.e. add the kwarg name to the serialised dill file.

kwargs.update({k: src})

# Put remaining kwargs directly into cfg
Expand All @@ -302,7 +302,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
if 'preprocess_drift' in func:
preprocess_cfg.update(kwargs)
else:
kwargs.update({'kwargs': kwargs})
preprocess_cfg.update({'kwargs': kwargs})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bug fix 2. Previously the kwargs were updated but this had no effect. It should have been the returned preprocess_cfg dict that was updated instead.


return preprocess_cfg

Expand Down
10 changes: 9 additions & 1 deletion alibi_detect/saving/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def encoder_dropout_model(backend, current_cases):


@fixture
def preprocess_custom(encoder_model):
def preprocess_uae(encoder_model):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More descriptive naming. preprocess_custom actually contains a uae model, so renamed preprocess_uae. Otherwise confusion in test_saving.py where we are now testing "custom" preprocessing functions.

"""
Preprocess function with Untrained Autoencoder.
"""
Expand Down Expand Up @@ -263,6 +263,14 @@ def preprocess_simple(x: np.ndarray):
return x*2.0


@fixture
def preprocess_simple_with_kwargs():
"""
Simple function to test serialization of generic Python function with kwargs, within preprocess_fn.
"""
return partial(preprocess_simple, kwarg1=42, kwarg2=True)


@fixture
def preprocess_nlp(embedding, tokenizer, max_len, backend):
"""
Expand Down
87 changes: 57 additions & 30 deletions alibi_detect/saving/tests/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Internal functions such as save_kernel/load_kernel_config etc are also tested.
"""
from functools import partial
import os
from pathlib import Path
from typing import Callable

Expand All @@ -19,7 +20,8 @@
import torch.nn as nn

from .datasets import BinData, CategoricalData, ContinuousData, MixedData, TextData
from .models import (encoder_model, preprocess_custom, preprocess_hiddenoutput, preprocess_simple, # noqa: F401
from .models import (encoder_model, preprocess_uae, preprocess_hiddenoutput, preprocess_simple, # noqa: F401
preprocess_simple_with_kwargs,
preprocess_nlp, LATENT_DIM, classifier_model, kernel, deep_kernel, nlp_embedding_and_tokenizer,
embedding, tokenizer, max_len, enc_dim, encoder_dropout_model, optimizer)

Expand Down Expand Up @@ -105,7 +107,7 @@ def test_load_simple_config(cfg, tmp_path):
assert v == cfg_new[k]


@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_ksdrift(data, preprocess_fn, tmp_path):
"""
Expand Down Expand Up @@ -171,7 +173,7 @@ def test_save_ksdrift_nlp(data, preprocess_fn, enc_dim, tmp_path): # noqa: F811
@pytest.mark.skipif(version.parse(scipy.__version__) < version.parse('1.7.0'),
reason="Requires scipy version >= 1.7.0")
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_cvmdrift(data, preprocess_custom, tmp_path):
def test_save_cvmdrift(data, preprocess_uae, tmp_path):
"""
Test CVMDrift on continuous datasets, with UAE as preprocess_fn.

Expand All @@ -181,14 +183,14 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
X_ref, X_h0 = data
cd = CVMDrift(X_ref,
p_val=P_VAL,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
preprocess_at_init=True,
)
save_detector(cd, tmp_path)
cd_load = load_detector(tmp_path)

# Assert
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load.x_ref)
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load.x_ref)
assert cd_load.n_features == LATENT_DIM
assert cd_load.p_val == P_VAL
assert isinstance(cd_load.preprocess_fn, Callable)
Expand All @@ -203,7 +205,7 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
], indirect=True
)
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811
def test_save_mmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811
"""
Test MMDDrift on continuous datasets, with UAE as preprocess_fn.

Expand All @@ -217,7 +219,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
kwargs = {
'p_val': P_VAL,
'backend': backend,
'preprocess_fn': preprocess_custom,
'preprocess_fn': preprocess_uae,
'n_permutations': N_PERMUTATIONS,
'preprocess_at_init': True,
'kernel': kernel,
Expand All @@ -237,7 +239,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
preds_load = cd_load.predict(X_h0)

# assertions
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref)
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref)
assert not cd_load._detector.infer_sigma
assert cd_load._detector.n_permutations == N_PERMUTATIONS
assert cd_load._detector.p_val == P_VAL
Expand All @@ -248,7 +250,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
assert preds['data']['p_val'] == preds_load['data']['p_val']


# @parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
# @parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
@parametrize('preprocess_at_init', [True, False])
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_lsdddrift(data, preprocess_at_init, backend, tmp_path, seed):
Expand Down Expand Up @@ -553,7 +555,7 @@ def test_save_contextmmddrift(data, kernel, backend, tmp_path, seed): # noqa: F
assert cd_load._detector.n_permutations == N_PERMUTATIONS
assert cd_load._detector.p_val == P_VAL
assert isinstance(cd_load._detector.preprocess_fn, Callable)
assert cd_load._detector.preprocess_fn.func.__name__ == 'preprocess_simple'
assert cd_load._detector.preprocess_fn.__name__ == 'preprocess_simple'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change required since we no longer wrap the function in a partial when there are no kwargs.

assert cd._detector.x_kernel.sigma == cd_load._detector.x_kernel.sigma
assert cd._detector.c_kernel.sigma == cd_load._detector.c_kernel.sigma
assert cd._detector.x_kernel.init_sigma_fn == cd_load._detector.x_kernel.init_sigma_fn
Expand Down Expand Up @@ -629,7 +631,7 @@ def test_save_regressoruncertaintydrift(data, regressor, backend, tmp_path, seed
], indirect=True
)
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811
def test_save_onlinemmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811
"""
Test MMDDriftOnline on continuous datasets, with UAE as preprocess_fn.

Expand All @@ -645,7 +647,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
cd = MMDDriftOnline(X_ref,
ert=ERT,
backend=backend,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
n_bootstraps=N_BOOTSTRAPS,
kernel=kernel,
window_size=WINDOW_SIZE
Expand All @@ -667,7 +669,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
stats_load.append(pred['data']['test_stat'])

# assertions
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref)
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref)
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS
assert cd_load._detector.ert == ERT
assert isinstance(cd_load._detector.preprocess_fn, Callable)
Expand All @@ -678,7 +680,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,


@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
def test_save_onlinelsdddrift(data, preprocess_uae, backend, tmp_path, seed):
"""
Test LSDDDriftOnline on continuous datasets, with UAE as preprocess_fn.

Expand All @@ -694,7 +696,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
cd = LSDDDriftOnline(X_ref,
ert=ERT,
backend=backend,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
n_bootstraps=N_BOOTSTRAPS,
window_size=WINDOW_SIZE
)
Expand All @@ -715,7 +717,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
stats_load.append(pred['data']['test_stat'])

# assertions
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5)
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5)
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS
assert cd_load._detector.ert == ERT
assert isinstance(cd_load._detector.preprocess_fn, Callable)
Expand All @@ -726,7 +728,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):


@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
def test_save_onlinecvmdrift(data, preprocess_uae, tmp_path, seed):
"""
Test CVMDriftOnline on continuous datasets, with UAE as preprocess_fn.

Expand All @@ -738,7 +740,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
with fixed_seed(seed):
cd = CVMDriftOnline(X_ref,
ert=ERT,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
n_bootstraps=N_BOOTSTRAPS,
window_sizes=[WINDOW_SIZE]
)
Expand All @@ -759,7 +761,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
stats_load.append(pred['data']['test_stat'])

# assertions
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5)
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5)
assert cd_load.n_bootstraps == N_BOOTSTRAPS
assert cd_load.ert == ERT
assert isinstance(cd_load.preprocess_fn, Callable)
Expand Down Expand Up @@ -1100,15 +1102,12 @@ def test_save_deepkernel(data, deep_kernel, backend, tmp_path): # noqa: F811
assert kernel_loaded.kernel_b.sigma == deep_kernel.kernel_b.sigma


@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_')
def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
def test_save_preprocess_drift(data, preprocess_fn, tmp_path, backend):
"""
Unit test for _save_preprocess_config and _load_preprocess_config, with continuous data.

preprocess_fn's are saved (serialized) and then loaded, with assertions to check equivalence.
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all well covered by this test.
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, with the
`model` either being a simple tf/torch model, or a `HiddenOutput` class.
"""
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
# Save preprocess_fn to config
Expand All @@ -1132,14 +1131,40 @@ def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
assert isinstance(preprocess_fn_load.keywords['model'], nn.Module)


@parametrize('preprocess_fn', [preprocess_simple, preprocess_simple_with_kwargs])
def test_save_preprocess_custom(preprocess_fn, tmp_path):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test function to test save/load of:

  1. preprocess_fn where it is a Python func with no kwargs.
  2. preprocess_fn where it is a partial with kwargs.

"""
Test saving/loading of custom preprocessing functions, without and with kwargs.
"""
# Save preprocess_fn to config
filepath = tmp_path
cfg_preprocess = _save_preprocess_config(preprocess_fn, input_shape=None, filepath=filepath)
cfg_preprocess = _path2str(cfg_preprocess)
cfg_preprocess = PreprocessConfig(**cfg_preprocess).dict() # pydantic validation

assert tmp_path.joinpath(cfg_preprocess['src']).is_file()
assert cfg_preprocess['src'] == os.path.join('preprocess_fn', 'function.dill')
if isinstance(preprocess_fn, partial): # kwargs expected
assert cfg_preprocess['kwargs'] == preprocess_fn.keywords
else: # no kwargs expected
assert cfg_preprocess['kwargs'] == {}

# Resolve and load preprocess config
cfg = {'preprocess_fn': cfg_preprocess}
preprocess_fn_load = resolve_config(cfg, tmp_path)['preprocess_fn'] # tests _load_preprocess_config implicitly
if isinstance(preprocess_fn, partial):
assert preprocess_fn_load.func == preprocess_fn.func
assert preprocess_fn_load.keywords == preprocess_fn.keywords
else:
assert preprocess_fn_load == preprocess_fn


@parametrize('preprocess_fn', [preprocess_nlp])
@parametrize_with_cases("data", cases=TextData.movie_sentiment_data, prefix='data_')
def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
"""
Unit test for _save_preprocess_config and _load_preprocess_config, with text data.

Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all covered by this test.
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, text
`tokenizer` and text `embedding` model.
"""
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
# Save preprocess_fn to config
Expand All @@ -1152,6 +1177,8 @@ def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
assert cfg_preprocess['src'] == '@cd.' + registry_str + '.preprocess.preprocess_drift'
assert cfg_preprocess['embedding']['src'] == 'preprocess_fn/embedding'
assert cfg_preprocess['tokenizer']['src'] == 'preprocess_fn/tokenizer'
assert tmp_path.joinpath(cfg_preprocess['preprocess_batch_fn']).is_file()
assert cfg_preprocess['preprocess_batch_fn'] == os.path.join('preprocess_fn', 'preprocess_batch_fn.dill')

if isinstance(preprocess_fn.keywords['model'], (TransformerEmbedding_tf, TransformerEmbedding_pt)):
assert cfg_preprocess['model'] is None
Expand Down