Skip to content

Commit 1898ad2

Browse files
author
Ashley Scillitoe
authored
Support for serializing detectors with scikit-learn backends and/or models (#642)
1 parent b915d63 commit 1898ad2

File tree

14 files changed

+660
-403
lines changed

14 files changed

+660
-403
lines changed
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from alibi_detect.saving._sklearn.saving import save_model_config as save_model_config_sk
2+
from alibi_detect.saving._sklearn.loading import load_model as load_model_sk
3+
4+
__all__ = [
5+
"save_model_config_sk",
6+
"load_model_sk"
7+
]
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Union
4+
5+
import joblib
6+
from sklearn.base import BaseEstimator
7+
8+
9+
def load_model(filepath: Union[str, os.PathLike],
10+
) -> BaseEstimator:
11+
"""
12+
Load scikit-learn (or xgboost) model. Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`.
13+
This includes xgboost models following the scikit-learn API
14+
(see https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn).
15+
16+
Parameters
17+
----------
18+
filepath
19+
Saved model directory.
20+
21+
Returns
22+
-------
23+
Loaded model.
24+
"""
25+
model_dir = Path(filepath)
26+
return joblib.load(model_dir.joinpath('model.joblib'))
+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import os
3+
from pathlib import Path
4+
from typing import Union
5+
6+
import joblib
7+
from sklearn.base import BaseEstimator
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def save_model_config(model: BaseEstimator,
13+
base_path: Path,
14+
local_path: Path = Path('.')) -> dict:
15+
"""
16+
Save a scikit-learn (or xgboost) model to a config dictionary.
17+
Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`. This includes xgboost models
18+
following the scikit-learn API
19+
(see https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn).
20+
21+
Parameters
22+
----------
23+
model
24+
The model to save.
25+
base_path
26+
Base filepath to save to (the location of the `config.toml` file).
27+
local_path
28+
A local (relative) filepath to append to base_path.
29+
30+
Returns
31+
-------
32+
The model config dict.
33+
"""
34+
filepath = base_path.joinpath(local_path)
35+
save_model(model, filepath=filepath, save_dir='model')
36+
cfg_model = {
37+
'flavour': 'sklearn',
38+
'src': local_path.joinpath('model')
39+
}
40+
return cfg_model
41+
42+
43+
def save_model(model: BaseEstimator,
44+
filepath: Union[str, os.PathLike],
45+
save_dir: Union[str, os.PathLike] = 'model') -> None:
46+
"""
47+
Save scikit-learn (and xgboost) models. Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`.
48+
This includes xgboost models following the scikit-learn API
49+
(see https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn).
50+
51+
Parameters
52+
----------
53+
model
54+
The tf.keras.Model to save.
55+
filepath
56+
Save directory.
57+
save_dir
58+
Name of folder to save to within the filepath directory.
59+
"""
60+
# create folder to save model in
61+
model_path = Path(filepath).joinpath(save_dir)
62+
if not model_path.is_dir():
63+
logger.warning('Directory {} does not exist and is now created.'.format(model_path))
64+
model_path.mkdir(parents=True, exist_ok=True)
65+
66+
# save model
67+
model_path = model_path.joinpath('model.joblib')
68+
joblib.dump(model, model_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from pytest_cases import param_fixture, parametrize, parametrize_with_cases
2+
3+
from alibi_detect.saving.tests.datasets import ContinuousData
4+
from alibi_detect.saving.tests.models import classifier_model, xgb_classifier_model
5+
6+
from alibi_detect.saving.loading import _load_model_config
7+
from alibi_detect.saving.saving import _path2str, _save_model_config
8+
from alibi_detect.saving.schemas import ModelConfig
9+
10+
backend = param_fixture("backend", ['sklearn'])
11+
12+
13+
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_')
14+
@parametrize('model', [classifier_model, xgb_classifier_model])
15+
def test_save_model_sk(data, model, tmp_path):
16+
"""
17+
Unit test for _save_model_config and _load_model_config with scikit-learn and xgboost model.
18+
"""
19+
# Save model
20+
filepath = tmp_path
21+
cfg_model, _ = _save_model_config(model, base_path=filepath)
22+
cfg_model = _path2str(cfg_model)
23+
cfg_model = ModelConfig(**cfg_model).dict()
24+
assert tmp_path.joinpath('model').is_dir()
25+
assert tmp_path.joinpath('model/model.joblib').is_file()
26+
27+
# Adjust config
28+
cfg_model['src'] = tmp_path.joinpath('model') # Need to manually set to absolute path here
29+
30+
# Load model
31+
model_load = _load_model_config(cfg_model)
32+
assert isinstance(model_load, type(model))

alibi_detect/saving/_tensorflow/loading.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def load_model(filepath: Union[str, os.PathLike],
6969
return model
7070

7171

72-
def prep_model_and_emb(model: Optional[Callable], emb: Optional[TransformerEmbedding]) -> Callable:
72+
def prep_model_and_emb(model: Callable, emb: Optional[TransformerEmbedding]) -> Callable:
7373
"""
7474
Function to perform final preprocessing of model (and/or embedding) before it is passed to preprocess_drift.
7575
@@ -78,25 +78,17 @@ def prep_model_and_emb(model: Optional[Callable], emb: Optional[TransformerEmbed
7878
model
7979
A compatible model.
8080
emb
81-
A text embedding model.
81+
An optional text embedding model.
8282
8383
Returns
8484
-------
8585
The final model ready to passed to preprocess_drift.
8686
"""
87-
# If a model exists, process it (and embedding)
88-
if model is not None:
89-
model = model.encoder if isinstance(model, UAE) else model # This is to avoid nesting UAE's already a UAE
90-
if emb is not None:
91-
model = _Encoder(emb, mlp=model)
92-
model = UAE(encoder_net=model)
93-
# If no model exists, store embedding as model
94-
else:
95-
model = emb
96-
if model is None:
97-
raise ValueError("A 'model' and/or `embedding` must be specified when "
98-
"preprocess_fn='preprocess_drift'")
99-
87+
# Process model (and embedding)
88+
model = model.encoder if isinstance(model, UAE) else model # This is to avoid nesting UAE's already a UAE
89+
if emb is not None:
90+
model = _Encoder(emb, mlp=model)
91+
model = UAE(encoder_net=model)
10092
return model
10193

10294

alibi_detect/saving/_tensorflow/saving.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828

2929
def save_model_config(model: Callable,
3030
base_path: Path,
31-
input_shape: tuple,
31+
input_shape: Optional[tuple],
3232
local_path: Path = Path('.')) -> Tuple[dict, Optional[dict]]:
3333
"""
34-
Save a model to a config dictionary. When a model has a text embedding model contained within it,
34+
Save a TensorFlow model to a config dictionary. When a model has a text embedding model contained within it,
3535
this is extracted and saved separately.
3636
3737
Parameters
@@ -53,6 +53,9 @@ def save_model_config(model: Callable,
5353
cfg_embed = None # type: Optional[Dict[str, Any]]
5454
if isinstance(model, UAE):
5555
if isinstance(model.encoder.layers[0], TransformerEmbedding): # if UAE contains embedding and encoder
56+
if input_shape is None:
57+
raise ValueError('Cannot save combined embedding and model when `input_shape` is None.')
58+
5659
# embedding
5760
embed = model.encoder.layers[0]
5861
cfg_embed = save_embedding_config(embed, base_path, local_path.joinpath('embedding'))
@@ -78,7 +81,10 @@ def save_model_config(model: Callable,
7881
if model is not None:
7982
filepath = base_path.joinpath(local_path)
8083
save_model(model, filepath=filepath, save_dir='model')
81-
cfg_model = {'src': local_path.joinpath('model')}
84+
cfg_model = {
85+
'flavour': 'tensorflow',
86+
'src': local_path.joinpath('model')
87+
}
8288
return cfg_model, cfg_embed
8389

8490

@@ -142,6 +148,7 @@ def save_embedding_config(embed: TransformerEmbedding,
142148
cfg_embed.update({'type': embed.emb_type})
143149
cfg_embed.update({'layers': embed.hs_emb.keywords['layers']})
144150
cfg_embed.update({'src': local_path})
151+
cfg_embed.update({'flavour': 'tensorflow'})
145152

146153
# Save embedding model
147154
logger.info('Saving embedding model to {}.'.format(filepath))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from pytest_cases import param_fixture, parametrize, parametrize_with_cases
2+
3+
from alibi_detect.saving.tests.datasets import ContinuousData
4+
from alibi_detect.saving.tests.models import encoder_model
5+
6+
from alibi_detect.cd.tensorflow import HiddenOutput as HiddenOutput_tf
7+
from alibi_detect.saving.loading import _load_model_config, _load_optimizer_config
8+
from alibi_detect.saving.saving import _path2str, _save_model_config
9+
from alibi_detect.saving.schemas import ModelConfig
10+
11+
backend = param_fixture("backend", ['tensorflow'])
12+
13+
14+
def test_load_optimizer_tf(backend):
15+
"Test the tensorflow _load_optimizer_config."
16+
class_name = 'Adam'
17+
learning_rate = 0.01
18+
epsilon = 1e-7
19+
amsgrad = False
20+
21+
# Load
22+
cfg_opt = {
23+
'class_name': class_name,
24+
'config': {
25+
'name': class_name,
26+
'learning_rate': learning_rate,
27+
'epsilon': epsilon,
28+
'amsgrad': amsgrad
29+
}
30+
}
31+
optimizer = _load_optimizer_config(cfg_opt, backend=backend)
32+
assert type(optimizer).__name__ == class_name
33+
assert optimizer.learning_rate == learning_rate
34+
assert optimizer.epsilon == epsilon
35+
assert optimizer.amsgrad == amsgrad
36+
37+
38+
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_')
39+
@parametrize('model', [encoder_model])
40+
@parametrize('layer', [None, -1])
41+
def test_save_model_tf(data, model, layer, tmp_path):
42+
"""
43+
Unit test for _save_model_config and _load_model_config with tensorflow model.
44+
"""
45+
# Save model
46+
filepath = tmp_path
47+
input_shape = (data[0].shape[1],)
48+
cfg_model, _ = _save_model_config(model, base_path=filepath, input_shape=input_shape)
49+
cfg_model = _path2str(cfg_model)
50+
cfg_model = ModelConfig(**cfg_model).dict()
51+
assert tmp_path.joinpath('model').is_dir()
52+
assert tmp_path.joinpath('model/model.h5').is_file()
53+
54+
# Adjust config
55+
cfg_model['src'] = tmp_path.joinpath('model') # Need to manually set to absolute path here
56+
if layer is not None:
57+
cfg_model['layer'] = layer
58+
59+
# Load model
60+
model_load = _load_model_config(cfg_model)
61+
if layer is None:
62+
assert isinstance(model_load, type(model))
63+
else:
64+
assert isinstance(model_load, HiddenOutput_tf)

0 commit comments

Comments
 (0)