-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Add support and unit test for PyOD models #34709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b838cee
d5f4ea6
d314f00
e84a284
3bd93c7
8f1f965
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import pickle | ||
from typing import Any | ||
from typing import Dict | ||
from typing import Iterable | ||
from typing import Optional | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
|
||
import apache_beam as beam | ||
from apache_beam.io.filesystems import FileSystems | ||
from apache_beam.ml.anomaly.detectors.offline import OfflineDetector | ||
from apache_beam.ml.anomaly.specifiable import specifiable | ||
from apache_beam.ml.anomaly.thresholds import FixedThreshold | ||
from apache_beam.ml.inference.base import KeyedModelHandler | ||
from apache_beam.ml.inference.base import ModelHandler | ||
from apache_beam.ml.inference.base import PredictionResult | ||
from apache_beam.ml.inference.base import _PostProcessingModelHandler | ||
from apache_beam.ml.inference.utils import _convert_to_result | ||
from pyod.models.base import BaseDetector as PyODBaseDetector | ||
|
||
# Turn the used ModelHandler into specifiable, but without lazy init. | ||
KeyedModelHandler = specifiable( # type: ignore[misc] | ||
KeyedModelHandler, | ||
on_demand_init=False, | ||
just_in_time_init=False) | ||
_PostProcessingModelHandler = specifiable( # type: ignore[misc] | ||
_PostProcessingModelHandler, | ||
on_demand_init=False, | ||
just_in_time_init=False) | ||
|
||
|
||
@specifiable | ||
class PyODModelHandler(ModelHandler[beam.Row, | ||
PredictionResult, | ||
PyODBaseDetector]): | ||
"""Implementation of the ModelHandler interface for PyOD [#]_ Models. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does [#]_ mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh does it create a footnote? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. That's for footnote. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
The ModelHandler processes input data as `beam.Row` objects. | ||
|
||
**NOTE:** This API and its implementation are currently under active | ||
development and may not be backward compatible. | ||
|
||
Args: | ||
model_uri: The URI specifying the location of the pickled PyOD model. | ||
|
||
.. [#] https://github.com/yzhao062/pyod | ||
""" | ||
def __init__(self, model_uri: str): | ||
self._model_uri = model_uri | ||
|
||
def load_model(self) -> PyODBaseDetector: | ||
file = FileSystems.open(self._model_uri, 'rb') | ||
return pickle.load(file) | ||
|
||
def run_inference( | ||
self, | ||
batch: Sequence[beam.Row], | ||
model: PyODBaseDetector, | ||
inference_args: Optional[Dict[str, Any]] = None | ||
) -> Iterable[PredictionResult]: | ||
np_batch = [] | ||
for row in batch: | ||
np_batch.append(np.fromiter(row, dtype=np.float64)) | ||
|
||
# stack a batch of samples into a 2-D array for better performance | ||
vectorized_batch = np.stack(np_batch, axis=0) | ||
predictions = model.decision_function(vectorized_batch) | ||
|
||
return _convert_to_result(batch, predictions, model_id=self._model_uri) | ||
|
||
|
||
class PyODFactory(): | ||
@staticmethod | ||
def create_detector(model_uri: str, **kwargs) -> OfflineDetector: | ||
"""A utility function to create OfflineDetector for a PyOD model. | ||
|
||
**NOTE:** This API and its implementation are currently under active | ||
development and may not be backward compatible. | ||
|
||
Args: | ||
model_uri: The URI specifying the location of the pickled PyOD model. | ||
**kwargs: Additional keyword arguments. | ||
""" | ||
model_handler = KeyedModelHandler( | ||
PyODModelHandler(model_uri=model_uri)).with_postprocess_fn( | ||
OfflineDetector.score_prediction_adapter) | ||
m = model_handler.load_model() | ||
assert (isinstance(m, PyODBaseDetector)) | ||
threshold = float(m.threshold_) | ||
detector = OfflineDetector( | ||
model_handler, threshold_criterion=FixedThreshold(threshold), **kwargs) # type: ignore[arg-type] | ||
return detector |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,160 @@ | ||||
# | ||||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||||
# contributor license agreements. See the NOTICE file distributed with | ||||
# this work for additional information regarding copyright ownership. | ||||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||||
# (the "License"); you may not use this file except in compliance with | ||||
# the License. You may obtain a copy of the License at | ||||
# | ||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||
# | ||||
# Unless required by applicable law or agreed to in writing, software | ||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
# | ||||
|
||||
import logging | ||||
import os.path | ||||
import pickle | ||||
import shutil | ||||
import tempfile | ||||
import unittest | ||||
|
||||
import numpy as np | ||||
from parameterized import parameterized | ||||
|
||||
import apache_beam as beam | ||||
from apache_beam.ml.anomaly.base import AnomalyPrediction | ||||
from apache_beam.ml.anomaly.base import AnomalyResult | ||||
from apache_beam.ml.anomaly.transforms import AnomalyDetection | ||||
from apache_beam.ml.anomaly.transforms_test import _keyed_result_is_equal_to | ||||
from apache_beam.options.pipeline_options import PipelineOptions | ||||
from apache_beam.testing.util import assert_that | ||||
from apache_beam.testing.util import equal_to | ||||
|
||||
# Protect against environments where onnx and pytorch library is not available. | ||||
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports | ||||
try: | ||||
from apache_beam.ml.anomaly.detectors.pyod_adapter import PyODFactory | ||||
from pyod.models.iforest import IForest | ||||
except ImportError: | ||||
raise unittest.SkipTest('PyOD dependencies are not installed') | ||||
|
||||
|
||||
class PyODIForestTest(unittest.TestCase): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these tests are skipped because you need to add the dependencies to Line 485 in 238233d
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! Let me add pyod there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||||
def setUp(self) -> None: | ||||
self.tmp_dir = tempfile.mkdtemp() | ||||
|
||||
seed = 1234 | ||||
model = IForest(random_state=seed) | ||||
model.fit(self.get_train_data()) | ||||
self.pickled_model_uri = os.path.join(self.tmp_dir, 'iforest_pickled') | ||||
|
||||
with open(self.pickled_model_uri, 'wb') as fp: | ||||
pickle.dump(model, fp) | ||||
|
||||
def tearDown(self) -> None: | ||||
shutil.rmtree(self.tmp_dir) | ||||
|
||||
def get_train_data(self): | ||||
return [ | ||||
np.array([1, 5], dtype="float32"), | ||||
np.array([2, 6], dtype="float32"), | ||||
np.array([3, 4], dtype="float32"), | ||||
np.array([2, 6], dtype="float32"), | ||||
np.array([10, 10], dtype="float32"), # need an outlier in training data | ||||
np.array([3, 4], dtype="float32"), | ||||
np.array([2, 6], dtype="float32"), | ||||
np.array([2, 6], dtype="float32"), | ||||
np.array([2, 5], dtype="float32"), | ||||
] | ||||
|
||||
def get_test_data(self): | ||||
return [ | ||||
np.array([2, 6], dtype="float32"), | ||||
np.array([100, 100], dtype="float32"), | ||||
] | ||||
|
||||
def get_test_data_with_target(self): | ||||
return [ | ||||
np.array([2, 6, 0], dtype="float32"), | ||||
np.array([100, 100, 1], dtype="float32"), | ||||
] | ||||
|
||||
@parameterized.expand([True, False]) | ||||
def test_scoring_with_matched_features(self, with_target): | ||||
if with_target: | ||||
rows = [beam.Row(a=2, b=6, target=0), beam.Row(a=100, b=100, target=1)] | ||||
field_names = ["a", "b", "target"] | ||||
# The selected features should match the features used for training | ||||
detector = PyODFactory.create_detector( | ||||
self.pickled_model_uri, features=["a", "b"]) | ||||
input_data = self.get_test_data_with_target() | ||||
else: | ||||
rows = [beam.Row(a=2, b=6), beam.Row(a=100, b=100)] | ||||
field_names = ["a", "b"] | ||||
detector = PyODFactory.create_detector(self.pickled_model_uri) | ||||
input_data = self.get_test_data() | ||||
|
||||
expected_out = [( | ||||
0, | ||||
AnomalyResult( | ||||
example=rows[0], | ||||
predictions=[ | ||||
AnomalyPrediction( | ||||
model_id='OfflineDetector', | ||||
score=-0.20316164744828075, | ||||
label=0, | ||||
threshold=8.326672684688674e-17, | ||||
info='', | ||||
source_predictions=None) | ||||
])), | ||||
( | ||||
0, | ||||
AnomalyResult( | ||||
example=rows[1], | ||||
predictions=[ | ||||
AnomalyPrediction( | ||||
model_id='OfflineDetector', | ||||
score=0.179516865091218, | ||||
label=1, | ||||
threshold=8.326672684688674e-17, | ||||
info='', | ||||
source_predictions=None) | ||||
]))] | ||||
|
||||
options = PipelineOptions([]) | ||||
with beam.Pipeline(options=options) as p: | ||||
out = ( | ||||
p | beam.Create(input_data) | ||||
| beam.Map(lambda x: beam.Row(**dict(zip(field_names, map(int, x))))) | ||||
| beam.WithKeys(0) | ||||
| AnomalyDetection(detector=detector)) | ||||
assert_that(out, equal_to(expected_out, _keyed_result_is_equal_to)) | ||||
|
||||
def test_scoring_with_unmatched_features(self): | ||||
# The model is trained with two features: a, b, but the input features of | ||||
# scoring has one more feature (target). | ||||
# In this case, we should either get rid of the extra feature(s) from | ||||
# the scoring input or set `features` when creating the offline detector | ||||
# (see the `test_scoring_with_matched_features`) | ||||
detector = PyODFactory.create_detector(self.pickled_model_uri) | ||||
options = PipelineOptions([]) | ||||
p = beam.Pipeline(options=options) | ||||
_ = ( | ||||
p | beam.Create(self.get_test_data_with_target()) | ||||
| beam.Map( | ||||
lambda x: beam.Row(**dict(zip(["a", "b", "target"], map(int, x))))) | ||||
| beam.WithKeys(0) | ||||
| AnomalyDetection(detector=detector)) | ||||
|
||||
# This should raise a ValueError with message | ||||
# "X has 3 features, but IsolationForest is expecting 2 features as input." | ||||
self.assertRaises(ValueError, p.run) | ||||
|
||||
|
||||
if __name__ == '__main__': | ||||
logging.getLogger().setLevel(logging.WARNING) | ||||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove underscores from these (e.g. _PostProcessingModelHandler -> PostProcessingModelHandler) since they are being imported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That one cannot be changed, because in the following code the class will have to be turned into specifiable so that we can create a spec (see https://github.com/apache/beam/blob/release-2.64/sdks/python/apache_beam/ml/anomaly/transforms_test.py#L336) for the whole model handler object without pickling it.