Skip to content

Commit 278c29e

Browse files
authored
Merge pull request #20 from kozistr/feature/adabound-optimizer
[Feature] Implement AdaBound/AdaBoundW optimizers
2 parents 29a9dd3 + c3eab9b commit 278c29e

File tree

5 files changed

+321
-3
lines changed

5 files changed

+321
-3
lines changed

README.md

+17
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ for input, output in data:
3333

3434
| Optimizer | Description | Official Code | Paper |
3535
| :---: | :---: | :---: | :---: |
36+
| AdaBound | *Adaptive Gradient Methods with Dynamic Bound of Learning Rate* | [github](https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py) | [https://openreview.net/forum?id=Bkg3g2R9FX](https://openreview.net/forum?id=Bkg3g2R9FX) |
3637
| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | [github](https://github.com/amirgholami/adahessian) | [https://arxiv.org/abs/2006.00719](https://arxiv.org/abs/2006.00719) |
3738
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | [https://arxiv.org/abs/2006.08217](https://arxiv.org/abs/2006.08217) |
3839
| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | [github](https://github.com/facebookresearch/madgrad) | [https://arxiv.org/abs/2101.11075](https://arxiv.org/abs/2101.11075) |
@@ -336,6 +337,22 @@ Acceleration via Fractal Learning Rate Schedules
336337

337338
</details>
338339

340+
<details>
341+
342+
<summary>AdaBound</summary>
343+
344+
```
345+
@inproceedings{Luo2019AdaBound,
346+
author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
347+
title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
348+
booktitle = {Proceedings of the 7th International Conference on Learning Representations},
349+
month = {May},
350+
year = {2019},
351+
address = {New Orleans, Louisiana}
352+
}
353+
```
354+
355+
</details>
339356

340357
## Author
341358

pytorch_optimizer/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pytorch_optimizer.adabound import AdaBound, AdaBoundW
12
from pytorch_optimizer.adahessian import AdaHessian
23
from pytorch_optimizer.adamp import AdamP
34
from pytorch_optimizer.agc import agc
@@ -10,4 +11,4 @@
1011
from pytorch_optimizer.ranger21 import Ranger21
1112
from pytorch_optimizer.sgdp import SGDP
1213

13-
__VERSION__ = '0.0.3'
14+
__VERSION__ = '0.0.4'

pytorch_optimizer/adabound.py

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.types import (
7+
BETAS,
8+
CLOSURE,
9+
DEFAULT_PARAMETERS,
10+
LOSS,
11+
PARAMS,
12+
STATE,
13+
)
14+
15+
16+
class AdaBound(Optimizer):
17+
"""
18+
Reference : https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
19+
Example :
20+
from pytorch_optimizer import AdaBound
21+
...
22+
model = YourModel()
23+
optimizer = AdaBound(model.parameters())
24+
...
25+
for input, output in data:
26+
optimizer.zero_grad()
27+
loss = loss_function(output, model(input))
28+
loss.backward()
29+
optimizer.step()
30+
"""
31+
32+
def __init__(
33+
self,
34+
params: PARAMS,
35+
lr: float = 1e-3,
36+
betas: BETAS = (0.9, 0.999),
37+
final_lr: float = 0.1,
38+
gamma: float = 1e-3,
39+
eps: float = 1e-8,
40+
weight_decay: float = 0.0,
41+
amsbound: bool = False,
42+
):
43+
"""AdaBound optimizer
44+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
45+
:param lr: float. learning rate
46+
:param final_lr: float. final learning rate
47+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
48+
:param gamma: float. convergence speed of the bound functions
49+
:param eps: float. term added to the denominator to improve numerical stability
50+
:param weight_decay: float. weight decay (L2 penalty)
51+
:param amsbound: bool. whether to use the AMSBound variant
52+
"""
53+
self.lr = lr
54+
self.betas = betas
55+
self.eps = eps
56+
self.weight_decay = weight_decay
57+
58+
defaults: DEFAULT_PARAMETERS = dict(
59+
lr=lr,
60+
betas=betas,
61+
final_lr=final_lr,
62+
gamma=gamma,
63+
eps=eps,
64+
weight_decay=weight_decay,
65+
amsbound=amsbound,
66+
)
67+
super().__init__(params, defaults)
68+
69+
self.base_lrs = [group['lr'] for group in self.param_groups]
70+
71+
def check_valid_parameters(self):
72+
if 0.0 > self.lr:
73+
raise ValueError(f'Invalid learning rate : {self.lr}')
74+
if 0.0 > self.eps:
75+
raise ValueError(f'Invalid eps : {self.eps}')
76+
if 0.0 > self.weight_decay:
77+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
78+
if not 0.0 <= self.betas[0] < 1.0:
79+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
80+
if not 0.0 <= self.betas[1] < 1.0:
81+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
82+
83+
def __setstate__(self, state: STATE):
84+
super().__setstate__(state)
85+
for group in self.param_groups:
86+
group.setdefault('amsbound', False)
87+
88+
def step(self, closure: CLOSURE = None) -> LOSS:
89+
loss: LOSS = None
90+
if closure is not None:
91+
loss = closure()
92+
93+
for group, base_lr in zip(self.param_groups, self.base_lrs):
94+
for p in group['params']:
95+
if p.grad is None:
96+
continue
97+
98+
grad = p.grad.data
99+
if grad.is_sparse:
100+
raise RuntimeError(
101+
'AdaBound does not support sparse gradients'
102+
)
103+
104+
amsbound = group['amsbound']
105+
106+
state = self.state[p]
107+
108+
if len(state) == 0:
109+
state['step'] = 0
110+
state['exp_avg'] = torch.zeros_like(p)
111+
state['exp_avg_sq'] = torch.zeros_like(p)
112+
if amsbound:
113+
state['max_exp_avg_sq'] = torch.zeros_like(p)
114+
115+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
116+
if amsbound:
117+
max_exp_avg_sq = state['max_exp_avg_sq']
118+
beta1, beta2 = group['betas']
119+
120+
state['step'] += 1
121+
122+
if group['weight_decay'] != 0:
123+
grad = grad.add(group['weight_decay'], p.data)
124+
125+
# Decay the first and second moment running average coefficient
126+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
127+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
128+
if amsbound:
129+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
130+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
131+
else:
132+
denom = exp_avg_sq.sqrt().add_(group['eps'])
133+
134+
bias_correction1 = 1 - beta1 ** state['step']
135+
bias_correction2 = 1 - beta2 ** state['step']
136+
step_size = (
137+
group['lr']
138+
* math.sqrt(bias_correction2)
139+
/ bias_correction1
140+
)
141+
142+
final_lr = group['final_lr'] * group['lr'] / base_lr
143+
lower_bound = final_lr * (
144+
1 - 1 / (group['gamma'] * state['step'] + 1)
145+
)
146+
upper_bound = final_lr * (
147+
1 + 1 / (group['gamma'] * state['step'])
148+
)
149+
step_size = torch.full_like(denom, step_size)
150+
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
151+
exp_avg
152+
)
153+
154+
p.data.add_(-step_size)
155+
156+
return loss
157+
158+
159+
class AdaBoundW(Optimizer):
160+
"""
161+
Reference : https://github.com/Luolc/AdaBound
162+
Example :
163+
from pytorch_optimizer import AdaBoundW
164+
...
165+
model = YourModel()
166+
optimizer = AdaBoundW(model.parameters())
167+
...
168+
for input, output in data:
169+
optimizer.zero_grad()
170+
loss = loss_function(output, model(input))
171+
loss.backward()
172+
optimizer.step()
173+
"""
174+
175+
def __init__(
176+
self,
177+
params: PARAMS,
178+
lr: float = 1e-3,
179+
betas: BETAS = (0.9, 0.999),
180+
final_lr: float = 0.1,
181+
gamma: float = 1e-3,
182+
eps: float = 1e-8,
183+
weight_decay: float = 0.0,
184+
amsbound: bool = False,
185+
):
186+
"""AdaBound optimizer with decoupled weight decay
187+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
188+
:param lr: float. learning rate
189+
:param final_lr: float. final learning rate
190+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
191+
:param gamma: float. convergence speed of the bound functions
192+
:param eps: float. term added to the denominator to improve numerical stability
193+
:param weight_decay: float. weight decay (L2 penalty)
194+
:param amsbound: bool. whether to use the AMSBound variant
195+
"""
196+
self.lr = lr
197+
self.betas = betas
198+
self.eps = eps
199+
self.weight_decay = weight_decay
200+
201+
defaults: DEFAULT_PARAMETERS = dict(
202+
lr=lr,
203+
betas=betas,
204+
final_lr=final_lr,
205+
gamma=gamma,
206+
eps=eps,
207+
weight_decay=weight_decay,
208+
amsbound=amsbound,
209+
)
210+
super().__init__(params, defaults)
211+
212+
self.base_lrs = [group['lr'] for group in self.param_groups]
213+
214+
def check_valid_parameters(self):
215+
if 0.0 > self.lr:
216+
raise ValueError(f'Invalid learning rate : {self.lr}')
217+
if 0.0 > self.eps:
218+
raise ValueError(f'Invalid eps : {self.eps}')
219+
if 0.0 > self.weight_decay:
220+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
221+
if not 0.0 <= self.betas[0] < 1.0:
222+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
223+
if not 0.0 <= self.betas[1] < 1.0:
224+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
225+
226+
def __setstate__(self, state: STATE):
227+
super().__setstate__(state)
228+
for group in self.param_groups:
229+
group.setdefault('amsbound', False)
230+
231+
def step(self, closure: CLOSURE = None) -> LOSS:
232+
loss: LOSS = None
233+
if closure is not None:
234+
loss = closure()
235+
236+
for group, base_lr in zip(self.param_groups, self.base_lrs):
237+
for p in group['params']:
238+
if p.grad is None:
239+
continue
240+
241+
p.mul_(1 - base_lr * group['weight_decay'])
242+
243+
grad = p.grad.data
244+
if grad.is_sparse:
245+
raise RuntimeError(
246+
'AdaBound does not support sparse gradients'
247+
)
248+
249+
amsbound = group['amsbound']
250+
251+
state = self.state[p]
252+
253+
if len(state) == 0:
254+
state['step'] = 0
255+
state['exp_avg'] = torch.zeros_like(p)
256+
state['exp_avg_sq'] = torch.zeros_like(p)
257+
if amsbound:
258+
state['max_exp_avg_sq'] = torch.zeros_like(p)
259+
260+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
261+
if amsbound:
262+
max_exp_avg_sq = state['max_exp_avg_sq']
263+
beta1, beta2 = group['betas']
264+
265+
state['step'] += 1
266+
267+
# Decay the first and second moment running average coefficient
268+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
269+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
270+
if amsbound:
271+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
272+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
273+
else:
274+
denom = exp_avg_sq.sqrt().add_(group['eps'])
275+
276+
bias_correction1 = 1 - beta1 ** state['step']
277+
bias_correction2 = 1 - beta2 ** state['step']
278+
step_size = (
279+
group['lr']
280+
* math.sqrt(bias_correction2)
281+
/ bias_correction1
282+
)
283+
284+
final_lr = group['final_lr'] * group['lr'] / base_lr
285+
lower_bound = final_lr * (
286+
1 - 1 / (group['gamma'] * state['step'] + 1)
287+
)
288+
upper_bound = final_lr * (
289+
1 + 1 / (group['gamma'] * state['step'])
290+
)
291+
step_size = torch.full_like(denom, step_size)
292+
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
293+
exp_avg
294+
)
295+
296+
p.data.add_(-step_size)
297+
298+
return loss

pytorch_optimizer/adamp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class AdamP(Optimizer):
2121
from pytorch_optimizer import AdamP
2222
...
2323
model = YourModel()
24-
optimizer = AdaHessian(model.parameters())
24+
optimizer = AdamP(model.parameters())
2525
...
2626
for input, output in data:
2727
optimizer.zero_grad()

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def read_version() -> str:
3434
'Intended Audience :: Developers',
3535
'Intended Audience :: Science/Research',
3636
'Programming Language :: Python :: 3',
37-
'Programming Language :: Python :: 3.6',
3837
'Programming Language :: Python :: 3.7',
3938
'Programming Language :: Python :: 3.8',
4039
'Operating System :: OS Independent',
@@ -55,6 +54,9 @@ def read_version() -> str:
5554
'chebyshev_schedule',
5655
'lookahead',
5756
'radam',
57+
'adabound',
58+
'adaboundw',
59+
'adahessian',
5860
]
5961
)
6062

0 commit comments

Comments
 (0)