@@ -673,7 +673,7 @@ def load_image_converter(self, cls, **kwargs):
673
673
"""Load an image converter layer from the preset."""
674
674
raise NotImplementedError
675
675
676
- def load_task (self , cls , load_weights , load_task_extras , ** kwargs ):
676
+ def load_task (self , cls , load_weights , load_task_weights , ** kwargs ):
677
677
"""Load a task model from the preset.
678
678
679
679
By default, we create a task from a backbone and preprocessor with
@@ -689,11 +689,10 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs):
689
689
if "preprocessor" not in kwargs :
690
690
kwargs ["preprocessor" ] = self .load_preprocessor (
691
691
cls .preprocessor_cls ,
692
- load_task_extras = load_task_extras ,
693
692
)
694
693
return cls (** kwargs )
695
694
696
- def load_preprocessor (self , cls , load_task_extras , ** kwargs ):
695
+ def load_preprocessor (self , cls , ** kwargs ):
697
696
"""Load a prepocessor layer from the preset.
698
697
699
698
By default, we create a preprocessor from a tokenizer with default
@@ -738,33 +737,25 @@ def load_image_converter(self, cls, **kwargs):
738
737
converter_config = load_json (self .preset , IMAGE_CONVERTER_CONFIG_FILE )
739
738
return load_serialized_object (converter_config , ** kwargs )
740
739
741
- def load_task (self , cls , load_weights , load_task_extras , ** kwargs ):
740
+ def load_task (self , cls , load_weights , load_task_weights , ** kwargs ):
742
741
# If there is no `task.json` or it's for the wrong class delegate to the
743
742
# super class loader.
744
- if not load_task_extras :
745
- return super ().load_task (
746
- cls , load_weights , load_task_extras , ** kwargs
747
- )
748
743
if not check_file_exists (self .preset , TASK_CONFIG_FILE ):
749
- raise ValueError (
750
- "Saved preset has no `task.json`, cannot load the task config "
751
- "from a file. Call `from_preset()` with "
752
- "`load_task_extras=False` to load the task from a backbone "
753
- "with library defaults."
744
+ return super ().load_task (
745
+ cls , load_weights , load_task_weights , ** kwargs
754
746
)
755
747
task_config = load_json (self .preset , TASK_CONFIG_FILE )
756
748
if not issubclass (check_config_class (task_config ), cls ):
757
- raise ValueError (
758
- f"Saved `task.json`does not match calling cls { cls } . Call "
759
- "`from_preset()` with `load_task_extras=False` to load the "
760
- "task from a backbone with library defaults."
749
+ return super ().load_task (
750
+ cls , load_weights , load_task_weights , ** kwargs
761
751
)
762
752
# We found a `task.json` with a complete config for our class.
763
753
task = load_serialized_object (task_config , ** kwargs )
764
754
if task .preprocessor is not None :
765
755
task .preprocessor .tokenizer .load_preset_assets (self .preset )
766
756
if load_weights :
767
- if check_file_exists (self .preset , TASK_WEIGHTS_FILE ):
757
+ has_task_weights = check_file_exists (self .preset , TASK_WEIGHTS_FILE )
758
+ if has_task_weights and load_task_weights :
768
759
jax_memory_cleanup (task )
769
760
task_weights = get_file (self .preset , TASK_WEIGHTS_FILE )
770
761
task .load_task_weights (task_weights )
@@ -774,23 +765,14 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs):
774
765
task .backbone .load_weights (backbone_weights )
775
766
return task
776
767
777
- def load_preprocessor (self , cls , load_task_extras , ** kwargs ):
778
- if not load_task_extras :
779
- return super (). load_preprocessor ( cls , load_task_extras , ** kwargs )
768
+ def load_preprocessor (self , cls , ** kwargs ):
769
+ # If there is no `preprocessing.json` or it's for the wrong class,
770
+ # delegate to the super class loader.
780
771
if not check_file_exists (self .preset , PREPROCESSOR_CONFIG_FILE ):
781
- raise ValueError (
782
- "Saved preset has no `preprocessor.json`, cannot load the task "
783
- "preprocessing config from a file. Call `from_preset()` with "
784
- "`load_task_extras=False` to load the preprocessor with "
785
- "library defaults."
786
- )
772
+ return super ().load_preprocessor (cls , ** kwargs )
787
773
preprocessor_json = load_json (self .preset , PREPROCESSOR_CONFIG_FILE )
788
774
if not issubclass (check_config_class (preprocessor_json ), cls ):
789
- raise ValueError (
790
- f"Saved `preprocessor.json`does not match calling cls { cls } . "
791
- "Call `from_preset()` with `load_task_extras=False` to "
792
- "load the the preprocessor with library defaults."
793
- )
775
+ return super ().load_preprocessor (cls , ** kwargs )
794
776
# We found a `preprocessing.json` with a complete config for our class.
795
777
preprocessor = load_serialized_object (preprocessor_json , ** kwargs )
796
778
preprocessor .tokenizer .load_preset_assets (self .preset )
0 commit comments