Skip to content

Commit b73a173

Browse files
authored
Reload big model with multiple state dict files (#1644)
* Reload big model with multiple state dict files * Add description for reload func
1 parent a61f89a commit b73a173

File tree

1 file changed

+28
-17
lines changed
  • src/sparseml/transformers/sparsification

1 file changed

+28
-17
lines changed

src/sparseml/transformers/sparsification/trainer.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torch.nn import Module
3434
from transformers import Trainer as HFTransformersTrainer
3535
from transformers import TrainerCallback, TrainerControl, TrainingArguments
36-
from transformers.file_utils import WEIGHTS_NAME, PaddingStrategy
36+
from transformers.file_utils import PaddingStrategy
3737
from transformers.integrations import TensorBoardCallback
3838
from transformers.trainer_callback import TrainerState
3939
from transformers.trainer_pt_utils import reissue_pt_warnings
@@ -218,12 +218,13 @@ def apply_manager(self, epoch: float, checkpoint: Optional[str]) -> bool:
218218

219219
# reload the state dict for the model now that architecture matches expected
220220
load_path = checkpoint or self.model_state_path
221-
self._reload_model_state(load_path, orig_state_dict)
221+
if self._reload_model_state(load_path, orig_state_dict):
222+
_LOGGER.info(
223+
"Reloaded model state after SparseML recipe structure modifications "
224+
f"from {load_path}"
225+
)
226+
222227
self.manager_applied = True
223-
_LOGGER.info(
224-
"Reloaded model state after SparseML recipe structure modifications "
225-
f"from {load_path}"
226-
)
227228

228229
return True
229230

@@ -652,27 +653,36 @@ def _setup_manager(
652653
return manager, arch_manager
653654

654655
def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
655-
if (
656-
not load_path
657-
or not os.path.isdir(load_path)
658-
or not os.path.isfile(os.path.join(load_path, WEIGHTS_NAME))
659-
):
656+
"""
657+
Reload the weights after model arch changes due to recipe application
658+
Return True if weights are successfully reloaded; False otherwise
659+
"""
660+
invalid_load_path = not load_path or not os.path.isdir(load_path)
661+
files = os.listdir(load_path) if not invalid_load_path else []
662+
weight_files = [
663+
os.path.join(load_path, f)
664+
for f in files
665+
if f.startswith("pytorch_model") and f.endswith("bin")
666+
]
667+
if not weight_files:
660668
_LOGGER.warning(
661669
"Model state was not reloaded for SparseML: "
662-
f"could not find model weights for model_path {load_path}"
670+
f"could not find model weights for {load_path}"
663671
)
664-
return
672+
return False
665673

666674
current_state_dict = self.model.state_dict()
667675

668676
if set(orig_state_dict.keys()) == set(current_state_dict):
669677
# no change in keys, ignore reload
670-
return
678+
return False
671679

672680
# change in keys due to architecture changes, reload statedict
673-
loaded_state_dict = torch.load(
674-
os.path.join(load_path, WEIGHTS_NAME), map_location="cpu"
675-
)
681+
loaded_state_dict = {}
682+
for f in weight_files:
683+
dd = torch.load(os.path.join(load_path, f), map_location="cpu")
684+
loaded_state_dict.update(dd)
685+
676686
_, missing, unexpected, _, _ = self.model._load_pretrained_model(
677687
model=self.model,
678688
state_dict=loaded_state_dict,
@@ -704,6 +714,7 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
704714
model_type="student" if self.teacher else "model",
705715
delayed_load=False,
706716
)
717+
return True
707718

708719
def _data_loader_builder(self, kwargs: Optional[Dict[str, Any]] = None):
709720
default_loader = self.get_train_dataloader()

0 commit comments

Comments
 (0)