Skip to content

Serialisation of (online) state for online detectors #604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2171395
Initial implementation of save/load state for LSDDDriftOnlineTorch
Sep 2, 2022
68883d4
Add tests for saving state
Sep 7, 2022
944a65d
WIP: LSDDDriftOnlineTF implementation
Nov 3, 2022
672fbcc
State save/load/reset for online LSDD (tf) and MMD (pt and tf)
Nov 7, 2022
9411d54
Online FET and CVM implementation
Nov 7, 2022
d3c397a
Update LSDD and MMD to reset via reset() rather than reset_state()
Nov 7, 2022
06a032b
use .get for state_dict in mmd/lsdd for consistency
Nov 7, 2022
9ea9339
WIP: saving state via save_detector. Need to add tests
Nov 7, 2022
faa483b
Fixes, and tests for save_detector save_state kwarg
Nov 7, 2022
ec48bb6
Merge branch 'master' into feature/save_state
Nov 23, 2022
6deb17f
Fix tests
Nov 23, 2022
272ef3a
Update changelog
Nov 24, 2022
e9c3685
Update docs
Nov 24, 2022
f73fbbf
POC of refactoring. LSDDDriftOnlineTorch only
Nov 30, 2022
09c4305
Refactor initialise methods (and tests)
Dec 14, 2022
1774ee9
Move save/load_state to base classes
Dec 14, 2022
d9d2e25
Manage online test seeds with fixed_seed
Dec 15, 2022
975094e
Add missing _state.py files
Dec 15, 2022
363b435
Fix ops deps tests
Dec 15, 2022
cfc8918
Remove misplaced offline_state code
Dec 15, 2022
30bc57a
Fix test_saving
Dec 15, 2022
1d81d35
Merge branch 'master' into feature/save_state
Jan 4, 2023
2d9c5cf
Remove uneccesary Framework import
Jan 4, 2023
bb53453
Simplify online state tests
Jan 4, 2023
bb4035a
Incorperate some updates from feature/save_offline_state
Jan 5, 2023
430ebd5
Revert 21519e4, but keep changes to tests and additions to online_sta…
Jan 10, 2023
91e1b61
Test saving of state for all online detectors in test_saving.py
Jan 10, 2023
9e84410
Added logging messages
Jan 10, 2023
5dea156
Remove old _set_state_path method
Jan 10, 2023
785365d
Make StatefulDetector inherit from ConfigurableDetector
Jan 10, 2023
7ce6476
Replace indexing with tf.gather
Jan 11, 2023
a2bce2b
Merge branch 'fix/tf_indexing_bug' into feature/save_state
Jan 11, 2023
b2806cc
Add filepath check back to _load_detector_config
Jan 12, 2023
2d4b6a3
Rename StatefulDetector to StatefulDetectorOnline
Jan 12, 2023
67915b8
Add online_state_keys attr to StatefulDetectorOnline
Jan 12, 2023
feb9fda
Revert "Add online_state_keys attr to StatefulDetectorOnline"
Jan 12, 2023
aa6b0d3
Rename rest to reset_state
Jan 12, 2023
15d1dfa
Update typings
Jan 13, 2023
da56714
Remove extra parentheses
Jan 16, 2023
6240fdf
Remove unused seed's in tests
Jan 16, 2023
5daf1b1
Add BaseDriftOnline class
Jan 16, 2023
c7d72c7
Add description for BaseUniDriftOnline
Jan 16, 2023
d190589
Remove save_state kwarg from save_detector
Jan 16, 2023
4de6fad
Move state methods to StateMixin
Jan 17, 2023
7b824d8
Update docs
Jan 18, 2023
d888276
Only mention .predict() in saving page
Jan 18, 2023
1026713
Move conftest.py's containing seed fixture to local test dirs
Jan 18, 2023
828e486
Add missing conftest
Jan 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
See the [documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/mmddrift.html) and [example notebook](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mmd_keops.html) for more info ([#548](https://github.com/SeldonIO/alibi-detect/pull/548)).
- **New feature** Added support for serializing detectors with PyTorch backends, and detectors containing PyTorch models in their proprocessing functions ([#656](https://github.com/SeldonIO/alibi-detect/pull/656)).
- **New feature** Added support for serializing detectors with KeOps backends ([#681](https://github.com/SeldonIO/alibi-detect/pull/681)).
- **New feature** Added support for saving and loading online detectors' state. This allows a detector to be restarted from previously generated checkpoints ([#604](https://github.com/SeldonIO/alibi-detect/pull/604)).
- **New feature** Added a PyTorch version of the `UAE` preprocessing utility function ([#656](https://github.com/SeldonIO/alibi-detect/pull/656)).
- If a `categories_per_feature` dictionary is not passed to `TabularDrift`, a warning is now raised to inform the user that all features are assumed to be numerical ([#606](https://github.com/SeldonIO/alibi-detect/pull/606)).
- For the `ClassifierDrift` and `SpotTheDiffDrift` detectors, we can also return the out-of-fold instances of the reference and test sets. When using `train_size` for training the detector, this allows to associate the returned prediction probabilities with the correct instances.
Expand Down
29 changes: 20 additions & 9 deletions alibi_detect/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
import os
import copy
import json
import numpy as np
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Union
from typing_extensions import Protocol, runtime_checkable
from alibi_detect.version import __version__

Expand Down Expand Up @@ -153,7 +154,7 @@ def from_config(cls, config: dict):
detector.config['meta']['version_warning'] = version_warning
return detector

def _set_config(self, inputs): # TODO - move to BaseDetector once config save/load implemented for non-drift
def _set_config(self, inputs: dict): # TODO - move to BaseDetector once config save/load implemented for non-drift
"""
Set a detectors `config` attribute upon detector instantiation.

Expand Down Expand Up @@ -216,17 +217,27 @@ def predict(self) -> Any: ...
class ConfigurableDetector(Detector, Protocol):
"""Type Protocol for detectors that have support for saving via config.

Used for typing save and load functionality in `alibi_detect.saving.saving.py`.
Used for typing save and load functionality in `alibi_detect.saving.saving`.
"""
def get_config(self) -> dict: ...

Note:
This exists to distinguish between detectors with and without support for config saving and loading. Once all
detector support this then this protocol will be removed.
@classmethod
def from_config(cls, config: dict): ...

def _set_config(self, inputs: dict): ...


@runtime_checkable
class StatefulDetectorOnline(ConfigurableDetector, Protocol):
"""Type Protocol for detectors that have support for save/loading of online state.

Used for typing save and load functionality in `alibi_detect.saving.saving`.
"""
def get_config(self): ...
t: int = 0

def from_config(self): ...
def save_state(self, filepath: Union[str, os.PathLike]): ...

def _set_config(self): ...
def load_state(self, filepath: Union[str, os.PathLike]): ...


class NumpyEncoder(json.JSONEncoder):
Expand Down
108 changes: 93 additions & 15 deletions alibi_detect/cd/base_online.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import logging
import warnings
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING

import numpy as np
from alibi_detect.base import BaseDetector, concept_drift_dict
from alibi_detect.cd.utils import get_input_shape
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
from alibi_detect.utils.state import StateMixin
from alibi_detect.utils._types import Literal

if has_pytorch:
if TYPE_CHECKING:
import torch

if has_tensorflow:
import tensorflow as tf

logger = logging.getLogger(__name__)


class BaseMultiDriftOnline(BaseDetector):
class BaseMultiDriftOnline(BaseDetector, StateMixin):
t: int = 0
thresholds: np.ndarray
backend: Literal['pytorch', 'tensorflow']
online_state_keys: Tuple[str, ...]

def __init__(
self,
Expand Down Expand Up @@ -126,17 +129,46 @@ def _preprocess_xt(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
return x_t[None, :]

def get_threshold(self, t: int) -> float:
"""
Return the threshold for timestep `t`.

Parameters
----------
t
The timestep to return a threshold for.

Returns
-------
The threshold at timestep `t`.
"""
return self.thresholds[t] if t < self.window_size else self.thresholds[-1]

def _initialise(self) -> None:
def _initialise_state(self) -> None:
"""
Initialise online state (the stateful attributes updated by `score` and `predict`).

If a subclassed detector has additional online state, an additional `_initialise_state` should be defined,
with a call to `super()._initialise_state()` included (see `LSDDDriftOnlineTorch._initialise_state()` for
an example).
"""
self.t = 0 # corresponds to a test set of ref data
self.test_stats = np.array([]) # type: ignore[var-annotated]
self.drift_preds = np.array([]) # type: ignore[var-annotated]
self._configure_ref_subset()

def reset(self) -> None:
"Resets the detector but does not reconfigure thresholds."
self._initialise()
"""
Deprecated reset method. This method will be repurposed or removed in the future. To reset the detector to
its initial state (`t=0`) use :meth:`reset_state`.
"""
self.reset_state()
warnings.warn('This method is deprecated and will be removed/repurposed in the future. To reset the detector '
'to its initial state use `reset_state`.', DeprecationWarning)

def reset_state(self) -> None:
"""
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
"""
self._initialise_state()

def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True,
) -> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
Expand Down Expand Up @@ -177,8 +209,10 @@ def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True,
return cd


class BaseUniDriftOnline(BaseDetector):
class BaseUniDriftOnline(BaseDetector, StateMixin):
t: int = 0
thresholds: np.ndarray
online_state_keys: Tuple[str, ...]

def __init__(
self,
Expand Down Expand Up @@ -291,6 +325,20 @@ def _update_state(self, x_t: np.ndarray):
pass

def _check_x(self, x: Any, x_ref: bool = False) -> np.ndarray:
"""
Check the type and shape of the data `x`, and coerces it to the correct shape if possible.

Parameters
----------
x
The data to be checked.
x_ref
Whether `x` is a batch of reference data instances (if `True`), or a single test data instance (if `False`).

Returns
-------
The checked data, coerced to be a np.ndarray of the correct shape.
"""
# Check the type of x
if isinstance(x, np.ndarray):
pass
Expand Down Expand Up @@ -333,21 +381,51 @@ def _preprocess_xt(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
return x_t

def get_threshold(self, t: int) -> np.ndarray:
"""
Return the threshold for timestep `t`.

Parameters
----------
t
The timestep to return a threshold for.

Returns
-------
The threshold at timestep `t`.
"""
return self.thresholds[t] if t < len(self.thresholds) else self.thresholds[-1]

def _initialise(self) -> None:
def _initialise_state(self) -> None:
"""
Initialise online state (the stateful attributes updated by `score` and `predict`).

If a subclassed detector has additional online state, an additional `_initialise_state` should be defined,
with a call to `super()._initialise_state()` included (see `CVMDriftOnlineTorch._initialise_state()` for
an example).
"""
self.t = 0
self.xs = np.array([]) # type: ignore[var-annotated]
self.test_stats = np.empty([0, len(self.window_sizes), self.n_features])
self.drift_preds = np.array([]) # type: ignore[var-annotated]
self._configure_ref()

@abstractmethod
def _check_drift(self, test_stats: np.ndarray, thresholds: np.ndarray) -> int:
pass

def reset(self) -> None:
"Resets the detector but does not reconfigure thresholds."
self._initialise()
"""
Deprecated reset method. This method will be repurposed or removed in the future. To reset the detector to
its initial state (`t=0`) use :meth:`reset_state`.
"""
self.reset_state()
warnings.warn('This method is deprecated and will be removed/repurposed in the future. To reset the detector '
'to its initial state use `reset_state`.', DeprecationWarning)

def reset_state(self) -> None:
"""
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
"""
self._initialise_state()

def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True,
) -> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
Expand Down
25 changes: 24 additions & 1 deletion alibi_detect/cd/cvm_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class CVMDriftOnline(BaseUniDriftOnline, DriftConfigMixin):
online_state_keys = ('t', 'test_stats', 'drift_preds', 'xs', 'ids_ref_wins', 'ids_wins_ref', 'ids_wins_wins')

def __init__(
self,
x_ref: Union[np.ndarray, list],
Expand Down Expand Up @@ -92,10 +94,14 @@ def __init__(
self.batch_size = n_bootstraps if batch_size is None else batch_size

# Configure thresholds and initialise detector
self._initialise()
self._initialise_state()
self._configure_thresholds()
self._configure_ref()

def _configure_ref(self) -> None:
"""
Configure the reference data.
"""
ids_ref_ref = self.x_ref[None, :, :] >= self.x_ref[:, None, :]
self.ref_cdf_ref = np.sum(ids_ref_ref, axis=0) / self.n

Expand Down Expand Up @@ -162,6 +168,14 @@ def _simulate_streams(self, t_max: int) -> np.ndarray:
return stats

def _update_state(self, x_t: np.ndarray):
"""
Update online state based on the provided test instance.

Parameters
----------
x_t
The test instance.
"""
self.t += 1
if self.t == 1:
# Initialise stream
Expand All @@ -186,6 +200,15 @@ def _update_state(self, x_t: np.ndarray):
[self.ids_wins_wins, (x_t <= self.xs[-self.max_ws:, :])[None, :, :]], 0
)

def _initialise_state(self) -> None:
"""
Initialise online state (the stateful attributes updated by `score` and `predict`).
"""
super()._initialise_state()
self.ids_ref_wins = np.array([])
self.ids_wins_ref = np.array([])
self.ids_wins_wins = np.array([])

def score(self, x_t: Union[np.ndarray, Any]) -> np.ndarray:
"""
Compute the test-statistic (CVM) between the reference window(s) and test window.
Expand Down
16 changes: 15 additions & 1 deletion alibi_detect/cd/fet_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@


class FETDriftOnline(BaseUniDriftOnline, DriftConfigMixin):
online_state_keys = ('t', 'test_stats', 'drift_preds', 'xs')

def __init__(
self,
x_ref: Union[np.ndarray, list],
Expand Down Expand Up @@ -119,10 +121,14 @@ def __init__(
raise ValueError("The `x_ref` data consists of all 0's or all 1's. Thresholds cannot be configured.")

# Configure thresholds and initialise detector
self._initialise()
self._initialise_state()
self._configure_thresholds()
self._configure_ref()

def _configure_ref(self) -> None:
"""
Configure the reference data.
"""
self.sum_ref = np.sum(self.x_ref, axis=0)

def _configure_thresholds(self) -> None:
Expand Down Expand Up @@ -227,6 +233,14 @@ def _exp_moving_avg(arr: np.ndarray, lam: float) -> np.ndarray:
return output

def _update_state(self, x_t: np.ndarray):
"""
Update online state based on the provided test instance.

Parameters
----------
x_t
The test instance.
"""
self.t += 1
if self.t == 1:
# Initialise stream
Expand Down
32 changes: 29 additions & 3 deletions alibi_detect/cd/lsdd_online.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import numpy as np
from typing import Any, Callable, Dict, Optional, Union
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator, Framework
Expand Down Expand Up @@ -112,9 +113,11 @@ def test_stats(self):
def thresholds(self):
return [self._detector.thresholds[min(s, self._detector.window_size-1)] for s in range(self.t)]

def reset(self):
"Resets the detector but does not reconfigure thresholds."
self._detector.reset()
def reset_state(self):
"""
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
"""
self._detector.reset_state()

def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True) \
-> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
Expand Down Expand Up @@ -163,3 +166,26 @@ def get_config(self) -> dict: # Needed due to need to unnormalize x_ref
# Unnormalize x_ref
cfg['x_ref'] = self._detector._unnormalize(cfg['x_ref'])
return cfg

def save_state(self, filepath: Union[str, os.PathLike]):
"""
Save a detector's state to disk in order to generate a checkpoint.

Parameters
----------
filepath
The directory to save state to.
"""
self._detector.save_state(filepath)

def load_state(self, filepath: Union[str, os.PathLike]):
"""
Load the detector's state from disk, in order to restart from a checkpoint previously generated with
:meth:`~save_state`.

Parameters
----------
filepath
The directory to load state from.
"""
self._detector.load_state(filepath)
Loading