Skip to content

Commit 58e1616

Browse files
authored
Get rid of unnecessary cogbk when running offline detector. (#34656)
* Get rid of unnecessary cogbk when running offline detector. * Fix lints.
1 parent a847f06 commit 58e1616

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

sdks/python/apache_beam/ml/anomaly/transforms.py

+10-17
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Callable
2222
from typing import Dict
2323
from typing import Iterable
24-
from typing import List
2524
from typing import Optional
2625
from typing import Tuple
2726
from typing import TypeVar
@@ -427,8 +426,8 @@ class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
427426
def __init__(self, offline_detector: OfflineDetector):
428427
self._offline_detector = offline_detector
429428

430-
def unnest_and_convert(
431-
self, nested: Tuple[Tuple[Any, Any], dict[str, List]]) -> KeyedOutputT:
429+
def restore_and_convert(
430+
self, elem: Tuple[Tuple[Any, Any, beam.Row], float]) -> KeyedOutputT:
432431
"""Unnests and converts the model output to AnomalyResult.
433432
434433
Args:
@@ -438,15 +437,14 @@ def unnest_and_convert(
438437
Returns:
439438
A tuple containing the original key and AnomalyResult.
440439
"""
441-
key, value_dict = nested
442-
score = value_dict['output'][0]
440+
(orig_key, temp_key, row), score = elem
443441
result = AnomalyResult(
444-
example=value_dict['input'][0],
442+
example=row,
445443
predictions=[
446444
AnomalyPrediction(
447445
model_id=self._offline_detector._model_id, score=score)
448446
])
449-
return key[0], (key[1], result)
447+
return orig_key, (temp_key, result)
450448

451449
def expand(
452450
self,
@@ -458,23 +456,18 @@ def expand(
458456
self._offline_detector._keyed_model_handler,
459457
**self._offline_detector._run_inference_args)
460458

461-
# ((orig_key, temp_key), beam.Row)
459+
# ((orig_key, temp_key, beam.Row), beam.Row)
462460
rekeyed_model_input = input | "Rekey" >> beam.Map(
463-
lambda x: ((x[0], x[1][0]), x[1][1]))
461+
lambda x: ((x[0], x[1][0], x[1][1]), x[1][1]))
464462

465-
# ((orig_key, temp_key), float)
463+
# ((orig_key, temp_key, beam.Row), float)
466464
rekeyed_model_output = (
467465
rekeyed_model_input
468466
| f"Call RunInference ({model_uuid})" >> run_inference)
469467

470-
# ((orig_key, temp_key), {'input':[row], 'output:[float]})
471-
rekeyed_cogbk = {
472-
'input': rekeyed_model_input, 'output': rekeyed_model_output
473-
} | beam.CoGroupByKey()
474-
475468
ret = (
476-
rekeyed_cogbk |
477-
"Unnest and convert model output" >> beam.Map(self.unnest_and_convert))
469+
rekeyed_model_output | "Restore keys and convert model output" >>
470+
beam.Map(self.restore_and_convert))
478471

479472
if self._offline_detector._threshold_criterion:
480473
ret = (

0 commit comments

Comments
 (0)