20
20
import hashlib
21
21
import re
22
22
from abc import ABC , abstractmethod
23
- from typing import Any , Callable , Dict , List , Union
23
+ from typing import Any , Callable , Dict , List , Type , Union
24
24
25
25
import yaml
26
26
from yaml import ScalarNode
@@ -822,13 +822,22 @@ class ModifierYAML(object):
822
822
823
823
:param framework: the string representing the ML framework the modifier should
824
824
be stored under
825
+ :param swap_class_by_state_fn: optional function to provide a different class
826
+ to construct on yaml load based on the state given (ie provide a
827
+ legacy class to load if certain parameters are passed). Expected format
828
+ is to take a dict of kwargs, expects a class to be returned
825
829
"""
826
830
827
- def __init__ (self , framework : str ):
831
+ def __init__ (
832
+ self ,
833
+ framework : str ,
834
+ swap_class_by_state_fn : Callable [[Dict [str , Any ]], Type [BaseModifier ]] = None ,
835
+ ):
828
836
if not framework :
829
837
raise ValueError ("a framework is required" )
830
838
831
839
self ._framework = framework
840
+ self ._swap_class_by_state_fn = swap_class_by_state_fn
832
841
833
842
def __call__ (self , clazz ):
834
843
"""
@@ -838,13 +847,18 @@ def __call__(self, clazz):
838
847
yaml_key = "{}" .format (BaseModifier .yaml_key (clazz , self ._framework ))
839
848
840
849
def constructor (loader , node ):
841
- instance = clazz .__new__ (clazz )
842
- yield instance
843
850
state = (
844
851
loader .construct_mapping (node , deep = True )
845
852
if not isinstance (node , ScalarNode )
846
853
else {}
847
854
)
855
+ target_class = (
856
+ self ._swap_class_by_state_fn (state )
857
+ if self ._swap_class_by_state_fn is not None
858
+ else clazz
859
+ )
860
+ instance = target_class .__new__ (target_class )
861
+ yield instance
848
862
# ignore the log_types arg in recipes to maintain backwards compatability
849
863
# while recipes are updated
850
864
if "log_types" in state :
0 commit comments