Skip to content

Commit a7c883d

Browse files
committed
Feature Optional backends functionality (#538)
* Add BackendValidator class * Protect ad, od and cd and other API objects from tensorflow and torch optional dependency import errors * Update import statements in notebooks
1 parent 642de25 commit a7c883d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+701
-242
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
if [ "$RUNNER_OS" != "Windows" ] && [ ${{ matrix.python }} < '3.10' ]; then # Skip Prophet tests on Windows as installation complex. Skip on Python 3.10 as not supported.
5555
python -m pip install --upgrade --upgrade-strategy eager -e .[prophet]
5656
fi
57-
python -m pip install --upgrade --upgrade-strategy eager -e .[torch]
57+
python -m pip install --upgrade --upgrade-strategy eager -e .[tensorflow,torch]
5858
python -m pip freeze
5959
6060
- name: Lint with flake8

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ We will use the [VAE outlier detector](https://docs.seldon.io/projects/alibi-det
117117

118118
```python
119119
from alibi_detect.od import OutlierVAE
120-
from alibi_detect.utils import save_detector, load_detector
120+
from alibi_detect.utils.saving import save_detector, load_detector
121121

122122
# initialize and fit detector
123123
od = OutlierVAE(threshold=0.1, encoder_net=encoder_net, decoder_net=decoder_net, latent_dim=1024)

alibi_detect/ad/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from .adversarialae import AdversarialAE
2-
from .model_distillation import ModelDistillation
1+
from alibi_detect.utils.missing_optional_dependency import import_optional
2+
3+
AdversarialAE = import_optional('alibi_detect.ad.adversarialae', names=['AdversarialAE'])
4+
ModelDistillation = import_optional('alibi_detect.ad.model_distillation', names=['ModelDistillation'])
35

46
__all__ = [
57
"AdversarialAE",

alibi_detect/cd/classifier.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from typing import Callable, Dict, Optional, Union
3-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, has_sklearn
3+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, has_sklearn, \
4+
BackendValidator
45
from alibi_detect.base import DriftConfigMixin
56

67
if has_sklearn:
@@ -147,12 +148,12 @@ def __init__(
147148
self._set_config(locals())
148149

149150
backend = backend.lower()
150-
if (backend == 'tensorflow' and not has_tensorflow) or (backend == 'pytorch' and not has_pytorch) or \
151-
(backend == 'sklearn' and not has_sklearn):
152-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
153-
f'ClassifierDrift detector with {backend} backend.')
154-
elif backend not in ['tensorflow', 'pytorch', 'sklearn']:
155-
raise NotImplementedError(f"{backend} not implemented. Use 'tensorflow', 'pytorch' or 'sklearn' instead.")
151+
BackendValidator(
152+
backend_options={'tensorflow': ['tensorflow'],
153+
'pytorch': ['pytorch'],
154+
'sklearn': ['sklearn']},
155+
construct_name='ClassifierDrift'
156+
).verify_backend(backend)
156157

157158
kwargs = locals()
158159
args = [kwargs['x_ref'], kwargs['model']]

alibi_detect/cd/context_aware.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import numpy as np
33
from typing import Callable, Dict, Optional, Union, Tuple
4-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
4+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
55
from alibi_detect.utils.warnings import deprecated_alias
66
from alibi_detect.base import DriftConfigMixin
77

@@ -92,11 +92,9 @@ def __init__(
9292
self._set_config(locals())
9393

9494
backend = backend.lower()
95-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
96-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
97-
f'ContextMMMDrift detector with {backend} backend.')
98-
elif backend not in ['tensorflow', 'pytorch']:
99-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
95+
BackendValidator(backend_options={'tensorflow': ['tensorflow'],
96+
'pytorch': ['pytorch']},
97+
construct_name='ContextMMDDrift').verify_backend(backend)
10098

10199
kwargs = locals()
102100
args = [kwargs['x_ref'], kwargs['c_ref']]

alibi_detect/cd/learned_kernel.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from typing import Callable, Dict, Optional, Union
3-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
3+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
44
from alibi_detect.utils.warnings import deprecated_alias
55
from alibi_detect.base import DriftConfigMixin
66

@@ -121,11 +121,11 @@ def __init__(
121121
self._set_config(locals())
122122

123123
backend = backend.lower()
124-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
125-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
126-
f'LearnedKernel detector with {backend} backend.')
127-
elif backend not in ['tensorflow', 'pytorch']:
128-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
124+
BackendValidator(
125+
backend_options={'tensorflow': ['tensorflow'],
126+
'pytorch': ['pytorch']},
127+
construct_name='LearnedKernelDrift'
128+
).verify_backend(backend)
129129

130130
kwargs = locals()
131131
args = [kwargs['x_ref'], kwargs['kernel']]

alibi_detect/cd/lsdd.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from typing import Callable, Dict, Optional, Union, Tuple
3-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
3+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
44
from alibi_detect.utils.warnings import deprecated_alias
55
from alibi_detect.base import DriftConfigMixin
66

@@ -81,11 +81,11 @@ def __init__(
8181
self._set_config(locals())
8282

8383
backend = backend.lower()
84-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
85-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
86-
f'LSDDDrift detector with {backend} backend.')
87-
elif backend not in ['tensorflow', 'pytorch']:
88-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
84+
BackendValidator(
85+
backend_options={'tensorflow': ['tensorflow'],
86+
'pytorch': ['pytorch']},
87+
construct_name='LSDDDrift'
88+
).verify_backend(backend)
8989

9090
kwargs = locals()
9191
args = [kwargs['x_ref']]

alibi_detect/cd/lsdd_online.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from typing import Any, Callable, Dict, Optional, Union
3-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
3+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
44
from alibi_detect.base import DriftConfigMixin
55
if has_pytorch:
66
from alibi_detect.cd.pytorch.lsdd_online import LSDDDriftOnlineTorch
@@ -82,11 +82,11 @@ def __init__(
8282
self._set_config(locals())
8383

8484
backend = backend.lower()
85-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
86-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
87-
f'MMDDrift detector with {backend} backend.')
88-
elif backend not in ['tensorflow', 'pytorch']:
89-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
85+
BackendValidator(
86+
backend_options={'tensorflow': ['tensorflow'],
87+
'pytorch': ['pytorch']},
88+
construct_name='LSDDDriftOnline'
89+
).verify_backend(backend)
9090

9191
kwargs = locals()
9292
args = [kwargs['x_ref'], kwargs['ert'], kwargs['window_size']]

alibi_detect/cd/mmd.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import numpy as np
33
from typing import Callable, Dict, Optional, Union, Tuple
4-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
4+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
55
from alibi_detect.utils.warnings import deprecated_alias
66
from alibi_detect.base import DriftConfigMixin
77

@@ -80,11 +80,11 @@ def __init__(
8080
self._set_config(locals())
8181

8282
backend = backend.lower()
83-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
84-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
85-
f'MMDDrift detector with {backend} backend.')
86-
elif backend not in ['tensorflow', 'pytorch']:
87-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
83+
BackendValidator(
84+
backend_options={'tensorflow': ['tensorflow'],
85+
'pytorch': ['pytorch']},
86+
construct_name='MMDDrift'
87+
).verify_backend(backend)
8888

8989
kwargs = locals()
9090
args = [kwargs['x_ref']]

alibi_detect/cd/mmd_online.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from typing import Any, Callable, Dict, Optional, Union
3-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
3+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
44
from alibi_detect.base import DriftConfigMixin
55

66
if has_pytorch:
@@ -75,11 +75,11 @@ def __init__(
7575
self._set_config(locals())
7676

7777
backend = backend.lower()
78-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
79-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
80-
f'MMDDrift detector with {backend} backend.')
81-
elif backend not in ['tensorflow', 'pytorch']:
82-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
78+
BackendValidator(
79+
backend_options={'tensorflow': ['tensorflow'],
80+
'pytorch': ['pytorch']},
81+
construct_name='MMDDriftOnline'
82+
).verify_backend(backend)
8383

8484
kwargs = locals()
8585
args = [kwargs['x_ref'], kwargs['ert'], kwargs['window_size']]

alibi_detect/cd/model_uncertainty.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from alibi_detect.cd.chisquare import ChiSquareDrift
77
from alibi_detect.cd.preprocess import classifier_uncertainty, regressor_uncertainty
88
from alibi_detect.cd.utils import encompass_batching, encompass_shuffling_and_batch_filling
9-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
9+
from alibi_detect.utils.frameworks import BackendValidator
1010
from alibi_detect.base import DriftConfigMixin
1111

1212
logger = logging.getLogger(__name__)
@@ -83,9 +83,12 @@ def __init__(
8383
# Set config
8484
self._set_config(locals())
8585

86-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
87-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
88-
f'ClassifierUncertaintyDrift detector with {backend} backend.')
86+
if backend:
87+
backend = backend.lower()
88+
BackendValidator(backend_options={'tensorflow': ['tensorflow'],
89+
'pytorch': ['pytorch'],
90+
None: []},
91+
construct_name='ClassifierUncertaintyDrift').verify_backend(backend)
8992

9093
if backend is None:
9194
if device not in [None, 'cpu']:
@@ -233,9 +236,12 @@ def __init__(
233236
# Set config
234237
self._set_config(locals())
235238

236-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
237-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
238-
f'RegressorUncertaintyDrift detector with {backend} backend.')
239+
if backend:
240+
backend = backend.lower()
241+
BackendValidator(backend_options={'tensorflow': ['tensorflow'],
242+
'pytorch': ['pytorch'],
243+
None: []},
244+
construct_name='RegressorUncertaintyDrift').verify_backend(backend)
239245

240246
if backend is None:
241247
model_fn = model

alibi_detect/cd/pytorch/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from .preprocess import HiddenOutput, preprocess_drift
1+
from alibi_detect.utils.missing_optional_dependency import import_optional
2+
3+
HiddenOutput, preprocess_drift = import_optional(
4+
'alibi_detect.cd.pytorch.preprocess',
5+
names=['HiddenOutput', 'preprocess_drift'])
26

37
__all__ = [
48
"HiddenOutput",

alibi_detect/cd/pytorch/context_aware.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,5 @@ def _sigma_median_diag(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) ->
274274
The computed bandwidth, `sigma`.
275275
"""
276276
n_median = np.prod(dist.shape) // 2
277-
sigma = (.5 * dist.flatten().sort().values[n_median].unsqueeze(dim=-1)) ** .5
277+
sigma = (.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** .5
278278
return sigma

alibi_detect/cd/spot_the_diff.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from typing import Callable, Dict, Optional, Union
3-
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
3+
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator
44
from alibi_detect.base import DriftConfigMixin
55

66
if has_pytorch:
@@ -126,12 +126,9 @@ def __init__(
126126
self._set_config(locals())
127127

128128
backend = backend.lower()
129-
if backend == 'tensorflow' and not has_tensorflow or backend == 'pytorch' and not has_pytorch:
130-
raise ImportError(f'{backend} not installed. Cannot initialize and run the '
131-
f'SpotTheDiffDrift detector with {backend} backend.')
132-
elif backend not in ['tensorflow', 'pytorch']:
133-
raise NotImplementedError(f'{backend} not implemented. Use tensorflow or pytorch instead.')
134-
129+
BackendValidator(backend_options={'tensorflow': ['tensorflow'],
130+
'pytorch': ['pytorch']},
131+
construct_name='SpotTheDiffDrift').verify_backend(backend)
135132
kwargs = locals()
136133
args = [kwargs['x_ref']]
137134
pop_kwargs = ['self', 'x_ref', 'backend', '__class__']

alibi_detect/cd/tensorflow/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from .preprocess import HiddenOutput, UAE, preprocess_drift
1+
from alibi_detect.utils.missing_optional_dependency import import_optional
2+
3+
HiddenOutput, UAE, preprocess_drift = import_optional(
4+
'alibi_detect.cd.tensorflow.preprocess',
5+
names=['HiddenOutput', 'UAE', 'preprocess_drift']
6+
)
27

38
__all__ = [
49
"HiddenOutput",

alibi_detect/datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010
import requests
1111
from alibi_detect.utils.data import Bunch
12-
from alibi_detect.utils.fetching import _join_url
12+
from alibi_detect.utils.url import _join_url
1313
from requests import RequestException
1414
from scipy.io import arff
1515
from sklearn.datasets import fetch_kddcup99

alibi_detect/models/pytorch/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from .embedding import TransformerEmbedding
2-
from .trainer import trainer
1+
from alibi_detect.utils.missing_optional_dependency import import_optional
2+
3+
TransformerEmbedding = import_optional(
4+
'alibi_detect.models.pytorch.embedding',
5+
names=['TransformerEmbedding'])
6+
7+
trainer = import_optional(
8+
'alibi_detect.models.pytorch.trainer',
9+
names=['trainer'])
10+
311

412
__all__ = [
513
"TransformerEmbedding",
+36-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,31 @@
1-
from .autoencoder import AE, AEGMM, VAE, VAEGMM, Seq2Seq
2-
from .embedding import TransformerEmbedding
3-
from .pixelcnn import PixelCNN
4-
from .resnet import resnet
5-
from .trainer import trainer
1+
from alibi_detect.utils.missing_optional_dependency import import_optional
2+
3+
4+
AE, AEGMM, VAE, VAEGMM, Seq2Seq, eucl_cosim_features = import_optional(
5+
'alibi_detect.models.tensorflow.autoencoder',
6+
names=['AE', 'AEGMM', 'VAE', 'VAEGMM', 'Seq2Seq', 'eucl_cosim_features'])
7+
8+
TransformerEmbedding = import_optional(
9+
'alibi_detect.models.tensorflow.embedding',
10+
names=['TransformerEmbedding'])
11+
12+
PixelCNN = import_optional(
13+
'alibi_detect.models.tensorflow.pixelcnn',
14+
names=['PixelCNN'])
15+
16+
resnet, scale_by_instance = import_optional(
17+
'alibi_detect.models.tensorflow.resnet',
18+
names=['resnet', 'scale_by_instance'])
19+
20+
trainer = import_optional(
21+
'alibi_detect.models.tensorflow.trainer',
22+
names=['trainer'])
23+
24+
loss_aegmm, loss_adv_ae, loss_distillation, elbo, loss_vaegmm = import_optional(
25+
'alibi_detect.models.tensorflow.losses',
26+
names=['loss_aegmm', 'loss_adv_ae', 'loss_distillation', 'elbo', 'loss_vaegmm']
27+
)
28+
629

730
__all__ = [
831
"AE",
@@ -11,7 +34,14 @@
1134
"VAE",
1235
"VAEGMM",
1336
"resnet",
37+
"scale_by_instance",
1438
"PixelCNN",
1539
"TransformerEmbedding",
16-
"trainer"
40+
"trainer",
41+
"eucl_cosim_features",
42+
"elbo",
43+
"loss_aegmm",
44+
"loss_vaegmm",
45+
"loss_adv_ae",
46+
"loss_distillation"
1747
]

alibi_detect/od/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from .aegmm import OutlierAEGMM
1+
from alibi_detect.utils.missing_optional_dependency import import_optional
2+
23
from .isolationforest import IForest
34
from .mahalanobis import Mahalanobis
4-
from .ae import OutlierAE
5-
from .vae import OutlierVAE
6-
from .vaegmm import OutlierVAEGMM
75
from .prophet import PROPHET_INSTALLED, OutlierProphet
8-
from .seq2seq import OutlierSeq2Seq
96
from .sr import SpectralResidual
10-
from .llr import LLR
7+
8+
OutlierAEGMM = import_optional('alibi_detect.od.aegmm', names=['OutlierAEGMM'])
9+
OutlierAE = import_optional('alibi_detect.od.ae', names=['OutlierAE'])
10+
OutlierVAE = import_optional('alibi_detect.od.vae', names=['OutlierVAE'])
11+
OutlierVAEGMM = import_optional('alibi_detect.od.vaegmm', names=['OutlierVAEGMM'])
12+
OutlierSeq2Seq = import_optional('alibi_detect.od.seq2seq', names=['OutlierSeq2Seq'])
13+
LLR = import_optional('alibi_detect.od.llr', names=['LLR'])
1114

1215
__all__ = [
1316
"OutlierAEGMM",

0 commit comments

Comments
 (0)