Skip to content

Commit 3c7a89f

Browse files
authored
Merge pull request #21 from kozistr/feature/adabelief-optimizer
[Feature] Implement AdaBelief optimizer
2 parents 278c29e + fea55f2 commit 3c7a89f

File tree

5 files changed

+261
-147
lines changed

5 files changed

+261
-147
lines changed

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ for input, output in data:
3333

3434
| Optimizer | Description | Official Code | Paper |
3535
| :---: | :---: | :---: | :---: |
36+
| AdaBelief | *Adapting Stepsizes by the Belief in Observed Gradients* | [github](https://github.com/juntang-zhuang/Adabelief-Optimizer) | [https://arxiv.org/abs/2010.07468](https://arxiv.org/abs/2010.07468) |
3637
| AdaBound | *Adaptive Gradient Methods with Dynamic Bound of Learning Rate* | [github](https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py) | [https://openreview.net/forum?id=Bkg3g2R9FX](https://openreview.net/forum?id=Bkg3g2R9FX) |
3738
| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | [github](https://github.com/amirgholami/adahessian) | [https://arxiv.org/abs/2006.00719](https://arxiv.org/abs/2006.00719) |
3839
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | [https://arxiv.org/abs/2006.08217](https://arxiv.org/abs/2006.08217) |
@@ -354,6 +355,21 @@ Acceleration via Fractal Learning Rate Schedules
354355

355356
</details>
356357

358+
<details>
359+
360+
<summary>AdaBelief</summary>
361+
362+
```
363+
@article{zhuang2020adabelief,
364+
title={Adabelief optimizer: Adapting stepsizes by the belief in observed gradients},
365+
author={Zhuang, Juntang and Tang, Tommy and Ding, Yifan and Tatikonda, Sekhar and Dvornek, Nicha and Papademetris, Xenophon and Duncan, James S},
366+
journal={arXiv preprint arXiv:2010.07468},
367+
year={2020}
368+
}
369+
```
370+
371+
</details>
372+
357373
## Author
358374

359375
Hyeongchan Kim / [@kozistr](http://kozistr.tech/about)

pytorch_optimizer/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from pytorch_optimizer.adabound import AdaBound, AdaBoundW
1+
from pytorch_optimizer.adabelief import AdaBelief
2+
from pytorch_optimizer.adabound import AdaBound
23
from pytorch_optimizer.adahessian import AdaHessian
34
from pytorch_optimizer.adamp import AdamP
45
from pytorch_optimizer.agc import agc
@@ -11,4 +12,4 @@
1112
from pytorch_optimizer.ranger21 import Ranger21
1213
from pytorch_optimizer.sgdp import SGDP
1314

14-
__VERSION__ = '0.0.4'
15+
__VERSION__ = '0.0.5'

pytorch_optimizer/adabelief.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.types import (
7+
BETAS,
8+
CLOSURE,
9+
DEFAULT_PARAMETERS,
10+
LOSS,
11+
PARAMS,
12+
STATE,
13+
)
14+
15+
16+
class AdaBelief(Optimizer):
17+
"""
18+
Reference : https://github.com/juntang-zhuang/Adabelief-Optimizer
19+
Example :
20+
from pytorch_optimizer import AdaBelief
21+
...
22+
model = YourModel()
23+
optimizer = AdaBelief(model.parameters())
24+
...
25+
for input, output in data:
26+
optimizer.zero_grad()
27+
loss = loss_function(output, model(input))
28+
loss.backward()
29+
optimizer.step()
30+
"""
31+
32+
def __init__(
33+
self,
34+
params: PARAMS,
35+
lr: float = 1e-3,
36+
betas: BETAS = (0.9, 0.999),
37+
eps: float = 1e-16,
38+
weight_decay: float = 0.0,
39+
n_sma_threshold: int = 5,
40+
amsgrad: bool = False,
41+
weight_decouple: bool = True,
42+
fixed_decay: bool = False,
43+
rectify: bool = True,
44+
degenerated_to_sgd: bool = True,
45+
):
46+
"""AdaBelief optimizer
47+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
48+
:param lr: float. learning rate
49+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
50+
:param eps: float. term added to the denominator to improve numerical stability
51+
:param weight_decay: float. weight decay (L2 penalty)
52+
:param n_sma_threshold: (recommended is 5)
53+
:param amsgrad: bool. whether to use the AMSBound variant
54+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
55+
:param fixed_decay: bool.
56+
:param rectify: bool. perform the rectified update similar to RAdam
57+
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
58+
"""
59+
self.lr = lr
60+
self.betas = betas
61+
self.eps = eps
62+
self.weight_decay = weight_decay
63+
self.n_sma_threshold = n_sma_threshold
64+
self.degenerated_to_sgd = degenerated_to_sgd
65+
self.weight_decouple = weight_decouple
66+
self.rectify = rectify
67+
self.fixed_decay = fixed_decay
68+
self.degenerated_to_sgd = degenerated_to_sgd
69+
70+
if (
71+
isinstance(params, (list, tuple))
72+
and len(params) > 0
73+
and isinstance(params[0], dict)
74+
):
75+
for param in params:
76+
if 'betas' in param and (
77+
param['betas'][0] != betas[0]
78+
or param['betas'][1] != betas[1]
79+
):
80+
param['buffer'] = [[None, None, None] for _ in range(10)]
81+
82+
defaults: DEFAULT_PARAMETERS = dict(
83+
lr=lr,
84+
betas=betas,
85+
eps=eps,
86+
weight_decay=weight_decay,
87+
amsgrad=amsgrad,
88+
buffer=[[None, None, None] for _ in range(10)],
89+
)
90+
super().__init__(params, defaults)
91+
92+
def __setstate__(self, state: STATE):
93+
super().__setstate__(state)
94+
for group in self.param_groups:
95+
group.setdefault('amsgrad', False)
96+
97+
def reset(self):
98+
for group in self.param_groups:
99+
for p in group['params']:
100+
state = self.state[p]
101+
amsgrad = group['amsgrad']
102+
103+
state['step'] = 0
104+
state['exp_avg'] = torch.zeros_like(p.data)
105+
state['exp_avg_var'] = torch.zeros_like(p.data)
106+
if amsgrad:
107+
state['max_exp_avg_var'] = torch.zeros_like(p.data)
108+
109+
def step(self, closure: CLOSURE = None) -> LOSS:
110+
loss: LOSS = None
111+
if closure is not None:
112+
loss = closure()
113+
114+
for group in self.param_groups:
115+
for p in group['params']:
116+
if p.grad is None:
117+
continue
118+
119+
half_precision: bool = False
120+
if p.data.dtype == torch.float16:
121+
half_precision = True
122+
p.data = p.data.float()
123+
p.grad = p.grad.float()
124+
125+
grad = p.grad.data
126+
if grad.is_sparse:
127+
raise RuntimeError(
128+
'AdaBelief does not support sparse gradients'
129+
)
130+
131+
amsgrad = group['amsgrad']
132+
133+
state = self.state[p]
134+
135+
beta1, beta2 = group['betas']
136+
137+
if len(state) == 0:
138+
state['step'] = 0
139+
state['exp_avg'] = torch.zeros_like(p.data)
140+
state['exp_avg_var'] = torch.zeros_like(p.data)
141+
if amsgrad:
142+
state['max_exp_avg_var'] = torch.zeros_like(p.data)
143+
144+
if self.weight_decouple:
145+
if not self.fixed_decay:
146+
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
147+
else:
148+
p.data.mul_(1.0 - group['weight_decay'])
149+
else:
150+
if group['weight_decay'] != 0:
151+
grad.add_(p.data, alpha=group['weight_decay'])
152+
153+
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
154+
155+
state['step'] += 1
156+
bias_correction1 = 1 - beta1 ** state['step']
157+
bias_correction2 = 1 - beta2 ** state['step']
158+
159+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
160+
grad_residual = grad - exp_avg
161+
exp_avg_var.mul_(beta2).addcmul_(
162+
grad_residual, grad_residual, value=1 - beta2
163+
)
164+
165+
if amsgrad:
166+
max_exp_avg_var = state['max_exp_avg_var']
167+
168+
torch.max(
169+
max_exp_avg_var,
170+
exp_avg_var.add_(group['eps']),
171+
out=max_exp_avg_var,
172+
)
173+
174+
denom = (
175+
max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)
176+
).add_(group['eps'])
177+
else:
178+
denom = (
179+
exp_avg_var.add_(group['eps']).sqrt()
180+
/ math.sqrt(bias_correction2)
181+
).add_(group['eps'])
182+
183+
if not self.rectify:
184+
step_size = group['lr'] / bias_correction1
185+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
186+
else:
187+
buffered = group['buffer'][int(state['step'] % 10)]
188+
if state['step'] == buffered[0]:
189+
n_sma, step_size = buffered[1], buffered[2]
190+
else:
191+
buffered[0] = state['step']
192+
beta2_t = beta2 ** state['step']
193+
n_sma_max = 2 / (1 - beta2) - 1
194+
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (
195+
1 - beta2_t
196+
)
197+
buffered[1] = n_sma
198+
199+
if n_sma >= self.n_sma_threshold:
200+
step_size = math.sqrt(
201+
(1 - beta2_t)
202+
* (n_sma - 4)
203+
/ (n_sma_max - 4)
204+
* (n_sma - 2)
205+
/ n_sma
206+
* n_sma_max
207+
/ (n_sma_max - 2)
208+
) / (1 - beta1 ** state['step'])
209+
elif self.degenerated_to_sgd:
210+
step_size = 1.0 / (1 - beta1 ** state['step'])
211+
else:
212+
step_size = -1
213+
buffered[2] = step_size
214+
215+
if n_sma >= self.n_sma_threshold:
216+
denom = exp_avg_var.sqrt().add_(group['eps'])
217+
p.data.addcdiv_(
218+
exp_avg, denom, value=-step_size * group['lr']
219+
)
220+
elif step_size > 0:
221+
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
222+
223+
if half_precision:
224+
p.data = p.data.half()
225+
p.grad = p.grad.half()
226+
227+
return loss

0 commit comments

Comments
 (0)