-
Notifications
You must be signed in to change notification settings - Fork 229
Serialisation of (online) state for online detectors #604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Nice idea with the name change, and yes I agree, "state" does not refer to any attributes set in init, as that would be "config" (with our definitions). My only concern is that users might expect a detector to give the same predictions as the original when loaded from a "checkpoint" via Maybe the answer is just to make it clear in the docstrings though... as in any case statistically the detectors behaviour should be the same after the checkpoint even if the exact predictions are not the same? |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #604 +/- ##
==========================================
+ Coverage 80.15% 80.32% +0.17%
==========================================
Files 133 137 +4
Lines 9177 9292 +115
==========================================
+ Hits 7356 7464 +108
- Misses 1821 1828 +7
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Edit: Resolved. |
Regarding the codecov report, the |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@arnaudvl @ojcobb (and @jklaise/@mauicv) could do with your thoughts on this. In the latest implementation, I have removed the new Question: Shall we keep a As Example 2 shows, determinism in the case of saving/loading of a detector is not affected by this decision anyway... Difference between
|
Additional side-note, this issue with setting random seeds not giving deterministic behaviour for a given operation (in this case the The only solution I can think of for this is a scikit-learn style approach, where we accept |
@arnaudvl @ojcobb a possible alternative strategy to make def _configure_ref_subset(self):
"""
Configure reference subset. If already configured, the stateful attributes `test_window` and `k_xtc` are
reset without re-configuring a new reference subset.
"""
etw_size = 2 * self.window_size - 1 # etw = extended test window
nkc_size = self.n - self.n_kernel_centers # nkc = non-kernel-centers
rw_size = nkc_size - etw_size # rw = ref-window
# Check if already configured, we will re-initialise stateful attributes w/o searching for new ref split if so
configure_ref = self.init_test_inds is None
if configure_ref:
# Make split and ensure it doesn't cause an initial detection
lsdd_init = None
while lsdd_init is None or lsdd_init >= self.get_threshold(0):
# Make split
perm = torch.randperm(nkc_size)
self.ref_inds, self.init_test_inds = perm[:rw_size], perm[-self.window_size:]
self.test_window = self.x_ref_eff[self.init_test_inds]
# Compute initial lsdd to check for initial detection
self.c2s = self.k_xc[self.ref_inds].mean(0) # (below Eqn 21)
self.k_xtc = self.kernel(self.test_window, self.kernel_centers)
h_init = self.c2s - self.k_xtc.mean(0) # (Eqn 21)
lsdd_init = h_init[None, :] @ self.H_lam_inv @ h_init[:, None] # (Eqn 11)
else:
# Reset stateful attributes using existing split
self.test_window = self.x_ref_eff[self.init_test_inds]
self.k_xtc = self.kernel(self.test_window, self.kernel_centers) This seems like a reasonable compromise to me? However, the additional duplication/complexity is unnecessary if we truly don't care about repeatable predictions post- |
alibi_detect/base.py
Outdated
def save_state(self, filepath): ... | ||
|
||
def load_state(self, filepath): ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should parameters have type hints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in 15d1dfa. Note: I also updated the pre-existing get_config
and set_config
methods here.
Mmn good point, thinking again it doesn't seem ideal to have it outside of Re it becoming a public module I suspect you're mostly right. We do
Weirdly though, with our
|
alibi_detect/cd/base_online.py
Outdated
filepath | ||
The directory to load state from. | ||
""" | ||
self._set_state_dir(filepath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary when loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so that self.state_dir
is set (and converted from str
to pathlib.Path) when
load_stateis called as well as when
save_state` is called.
I thought it might be helpful to have state_dir
as a public attribute so that a user could see interrogate the detector to see where state was loaded from. Although thinking about it more, for the backend detectors one would have to do detector._detector.state_dir
(access a private attribute) anyway. I guess we'd probably want to define a @property
if we actually want to support this functionality properly...
Happy to just make it private if you think its better though...
alibi_detect/cd/base_online.py
Outdated
dirpath | ||
The directory to save state file inside. | ||
""" | ||
self.state_dir = Path(dirpath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a private attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #604 (comment)
alibi_detect/cd/base_online.py
Outdated
def _set_state_dir(self, dirpath: Union[str, os.PathLike]): | ||
""" | ||
Set the directory path to store state in, and create an empty directory if it doesn't already exist. | ||
|
||
Parameters | ||
---------- | ||
dirpath | ||
The directory to save state file inside. | ||
""" | ||
self.state_dir = Path(dirpath) | ||
self.state_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
def save_state(self, filepath: Union[str, os.PathLike]): | ||
""" | ||
Save a detector's state to disk in order to generate a checkpoint. | ||
|
||
Parameters | ||
---------- | ||
filepath | ||
The directory to save state to. | ||
""" | ||
self._set_state_dir(filepath) | ||
self._save_state() | ||
logger.info('Saved state for t={} to {}'.format(self.t, self.state_dir)) | ||
|
||
def load_state(self, filepath: Union[str, os.PathLike]): | ||
""" | ||
Load the detector's state from disk, in order to restart from a checkpoint previously generated with | ||
`save_state`. | ||
|
||
Parameters | ||
---------- | ||
filepath | ||
The directory to load state from. | ||
""" | ||
self._set_state_dir(filepath) | ||
self._load_state() | ||
logger.info('State loaded for t={} from {}'.format(self.t, self.state_dir)) | ||
|
||
def _save_state(self): | ||
""" | ||
Private method to save a detector's state to disk. | ||
|
||
TODO - Method slightly verbose as designed to facilitate saving of "offline" state in follow-up PR. | ||
""" | ||
filename = 'state' | ||
keys = self.online_state_keys | ||
save_state_dict(self, keys, self.state_dir.joinpath(filename + '.npz')) | ||
|
||
def _load_state(self, offline: bool = False): | ||
""" | ||
Private method to load a detector's state from disk. | ||
|
||
TODO - Method slightly verbose as designed to facilitate loading of "offline" state in follow-up PR. | ||
""" | ||
filename = 'state' | ||
load_state_dict(self, self.state_dir.joinpath(filename + '.npz'), raise_error=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a lot of duplicated code that's exactly the same as for BaseMultiDriftOnline
which suggests we may want to refactor using functions instead of methods or a mixin class? Or perhaps the class hierarchy needs to be updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this seems to apply to other methods too, so perhaps is a more widespread problem requiring a refactoring later...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point. Having a BaseDriftOnline
class for generic methods, or a mix-in both seem much nicer than this current pattern. I'll have a rethink 👍🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce duplication, 5daf1b1 adds a BaseDriftOnline
class. @jklaise @mauicv could I get your thoughts on the design of BaseDriftOnline
please? I've gone with a parent class rather than mix-in since it seems strange to define a mix-in in alibi_detect/base.py
when it is only to be used in two classes (BaseMultiDriftOnline
and BaseUniDriftOnline
). I also decided to put it in alibi_detect/cd/base_online.py
rather than alibi_detect/base.py
since at the moment the concept of "online" detectors is specific to drift (this may change if we decide stateful outlier detectors are in fact "online").
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM however noting that there's quite a few abstract methods, some of which (not all?) are implemented in the Multi/Uni
abstract child classes, which come with their own set of abstract methods... Worried that this may become a bit tricky to keep track of. As a minimum, would group all abstract methods to come after each other and add docstrings on expected implementation and also, where valid, which of the Multi/Uni
classes implement these methods (+ type hints as always).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #604 (comment) wrt to type hints, not sure on best approach here.
Wrt to the abstract methods, if they are missing from the Multi/Uni
child classes that will be because they are instead defined in the next subclass down i.e. LSDDDriftOnlineTorch._initialise_state
or CVMDriftOnline._configure_thresholds
...
We could move the abstract methods such as _configure_thresholds
back to their respective Multi/Uni
abstract class, at the cost of more duplication (but maybe less complexity?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the new base class, and moved state methods to StateMixin
. See #604 (comment).
# Skip if backend not `tensorflow` or `pytorch` | ||
if backend not in ('tensorflow', 'pytorch'): | ||
pytest.skip("Detector doesn't have this backend") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this due to some keops
behaviour? Basically asking why skip here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test_saving
file cycles through all possible backends:
backends = ['tensorflow', 'pytorch', 'sklearn']
if has_keops: # pykeops only installed in Linux CI
backends.append('keops')
backend = param_fixture("backend", backends)
We have to skip tests if the associated detector doesn't have that backend. In this case, online detectors do not have a keops
backend.
@pytest.mark.parametrize('batch_size', batch_size) | ||
@pytest.mark.parametrize('n_feat', n_features) | ||
def test_cvmdriftonline(window_sizes, batch_size, n_feat, seed): | ||
with fixed_seed(seed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noting that the previous version of tests didn't have a fixed seed, presumably it wasn't needed in this setting as test suite has been passing. Is there a need to introduce a fixed seed here as it seems detrimental to the testing for this particular set of tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P.S. same comment applies to tests below and in other modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmn yeh I figured it would be good to add since there is randomness in the initialization of these detectors (in _configure_thresholds
, and in the generation of x_ref
/x_h0
/x_h1
). Although the tests do currently pass without fixing the seed, this doesn't actually mean they pass for any random seed. I seem to recall that when I looked into this before, np.random.seed
's set in one test leaked into others. Presumably, this means we have been implicitly fixing the seed in these tests anyway.
Ideally (IMO), we'd get to a point where any random operations in tests are done inside with fixed_seed(seed)
's, then if a new bug is introduced, we can go back and reproduce it with the same random seed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that in case a bug happens then it's valuable to be able to reproduce with the same seed. But, on the other hand, "any random operations in tests done inside with fixed_seed(seed)
" sounds like the opposite to what we want to do (unless for tests where we compare outputs of the same seed) - as stuff should pass most tests with any seed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeh fair point, tests should generally pass with any seed, especially if they are unit type tests. The problem at the moment is we have lots of functional tests where we are testing a detector's predictions and checking things like Expected Runtime (ERT) for online detectors. We probably want more granular unit tests in lots of places...
Edit: by "any random operations in tests done inside with fixed_seed(seed)", I more meant any random operations that might for some reason affect the outcome of the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, main question about default behaviour wrt state saving when save_detector
is called on online detectors. Regardless of choice, I believe this should be prominent in saving and method docs.
alibi_detect/cd/base_online.py
Outdated
@abstractmethod | ||
def _update_state(self, x_t): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type hints of parameters and return types required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unfortunately necessary, since univariate online detectors have def _update_state(self, x_t: np.ndarray):
whilst multivariate have def _update_state(self, x_t: torch.Tensor):
(or tf.Tensor
). We violate Liskov's substitution principle slightly.
I sort of think this is OK to not add type hints in the abstract method since we only have it there to signal that sub-classes must have an _update_method
which takes an instance and updates online state, but we don't specify the exact type. However, we could also do def _update_state(self, x_t: Union[np.ndarray, 'torch.Tensor', 'tf.Tensor'):
and then add # type: ignore[override]
in the sub-class? This is actually what we did previously...
P.s. its a similar story for the get_threshold
method...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d190589 removes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Should we add some documentation somewhere that save
will save the state by default, and if that's not desired one should call reset_state
first? Or is it going to be too confusing for now?
Thanks! Will add this documentation now, just realised I added it in #628 instead of here. Doh! |
doc/source/overview/saving.md
Outdated
[Online drift detectors](../cd/methods.md#online) are stateful, with their state updated upon each `predict` call. | ||
When saving an online detector, the `save_state` option controls whether to include the detector's state: | ||
[Online drift detectors](../cd/methods.md#online) are stateful, with their state updated each timestep `t` (each time | ||
`.predict()` or `.state()` is called). {func}`~alibi_detect.saving.save_detector` will save the state of online |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean score()
instead of state()
? On that note, do we even document the usage/use cases of score()
? If not, perhaps should leave it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot thanks. Also fair point about not really documenting it. I'll remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed .state()
in d888276
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR implements the functionality to save and load state for online detectors. At a given time step, the
save_state
method can be called to create a "checkpoint". This can later be loaded via theload_state
method. At any time, thereset_state
reset
method can be used to reset the state back to thet=0
timestep.Scope
This PR deals with online state only. See #604 (comment) for a discussion on online versus offline state.
Example(s)
Saving and loading state
The state of online detectors can now be saved and loaded via
save_state
andload_state
:The detector's state may be saved with the
save_state
method:The previously saved state may then be loaded via the
load_state
method:At any point, the state may be reset with the
reset
method. Also see colab notebook.Saving and loading detector with state
Calling
save_detector
withsave_state=True
will save an online detectors state tostate/
within the detector save directory.load_detector
will simply attempt to load state if astate/
directory exists.TODO's:
save_detector
andload_detector
functions, to allow state to be saved and loaded when the detector itself is serialized/unserialized.test_saving.py
state test.## Outstanding considerations (specific to LSDD for now but maybe more widely applicible)There might be an open question to resolve regarding what we define "state" to be. This PR currently considers it to be only the attributes that are updated in_update_state
(self.t
,self.test_window
andself.k_xtc
). In other words, "state" is defined as any attribute that is dependent on time (updated when a new instancex_t
is given viascore
orpredict
).However, there is already a notion of "state" introduced when weinitialise
a detector (or reinitialise it via thereset
method). Here, in addition to the attributes already mentioned, we setself.ref_inds
,self.c2s
, andself.init_test_inds
. This leads to considerations:1. Will there be confusion between thereset
andreset_state
methods, and do we need to change the docstrings or names?2. There is randomness involved in the initialisation of
LSDDDrift
(in_configure_ref_subset
). It is likely that if the detector is instantiated later on, andload_state
is used to restart from a checkpoint, predictions will still be different compared to those that were observed aftersave_state
was called with the original detector. This would only be avoided if random seeds were set both times. With this in mind, do we want to change our definition of "state" to includeself.ref_inds
,self.c2s
, andself.init_test_inds
?