Skip to content

Commit 6cdc585

Browse files
authored
Fix quant model re-load bug (#978) (#1027)
1 parent 3c80cde commit 6cdc585

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/sparseml/pytorch/utils/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def load_model(
8080
if path.startswith("zoo:"):
8181
path = download_framework_model_by_recipe_type(Model(path))
8282
model_dict = torch.load(path, map_location="cpu")
83-
current_dict = model.state_dict()
8483
recipe = model_dict.get("recipe")
8584

8685
if recipe:
@@ -90,6 +89,7 @@ def load_model(
9089
checkpoint_manager = ScheduledModifierManager.from_yaml(recipe)
9190
checkpoint_manager.apply_structure(module=model, epoch=epoch)
9291

92+
current_dict = model.state_dict()
9393
if "state_dict" in model_dict:
9494
model_dict = model_dict["state_dict"]
9595

0 commit comments

Comments
 (0)