-
Notifications
You must be signed in to change notification settings - Fork 229
Fix two bugs w/ kwargs and preprocess_batch_fn
in preprocess_fn
serialisation
#752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -253,7 +253,10 @@ def _load_preprocess_config(cfg: dict) -> Optional[Callable]: | |||
logger.warning('Unable to process preprocess_fn. No preprocessing function is defined.') | |||
return None | |||
|
|||
return partial(preprocess_fn, **kwargs) | |||
if kwargs == {}: | |||
return preprocess_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously the preprocess function was wrapped in a partial
even if there were no kwarg's. Now the bare function is returned in this case.
@@ -292,7 +292,7 @@ def _save_preprocess_config(preprocess_fn: Callable, | |||
|
|||
# Arbitrary function | |||
elif callable(v): | |||
src, _ = _serialize_object(v, filepath, local_path) | |||
src, _ = _serialize_object(v, filepath, local_path.joinpath(k)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is bug fix 1. i.e. add the kwarg name to the serialised dill file.
@@ -302,7 +302,7 @@ def _save_preprocess_config(preprocess_fn: Callable, | |||
if 'preprocess_drift' in func: | |||
preprocess_cfg.update(kwargs) | |||
else: | |||
kwargs.update({'kwargs': kwargs}) | |||
preprocess_cfg.update({'kwargs': kwargs}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is bug fix 2. Previously the kwargs
were updated but this had no effect. It should have been the returned preprocess_cfg
dict that was updated instead.
@@ -89,7 +89,7 @@ def encoder_dropout_model(backend, current_cases): | |||
|
|||
|
|||
@fixture | |||
def preprocess_custom(encoder_model): | |||
def preprocess_uae(encoder_model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More descriptive naming. preprocess_custom
actually contains a uae model, so renamed preprocess_uae
. Otherwise confusion in test_saving.py
where we are now testing "custom" preprocessing functions.
@@ -1132,14 +1130,40 @@ def test_save_preprocess(data, preprocess_fn, tmp_path, backend): | |||
assert isinstance(preprocess_fn_load.keywords['model'], nn.Module) | |||
|
|||
|
|||
@parametrize('preprocess_fn', [preprocess_simple, preprocess_simple_with_kwargs]) | |||
def test_save_preprocess_custom(preprocess_fn, tmp_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New test function to test save/load of:
preprocess_fn
where it is a Python func with no kwargs.preprocess_fn
where it is apartial
with kwargs.
@@ -1152,6 +1176,8 @@ def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend): | |||
assert cfg_preprocess['src'] == '@cd.' + registry_str + '.preprocess.preprocess_drift' | |||
assert cfg_preprocess['embedding']['src'] == 'preprocess_fn/embedding' | |||
assert cfg_preprocess['tokenizer']['src'] == 'preprocess_fn/tokenizer' | |||
assert tmp_path.joinpath(cfg_preprocess['preprocess_batch_fn']).is_file() | |||
assert cfg_preprocess['preprocess_batch_fn'] == 'preprocess_fn/preprocess_batch_fn.dill' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tests that the preprocess_batch_fn
is serialized to disk, and has the correct name. Implicitly tests bug fix 1...
@@ -553,7 +554,7 @@ def test_save_contextmmddrift(data, kernel, backend, tmp_path, seed): # noqa: F | |||
assert cd_load._detector.n_permutations == N_PERMUTATIONS | |||
assert cd_load._detector.p_val == P_VAL | |||
assert isinstance(cd_load._detector.preprocess_fn, Callable) | |||
assert cd_load._detector.preprocess_fn.func.__name__ == 'preprocess_simple' | |||
assert cd_load._detector.preprocess_fn.__name__ == 'preprocess_simple' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change required since we no longer wrap the function in a partial
when there are no kwargs.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #752 +/- ##
==========================================
- Coverage 80.48% 80.32% -0.17%
==========================================
Files 137 137
Lines 9302 9304 +2
==========================================
- Hits 7487 7473 -14
- Misses 1815 1831 +16
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Agree with doing the patch release from master
instead of a patch branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR fixes two bugs in serialisation/deserialisation of the drift detector
preprocess_fn
:partial
to be serialised to.dill
are now named according their kwarg name i.e.preprocess_batch_fn
is serialised topreprocess_fn/preprocess_batch_fn.dill
. Previously, this would have been saved topreprocess_fn.dill
. This was loaded fine so not spotted in tests, but there would have been a problem if multiple kwargs in apartial
were to be saved, since they'd all have the same name...preprocess_fn
was a custom Python function with kwarg's (i.e. NOTpreprocess_drift
), the kwargs were not serialised. This has now been fixed and a test added to check for a such a bug in the future.@jklaise @mauicv assuming we are ok to include the
mypy
,flake8
and docs being built w/ 3.10 changes in the patch, we could merge this directly intomaster
and release off that, rather thanpatch/v0.11.1
?TODO: