Skip to content

Commit 8ef27c3

Browse files
authored
Merge pull request #22 from kozistr/feature/sam-optimizer
[Feature] Implement SAM optimizer
2 parents 7588ee2 + 2749a08 commit 8ef27c3

File tree

4 files changed

+204
-7
lines changed

4 files changed

+204
-7
lines changed

README.md

+45-6
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Also, most of the captures are taken from `Ranger21` paper.
5353
This idea originally proposed in `NFNet (Normalized-Free Network)` paper.
5454
AGC (Adaptive Gradient Clipping) clips gradients based on the `unit-wise ratio of gradient norms to parameter norms`.
5555

56-
* github : [code](https://github.com/deepmind/deepmind-research/tree/master/nfnets)
56+
* code : [github](https://github.com/deepmind/deepmind-research/tree/master/nfnets)
5757
* paper : [arXiv](https://arxiv.org/abs/2102.06171)
5858

5959
### Gradient Centralization (GC)
@@ -62,7 +62,7 @@ AGC (Adaptive Gradient Clipping) clips gradients based on the `unit-wise ratio o
6262

6363
Gradient Centralization (GC) operates directly on gradients by centralizing the gradient to have zero mean.
6464

65-
* github : [code](https://github.com/Yonghongwei/Gradient-Centralization)
65+
* code : [github](https://github.com/Yonghongwei/Gradient-Centralization)
6666
* paper : [arXiv](https://arxiv.org/abs/2004.01461)
6767

6868
### Softplus Transformation
@@ -83,7 +83,7 @@ By running the final variance denom through the softplus function, it lifts extr
8383

8484
![positive_negative_momentum](assets/positive_negative_momentum.png)
8585

86-
* github : [code](https://github.com/zeke-xie/Positive-Negative-Momentum)
86+
* code : [github](https://github.com/zeke-xie/Positive-Negative-Momentum)
8787
* paper : [arXiv](https://arxiv.org/abs/2103.17182)
8888

8989
### Linear learning-rate warm-up
@@ -96,22 +96,22 @@ By running the final variance denom through the softplus function, it lifts extr
9696

9797
![stable_weight_decay](assets/stable_weight_decay.png)
9898

99-
* github : [code](https://github.com/zeke-xie/stable-weight-decay-regularization)
99+
* code : [github](https://github.com/zeke-xie/stable-weight-decay-regularization)
100100
* paper : [arXiv](https://arxiv.org/abs/2011.11152)
101101

102102
### Explore-exploit learning-rate schedule
103103

104104
![explore_exploit_lr_schedule](assets/explore_exploit_lr_schedule.png)
105105

106-
* github : [code](https://github.com/nikhil-iyer-97/wide-minima-density-hypothesis)
106+
* code : [github](https://github.com/nikhil-iyer-97/wide-minima-density-hypothesis)
107107
* paper : [arXiv](https://arxiv.org/abs/2003.03977)
108108

109109
### Lookahead
110110

111111
`k` steps forward, 1 step back. `Lookahead` consisting of keeping an exponential moving average of the weights that is
112112
updated and substituted to the current weights every `k_{lookahead}` steps (5 by default).
113113

114-
* github : [code](https://github.com/alphadl/lookahead.pytorch)
114+
* code : [github](https://github.com/alphadl/lookahead.pytorch)
115115
* paper : [arXiv](https://arxiv.org/abs/1907.08610v2)
116116

117117
### Chebyshev learning rate schedule
@@ -120,6 +120,15 @@ Acceleration via Fractal Learning Rate Schedules
120120

121121
* paper : [arXiv](https://arxiv.org/abs/2103.01338v1)
122122

123+
### (Adaptive) Sharpness-Aware Minimization (A/SAM)
124+
125+
Sharpness-Aware Minimization (SAM) simultaneously minimizes loss value and loss sharpness.
126+
In particular, it seeks parameters that lie in neighborhoods having uniformly low loss.
127+
128+
* SAM paper : [paper](https://arxiv.org/abs/2010.01412)
129+
* ASAM paper : [paper](https://arxiv.org/abs/2102.11600)
130+
* A/SAM code : [github](https://github.com/davda54/sam)
131+
123132
## Citations
124133

125134
<details>
@@ -370,6 +379,36 @@ Acceleration via Fractal Learning Rate Schedules
370379

371380
</details>
372381

382+
<details>
383+
384+
<summary>Sharpness-Aware Minimization</summary>
385+
386+
```
387+
@article{foret2020sharpness,
388+
title={Sharpness-aware minimization for efficiently improving generalization},
389+
author={Foret, Pierre and Kleiner, Ariel and Mobahi, Hossein and Neyshabur, Behnam},
390+
journal={arXiv preprint arXiv:2010.01412},
391+
year={2020}
392+
}
393+
```
394+
395+
</details>
396+
397+
<details>
398+
399+
<summary>Adaptive Sharpness-Aware Minimization</summary>
400+
401+
```
402+
@article{kwon2021asam,
403+
title={ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks},
404+
author={Kwon, Jungmin and Kim, Jeongseop and Park, Hyunseo and Choi, In Kwon},
405+
journal={arXiv preprint arXiv:2102.11600},
406+
year={2021}
407+
}
408+
```
409+
410+
</details>
411+
373412
## Author
374413

375414
Hyeongchan Kim / [@kozistr](http://kozistr.tech/about)

pytorch_optimizer/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytorch_optimizer.radam import RAdam
1111
from pytorch_optimizer.ranger import Ranger
1212
from pytorch_optimizer.ranger21 import Ranger21
13+
from pytorch_optimizer.sam import SAM
1314
from pytorch_optimizer.sgdp import SGDP
1415

15-
__VERSION__ = '0.0.5'
16+
__VERSION__ = '0.0.6'

pytorch_optimizer/sam.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from typing import Dict
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.types import (
7+
CLOSURE,
8+
DEFAULT_PARAMETERS,
9+
PARAM_GROUPS,
10+
PARAMS,
11+
)
12+
13+
14+
class SAM(Optimizer):
15+
"""
16+
Reference : https://github.com/davda54/sam
17+
Example :
18+
from pytorch_optimizer import SAM
19+
...
20+
model = YourModel()
21+
base_optimizer = Ranger21
22+
optimizer = SAM(model.parameters(), base_optimizer)
23+
...
24+
for input, output in data:
25+
# first forward-backward pass
26+
loss = loss_function(output, model(input)) # use this loss for any training statistics
27+
loss.backward()
28+
optimizer.first_step(zero_grad=True)
29+
30+
# second forward-backward pass
31+
loss_function(output, model(input)).backward() # make sure to do a full forward pass
32+
optimizer.second_step(zero_grad=True)
33+
34+
Alternative Example with a single closure-based step function:
35+
from pytorch_optimizer import SAM
36+
...
37+
model = YourModel()
38+
base_optimizer = Ranger21
39+
optimizer = SAM(model.parameters(), base_optimizer)
40+
41+
def closure():
42+
loss = loss_function(output, model(input))
43+
loss.backward()
44+
return loss
45+
...
46+
47+
for input, output in data:
48+
loss = loss_function(output, model(input))
49+
loss.backward()
50+
optimizer.step(closure)
51+
optimizer.zero_grad()
52+
"""
53+
54+
def __init__(
55+
self,
56+
params: PARAMS,
57+
base_optimizer,
58+
rho: float = 0.05,
59+
adaptive: bool = False,
60+
**kwargs,
61+
):
62+
"""(Adaptive) Sharpness-Aware Minimization
63+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
64+
:param base_optimizer:
65+
:param rho: float. size of the neighborhood for computing the max loss
66+
:param adaptive: bool. element-wise Adaptive SAM
67+
:param kwargs: Dict. parameters for optimizer.
68+
"""
69+
self.rho = rho
70+
71+
self.check_valid_parameters()
72+
73+
defaults: DEFAULT_PARAMETERS = dict(
74+
rho=rho, adaptive=adaptive, **kwargs
75+
)
76+
super().__init__(params, defaults)
77+
78+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
79+
self.param_groups: PARAM_GROUPS = self.base_optimizer.param_groups
80+
81+
def check_valid_parameters(self):
82+
if 0.0 > self.rho:
83+
raise ValueError(f'Invalid rho : {self.rho}')
84+
85+
@torch.no_grad()
86+
def first_step(self, zero_grad: bool = False):
87+
grad_norm = self.grad_norm()
88+
for group in self.param_groups:
89+
scale = group['rho'] / (grad_norm + 1e-12)
90+
91+
for p in group['params']:
92+
if p.grad is None:
93+
continue
94+
self.state[p]['old_p'] = p.data.clone()
95+
e_w = (
96+
(torch.pow(p, 2) if group['adaptive'] else 1.0)
97+
* p.grad
98+
* scale.to(p)
99+
)
100+
p.add_(e_w) # climb to the local maximum "w + e(w)"
101+
102+
if zero_grad:
103+
self.zero_grad()
104+
105+
@torch.no_grad()
106+
def second_step(self, zero_grad: bool = False):
107+
for group in self.param_groups:
108+
for p in group['params']:
109+
if p.grad is None:
110+
continue
111+
p.data = self.state[p][
112+
'old_p'
113+
] # get back to "w" from "w + e(w)"
114+
115+
self.base_optimizer.step() # do the actual "sharpness-aware" update
116+
117+
if zero_grad:
118+
self.zero_grad()
119+
120+
@torch.no_grad()
121+
def step(self, closure: CLOSURE = None):
122+
if closure is None:
123+
raise RuntimeError(
124+
'Sharpness Aware Minimization requires closure, but it was not provided'
125+
)
126+
127+
# the closure should do a full forward-backward pass
128+
closure = torch.enable_grad()(closure)
129+
130+
self.first_step(zero_grad=True)
131+
closure()
132+
self.second_step()
133+
134+
def grad_norm(self) -> torch.Tensor:
135+
shared_device = self.param_groups[0]['params'][
136+
0
137+
].device # put everything on the same device, in case of model parallelism
138+
norm = torch.norm(
139+
torch.stack(
140+
[
141+
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad)
142+
.norm(p=2)
143+
.to(shared_device)
144+
for group in self.param_groups
145+
for p in group['params']
146+
if p.grad is not None
147+
]
148+
),
149+
p=2,
150+
)
151+
return norm
152+
153+
def load_state_dict(self, state_dict: Dict):
154+
super().load_state_dict(state_dict)
155+
self.base_optimizer.param_groups = self.param_groups

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def read_version() -> str:
5757
'adabound',
5858
'adahessian',
5959
'adabelief',
60+
'sam',
61+
'asam',
6062
]
6163
)
6264

0 commit comments

Comments
 (0)