Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit c092631

Browse files
ArmenAgfacebook-github-bot
authored andcommitted
Implement LAMB optimizer
Summary: Implement the LAMB Optimizer with support (https://arxiv.org/abs/1904.00962). The reference implementation can be found here: https://github.com/cybertronai/pytorch-lamb Reviewed By: AkshatSh Differential Revision: D18725798 fbshipit-source-id: 91b7de6ced02dac2a85f7d9ef6c8d875a7779cd1
1 parent 1289b24 commit c092631

File tree

4 files changed

+140
-2
lines changed

4 files changed

+140
-2
lines changed

pytext/optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
FP16OptimizerApex,
77
FP16OptimizerFairseq,
88
)
9+
from pytext.optimizer.lamb import Lamb # noqa
910
from pytext.optimizer.optimizers import ( # noqa
1011
SGD,
1112
Adagrad,

pytext/optimizer/lamb.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
5+
import torch
6+
from pytext.optimizer.optimizers import Optimizer
7+
from torch.optim import Optimizer as PT_Optimizer
8+
9+
10+
class Lamb(Optimizer, PT_Optimizer):
11+
r"""Implements Lamb algorithm.
12+
THIS WAS DIRECTLY COPIED OVER FROM pytorch/contrib
13+
https://github.com/cybertronai/pytorch-lamb
14+
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
15+
16+
Arguments:
17+
params (iterable): iterable of parameters to optimize or dicts defining
18+
parameter groups
19+
lr (float, optional): learning rate (default: 1e-3)
20+
betas (Tuple[float, float], optional): coefficients used for computing
21+
running averages of gradient and its square (default: (0.9, 0.999))
22+
eps (float, optional): term added to the denominator to improve
23+
numerical stability (default: 1e-8)
24+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
25+
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
26+
https://arxiv.org/abs/1904.00962
27+
"""
28+
29+
class Config(Optimizer.Config):
30+
lr: float = 0.001
31+
weight_decay: float = 0.00001
32+
eps: float = 1e-8
33+
34+
@classmethod
35+
def from_config(cls, config: Config, model: torch.nn.Module):
36+
return cls(
37+
model.parameters(),
38+
lr=config.lr,
39+
weight_decay=config.weight_decay,
40+
eps=config.eps,
41+
)
42+
43+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0):
44+
if not 0.0 <= lr:
45+
raise ValueError("Invalid learning rate: {}".format(lr))
46+
if not 0.0 <= eps:
47+
raise ValueError("Invalid epsilon value: {}".format(eps))
48+
if not 0.0 <= betas[0] < 1.0:
49+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
50+
if not 0.0 <= betas[1] < 1.0:
51+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
52+
PT_Optimizer.__init__(
53+
self,
54+
params,
55+
{"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
56+
)
57+
58+
def step(self, closure=None):
59+
"""Performs a single optimization step.
60+
61+
Arguments:
62+
closure (callable, optional): A closure that reevaluates the model
63+
and returns the loss.
64+
"""
65+
loss = None
66+
if closure is not None:
67+
loss = closure()
68+
69+
for group in self.param_groups:
70+
for p in group["params"]:
71+
if p.grad is None:
72+
continue
73+
grad = p.grad.data
74+
if grad.is_sparse:
75+
raise RuntimeError(
76+
"Lamb does not support sparse gradients, consider SparseAdam instad."
77+
)
78+
79+
state = self.state[p]
80+
81+
# State initialization
82+
if len(state) == 0:
83+
state["step"] = 0
84+
# Exponential moving average of gradient values
85+
state["exp_avg"] = torch.zeros_like(p.data)
86+
# Exponential moving average of squared gradient values
87+
state["exp_avg_sq"] = torch.zeros_like(p.data)
88+
89+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
90+
beta1, beta2 = group["betas"]
91+
92+
state["step"] += 1
93+
94+
# Decay the first and second moment running average coefficient
95+
# m_t
96+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
97+
# v_t
98+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
99+
100+
# Paper v3 does not use debiasing.
101+
# bias_correction1 = 1 - beta1 ** state['step']
102+
# bias_correction2 = 1 - beta2 ** state['step']
103+
# Apply bias to lr to avoid broadcast.
104+
step_size = group["lr"]
105+
# * math.sqrt(bias_correction2) / bias_correction1
106+
107+
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
108+
109+
adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
110+
if group["weight_decay"] != 0:
111+
adam_step.add_(group["weight_decay"], p.data)
112+
113+
adam_norm = adam_step.pow(2).sum().sqrt()
114+
if weight_norm == 0 or adam_norm == 0:
115+
trust_ratio = 1
116+
else:
117+
trust_ratio = weight_norm / adam_norm
118+
state["weight_norm"] = weight_norm
119+
state["adam_norm"] = adam_norm
120+
state["trust_ratio"] = trust_ratio
121+
p.data.add_(-step_size * trust_ratio, adam_step)
122+
123+
return loss

pytext/optimizer/swa.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from pytext.config.component import create_optimizer
10+
from pytext.optimizer.lamb import Lamb
1011
from pytext.optimizer.optimizers import SGD, Adagrad, Adam, AdamW, Optimizer
1112
from pytext.optimizer.radam import RAdam
1213
from torch.optim import Optimizer as PT_Optimizer
@@ -15,7 +16,12 @@
1516
class StochasticWeightAveraging(Optimizer, PT_Optimizer):
1617
class Config(Optimizer.Config):
1718
optimizer: Union[
18-
SGD.Config, Adam.Config, AdamW.Config, Adagrad.Config, RAdam.Config
19+
SGD.Config,
20+
Adam.Config,
21+
AdamW.Config,
22+
Adagrad.Config,
23+
RAdam.Config,
24+
Lamb.Config,
1925
] = SGD.Config()
2026
start: int = 10
2127
frequency: int = 5

pytext/optimizer/tests/test_swa.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from unittest import TestCase
1111

1212
import torch
13-
from pytext.optimizer import StochasticWeightAveraging
13+
from pytext.optimizer import Lamb, StochasticWeightAveraging
1414
from torch import nn, optim, sparse
1515
from torch.autograd import Variable
1616
from torch.utils import data
@@ -675,3 +675,11 @@ def forward(self, x):
675675
test(CNN, (objects, channels, height, width), objects, "cpu")
676676
if torch.cuda.is_available():
677677
test(CNN, (objects, channels, height, width), objects, "cuda")
678+
679+
def test_lamb(self):
680+
def lamb_constructor(params):
681+
return StochasticWeightAveraging(
682+
Lamb(params, weight_decay=0.01), swa_start=1000, swa_freq=1, swa_lr=1e-2
683+
)
684+
685+
self._test_rosenbrock(lamb_constructor, automode=False)

0 commit comments

Comments
 (0)