Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit c35d513

Browse files
jeanmfacebook-github-bot
authored andcommitted
Inverse Sqrt Scheduler (#1150)
Summary: Pull Request resolved: #1150 Currently, WarmupScheduler does this during the warm-up period: lr = base_lr * current_step / warmup_steps This diff adds the option of adding LR decay after the warm-up period: lr = base_lr * sqrt(warmup_steps) / sqrt(current_step) This is similar to [Fairseq's implementation](https://github.com/pytorch/fairseq/blob/master/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py). Reviewed By: ccsasuke Differential Revision: D18491650 fbshipit-source-id: d42cea2e2cbd169297508403300fb686c8664d68
1 parent 13a0b7a commit c35d513

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

pytext/optimizer/scheduler.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,22 +304,34 @@ def step_epoch(self, metrics=None, epoch=None):
304304

305305
class WarmupScheduler(_LRScheduler, BatchScheduler):
306306
"""
307-
Scheduler to linearly increase learning rate from 0 to final value at the beginning
308-
of training.
307+
Scheduler to linearly increase the learning rate from 0 to its final value over
308+
a number of steps:
309+
310+
lr = base_lr * current_step / warmup_steps
311+
312+
After the warm-up phase, the scheduler has the option of decaying the learning
313+
rate as the inverse square root of the number of training steps taken:
314+
315+
lr = base_lr * sqrt(warmup_steps) / sqrt(current_step)
309316
"""
310317

311318
class Config(BatchScheduler.Config):
312319
#: number of training steps over which to increase learning rate
313320
warmup_steps: int = 10000
314321

322+
#: whether to perform inverse sqrt decay after the warmup phase
323+
inverse_sqrt_decay: bool = False
324+
315325
@classmethod
316326
def from_config(cls, config: Config, optimizer: Optimizer):
317-
return cls(optimizer, config.warmup_steps)
327+
return cls(optimizer, config.warmup_steps, config.inverse_sqrt_decay)
318328

319-
def __init__(self, optimizer, warmup_steps):
329+
def __init__(self, optimizer, warmup_steps, inverse_sqrt_decay):
320330
assert warmup_steps > 0
321331
self.warmup_steps = warmup_steps
322332
self.current_steps = 0
333+
self.inverse_sqrt_decay = inverse_sqrt_decay
334+
self.decay_factor = warmup_steps ** 0.5
323335
super().__init__(optimizer)
324336

325337
def prepare(self, train_iter, total_epochs):
@@ -332,7 +344,10 @@ def step_batch(self):
332344

333345
def get_lr(self):
334346
if self.current_steps >= self.warmup_steps:
335-
lr_multiplier = 1.0
347+
if self.inverse_sqrt_decay:
348+
lr_multiplier = self.decay_factor / (self.current_steps ** 0.5)
349+
else:
350+
lr_multiplier = 1.0
336351
else:
337352
lr_multiplier = self.current_steps / self.warmup_steps
338353
return [lr_multiplier * base_lr for base_lr in self.base_lrs]

0 commit comments

Comments
 (0)