Skip to content

Serialisation of (online) state for online detectors #604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2171395
Initial implementation of save/load state for LSDDDriftOnlineTorch
Sep 2, 2022
68883d4
Add tests for saving state
Sep 7, 2022
944a65d
WIP: LSDDDriftOnlineTF implementation
Nov 3, 2022
672fbcc
State save/load/reset for online LSDD (tf) and MMD (pt and tf)
Nov 7, 2022
9411d54
Online FET and CVM implementation
Nov 7, 2022
d3c397a
Update LSDD and MMD to reset via reset() rather than reset_state()
Nov 7, 2022
06a032b
use .get for state_dict in mmd/lsdd for consistency
Nov 7, 2022
9ea9339
WIP: saving state via save_detector. Need to add tests
Nov 7, 2022
faa483b
Fixes, and tests for save_detector save_state kwarg
Nov 7, 2022
ec48bb6
Merge branch 'master' into feature/save_state
Nov 23, 2022
6deb17f
Fix tests
Nov 23, 2022
272ef3a
Update changelog
Nov 24, 2022
e9c3685
Update docs
Nov 24, 2022
f73fbbf
POC of refactoring. LSDDDriftOnlineTorch only
Nov 30, 2022
09c4305
Refactor initialise methods (and tests)
Dec 14, 2022
1774ee9
Move save/load_state to base classes
Dec 14, 2022
d9d2e25
Manage online test seeds with fixed_seed
Dec 15, 2022
975094e
Add missing _state.py files
Dec 15, 2022
363b435
Fix ops deps tests
Dec 15, 2022
cfc8918
Remove misplaced offline_state code
Dec 15, 2022
30bc57a
Fix test_saving
Dec 15, 2022
1d81d35
Merge branch 'master' into feature/save_state
Jan 4, 2023
2d9c5cf
Remove uneccesary Framework import
Jan 4, 2023
bb53453
Simplify online state tests
Jan 4, 2023
bb4035a
Incorperate some updates from feature/save_offline_state
Jan 5, 2023
430ebd5
Revert 21519e4, but keep changes to tests and additions to online_sta…
Jan 10, 2023
91e1b61
Test saving of state for all online detectors in test_saving.py
Jan 10, 2023
9e84410
Added logging messages
Jan 10, 2023
5dea156
Remove old _set_state_path method
Jan 10, 2023
785365d
Make StatefulDetector inherit from ConfigurableDetector
Jan 10, 2023
7ce6476
Replace indexing with tf.gather
Jan 11, 2023
a2bce2b
Merge branch 'fix/tf_indexing_bug' into feature/save_state
Jan 11, 2023
b2806cc
Add filepath check back to _load_detector_config
Jan 12, 2023
2d4b6a3
Rename StatefulDetector to StatefulDetectorOnline
Jan 12, 2023
67915b8
Add online_state_keys attr to StatefulDetectorOnline
Jan 12, 2023
feb9fda
Revert "Add online_state_keys attr to StatefulDetectorOnline"
Jan 12, 2023
aa6b0d3
Rename rest to reset_state
Jan 12, 2023
15d1dfa
Update typings
Jan 13, 2023
da56714
Remove extra parentheses
Jan 16, 2023
6240fdf
Remove unused seed's in tests
Jan 16, 2023
5daf1b1
Add BaseDriftOnline class
Jan 16, 2023
c7d72c7
Add description for BaseUniDriftOnline
Jan 16, 2023
d190589
Remove save_state kwarg from save_detector
Jan 16, 2023
4de6fad
Move state methods to StateMixin
Jan 17, 2023
7b824d8
Update docs
Jan 18, 2023
d888276
Only mention .predict() in saving page
Jan 18, 2023
1026713
Move conftest.py's containing seed fixture to local test dirs
Jan 18, 2023
828e486
Add missing conftest
Jan 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
See the [documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/mmddrift.html) and [example notebook](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mmd_keops.html) for more info ([#548](https://github.com/SeldonIO/alibi-detect/pull/548)).
- **New feature** Added support for serializing detectors with PyTorch backends, and detectors containing PyTorch models in their proprocessing functions ([#656](https://github.com/SeldonIO/alibi-detect/pull/656)).
- **New feature** Added support for serializing detectors with KeOps backends ([#681](https://github.com/SeldonIO/alibi-detect/pull/681)).
- **New feature** Added support for saving and loading online detectors' state. This allows a detector to be restarted from previously generated checkpoints ([#604](https://github.com/SeldonIO/alibi-detect/pull/604)).
- **New feature** Added a PyTorch version of the `UAE` preprocessing utility function ([#656](https://github.com/SeldonIO/alibi-detect/pull/656)).
- If a `categories_per_feature` dictionary is not passed to `TabularDrift`, a warning is now raised to inform the user that all features are assumed to be numerical ([#606](https://github.com/SeldonIO/alibi-detect/pull/606)).
- For the `ClassifierDrift` and `SpotTheDiffDrift` detectors, we can also return the out-of-fold instances of the reference and test sets. When using `train_size` for training the detector, this allows to associate the returned prediction probabilities with the correct instances.
Expand Down
17 changes: 12 additions & 5 deletions alibi_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,7 @@ def predict(self) -> Any: ...
class ConfigurableDetector(Detector, Protocol):
"""Type Protocol for detectors that have support for saving via config.

Used for typing save and load functionality in `alibi_detect.saving.saving.py`.

Note:
This exists to distinguish between detectors with and without support for config saving and loading. Once all
detector support this then this protocol will be removed.
Used for typing save and load functionality in `alibi_detect.saving.saving`.
"""
def get_config(self): ...

Expand All @@ -229,6 +225,17 @@ def from_config(self): ...
def _set_config(self): ...


@runtime_checkable
class StatefulDetectorOnline(ConfigurableDetector, Protocol):
"""Type Protocol for detectors that have support for save/loading of online state.

Used for typing save and load functionality in `alibi_detect.saving.saving`.
"""
def save_state(self, filepath): ...

def load_state(self, filepath): ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should parameters have type hints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 15d1dfa. Note: I also updated the pre-existing get_config and set_config methods here.



class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(
Expand Down
219 changes: 209 additions & 10 deletions alibi_detect/cd/base_online.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
from pathlib import Path
import logging
import warnings
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

import numpy as np
from alibi_detect.base import BaseDetector, concept_drift_dict
from alibi_detect.cd.utils import get_input_shape
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
from alibi_detect.utils.frameworks import Framework, has_pytorch, has_tensorflow
from alibi_detect.utils._state import save_state_dict, load_state_dict
from alibi_detect.utils._types import Literal

if has_pytorch:
import torch
Expand All @@ -18,6 +23,8 @@

class BaseMultiDriftOnline(BaseDetector):
thresholds: np.ndarray
backend: Literal['pytorch', 'tensorflow']
online_state_keys: Tuple[str, ...]

def __init__(
self,
Expand Down Expand Up @@ -106,6 +113,66 @@ def _configure_ref_subset(self):
def _update_state(self, x_t: Union[np.ndarray, 'tf.Tensor', 'torch.Tensor']):
pass

def _set_state_dir(self, dirpath: Union[str, os.PathLike]):
"""
Set the directory path to store state in, and create an empty directory if it doesn't already exist.

Parameters
----------
dirpath
The directory to save state file inside.
"""
self.state_dir = Path(dirpath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a private attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.state_dir.mkdir(parents=True, exist_ok=True)

def save_state(self, filepath: Union[str, os.PathLike]):
"""
Save a detector's state to disk in order to generate a checkpoint.

Parameters
----------
filepath
The directory to save state to.
"""
self._set_state_dir(filepath)
self._save_state()

def load_state(self, filepath: Union[str, os.PathLike]):
"""
Load the detector's state from disk, in order to restart from a checkpoint previously generated with
`save_state`.

Parameters
----------
filepath
The directory to load state from.
"""
self._set_state_dir(filepath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary when loading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so that self.state_dir is set (and converted from str to pathlib.Path) when load_stateis called as well as whensave_state` is called.

I thought it might be helpful to have state_dir as a public attribute so that a user could see interrogate the detector to see where state was loaded from. Although thinking about it more, for the backend detectors one would have to do detector._detector.state_dir (access a private attribute) anyway. I guess we'd probably want to define a @property if we actually want to support this functionality properly...

Happy to just make it private if you think its better though...

self._load_state()
logger.info('State loaded for t={} from {}'.format(self.t, self.state_dir))

def _save_state(self):
"""
Private method to save a detector's state to disk.

TODO - Method slightly verbose as designed to facilitate saving of "offline" state in follow-up PR.
"""
suffix = '.pt' if self.backend == Framework.PYTORCH else '.npz'
filename = 'state'
keys = self.online_state_keys
save_state_dict(self, keys, self.state_dir.joinpath(filename + suffix))
logger.info('Saved state for t={} to {}'.format(self.t, self.state_dir))

def _load_state(self, offline: bool = False):
"""
Private method to load a detector's state from disk.

TODO - Method slightly verbose as designed to facilitate loading of "offline" state in follow-up PR.
"""
suffix = '.pt' if self.backend == Framework.PYTORCH else '.npz'
filename = 'state'
load_state_dict(self, self.state_dir.joinpath(filename + suffix), raise_error=True)

def _preprocess_xt(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
"""
Private method to preprocess a single test instance ready for _update_state.
Expand All @@ -126,17 +193,46 @@ def _preprocess_xt(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
return x_t[None, :]

def get_threshold(self, t: int) -> float:
"""
Return the threshold for timestep `t`.

Parameters
----------
t
The timestep to return a threshold for.

Returns
-------
The threshold at timestep `t`.
"""
return self.thresholds[t] if t < self.window_size else self.thresholds[-1]

def _initialise(self) -> None:
def _initialise_state(self) -> None:
"""
Initialise online state (the stateful attributes updated by `score` and `predict`).

If a subclassed detector has additional online state, an additional `_initialise_state` should be defined,
with a call to `super()._initialise_state()` included (see `LSDDDriftOnlineTorch._initialise_state()` for
an example).
"""
self.t = 0 # corresponds to a test set of ref data
self.test_stats = np.array([]) # type: ignore[var-annotated]
self.drift_preds = np.array([]) # type: ignore[var-annotated]
self._configure_ref_subset()

def reset(self) -> None:
"Resets the detector but does not reconfigure thresholds."
self._initialise()
"""
Deprecated reset method. This method will be repurposed or removed in the future. To reset the detector to
its initial state (`t=0`) use :meth:`reset_state`.
"""
self.reset_state()
warnings.warn('This method is deprecated and will be removed/repurposed in the future. To reset the detector '
'to its initial state use `reset_state`.', DeprecationWarning)

def reset_state(self) -> None:
"""
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
"""
self._initialise_state()

def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True,
) -> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
Expand Down Expand Up @@ -179,6 +275,7 @@ def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True,

class BaseUniDriftOnline(BaseDetector):
thresholds: np.ndarray
online_state_keys: Tuple[str, ...]

def __init__(
self,
Expand Down Expand Up @@ -290,7 +387,79 @@ def _configure_ref(self):
def _update_state(self, x_t: np.ndarray):
pass

def _set_state_dir(self, dirpath: Union[str, os.PathLike]):
"""
Set the directory path to store state in, and create an empty directory if it doesn't already exist.

Parameters
----------
dirpath
The directory to save state file inside.
"""
self.state_dir = Path(dirpath)
self.state_dir.mkdir(parents=True, exist_ok=True)

def save_state(self, filepath: Union[str, os.PathLike]):
"""
Save a detector's state to disk in order to generate a checkpoint.

Parameters
----------
filepath
The directory to save state to.
"""
self._set_state_dir(filepath)
self._save_state()
logger.info('Saved state for t={} to {}'.format(self.t, self.state_dir))

def load_state(self, filepath: Union[str, os.PathLike]):
"""
Load the detector's state from disk, in order to restart from a checkpoint previously generated with
`save_state`.

Parameters
----------
filepath
The directory to load state from.
"""
self._set_state_dir(filepath)
self._load_state()
logger.info('State loaded for t={} from {}'.format(self.t, self.state_dir))

def _save_state(self):
"""
Private method to save a detector's state to disk.

TODO - Method slightly verbose as designed to facilitate saving of "offline" state in follow-up PR.
"""
filename = 'state'
keys = self.online_state_keys
save_state_dict(self, keys, self.state_dir.joinpath(filename + '.npz'))

def _load_state(self, offline: bool = False):
"""
Private method to load a detector's state from disk.

TODO - Method slightly verbose as designed to facilitate loading of "offline" state in follow-up PR.
"""
filename = 'state'
load_state_dict(self, self.state_dir.joinpath(filename + '.npz'), raise_error=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a lot of duplicated code that's exactly the same as for BaseMultiDriftOnline which suggests we may want to refactor using functions instead of methods or a mixin class? Or perhaps the class hierarchy needs to be updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this seems to apply to other methods too, so perhaps is a more widespread problem requiring a refactoring later...

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. Having a BaseDriftOnline class for generic methods, or a mix-in both seem much nicer than this current pattern. I'll have a rethink 👍🏻

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce duplication, 5daf1b1 adds a BaseDriftOnline class. @jklaise @mauicv could I get your thoughts on the design of BaseDriftOnline please? I've gone with a parent class rather than mix-in since it seems strange to define a mix-in in alibi_detect/base.py when it is only to be used in two classes (BaseMultiDriftOnline and BaseUniDriftOnline). I also decided to put it in alibi_detect/cd/base_online.py rather than alibi_detect/base.py since at the moment the concept of "online" detectors is specific to drift (this may change if we decide stateful outlier detectors are in fact "online").

Copy link
Contributor

@jklaise jklaise Jan 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM however noting that there's quite a few abstract methods, some of which (not all?) are implemented in the Multi/Uni abstract child classes, which come with their own set of abstract methods... Worried that this may become a bit tricky to keep track of. As a minimum, would group all abstract methods to come after each other and add docstrings on expected implementation and also, where valid, which of the Multi/Uni classes implement these methods (+ type hints as always).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #604 (comment) wrt to type hints, not sure on best approach here.

Wrt to the abstract methods, if they are missing from the Multi/Uni child classes that will be because they are instead defined in the next subclass down i.e. LSDDDriftOnlineTorch._initialise_state or CVMDriftOnline._configure_thresholds...

We could move the abstract methods such as _configure_thresholds back to their respective Multi/Uni abstract class, at the cost of more duplication (but maybe less complexity?)

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the new base class, and moved state methods to StateMixin. See #604 (comment).

def _check_x(self, x: Any, x_ref: bool = False) -> np.ndarray:
"""
Check the type and shape of the data `x`, and coerces it to the correct shape if possible.

Parameters
----------
x
The data to be checked.
x_ref
Whether `x` is a batch of reference data instances (if `True`), or a single test data instance (if `False`).

Returns
-------
The checked data, coerced to be a np.ndarray of the correct shape.
"""
# Check the type of x
if isinstance(x, np.ndarray):
pass
Expand Down Expand Up @@ -333,21 +502,51 @@ def _preprocess_xt(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
return x_t

def get_threshold(self, t: int) -> np.ndarray:
"""
Return the threshold for timestep `t`.

Parameters
----------
t
The timestep to return a threshold for.

Returns
-------
The threshold at timestep `t`.
"""
return self.thresholds[t] if t < len(self.thresholds) else self.thresholds[-1]

def _initialise(self) -> None:
def _initialise_state(self) -> None:
"""
Initialise online state (the stateful attributes updated by `score` and `predict`).

If a subclassed detector has additional online state, an additional `_initialise_state` should be defined,
with a call to `super()._initialise_state()` included (see `CVMDriftOnlineTorch._initialise_state()` for
an example).
"""
self.t = 0
self.xs = np.array([])
self.test_stats = np.empty([0, len(self.window_sizes), self.n_features])
self.drift_preds = np.array([]) # type: ignore[var-annotated]
self._configure_ref()

@abstractmethod
def _check_drift(self, test_stats: np.ndarray, thresholds: np.ndarray) -> int:
pass

def reset(self) -> None:
"Resets the detector but does not reconfigure thresholds."
self._initialise()
"""
Deprecated reset method. This method will be repurposed or removed in the future. To reset the detector to
its initial state (`t=0`) use :meth:`reset_state`.
"""
self.reset_state()
warnings.warn('This method is deprecated and will be removed/repurposed in the future. To reset the detector '
'to its initial state use `reset_state`.', DeprecationWarning)

def reset_state(self) -> None:
"""
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
"""
self._initialise_state()

def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True,
) -> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
Expand Down
25 changes: 24 additions & 1 deletion alibi_detect/cd/cvm_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class CVMDriftOnline(BaseUniDriftOnline, DriftConfigMixin):
online_state_keys = ('t', 'test_stats', 'drift_preds', 'xs', 'ids_ref_wins', 'ids_wins_ref', 'ids_wins_wins')

def __init__(
self,
x_ref: Union[np.ndarray, list],
Expand Down Expand Up @@ -92,10 +94,14 @@ def __init__(
self.batch_size = n_bootstraps if batch_size is None else batch_size

# Configure thresholds and initialise detector
self._initialise()
self._initialise_state()
self._configure_thresholds()
self._configure_ref()

def _configure_ref(self) -> None:
"""
Configure the reference data.
"""
ids_ref_ref = self.x_ref[None, :, :] >= self.x_ref[:, None, :]
self.ref_cdf_ref = np.sum(ids_ref_ref, axis=0) / self.n

Expand Down Expand Up @@ -162,6 +168,14 @@ def _simulate_streams(self, t_max: int) -> np.ndarray:
return stats

def _update_state(self, x_t: np.ndarray):
"""
Update online state based on the provided test instance.

Parameters
----------
x_t
The test instance.
"""
self.t += 1
if self.t == 1:
# Initialise stream
Expand All @@ -186,6 +200,15 @@ def _update_state(self, x_t: np.ndarray):
[self.ids_wins_wins, (x_t <= self.xs[-self.max_ws:, :])[None, :, :]], 0
)

def _initialise_state(self) -> None:
"""
Initialise online state (the stateful attributes updated by `score` and `predict`).
"""
super()._initialise_state()
self.ids_ref_wins = np.array([])
self.ids_wins_ref = np.array([])
self.ids_wins_wins = np.array([])

def score(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
"""
Compute the test-statistic (CVM) between the reference window(s) and test window.
Expand Down
Loading