|
33 | 33 | from torch.nn import Module
|
34 | 34 | from transformers import Trainer as HFTransformersTrainer
|
35 | 35 | from transformers import TrainerCallback, TrainerControl, TrainingArguments
|
36 |
| -from transformers.file_utils import WEIGHTS_NAME, PaddingStrategy |
| 36 | +from transformers.file_utils import PaddingStrategy |
37 | 37 | from transformers.integrations import TensorBoardCallback
|
38 | 38 | from transformers.trainer_callback import TrainerState
|
39 | 39 | from transformers.trainer_pt_utils import reissue_pt_warnings
|
@@ -218,12 +218,13 @@ def apply_manager(self, epoch: float, checkpoint: Optional[str]) -> bool:
|
218 | 218 |
|
219 | 219 | # reload the state dict for the model now that architecture matches expected
|
220 | 220 | load_path = checkpoint or self.model_state_path
|
221 |
| - self._reload_model_state(load_path, orig_state_dict) |
| 221 | + if self._reload_model_state(load_path, orig_state_dict): |
| 222 | + _LOGGER.info( |
| 223 | + "Reloaded model state after SparseML recipe structure modifications " |
| 224 | + f"from {load_path}" |
| 225 | + ) |
| 226 | + |
222 | 227 | self.manager_applied = True
|
223 |
| - _LOGGER.info( |
224 |
| - "Reloaded model state after SparseML recipe structure modifications " |
225 |
| - f"from {load_path}" |
226 |
| - ) |
227 | 228 |
|
228 | 229 | return True
|
229 | 230 |
|
@@ -652,27 +653,36 @@ def _setup_manager(
|
652 | 653 | return manager, arch_manager
|
653 | 654 |
|
654 | 655 | def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
|
655 |
| - if ( |
656 |
| - not load_path |
657 |
| - or not os.path.isdir(load_path) |
658 |
| - or not os.path.isfile(os.path.join(load_path, WEIGHTS_NAME)) |
659 |
| - ): |
| 656 | + """ |
| 657 | + Reload the weights after model arch changes due to recipe application |
| 658 | + Return True if weights are successfully reloaded; False otherwise |
| 659 | + """ |
| 660 | + invalid_load_path = not load_path or not os.path.isdir(load_path) |
| 661 | + files = os.listdir(load_path) if not invalid_load_path else [] |
| 662 | + weight_files = [ |
| 663 | + os.path.join(load_path, f) |
| 664 | + for f in files |
| 665 | + if f.startswith("pytorch_model") and f.endswith("bin") |
| 666 | + ] |
| 667 | + if not weight_files: |
660 | 668 | _LOGGER.warning(
|
661 | 669 | "Model state was not reloaded for SparseML: "
|
662 |
| - f"could not find model weights for model_path {load_path}" |
| 670 | + f"could not find model weights for {load_path}" |
663 | 671 | )
|
664 |
| - return |
| 672 | + return False |
665 | 673 |
|
666 | 674 | current_state_dict = self.model.state_dict()
|
667 | 675 |
|
668 | 676 | if set(orig_state_dict.keys()) == set(current_state_dict):
|
669 | 677 | # no change in keys, ignore reload
|
670 |
| - return |
| 678 | + return False |
671 | 679 |
|
672 | 680 | # change in keys due to architecture changes, reload statedict
|
673 |
| - loaded_state_dict = torch.load( |
674 |
| - os.path.join(load_path, WEIGHTS_NAME), map_location="cpu" |
675 |
| - ) |
| 681 | + loaded_state_dict = {} |
| 682 | + for f in weight_files: |
| 683 | + dd = torch.load(os.path.join(load_path, f), map_location="cpu") |
| 684 | + loaded_state_dict.update(dd) |
| 685 | + |
676 | 686 | _, missing, unexpected, _, _ = self.model._load_pretrained_model(
|
677 | 687 | model=self.model,
|
678 | 688 | state_dict=loaded_state_dict,
|
@@ -704,6 +714,7 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
|
704 | 714 | model_type="student" if self.teacher else "model",
|
705 | 715 | delayed_load=False,
|
706 | 716 | )
|
| 717 | + return True |
707 | 718 |
|
708 | 719 | def _data_loader_builder(self, kwargs: Optional[Dict[str, Any]] = None):
|
709 | 720 | default_loader = self.get_train_dataloader()
|
|
0 commit comments