Skip to content

Commit a1b6a2a

Browse files
authored
[AnomalyDetection] Add main and auxiliary transforms. (#34234)
* Add main and auxiliary transforms. * Minor fix per reviewer's feedback and fix lints.
1 parent cad3dc0 commit a1b6a2a

File tree

2 files changed

+635
-0
lines changed

2 files changed

+635
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import uuid
19+
from typing import Callable
20+
from typing import Iterable
21+
from typing import Optional
22+
from typing import Tuple
23+
from typing import TypeVar
24+
25+
import apache_beam as beam
26+
from apache_beam.coders import DillCoder
27+
from apache_beam.ml.anomaly import aggregations
28+
from apache_beam.ml.anomaly.base import AggregationFn
29+
from apache_beam.ml.anomaly.base import AnomalyDetector
30+
from apache_beam.ml.anomaly.base import AnomalyPrediction
31+
from apache_beam.ml.anomaly.base import AnomalyResult
32+
from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector
33+
from apache_beam.ml.anomaly.specifiable import Spec
34+
from apache_beam.ml.anomaly.specifiable import Specifiable
35+
from apache_beam.ml.anomaly.thresholds import StatefulThresholdDoFn
36+
from apache_beam.ml.anomaly.thresholds import StatelessThresholdDoFn
37+
from apache_beam.ml.anomaly.thresholds import ThresholdFn
38+
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
39+
40+
KeyT = TypeVar('KeyT')
41+
TempKeyT = TypeVar('TempKeyT', bound=int)
42+
InputT = Tuple[KeyT, beam.Row]
43+
KeyedInputT = Tuple[KeyT, Tuple[TempKeyT, beam.Row]]
44+
KeyedOutputT = Tuple[KeyT, Tuple[TempKeyT, AnomalyResult]]
45+
OutputT = Tuple[KeyT, AnomalyResult]
46+
47+
48+
class _ScoreAndLearnDoFn(beam.DoFn):
49+
"""Scores and learns from incoming data using an anomaly detection model.
50+
51+
This DoFn applies an anomaly detection model to score incoming data and
52+
then updates the model with the same data. It maintains the model state
53+
using Beam's state management.
54+
"""
55+
MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder())
56+
57+
def __init__(self, detector_spec: Spec):
58+
self._detector_spec = detector_spec
59+
self._detector_spec.config["_run_init"] = True
60+
61+
def score_and_learn(self, data):
62+
"""Scores and learns from a single data point.
63+
64+
Args:
65+
data: A `beam.Row` representing the input data point.
66+
67+
Returns:
68+
float: The anomaly score predicted by the model.
69+
"""
70+
assert self._underlying
71+
if self._underlying._features is not None:
72+
x = beam.Row(**{f: getattr(data, f) for f in self._underlying._features})
73+
else:
74+
x = beam.Row(**data._asdict())
75+
76+
# score the incoming data using the existing model
77+
y_pred = self._underlying.score_one(x)
78+
79+
# then update the model with the same data
80+
self._underlying.learn_one(x)
81+
82+
return y_pred
83+
84+
def process(
85+
self,
86+
element: KeyedInputT,
87+
model_state=beam.DoFn.StateParam(MODEL_STATE_INDEX),
88+
**kwargs) -> Iterable[KeyedOutputT]:
89+
90+
k1, (k2, data) = element
91+
self._underlying: AnomalyDetector = model_state.read()
92+
if self._underlying is None:
93+
self._underlying = Specifiable.from_spec(self._detector_spec)
94+
95+
yield k1, (k2,
96+
AnomalyResult(
97+
example=data,
98+
predictions=[AnomalyPrediction(
99+
model_id=self._underlying._model_id,
100+
score=self.score_and_learn(data))]))
101+
102+
model_state.write(self._underlying)
103+
104+
105+
class RunScoreAndLearn(beam.PTransform[beam.PCollection[KeyedInputT],
106+
beam.PCollection[KeyedOutputT]]):
107+
"""Applies the _ScoreAndLearnDoFn to a PCollection of data.
108+
109+
This PTransform scores and learns from data points using an anomaly
110+
detection model.
111+
112+
Args:
113+
detector: The anomaly detection model to use.
114+
"""
115+
def __init__(self, detector: AnomalyDetector):
116+
self._detector = detector
117+
118+
def expand(
119+
self,
120+
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
121+
return input | beam.ParDo(_ScoreAndLearnDoFn(self._detector.to_spec()))
122+
123+
124+
class RunThresholdCriterion(beam.PTransform[beam.PCollection[KeyedOutputT],
125+
beam.PCollection[KeyedOutputT]]):
126+
"""Applies a threshold criterion to anomaly detection results.
127+
128+
This PTransform applies a `ThresholdFn` to the anomaly scores in
129+
`AnomalyResult` objects, updating the prediction labels. It handles both
130+
stateful and stateless `ThresholdFn` implementations.
131+
132+
Args:
133+
threshold_criterion: The `ThresholdFn` to apply.
134+
"""
135+
def __init__(self, threshold_criterion: ThresholdFn):
136+
self._threshold_fn = threshold_criterion
137+
138+
def expand(
139+
self,
140+
input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]:
141+
142+
if self._threshold_fn.is_stateful:
143+
return (
144+
input
145+
| beam.ParDo(StatefulThresholdDoFn(self._threshold_fn.to_spec())))
146+
else:
147+
return (
148+
input
149+
| beam.ParDo(StatelessThresholdDoFn(self._threshold_fn.to_spec())))
150+
151+
152+
class RunAggregationStrategy(beam.PTransform[beam.PCollection[KeyedOutputT],
153+
beam.PCollection[KeyedOutputT]]):
154+
"""Applies an aggregation strategy to grouped anomaly detection results.
155+
156+
This PTransform aggregates anomaly predictions from multiple models or
157+
data points using an `AggregationFn`. It handles both custom and simple
158+
aggregation strategies.
159+
160+
Args:
161+
aggregation_strategy: The `AggregationFn` to use.
162+
agg_model_id: The model ID for aggregation.
163+
"""
164+
def __init__(
165+
self, aggregation_strategy: Optional[AggregationFn], agg_model_id: str):
166+
self._aggregation_fn = aggregation_strategy
167+
self._agg_model_id = agg_model_id
168+
169+
def expand(
170+
self,
171+
input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]:
172+
post_gbk = (
173+
input | beam.MapTuple(lambda k, v: ((k, v[0]), v[1]))
174+
| beam.GroupByKey())
175+
176+
if self._aggregation_fn is None:
177+
# simply put predictions into an iterable (list)
178+
ret = (
179+
post_gbk | beam.MapTuple(
180+
lambda k,
181+
v: (
182+
k[0],
183+
(
184+
k[1],
185+
AnomalyResult(
186+
example=v[0].example,
187+
predictions=[
188+
prediction for result in v
189+
for prediction in result.predictions
190+
])))))
191+
return ret
192+
193+
# create a new aggregation_fn from spec and make sure it is initialized
194+
aggregation_fn_spec = self._aggregation_fn.to_spec()
195+
aggregation_fn_spec.config["_run_init"] = True
196+
aggregation_fn = Specifiable.from_spec(aggregation_fn_spec)
197+
198+
# if no _agg_model_id is set in the aggregation function, use
199+
# model id from the ensemble instance
200+
if (isinstance(aggregation_fn, aggregations._AggModelIdMixin)):
201+
aggregation_fn._set_agg_model_id_if_unset(self._agg_model_id)
202+
203+
# post_gbk is a PCollection of ((original_key, temp_key), AnomalyResult).
204+
# We use (original_key, temp_key) as the key for GroupByKey() so that
205+
# scores from multiple detectors per data point are grouped.
206+
ret = (
207+
post_gbk | beam.MapTuple(
208+
lambda k,
209+
v,
210+
agg=aggregation_fn: (
211+
k[0],
212+
(
213+
k[1],
214+
AnomalyResult(
215+
example=v[0].example,
216+
predictions=[
217+
agg.apply([
218+
prediction for result in v
219+
for prediction in result.predictions
220+
])
221+
])))))
222+
return ret
223+
224+
225+
class RunOneDetector(beam.PTransform[beam.PCollection[KeyedInputT],
226+
beam.PCollection[KeyedOutputT]]):
227+
"""Runs a single anomaly detector on a PCollection of data.
228+
229+
This PTransform applies a single `AnomalyDetector` to the input data,
230+
including scoring, learning, and thresholding.
231+
232+
Args:
233+
detector: The `AnomalyDetector` to run.
234+
"""
235+
def __init__(self, detector):
236+
self._detector = detector
237+
238+
def expand(
239+
self,
240+
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
241+
model_id = getattr(
242+
self._detector,
243+
"_model_id",
244+
getattr(self._detector, "_key", "unknown_model"))
245+
model_uuid = f"{model_id}:{uuid.uuid4().hex[:6]}"
246+
247+
ret = (
248+
input
249+
| beam.Reshuffle()
250+
| f"Score and Learn ({model_uuid})" >> RunScoreAndLearn(self._detector))
251+
252+
if self._detector._threshold_criterion:
253+
ret = (
254+
ret | f"Run Threshold Criterion ({model_uuid})" >>
255+
RunThresholdCriterion(self._detector._threshold_criterion))
256+
257+
return ret
258+
259+
260+
class RunEnsembleDetector(beam.PTransform[beam.PCollection[KeyedInputT],
261+
beam.PCollection[KeyedOutputT]]):
262+
"""Runs an ensemble of anomaly detectors on a PCollection of data.
263+
264+
This PTransform applies an `EnsembleAnomalyDetector` to the input data,
265+
running each sub-detector and aggregating the results.
266+
267+
Args:
268+
ensemble_detector: The `EnsembleAnomalyDetector` to run.
269+
"""
270+
def __init__(self, ensemble_detector: EnsembleAnomalyDetector):
271+
self._ensemble_detector = ensemble_detector
272+
273+
def expand(
274+
self,
275+
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
276+
model_uuid = f"{self._ensemble_detector._model_id}:{uuid.uuid4().hex[:6]}"
277+
278+
assert self._ensemble_detector._sub_detectors is not None
279+
if not self._ensemble_detector._sub_detectors:
280+
raise ValueError(f"No detectors found at {model_uuid}")
281+
282+
results = []
283+
for idx, detector in enumerate(self._ensemble_detector._sub_detectors):
284+
if isinstance(detector, EnsembleAnomalyDetector):
285+
results.append(
286+
input | f"Run Ensemble Detector at index {idx} ({model_uuid})" >>
287+
RunEnsembleDetector(detector))
288+
else:
289+
results.append(
290+
input
291+
| f"Run One Detector at index {idx} ({model_uuid})" >>
292+
RunOneDetector(detector))
293+
294+
if self._ensemble_detector._aggregation_strategy is None:
295+
aggregation_type = "Simple"
296+
else:
297+
aggregation_type = "Custom"
298+
299+
ret = (
300+
results | beam.Flatten()
301+
| f"Run {aggregation_type} Aggregation Strategy ({model_uuid})" >>
302+
RunAggregationStrategy(
303+
self._ensemble_detector._aggregation_strategy,
304+
self._ensemble_detector._model_id))
305+
306+
if self._ensemble_detector._threshold_criterion:
307+
ret = (
308+
ret | f"Run Threshold Criterion ({model_uuid})" >>
309+
RunThresholdCriterion(self._ensemble_detector._threshold_criterion))
310+
311+
return ret
312+
313+
314+
class AnomalyDetection(beam.PTransform[beam.PCollection[InputT],
315+
beam.PCollection[OutputT]]):
316+
"""Performs anomaly detection on a PCollection of data.
317+
318+
This PTransform applies an `AnomalyDetector` or `EnsembleAnomalyDetector` to
319+
the input data and returns a PCollection of `AnomalyResult` objects.
320+
321+
Examples::
322+
323+
# Run a single anomaly detector
324+
p | AnomalyDetection(ZScore(features=["x1"]))
325+
326+
# Run an ensemble anomaly detector
327+
sub_detectors = [ZScore(features=["x1"]), IQR(features=["x2"])]
328+
p | AnomalyDetection(
329+
EnsembleAnomalyDetector(sub_detectors, aggregation_strategy=AnyVote()))
330+
331+
Args:
332+
detector: The `AnomalyDetector` or `EnsembleAnomalyDetector` to use.
333+
"""
334+
def __init__(
335+
self,
336+
detector: AnomalyDetector,
337+
) -> None:
338+
self._root_detector = detector
339+
340+
def expand(
341+
self,
342+
input: beam.PCollection[InputT],
343+
) -> beam.PCollection[OutputT]:
344+
345+
# Add a temporary unique key per data point to facilitate grouping the
346+
# outputs from multiple anomaly detectors for the same data point.
347+
#
348+
# Unique key generation options:
349+
# (1) Timestamp-based methods: https://docs.python.org/3/library/time.html
350+
# (2) UUID module: https://docs.python.org/3/library/uuid.html
351+
#
352+
# Timestamp precision on Windows can lead to key collisions (see PEP 564:
353+
# https://peps.python.org/pep-0564/#windows). Only time.perf_counter_ns()
354+
# provides sufficient precision for our needs.
355+
#
356+
# Performance note:
357+
# $ python -m timeit -n 100000 "import uuid; uuid.uuid1()"
358+
# 100000 loops, best of 5: 806 nsec per loop
359+
# $ python -m timeit -n 100000 "import uuid; uuid.uuid4()"
360+
# 100000 loops, best of 5: 1.53 usec per loop
361+
# $ python -m timeit -n 100000 "import time; time.perf_counter_ns()"
362+
# 100000 loops, best of 5: 82.3 nsec per loop
363+
#
364+
# We select uuid.uuid1() for its inclusion of node information, making it
365+
# more suitable for parallel execution environments.
366+
add_temp_key_fn: Callable[[InputT], KeyedInputT] \
367+
= lambda e: (e[0], (str(uuid.uuid1()), e[1]))
368+
keyed_input = (input | "Add temp key" >> beam.Map(add_temp_key_fn))
369+
370+
if isinstance(self._root_detector, EnsembleAnomalyDetector):
371+
keyed_output = (keyed_input | RunEnsembleDetector(self._root_detector))
372+
else:
373+
keyed_output = (keyed_input | RunOneDetector(self._root_detector))
374+
375+
# remove the temporary key and simplify the output.
376+
remove_temp_key_fn: Callable[[KeyedOutputT], OutputT] \
377+
= lambda e: (e[0], e[1][1])
378+
ret = keyed_output | "Remove temp key" >> beam.Map(remove_temp_key_fn)
379+
380+
return ret

0 commit comments

Comments
 (0)