21
21
from typing import Callable
22
22
from typing import Dict
23
23
from typing import Iterable
24
- from typing import List
25
24
from typing import Optional
26
25
from typing import Tuple
27
26
from typing import TypeVar
@@ -427,8 +426,8 @@ class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
427
426
def __init__ (self , offline_detector : OfflineDetector ):
428
427
self ._offline_detector = offline_detector
429
428
430
- def unnest_and_convert (
431
- self , nested : Tuple [Tuple [Any , Any ], dict [ str , List ] ]) -> KeyedOutputT :
429
+ def restore_and_convert (
430
+ self , elem : Tuple [Tuple [Any , Any , beam . Row ], float ]) -> KeyedOutputT :
432
431
"""Unnests and converts the model output to AnomalyResult.
433
432
434
433
Args:
@@ -438,15 +437,14 @@ def unnest_and_convert(
438
437
Returns:
439
438
A tuple containing the original key and AnomalyResult.
440
439
"""
441
- key , value_dict = nested
442
- score = value_dict ['output' ][0 ]
440
+ (orig_key , temp_key , row ), score = elem
443
441
result = AnomalyResult (
444
- example = value_dict [ 'input' ][ 0 ] ,
442
+ example = row ,
445
443
predictions = [
446
444
AnomalyPrediction (
447
445
model_id = self ._offline_detector ._model_id , score = score )
448
446
])
449
- return key [ 0 ] , (key [ 1 ] , result )
447
+ return orig_key , (temp_key , result )
450
448
451
449
def expand (
452
450
self ,
@@ -458,23 +456,18 @@ def expand(
458
456
self ._offline_detector ._keyed_model_handler ,
459
457
** self ._offline_detector ._run_inference_args )
460
458
461
- # ((orig_key, temp_key), beam.Row)
459
+ # ((orig_key, temp_key, beam.Row ), beam.Row)
462
460
rekeyed_model_input = input | "Rekey" >> beam .Map (
463
- lambda x : ((x [0 ], x [1 ][0 ]), x [1 ][1 ]))
461
+ lambda x : ((x [0 ], x [1 ][0 ], x [ 1 ][ 1 ] ), x [1 ][1 ]))
464
462
465
- # ((orig_key, temp_key), float)
463
+ # ((orig_key, temp_key, beam.Row ), float)
466
464
rekeyed_model_output = (
467
465
rekeyed_model_input
468
466
| f"Call RunInference ({ model_uuid } )" >> run_inference )
469
467
470
- # ((orig_key, temp_key), {'input':[row], 'output:[float]})
471
- rekeyed_cogbk = {
472
- 'input' : rekeyed_model_input , 'output' : rekeyed_model_output
473
- } | beam .CoGroupByKey ()
474
-
475
468
ret = (
476
- rekeyed_cogbk |
477
- "Unnest and convert model output" >> beam .Map (self .unnest_and_convert ))
469
+ rekeyed_model_output | "Restore keys and convert model output" >>
470
+ beam .Map (self .restore_and_convert ))
478
471
479
472
if self ._offline_detector ._threshold_criterion :
480
473
ret = (
0 commit comments