-
Notifications
You must be signed in to change notification settings - Fork 285
Only load a full task config when load_task_extras
is passed
#1812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -656,7 +656,7 @@ def load_tokenizer(self, cls, **kwargs): | |
"""Load a tokenizer layer from the preset.""" | ||
raise NotImplementedError | ||
|
||
def load_task(self, cls, load_weights, **kwargs): | ||
def load_task(self, cls, load_weights, load_task, **kwargs): | ||
"""Load a task model from the preset. | ||
|
||
By default, we create a task from a backbone and preprocessor with | ||
|
@@ -671,11 +671,12 @@ def load_task(self, cls, load_weights, **kwargs): | |
) | ||
if "preprocessor" not in kwargs: | ||
kwargs["preprocessor"] = self.load_preprocessor( | ||
cls.preprocessor_cls | ||
cls.preprocessor_cls, | ||
load_task=load_task, | ||
) | ||
return cls(**kwargs) | ||
|
||
def load_preprocessor(self, cls, **kwargs): | ||
def load_preprocessor(self, cls, load_task, **kwargs): | ||
"""Load a prepocessor layer from the preset. | ||
|
||
By default, we create a preprocessor from a tokenizer with default | ||
|
@@ -704,35 +705,54 @@ def load_tokenizer(self, cls, **kwargs): | |
tokenizer.load_preset_assets(self.preset) | ||
return tokenizer | ||
|
||
def load_task(self, cls, load_weights, **kwargs): | ||
def load_task(self, cls, load_weights, load_task, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this a little bit confusing that the function is called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah good call! This function isn't exposed at least but I'll think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying this out. |
||
# If there is no `task.json` or it's for the wrong class delegate to the | ||
# super class loader. | ||
if not load_task: | ||
return super().load_task(cls, load_weights, load_task, **kwargs) | ||
if not check_file_exists(self.preset, TASK_CONFIG_FILE): | ||
return super().load_task(cls, load_weights, **kwargs) | ||
raise ValueError( | ||
"Saved preset has no `task.json`, cannot load the task config " | ||
"from a file. Call `from_preset()` with `load_task=False` to " | ||
"load the task from a backbone with library defaults." | ||
) | ||
task_config = load_json(self.preset, TASK_CONFIG_FILE) | ||
if not issubclass(check_config_class(task_config), cls): | ||
return super().load_task(cls, load_weights, **kwargs) | ||
raise ValueError( | ||
f"Saved `task.json`does not match calling cls {cls}. Call " | ||
"`from_preset()` with `load_task=False` to load the task from " | ||
"a backbone with library defaults." | ||
) | ||
# We found a `task.json` with a complete config for our class. | ||
task = load_serialized_object(task_config, **kwargs) | ||
if task.preprocessor is not None: | ||
task.preprocessor.tokenizer.load_preset_assets(self.preset) | ||
if load_weights: | ||
jax_memory_cleanup(task) | ||
jax_memory_cleanup(task.backbone) | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if check_file_exists(self.preset, TASK_WEIGHTS_FILE): | ||
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE) | ||
task.load_task_weights(task_weights) | ||
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) | ||
task.backbone.load_weights(backbone_weights) | ||
return task | ||
|
||
def load_preprocessor(self, cls, **kwargs): | ||
# If there is no `preprocessing.json` or it's for the wrong class, | ||
# delegate to the super class loader. | ||
def load_preprocessor(self, cls, load_task, **kwargs): | ||
if not load_task: | ||
return super().load_preprocessor(cls, load_task, **kwargs) | ||
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE): | ||
return super().load_preprocessor(cls, **kwargs) | ||
raise ValueError( | ||
"Saved preset has no `preprocessor.json`, cannot load the task " | ||
"preprocessing config from a file. Call `from_preset()` with " | ||
"`load_task=False` to load the preprocessor with library " | ||
"defaults." | ||
) | ||
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE) | ||
if not issubclass(check_config_class(preprocessor_json), cls): | ||
return super().load_preprocessor(cls, **kwargs) | ||
raise ValueError( | ||
f"Saved `preprocessor.json`does not match calling cls {cls}. " | ||
"Call `from_preset()` with `load_task=False` to " | ||
"load the the preprocessor with library defaults." | ||
) | ||
# We found a `preprocessing.json` with a complete config for our class. | ||
preprocessor = load_serialized_object(preprocessor_json, **kwargs) | ||
preprocessor.tokenizer.load_preset_assets(self.preset) | ||
|
Uh oh!
There was an error while loading. Please reload this page.