@@ -454,16 +454,6 @@ def load_json(preset, config_file=CONFIG_FILE):
454
454
return config
455
455
456
456
457
- def load_serialized_object (config , ** kwargs ):
458
- # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
459
- # Ensure that `dtype` is properly configured.
460
- dtype = kwargs .pop ("dtype" , None )
461
- config = set_dtype_in_config (config , dtype )
462
-
463
- config ["config" ] = {** config ["config" ], ** kwargs }
464
- return keras .saving .deserialize_keras_object (config )
465
-
466
-
467
457
def check_config_class (config ):
468
458
"""Validate a preset is being loaded on the correct class."""
469
459
registered_name = config ["registered_name" ]
@@ -631,26 +621,26 @@ def check_backbone_class(self):
631
621
return check_config_class (self .config )
632
622
633
623
def load_backbone (self , cls , load_weights , ** kwargs ):
634
- backbone = load_serialized_object (self .config , ** kwargs )
624
+ backbone = self . _load_serialized_object (self .config , ** kwargs )
635
625
if load_weights :
636
626
jax_memory_cleanup (backbone )
637
627
backbone .load_weights (get_file (self .preset , MODEL_WEIGHTS_FILE ))
638
628
return backbone
639
629
640
630
def load_tokenizer (self , cls , config_file = TOKENIZER_CONFIG_FILE , ** kwargs ):
641
631
tokenizer_config = load_json (self .preset , config_file )
642
- tokenizer = load_serialized_object (tokenizer_config , ** kwargs )
632
+ tokenizer = self . _load_serialized_object (tokenizer_config , ** kwargs )
643
633
if hasattr (tokenizer , "load_preset_assets" ):
644
634
tokenizer .load_preset_assets (self .preset )
645
635
return tokenizer
646
636
647
637
def load_audio_converter (self , cls , ** kwargs ):
648
638
converter_config = load_json (self .preset , AUDIO_CONVERTER_CONFIG_FILE )
649
- return load_serialized_object (converter_config , ** kwargs )
639
+ return self . _load_serialized_object (converter_config , ** kwargs )
650
640
651
641
def load_image_converter (self , cls , ** kwargs ):
652
642
converter_config = load_json (self .preset , IMAGE_CONVERTER_CONFIG_FILE )
653
- return load_serialized_object (converter_config , ** kwargs )
643
+ return self . _load_serialized_object (converter_config , ** kwargs )
654
644
655
645
def load_task (self , cls , load_weights , load_task_weights , ** kwargs ):
656
646
# If there is no `task.json` or it's for the wrong class delegate to the
@@ -671,7 +661,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
671
661
backbone_config = task_config ["config" ]["backbone" ]["config" ]
672
662
backbone_config = {** backbone_config , ** backbone_kwargs }
673
663
task_config ["config" ]["backbone" ]["config" ] = backbone_config
674
- task = load_serialized_object (task_config , ** kwargs )
664
+ task = self . _load_serialized_object (task_config , ** kwargs )
675
665
if task .preprocessor and hasattr (
676
666
task .preprocessor , "load_preset_assets"
677
667
):
@@ -699,11 +689,20 @@ def load_preprocessor(
699
689
if not issubclass (check_config_class (preprocessor_json ), cls ):
700
690
return super ().load_preprocessor (cls , ** kwargs )
701
691
# We found a `preprocessing.json` with a complete config for our class.
702
- preprocessor = load_serialized_object (preprocessor_json , ** kwargs )
692
+ preprocessor = self . _load_serialized_object (preprocessor_json , ** kwargs )
703
693
if hasattr (preprocessor , "load_preset_assets" ):
704
694
preprocessor .load_preset_assets (self .preset )
705
695
return preprocessor
706
696
697
+ def _load_serialized_object (self , config , ** kwargs ):
698
+ # `dtype` in config might be a serialized `DTypePolicy` or
699
+ # `DTypePolicyMap`. Ensure that `dtype` is properly configured.
700
+ dtype = kwargs .pop ("dtype" , None )
701
+ config = set_dtype_in_config (config , dtype )
702
+
703
+ config ["config" ] = {** config ["config" ], ** kwargs }
704
+ return keras .saving .deserialize_keras_object (config )
705
+
707
706
708
707
class KerasPresetSaver :
709
708
def __init__ (self , preset_dir ):
0 commit comments