Skip to content

Commit 1af9fc0

Browse files
authored
Return the OOF instances for classifier-based drift detectors (ClassifierDrift and SpotTheDiffDrift) (#665)
* add out-of-fold instances to the return dict for classifier detectors * update docs * update score return type * fix typo and mypy error * extend to list inputs and update score return types * add Union import * fix type error * add changelog
1 parent 47a0134 commit 1af9fc0

File tree

10 files changed

+97
-46
lines changed

10 files changed

+97
-46
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- **New feature** MMD drift detector has been extended with a [KeOps](https://www.kernel-operations.io/keops/index.html) backend to scale and speed up the detector.
88
See the [documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/mmddrift.html) and [example notebook](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mmd_keops.html) for more info ([#548](https://github.com/SeldonIO/alibi-detect/pull/548)).
99
- If a `categories_per_feature` dictionary is not passed to `TabularDrift`, a warning is now raised to inform the user that all features are assumed to be numerical ([#606](https://github.com/SeldonIO/alibi-detect/pull/606)).
10+
- For the `ClassifierDrift` and `SpotTheDiffDrift` detectors, we can also return the out-of-fold instances of the reference and test sets. When using `train_size` for training the detector, this allows to associate the returned prediction probabilities with the correct instances.
1011

1112
### Changed
1213
- Minimum `prophet` version bumped to `1.1.0` (used by `OutlierProphet`). This upgrade removes the dependency on `pystan` as `cmdstanpy` is used instead. This version also comes with pre-built wheels for all major platforms and Python versions, making both installation and testing easier ([#627](https://github.com/SeldonIO/alibi-detect/pull/627)).

alibi_detect/cd/base.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def test_probs(
240240
return p_val, dist
241241

242242
@abstractmethod
243-
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
243+
def score(self, x: Union[np.ndarray, list]) \
244+
-> Tuple[float, float, np.ndarray, np.ndarray, Union[np.ndarray, list], Union[np.ndarray, list]]:
244245
pass
245246

246247
def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
@@ -260,7 +261,8 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
260261
K-S test stat if binarize_preds=False, otherwise relative error reduction.
261262
return_probs
262263
Whether to return the instance level classifier probabilities for the reference and test data
263-
(0=reference data, 1=test data).
264+
(0=reference data, 1=test data). The reference and test instances of the associated
265+
probabilities are also returned.
264266
return_model
265267
Whether to return the updated model trained to discriminate reference and test instances.
266268
@@ -270,10 +272,11 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
270272
'meta' has the model's metadata.
271273
'data' contains the drift prediction and optionally the p-value, performance of the classifier
272274
relative to its expectation under the no-change null, the out-of-fold classifier model
273-
prediction probabilities on the reference and test data, and the trained model.
275+
prediction probabilities on the reference and test data as well as the associated reference
276+
and test instances of the out-of-fold predictions, and the trained model.
274277
"""
275278
# compute drift scores
276-
p_val, dist, probs_ref, probs_test = self.score(x)
279+
p_val, dist, probs_ref, probs_test, x_ref_oof, x_test_oof = self.score(x)
277280
drift_pred = int(p_val < self.p_val)
278281

279282
# update reference dataset
@@ -297,6 +300,8 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
297300
if return_probs:
298301
cd['data']['probs_ref'] = probs_ref
299302
cd['data']['probs_test'] = probs_test
303+
cd['data']['x_ref_oof'] = x_ref_oof
304+
cd['data']['x_test_oof'] = x_test_oof
300305
if return_model:
301306
cd['data']['model'] = self.model
302307
return cd

alibi_detect/cd/pytorch/classifier.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def __init__(
157157
if isinstance(train_kwargs, dict):
158158
self.train_kwargs.update(train_kwargs)
159159

160-
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
160+
def score(self, x: Union[np.ndarray, list]) \
161+
-> Tuple[float, float, np.ndarray, np.ndarray, Union[np.ndarray, list], Union[np.ndarray, list]]:
161162
"""
162163
Compute the out-of-fold drift metric such as the accuracy from a classifier
163164
trained to distinguish the reference data from the data to be tested.
@@ -171,7 +172,8 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
171172
-------
172173
p-value, a notion of distance between the trained classifier's out-of-fold performance \
173174
and that which we'd expect under the null assumption of no drift, \
174-
and the out-of-fold classifier model prediction probabilities on the reference and test data
175+
and the out-of-fold classifier model prediction probabilities on the reference and test data \
176+
as well as the associated reference and test instances of the out-of-fold predictions.
175177
"""
176178
x_ref, x = self.preprocess(x)
177179
x, y, splits = self.get_splits(x_ref, x) # type: ignore
@@ -202,5 +204,12 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
202204
n_cur = y_oof.sum()
203205
n_ref = len(y_oof) - n_cur
204206
p_val, dist = self.test_probs(y_oof, probs_oof, n_ref, n_cur)
205-
probs_sort = probs_oof[np.argsort(idx_oof)]
206-
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1]
207+
idx_sort = np.argsort(idx_oof)
208+
probs_sort = probs_oof[idx_sort]
209+
if isinstance(x, np.ndarray):
210+
x_oof = x[idx_oof]
211+
x_sort = x_oof[idx_sort]
212+
else:
213+
x_oof = [x[_] for _ in idx_oof]
214+
x_sort = [x_oof[_] for _ in idx_sort]
215+
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1], x_sort[:n_ref], x_sort[n_ref:]

alibi_detect/cd/pytorch/spot_the_diff.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def predict(
219219
'data' contains the drift prediction, the diffs used to distinguish reference from test instances,
220220
and optionally the p-value, performance of the classifier relative to its expectation under the
221221
no-change null, the out-of-fold classifier model prediction probabilities on the reference and test
222-
data, and the trained model.
222+
data as well as well as the associated reference and test instances of the out-of-fold predictions,
223+
and the trained model.
223224
"""
224225
preds = self._detector.predict(x, return_p_val, return_distance, return_probs, return_model=True)
225226
preds['data']['diffs'] = preds['data']['model'].diffs.detach().cpu().numpy() # type: ignore

alibi_detect/cd/sklearn/classifier.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ def predict_proba(self, X):
229229

230230
return model
231231

232-
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
232+
def score(self, x: Union[np.ndarray, list]) \
233+
-> Tuple[float, float, np.ndarray, np.ndarray, Union[np.ndarray, list], Union[np.ndarray, list]]:
233234
"""
234235
Compute the out-of-fold drift metric such as the accuracy from a classifier
235236
trained to distinguish the reference data from the data to be tested.
@@ -243,14 +244,16 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
243244
-------
244245
p-value, a notion of distance between the trained classifier's out-of-fold performance \
245246
and that which we'd expect under the null assumption of no drift, \
246-
and the out-of-fold classifier model prediction probabilities on the reference and test data
247+
and the out-of-fold classifier model prediction probabilities on the reference and test data \
248+
as well as the associated reference and test instances of the out-of-fold predictions.
247249
"""
248250
if self.use_oob and isinstance(self.model, RandomForestClassifier):
249251
return self._score_rf(x)
250252

251253
return self._score(x)
252254

253-
def _score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
255+
def _score(self, x: Union[np.ndarray, list]) \
256+
-> Tuple[float, float, np.ndarray, np.ndarray, Union[np.ndarray, list], Union[np.ndarray, list]]:
254257
x_ref, x = self.preprocess(x)
255258
x, y, splits = self.get_splits(x_ref, x, return_splits=True) # type: ignore
256259

@@ -274,20 +277,34 @@ def _score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray,
274277
n_cur = y_oof.sum()
275278
n_ref = len(y_oof) - n_cur
276279
p_val, dist = self.test_probs(y_oof, probs_oof, n_ref, n_cur)
277-
probs_sort = probs_oof[np.argsort(idx_oof)]
278-
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1]
280+
idx_sort = np.argsort(idx_oof)
281+
probs_sort = probs_oof[idx_sort]
282+
if isinstance(x, np.ndarray):
283+
x_oof = x[idx_oof]
284+
x_sort = x_oof[idx_sort]
285+
else:
286+
x_oof = [x[_] for _ in idx_oof]
287+
x_sort = [x_oof[_] for _ in idx_sort]
288+
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1], x_sort[:n_ref], x_sort[n_ref:]
279289

280-
def _score_rf(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
290+
def _score_rf(self, x: Union[np.ndarray, list]) \
291+
-> Tuple[float, float, np.ndarray, np.ndarray, Union[np.ndarray, list], Union[np.ndarray, list]]:
281292
x_ref, x = self.preprocess(x)
282293
x, y = self.get_splits(x_ref, x, return_splits=False) # type: ignore
283294
self.model.fit(x, y)
284295
# it is possible that some inputs do not have OOB scores. This is probably means
285296
# that too few trees were used to compute any reliable estimates.
286-
index_oob = np.where(np.all(~np.isnan(self.model.oob_decision_function_), axis=1))[0]
287-
probs_oob = self.model.oob_decision_function_[index_oob]
288-
y_oob = y[index_oob]
297+
idx_oob = np.where(np.all(~np.isnan(self.model.oob_decision_function_), axis=1))[0]
298+
probs_oob = self.model.oob_decision_function_[idx_oob]
299+
y_oob = y[idx_oob]
300+
if isinstance(x, np.ndarray):
301+
x_oob = x[idx_oob]
302+
elif isinstance(x, list):
303+
x_oob = [x[_] for _ in idx_oob]
304+
else:
305+
raise TypeError(f'x needs to be of type np.ndarray or list and not {type(x)}.')
289306
# comparison due to ordering in get_split (i.e, x = [x_ref, x])
290-
n_ref = np.sum(index_oob < len(x_ref)).item()
291-
n_cur = np.sum(index_oob >= len(x_ref)).item()
307+
n_ref = np.sum(idx_oob < len(x_ref)).item()
308+
n_cur = np.sum(idx_oob >= len(x_ref)).item()
292309
p_val, dist = self.test_probs(y_oob, probs_oob, n_ref, n_cur)
293-
return p_val, dist, probs_oob[:n_ref, 1], probs_oob[n_ref:, 1]
310+
return p_val, dist, probs_oob[:n_ref, 1], probs_oob[n_ref:, 1], x_oob[:n_ref], x_oob[n_ref:]

alibi_detect/cd/spot_the_diff.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def predict(
181181
'data' contains the drift prediction, the diffs used to distinguish reference from test instances,
182182
and optionally the p-value, performance of the classifier relative to its expectation under the
183183
no-change null, the out-of-fold classifier model prediction probabilities on the reference and test
184-
data, and the trained model.
184+
data as well as well as the associated reference and test instances of the out-of-fold predictions,
185+
and the trained model.
185186
"""
186187
return self._detector.predict(x, return_p_val, return_distance, return_probs, return_model)

alibi_detect/cd/tensorflow/classifier.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tensorflow as tf
44
from tensorflow.keras.losses import BinaryCrossentropy
55
from scipy.special import softmax
6-
from typing import Callable, Dict, Optional, Tuple
6+
from typing import Callable, Dict, Optional, Tuple, Union
77
from alibi_detect.cd.base import BaseClassifierDrift
88
from alibi_detect.models.tensorflow.trainer import trainer
99
from alibi_detect.utils.tensorflow.data import TFDataset
@@ -144,7 +144,8 @@ def __init__(
144144
if isinstance(train_kwargs, dict):
145145
self.train_kwargs.update(train_kwargs)
146146

147-
def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]: # type: ignore[override]
147+
def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray, # type: ignore[override]
148+
Union[np.ndarray, list], Union[np.ndarray, list]]:
148149
"""
149150
Compute the out-of-fold drift metric such as the accuracy from a classifier
150151
trained to distinguish the reference data from the data to be tested.
@@ -158,7 +159,8 @@ def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]:
158159
-------
159160
p-value, a notion of distance between the trained classifier's out-of-fold performance \
160161
and that which we'd expect under the null assumption of no drift, \
161-
and the out-of-fold classifier model prediction probabilities on the reference and test data
162+
and the out-of-fold classifier model prediction probabilities on the reference and test data \
163+
as well as the associated reference and test instances of the out-of-fold predictions.
162164
"""
163165
x_ref, x = self.preprocess(x) # type: ignore[assignment]
164166
x, y, splits = self.get_splits(x_ref, x) # type: ignore
@@ -189,5 +191,12 @@ def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]:
189191
n_cur = y_oof.sum()
190192
n_ref = len(y_oof) - n_cur
191193
p_val, dist = self.test_probs(y_oof, probs_oof, n_ref, n_cur)
192-
probs_sort = probs_oof[np.argsort(idx_oof)]
193-
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1]
194+
idx_sort = np.argsort(idx_oof)
195+
probs_sort = probs_oof[idx_sort]
196+
if isinstance(x, np.ndarray):
197+
x_oof = x[idx_oof]
198+
x_sort = x_oof[idx_sort]
199+
else:
200+
x_oof = [x[_] for _ in idx_oof]
201+
x_sort = [x_oof[_] for _ in idx_sort]
202+
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1], x_sort[:n_ref], x_sort[n_ref:]

alibi_detect/cd/tensorflow/spot_the_diff.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def predict(
215215
'data' contains the drift prediction, the diffs used to distinguish reference from test instances,
216216
and optionally the p-value, performance of the classifier relative to its expectation under the
217217
no-change null, the out-of-fold classifier model prediction probabilities on the reference and test
218-
data, and the trained model.
218+
data as well as well as the associated reference and test instances of the out-of-fold predictions,
219+
and the trained model.
219220
"""
220221
preds = self._detector.predict(x, return_p_val, return_distance, return_probs, return_model=True)
221222
preds['data']['diffs'] = preds['data']['model'].diffs.numpy() # type: ignore

doc/source/cd/methods/classifierdrift.ipynb

+4-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
"source": [
140140
"### Detect Drift\n",
141141
"\n",
142-
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift and `return_probs` equals *True* also returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data).\n",
142+
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift and `return_probs` equals *True* also returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) as well as the associated out-of-fold reference and test instances.\n",
143143
"\n",
144144
"The prediction takes the form of a dictionary with `meta` and `data` keys. `meta` contains the detector's metadata while `data` is also a dictionary which contains the actual predictions stored in the following keys:\n",
145145
"\n",
@@ -155,6 +155,9 @@
155155
"\n",
156156
"* `probs_test`: the instance level prediction probability for the test data `x` if `return_probs` is *true*.\n",
157157
"\n",
158+
"* `x_ref_oof`: the instances associated with `probs_ref` if `return_probs` equals *True*.\n",
159+
"\n",
160+
"* `x_test_oof`: the instances associated with `probs_test` if `return_probs` equals *True*.\n",
158161
"\n",
159162
"```python\n",
160163
"preds = cd.predict(x)\n",

doc/source/cd/methods/spotthediffdrift.ipynb

+21-17
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"* `initial_diffs`: Array used to initialise the diffs that will be learned. Defaults to Gaussian for each feature with equal variance to that of reference data.\n",
5454
"\n",
5555
"* `l1_reg`: Strength of l1 regularisation to apply to the differences.\n",
56-
" \n",
56+
"\n",
5757
"* `binarize_preds`: Whether to test for discrepency on soft (e.g. probs/logits) model predictions directly with a K-S test or binarise to 0-1 prediction errors and apply a binomial test.\n",
5858
"\n",
5959
"* `train_size`: Optional fraction (float between 0 and 1) of the dataset used to train the classifier. The drift is detected on *1 - train_size*. Cannot be used in combination with `n_folds`.\n",
@@ -109,12 +109,12 @@
109109
"from alibi_detect.cd import SpotTheDiffDrift\n",
110110
"\n",
111111
"cd = SpotTheDiffDrift(\n",
112-
" x_ref, \n",
113-
" backend='pytorch', \n",
114-
" p_val=.05, \n",
115-
" n_diffs=1, \n",
116-
" l1_reg=1e-3, \n",
117-
" epochs=10, \n",
112+
" x_ref,\n",
113+
" backend='pytorch',\n",
114+
" p_val=.05,\n",
115+
" n_diffs=1,\n",
116+
" l1_reg=1e-3,\n",
117+
" epochs=10,\n",
118118
" batch_size=32\n",
119119
")\n",
120120
"\n",
@@ -143,13 +143,13 @@
143143
"\n",
144144
"# instantiate the detector\n",
145145
"cd = SpotTheDiffDrift(\n",
146-
" x_ref, \n",
147-
" backend='tensorflow', \n",
148-
" p_val=.05, \n",
149-
" kernel=kernel, \n",
150-
" n_diffs=1, \n",
151-
" l1_reg=1e-3, \n",
152-
" epochs=10, \n",
146+
" x_ref,\n",
147+
" backend='tensorflow',\n",
148+
" p_val=.05,\n",
149+
" kernel=kernel,\n",
150+
" n_diffs=1,\n",
151+
" l1_reg=1e-3,\n",
152+
" epochs=10,\n",
153153
" batch_size=32\n",
154154
")\n",
155155
"```"
@@ -161,7 +161,7 @@
161161
"source": [
162162
"### Detect Drift\n",
163163
"\n",
164-
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift, `return_probs` equals *True* returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) and `return_kernel` equals *True* will also return the trained kernel.\n",
164+
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift, `return_probs` equals *True* returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) as well as the associated out-of-fold reference and test instances, and `return_kernel` equals *True* will also return the trained kernel.\n",
165165
"\n",
166166
"The prediction takes the form of a dictionary with `meta` and `data` keys. `meta` contains the detector's metadata while `data` is also a dictionary which contains the actual predictions stored in the following keys:\n",
167167
"\n",
@@ -181,6 +181,10 @@
181181
"\n",
182182
"* `probs_test`: the instance level prediction probability for the test data `x` if `return_probs` is *true*.\n",
183183
"\n",
184+
"* `x_ref_oof`: the instances associated with `probs_ref` if `return_probs` equals *True*.\n",
185+
"\n",
186+
"* `x_test_oof`: the instances associated with `probs_test` if `return_probs` equals *True*.\n",
187+
"\n",
184188
"* `kernel`: The trained kernel if `return_kernel` equals *True*.\n",
185189
"\n",
186190
"\n",
@@ -201,7 +205,7 @@
201205
],
202206
"metadata": {
203207
"kernelspec": {
204-
"display_name": "Python 3 (ipykernel)",
208+
"display_name": "Python 3",
205209
"language": "python",
206210
"name": "python3"
207211
},
@@ -215,7 +219,7 @@
215219
"name": "python",
216220
"nbconvert_exporter": "python",
217221
"pygments_lexer": "ipython3",
218-
"version": "3.8.11"
222+
"version": "3.7.6"
219223
}
220224
},
221225
"nbformat": 4,

0 commit comments

Comments
 (0)