-
Notifications
You must be signed in to change notification settings - Fork 229
Fix two bugs w/ kwargs and preprocess_batch_fn
in preprocess_fn
serialisation
#752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import warnings | ||
from functools import partial | ||
from pathlib import Path | ||
from typing import Callable, Optional, Tuple, Union, Any, TYPE_CHECKING | ||
from typing import Callable, Optional, Tuple, Union, Any, Dict, TYPE_CHECKING | ||
import dill | ||
import numpy as np | ||
import toml | ||
|
@@ -264,7 +264,7 @@ def _save_preprocess_config(preprocess_fn: Callable, | |
The config dictionary, containing references to the serialized artefacts. The format if this dict matches that | ||
of the `preprocess` field in the drift detector specification. | ||
""" | ||
preprocess_cfg = {} | ||
preprocess_cfg: Dict[str, Any] = {} | ||
local_path = Path('preprocess_fn') | ||
|
||
# Serialize function | ||
|
@@ -292,7 +292,7 @@ def _save_preprocess_config(preprocess_fn: Callable, | |
|
||
# Arbitrary function | ||
elif callable(v): | ||
src, _ = _serialize_object(v, filepath, local_path) | ||
src, _ = _serialize_object(v, filepath, local_path.joinpath(k)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is bug fix 1. i.e. add the kwarg name to the serialised dill file. |
||
kwargs.update({k: src}) | ||
|
||
# Put remaining kwargs directly into cfg | ||
|
@@ -302,7 +302,7 @@ def _save_preprocess_config(preprocess_fn: Callable, | |
if 'preprocess_drift' in func: | ||
preprocess_cfg.update(kwargs) | ||
else: | ||
kwargs.update({'kwargs': kwargs}) | ||
preprocess_cfg.update({'kwargs': kwargs}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is bug fix 2. Previously the |
||
|
||
return preprocess_cfg | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,7 +89,7 @@ def encoder_dropout_model(backend, current_cases): | |
|
||
|
||
@fixture | ||
def preprocess_custom(encoder_model): | ||
def preprocess_uae(encoder_model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More descriptive naming. |
||
""" | ||
Preprocess function with Untrained Autoencoder. | ||
""" | ||
|
@@ -263,6 +263,14 @@ def preprocess_simple(x: np.ndarray): | |
return x*2.0 | ||
|
||
|
||
@fixture | ||
def preprocess_simple_with_kwargs(): | ||
""" | ||
Simple function to test serialization of generic Python function with kwargs, within preprocess_fn. | ||
""" | ||
return partial(preprocess_simple, kwarg1=42, kwarg2=True) | ||
|
||
|
||
@fixture | ||
def preprocess_nlp(embedding, tokenizer, max_len, backend): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
Internal functions such as save_kernel/load_kernel_config etc are also tested. | ||
""" | ||
from functools import partial | ||
import os | ||
from pathlib import Path | ||
from typing import Callable | ||
|
||
|
@@ -19,7 +20,8 @@ | |
import torch.nn as nn | ||
|
||
from .datasets import BinData, CategoricalData, ContinuousData, MixedData, TextData | ||
from .models import (encoder_model, preprocess_custom, preprocess_hiddenoutput, preprocess_simple, # noqa: F401 | ||
from .models import (encoder_model, preprocess_uae, preprocess_hiddenoutput, preprocess_simple, # noqa: F401 | ||
preprocess_simple_with_kwargs, | ||
preprocess_nlp, LATENT_DIM, classifier_model, kernel, deep_kernel, nlp_embedding_and_tokenizer, | ||
embedding, tokenizer, max_len, enc_dim, encoder_dropout_model, optimizer) | ||
|
||
|
@@ -105,7 +107,7 @@ def test_load_simple_config(cfg, tmp_path): | |
assert v == cfg_new[k] | ||
|
||
|
||
@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput]) | ||
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput]) | ||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_ksdrift(data, preprocess_fn, tmp_path): | ||
""" | ||
|
@@ -171,7 +173,7 @@ def test_save_ksdrift_nlp(data, preprocess_fn, enc_dim, tmp_path): # noqa: F811 | |
@pytest.mark.skipif(version.parse(scipy.__version__) < version.parse('1.7.0'), | ||
reason="Requires scipy version >= 1.7.0") | ||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_cvmdrift(data, preprocess_custom, tmp_path): | ||
def test_save_cvmdrift(data, preprocess_uae, tmp_path): | ||
""" | ||
Test CVMDrift on continuous datasets, with UAE as preprocess_fn. | ||
|
||
|
@@ -181,14 +183,14 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path): | |
X_ref, X_h0 = data | ||
cd = CVMDrift(X_ref, | ||
p_val=P_VAL, | ||
preprocess_fn=preprocess_custom, | ||
preprocess_fn=preprocess_uae, | ||
preprocess_at_init=True, | ||
) | ||
save_detector(cd, tmp_path) | ||
cd_load = load_detector(tmp_path) | ||
|
||
# Assert | ||
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load.x_ref) | ||
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load.x_ref) | ||
assert cd_load.n_features == LATENT_DIM | ||
assert cd_load.p_val == P_VAL | ||
assert isinstance(cd_load.preprocess_fn, Callable) | ||
|
@@ -203,7 +205,7 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path): | |
], indirect=True | ||
) | ||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811 | ||
def test_save_mmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811 | ||
""" | ||
Test MMDDrift on continuous datasets, with UAE as preprocess_fn. | ||
|
||
|
@@ -217,7 +219,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed) | |
kwargs = { | ||
'p_val': P_VAL, | ||
'backend': backend, | ||
'preprocess_fn': preprocess_custom, | ||
'preprocess_fn': preprocess_uae, | ||
'n_permutations': N_PERMUTATIONS, | ||
'preprocess_at_init': True, | ||
'kernel': kernel, | ||
|
@@ -237,7 +239,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed) | |
preds_load = cd_load.predict(X_h0) | ||
|
||
# assertions | ||
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref) | ||
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref) | ||
assert not cd_load._detector.infer_sigma | ||
assert cd_load._detector.n_permutations == N_PERMUTATIONS | ||
assert cd_load._detector.p_val == P_VAL | ||
|
@@ -248,7 +250,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed) | |
assert preds['data']['p_val'] == preds_load['data']['p_val'] | ||
|
||
|
||
# @parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput]) | ||
# @parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput]) | ||
@parametrize('preprocess_at_init', [True, False]) | ||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_lsdddrift(data, preprocess_at_init, backend, tmp_path, seed): | ||
|
@@ -553,7 +555,7 @@ def test_save_contextmmddrift(data, kernel, backend, tmp_path, seed): # noqa: F | |
assert cd_load._detector.n_permutations == N_PERMUTATIONS | ||
assert cd_load._detector.p_val == P_VAL | ||
assert isinstance(cd_load._detector.preprocess_fn, Callable) | ||
assert cd_load._detector.preprocess_fn.func.__name__ == 'preprocess_simple' | ||
assert cd_load._detector.preprocess_fn.__name__ == 'preprocess_simple' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change required since we no longer wrap the function in a |
||
assert cd._detector.x_kernel.sigma == cd_load._detector.x_kernel.sigma | ||
assert cd._detector.c_kernel.sigma == cd_load._detector.c_kernel.sigma | ||
assert cd._detector.x_kernel.init_sigma_fn == cd_load._detector.x_kernel.init_sigma_fn | ||
|
@@ -629,7 +631,7 @@ def test_save_regressoruncertaintydrift(data, regressor, backend, tmp_path, seed | |
], indirect=True | ||
) | ||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811 | ||
def test_save_onlinemmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811 | ||
""" | ||
Test MMDDriftOnline on continuous datasets, with UAE as preprocess_fn. | ||
|
||
|
@@ -645,7 +647,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, | |
cd = MMDDriftOnline(X_ref, | ||
ert=ERT, | ||
backend=backend, | ||
preprocess_fn=preprocess_custom, | ||
preprocess_fn=preprocess_uae, | ||
n_bootstraps=N_BOOTSTRAPS, | ||
kernel=kernel, | ||
window_size=WINDOW_SIZE | ||
|
@@ -667,7 +669,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, | |
stats_load.append(pred['data']['test_stat']) | ||
|
||
# assertions | ||
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref) | ||
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref) | ||
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS | ||
assert cd_load._detector.ert == ERT | ||
assert isinstance(cd_load._detector.preprocess_fn, Callable) | ||
|
@@ -678,7 +680,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, | |
|
||
|
||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed): | ||
def test_save_onlinelsdddrift(data, preprocess_uae, backend, tmp_path, seed): | ||
""" | ||
Test LSDDDriftOnline on continuous datasets, with UAE as preprocess_fn. | ||
|
||
|
@@ -694,7 +696,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed): | |
cd = LSDDDriftOnline(X_ref, | ||
ert=ERT, | ||
backend=backend, | ||
preprocess_fn=preprocess_custom, | ||
preprocess_fn=preprocess_uae, | ||
n_bootstraps=N_BOOTSTRAPS, | ||
window_size=WINDOW_SIZE | ||
) | ||
|
@@ -715,7 +717,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed): | |
stats_load.append(pred['data']['test_stat']) | ||
|
||
# assertions | ||
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5) | ||
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5) | ||
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS | ||
assert cd_load._detector.ert == ERT | ||
assert isinstance(cd_load._detector.preprocess_fn, Callable) | ||
|
@@ -726,7 +728,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed): | |
|
||
|
||
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_') | ||
def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed): | ||
def test_save_onlinecvmdrift(data, preprocess_uae, tmp_path, seed): | ||
""" | ||
Test CVMDriftOnline on continuous datasets, with UAE as preprocess_fn. | ||
|
||
|
@@ -738,7 +740,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed): | |
with fixed_seed(seed): | ||
cd = CVMDriftOnline(X_ref, | ||
ert=ERT, | ||
preprocess_fn=preprocess_custom, | ||
preprocess_fn=preprocess_uae, | ||
n_bootstraps=N_BOOTSTRAPS, | ||
window_sizes=[WINDOW_SIZE] | ||
) | ||
|
@@ -759,7 +761,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed): | |
stats_load.append(pred['data']['test_stat']) | ||
|
||
# assertions | ||
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5) | ||
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5) | ||
assert cd_load.n_bootstraps == N_BOOTSTRAPS | ||
assert cd_load.ert == ERT | ||
assert isinstance(cd_load.preprocess_fn, Callable) | ||
|
@@ -1100,15 +1102,12 @@ def test_save_deepkernel(data, deep_kernel, backend, tmp_path): # noqa: F811 | |
assert kernel_loaded.kernel_b.sigma == deep_kernel.kernel_b.sigma | ||
|
||
|
||
@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput]) | ||
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput]) | ||
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_') | ||
def test_save_preprocess(data, preprocess_fn, tmp_path, backend): | ||
def test_save_preprocess_drift(data, preprocess_fn, tmp_path, backend): | ||
""" | ||
Unit test for _save_preprocess_config and _load_preprocess_config, with continuous data. | ||
|
||
preprocess_fn's are saved (serialized) and then loaded, with assertions to check equivalence. | ||
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config, | ||
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all well covered by this test. | ||
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, with the | ||
`model` either being a simple tf/torch model, or a `HiddenOutput` class. | ||
""" | ||
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch' | ||
# Save preprocess_fn to config | ||
|
@@ -1132,14 +1131,40 @@ def test_save_preprocess(data, preprocess_fn, tmp_path, backend): | |
assert isinstance(preprocess_fn_load.keywords['model'], nn.Module) | ||
|
||
|
||
@parametrize('preprocess_fn', [preprocess_simple, preprocess_simple_with_kwargs]) | ||
def test_save_preprocess_custom(preprocess_fn, tmp_path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New test function to test save/load of:
|
||
""" | ||
Test saving/loading of custom preprocessing functions, without and with kwargs. | ||
""" | ||
# Save preprocess_fn to config | ||
filepath = tmp_path | ||
cfg_preprocess = _save_preprocess_config(preprocess_fn, input_shape=None, filepath=filepath) | ||
cfg_preprocess = _path2str(cfg_preprocess) | ||
cfg_preprocess = PreprocessConfig(**cfg_preprocess).dict() # pydantic validation | ||
|
||
assert tmp_path.joinpath(cfg_preprocess['src']).is_file() | ||
assert cfg_preprocess['src'] == os.path.join('preprocess_fn', 'function.dill') | ||
if isinstance(preprocess_fn, partial): # kwargs expected | ||
assert cfg_preprocess['kwargs'] == preprocess_fn.keywords | ||
else: # no kwargs expected | ||
assert cfg_preprocess['kwargs'] == {} | ||
|
||
# Resolve and load preprocess config | ||
cfg = {'preprocess_fn': cfg_preprocess} | ||
preprocess_fn_load = resolve_config(cfg, tmp_path)['preprocess_fn'] # tests _load_preprocess_config implicitly | ||
if isinstance(preprocess_fn, partial): | ||
assert preprocess_fn_load.func == preprocess_fn.func | ||
assert preprocess_fn_load.keywords == preprocess_fn.keywords | ||
else: | ||
assert preprocess_fn_load == preprocess_fn | ||
|
||
|
||
@parametrize('preprocess_fn', [preprocess_nlp]) | ||
@parametrize_with_cases("data", cases=TextData.movie_sentiment_data, prefix='data_') | ||
def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend): | ||
""" | ||
Unit test for _save_preprocess_config and _load_preprocess_config, with text data. | ||
|
||
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config, | ||
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all covered by this test. | ||
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, text | ||
`tokenizer` and text `embedding` model. | ||
""" | ||
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch' | ||
# Save preprocess_fn to config | ||
|
@@ -1152,6 +1177,8 @@ def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend): | |
assert cfg_preprocess['src'] == '@cd.' + registry_str + '.preprocess.preprocess_drift' | ||
assert cfg_preprocess['embedding']['src'] == 'preprocess_fn/embedding' | ||
assert cfg_preprocess['tokenizer']['src'] == 'preprocess_fn/tokenizer' | ||
assert tmp_path.joinpath(cfg_preprocess['preprocess_batch_fn']).is_file() | ||
assert cfg_preprocess['preprocess_batch_fn'] == os.path.join('preprocess_fn', 'preprocess_batch_fn.dill') | ||
|
||
if isinstance(preprocess_fn.keywords['model'], (TransformerEmbedding_tf, TransformerEmbedding_pt)): | ||
assert cfg_preprocess['model'] is None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously the preprocess function was wrapped in a
partial
even if there were no kwarg's. Now the bare function is returned in this case.