Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 99500b2

Browse files
psuzhanhyfacebook-github-bot
authored andcommitted
Implement loss aware sparsifier (#1204)
Summary: Pull Request resolved: #1204 Implement a new loss agnostic sparsifier based on estimating the expected loss after removing a parameter, using Taylor series approximation. Reviewed By: hudeven Differential Revision: D18947888 fbshipit-source-id: db709f6a68933e5ba364f26035b00b7934ce3ddb
1 parent dcedc2d commit 99500b2

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

pytext/optimizer/sparsifiers/blockwise_sparsifier.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,28 @@ def from_config(cls, config: Config):
9191
config.layerwise_pruning,
9292
)
9393

94+
def get_sparsifiable_params(self, model, requires_name=False):
95+
sparsifiable_params = [
96+
p
97+
for n, p in model.named_parameters()
98+
if p.requires_grad and len(p.shape) == 2
99+
]
100+
sparsifiable_params_name = [
101+
n
102+
for n, p in model.named_parameters()
103+
if p.requires_grad and len(p.shape) == 2
104+
]
105+
if requires_name:
106+
return sparsifiable_params_name, sparsifiable_params
107+
else:
108+
return sparsifiable_params
109+
110+
def get_current_sparsity(self, model):
111+
sparsifiable_params = self.get_sparsifiable_params(model)
112+
sparsifiable_params_count = sum(p.numel() for p in sparsifiable_params)
113+
nonzero_params = sum(p.nonzero().size(0) for p in sparsifiable_params)
114+
return (sparsifiable_params_count - nonzero_params) / sparsifiable_params_count
115+
94116
def _padding_into_full_blocks(self, param):
95117
nrows, ncols = param.shape
96118
ncols_pad = math.ceil(ncols / self.block_size) * self.block_size

pytext/trainers/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class Config(ConfigBase):
9696
target_time_limit_seconds: Optional[int] = None
9797
#: Whether to do evaluation and model selection based on it.
9898
do_eval: bool = True
99+
#: if do_eval, do we load the best model state dict after training or just
100+
# use the latest model state
101+
load_best_model_after_train: bool = True
99102
#: Number of samples for logging training progress.
100103
num_samples_to_log_progress: int = 1000
101104
#: Number of forward & backward per batch before update gradients, the
@@ -465,7 +468,11 @@ def train_from_state(
465468
if should_update_model or train_config.save_all_checkpoints:
466469
self.save_checkpoint(state, train_config)
467470
# Only bother loading the best model for master worker
468-
if rank == 0 and state.best_model_state is not None:
471+
if (
472+
rank == 0
473+
and state.best_model_state is not None
474+
and self.config.load_best_model_after_train
475+
):
469476
self.load_best_model(state)
470477

471478
return state.model, state.best_model_metric

0 commit comments

Comments
 (0)