We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3c80cde commit 6cdc585Copy full SHA for 6cdc585
src/sparseml/pytorch/utils/model.py
@@ -80,7 +80,6 @@ def load_model(
80
if path.startswith("zoo:"):
81
path = download_framework_model_by_recipe_type(Model(path))
82
model_dict = torch.load(path, map_location="cpu")
83
- current_dict = model.state_dict()
84
recipe = model_dict.get("recipe")
85
86
if recipe:
@@ -90,6 +89,7 @@ def load_model(
90
89
checkpoint_manager = ScheduledModifierManager.from_yaml(recipe)
91
checkpoint_manager.apply_structure(module=model, epoch=epoch)
92
+ current_dict = model.state_dict()
93
if "state_dict" in model_dict:
94
model_dict = model_dict["state_dict"]
95
0 commit comments