Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Implement loss aware sparsifier #1204

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pytext/optimizer/sparsifiers/blockwise_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ def from_config(cls, config: Config):
config.layerwise_pruning,
)

def get_sparsifiable_params(self, model, requires_name=False):
sparsifiable_params = [
p
for n, p in model.named_parameters()
if p.requires_grad and len(p.shape) == 2
]
sparsifiable_params_name = [
n
for n, p in model.named_parameters()
if p.requires_grad and len(p.shape) == 2
]
if requires_name:
return sparsifiable_params_name, sparsifiable_params
else:
return sparsifiable_params

def get_current_sparsity(self, model):
sparsifiable_params = self.get_sparsifiable_params(model)
sparsifiable_params_count = sum(p.numel() for p in sparsifiable_params)
nonzero_params = sum(p.nonzero().size(0) for p in sparsifiable_params)
return (sparsifiable_params_count - nonzero_params) / sparsifiable_params_count

def _padding_into_full_blocks(self, param):
nrows, ncols = param.shape
ncols_pad = math.ceil(ncols / self.block_size) * self.block_size
Expand Down
9 changes: 8 additions & 1 deletion pytext/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class Config(ConfigBase):
target_time_limit_seconds: Optional[int] = None
#: Whether to do evaluation and model selection based on it.
do_eval: bool = True
#: if do_eval, do we load the best model state dict after training or just
# use the latest model state
load_best_model_after_train: bool = True
#: Number of samples for logging training progress.
num_samples_to_log_progress: int = 1000
#: Number of forward & backward per batch before update gradients, the
Expand Down Expand Up @@ -465,7 +468,11 @@ def train_from_state(
if should_update_model or train_config.save_all_checkpoints:
self.save_checkpoint(state, train_config)
# Only bother loading the best model for master worker
if rank == 0 and state.best_model_state is not None:
if (
rank == 0
and state.best_model_state is not None
and self.config.load_best_model_after_train
):
self.load_best_model(state)

return state.model, state.best_model_metric
Expand Down