Skip to content

Commit b82f7c4

Browse files
authored
Merge pull request #344 from kozistr/feature/looksam-optimizer
[Feature] Implement `GCSAM` and `LookSAM` optimizers
2 parents 111249d + 40ec30d commit b82f7c4

File tree

12 files changed

+285
-15
lines changed

12 files changed

+285
-15
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **97 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -204,6 +204,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
204204
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
205205
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
206206
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
207+
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
208+
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |
207209

208210
## Supported LR Scheduler
209211

docs/changelogs/v3.4.1.md

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
### Change Log
22

3+
### Feature
4+
5+
* Support `GCSAM` optimizer. (#343, #344)
6+
* [Gradient Centralized Sharpness Aware Minimization](https://arxiv.org/abs/2501.11584)
7+
* you can use it from `SAM` optimizer by setting `use_gc=True`.
8+
* Support `LookSAM` optimizer. (#343, #344)
9+
* [Towards Efficient and Scalable Sharpness-Aware Minimization](https://arxiv.org/abs/2203.02714)
10+
311
### Update
412

513
* Support alternative precision training for `Shampoo` optimizer. (#339)

docs/index.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **97 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -204,6 +204,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
204204
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
205205
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
206206
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
207+
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
208+
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |
207209

208210
## Supported LR Scheduler
209211

docs/optimizer.md

+4
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@
240240
:docstring:
241241
:members:
242242

243+
::: pytorch_optimizer.LookSAM
244+
:docstring:
245+
:members:
246+
243247
::: pytorch_optimizer.MADGRAD
244248
:docstring:
245249
:members:

pyproject.toml

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ keywords = [
1717
"DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity",
1818
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
1919
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
20-
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam",
21-
"SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW",
22-
"SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
23-
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
20+
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
21+
"ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM",
22+
"SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
23+
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
24+
"QGaLore",
2425
]
2526
classifiers = [
2627
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
LaProp,
120120
Lion,
121121
Lookahead,
122+
LookSAM,
122123
Muon,
123124
Nero,
124125
NovoGrad,

pytorch_optimizer/optimizer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
from pytorch_optimizer.optimizer.ranger import Ranger
8080
from pytorch_optimizer.optimizer.ranger21 import Ranger21
8181
from pytorch_optimizer.optimizer.rotograd import RotoGrad
82-
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
82+
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM, LookSAM
8383
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
8484
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
8585
from pytorch_optimizer.optimizer.sgdp import SGDP

pytorch_optimizer/optimizer/sam.py

+195-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytorch_optimizer.base.exception import NoClosureError
1212
from pytorch_optimizer.base.optimizer import BaseOptimizer
1313
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS
14+
from pytorch_optimizer.optimizer.gc import centralize_gradient
1415
from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats
1516

1617

@@ -58,6 +59,7 @@ def closure():
5859
:param base_optimizer: OPTIMIZER. base optimizer.
5960
:param rho: float. size of the neighborhood for computing the max loss.
6061
:param adaptive: bool. element-wise Adaptive SAM.
62+
:param use_gc: bool. perform gradient centralization, GCSAM variant.
6163
:param perturb_eps: float. eps for perturbation.
6264
:param kwargs: Dict. parameters for optimizer.
6365
"""
@@ -68,12 +70,14 @@ def __init__(
6870
base_optimizer: OPTIMIZER,
6971
rho: float = 0.05,
7072
adaptive: bool = False,
73+
use_gc: bool = False,
7174
perturb_eps: float = 1e-12,
7275
**kwargs,
7376
):
7477
self.validate_non_negative(rho, 'rho')
7578
self.validate_non_negative(perturb_eps, 'perturb_eps')
7679

80+
self.use_gc = use_gc
7781
self.perturb_eps = perturb_eps
7882

7983
defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
@@ -92,16 +96,20 @@ def reset(self):
9296

9397
@torch.no_grad()
9498
def first_step(self, zero_grad: bool = False):
95-
grad_norm = self.grad_norm()
99+
grad_norm = self.grad_norm().add_(self.perturb_eps)
96100
for group in self.param_groups:
97-
scale = group['rho'] / (grad_norm + self.perturb_eps)
101+
scale = group['rho'] / grad_norm
98102

99103
for p in group['params']:
100104
if p.grad is None:
101105
continue
102106

107+
grad = p.grad
108+
if self.use_gc:
109+
centralize_gradient(grad, gc_conv_only=False)
110+
103111
self.state[p]['old_p'] = p.clone()
104-
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)
112+
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)
105113

106114
p.add_(e_w)
107115

@@ -670,3 +678,187 @@ def step(self, closure: CLOSURE = None):
670678
self.third_step()
671679

672680
return loss
681+
682+
683+
class LookSAM(BaseOptimizer):
684+
r"""Towards Efficient and Scalable Sharpness-Aware Minimization.
685+
686+
Example:
687+
-------
688+
Here's an example::
689+
690+
model = YourModel()
691+
base_optimizer = Ranger21
692+
optimizer = LookSAM(model.parameters(), base_optimizer)
693+
694+
for input, output in data:
695+
# first forward-backward pass
696+
697+
loss = loss_function(output, model(input))
698+
loss.backward()
699+
optimizer.first_step(zero_grad=True)
700+
701+
# second forward-backward pass
702+
# make sure to do a full forward pass
703+
loss_function(output, model(input)).backward()
704+
optimizer.second_step(zero_grad=True)
705+
706+
Alternative example with a single closure-based step function::
707+
708+
model = YourModel()
709+
base_optimizer = Ranger21
710+
optimizer = LookSAM(model.parameters(), base_optimizer)
711+
712+
def closure():
713+
loss = loss_function(output, model(input))
714+
loss.backward()
715+
return loss
716+
717+
for input, output in data:
718+
loss = loss_function(output, model(input))
719+
loss.backward()
720+
optimizer.step(closure)
721+
optimizer.zero_grad()
722+
723+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
724+
:param base_optimizer: OPTIMIZER. base optimizer.
725+
:param rho: float. size of the neighborhood for computing the max loss.
726+
:param k: int. lookahead step.
727+
:param alpha: float. lookahead blending alpha.
728+
:param adaptive: bool. element-wise Adaptive SAM.
729+
:param use_gc: bool. perform gradient centralization, GCSAM variant.
730+
:param perturb_eps: float. eps for perturbation.
731+
:param kwargs: Dict. parameters for optimizer.
732+
"""
733+
734+
def __init__(
735+
self,
736+
params: PARAMETERS,
737+
base_optimizer: OPTIMIZER,
738+
rho: float = 0.1,
739+
k: int = 10,
740+
alpha: float = 0.7,
741+
adaptive: bool = False,
742+
use_gc: bool = False,
743+
perturb_eps: float = 1e-12,
744+
**kwargs,
745+
):
746+
self.validate_non_negative(rho, 'rho')
747+
self.validate_positive(k, 'k')
748+
self.validate_range(alpha, 'alpha', 0.0, 1.0, '()')
749+
self.validate_non_negative(perturb_eps, 'perturb_eps')
750+
751+
self.k = k
752+
self.alpha = alpha
753+
self.use_gc = use_gc
754+
self.perturb_eps = perturb_eps
755+
756+
defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
757+
defaults.update(kwargs)
758+
759+
super().__init__(params, defaults)
760+
761+
self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
762+
self.param_groups = self.base_optimizer.param_groups
763+
764+
def __str__(self) -> str:
765+
return 'LookSAM'
766+
767+
@torch.no_grad()
768+
def reset(self):
769+
pass
770+
771+
def get_step(self):
772+
return (
773+
self.param_groups[0]['step']
774+
if 'step' in self.param_groups[0]
775+
else next(iter(self.base_optimizer.state.values()))['step'] if self.base_optimizer.state else 0
776+
)
777+
778+
@torch.no_grad()
779+
def first_step(self, zero_grad: bool = False) -> None:
780+
if self.get_step() % self.k != 0:
781+
return
782+
783+
grad_norm = self.grad_norm().add_(self.perturb_eps)
784+
for group in self.param_groups:
785+
scale = group['rho'] / grad_norm
786+
787+
for i, p in enumerate(group['params']):
788+
if p.grad is None:
789+
continue
790+
791+
grad = p.grad
792+
if self.use_gc:
793+
centralize_gradient(grad, gc_conv_only=False)
794+
795+
self.state[p]['old_p'] = p.clone()
796+
self.state[f'old_grad_p_{i}']['old_grad_p'] = grad.clone()
797+
798+
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)
799+
p.add_(e_w)
800+
801+
if zero_grad:
802+
self.zero_grad()
803+
804+
@torch.no_grad()
805+
def second_step(self, zero_grad: bool = False):
806+
step = self.get_step()
807+
808+
for group in self.param_groups:
809+
for i, p in enumerate(group['params']):
810+
if p.grad is None:
811+
continue
812+
813+
grad = p.grad
814+
grad_norm = grad.norm(p=2)
815+
816+
if step % self.k == 0:
817+
old_grad_p = self.state[f'old_grad_p_{i}']['old_grad_p']
818+
819+
g_grad_norm = old_grad_p / old_grad_p.norm(p=2)
820+
g_s_grad_norm = grad / grad_norm
821+
822+
self.state[f'gv_{i}']['gv'] = torch.sub(
823+
grad, grad_norm * torch.sum(g_grad_norm * g_s_grad_norm) * g_grad_norm
824+
)
825+
else:
826+
gv = self.state[f'gv_{i}']['gv']
827+
grad.add_(grad_norm / (gv.norm(p=2) + 1e-8) * gv, alpha=self.alpha)
828+
829+
p.data = self.state[p]['old_p']
830+
831+
self.base_optimizer.step()
832+
833+
if zero_grad:
834+
self.zero_grad()
835+
836+
@torch.no_grad()
837+
def step(self, closure: CLOSURE = None):
838+
if closure is None:
839+
raise NoClosureError(str(self))
840+
841+
self.first_step(zero_grad=True)
842+
843+
with torch.enable_grad():
844+
closure()
845+
846+
self.second_step()
847+
848+
def grad_norm(self) -> torch.Tensor:
849+
shared_device = self.param_groups[0]['params'][0].device
850+
return torch.norm(
851+
torch.stack(
852+
[
853+
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
854+
for group in self.param_groups
855+
for p in group['params']
856+
if p.grad is not None
857+
]
858+
),
859+
p=2,
860+
)
861+
862+
def load_state_dict(self, state_dict: Dict):
863+
super().load_state_dict(state_dict)
864+
self.base_optimizer.param_groups = self.param_groups

tests/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
'sam',
103103
'gsam',
104104
'wsam',
105+
'looksam',
105106
'pcgrad',
106107
'lookahead',
107108
'trac',

tests/test_gradients.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from pytorch_optimizer.base.exception import NoSparseGradientError
5-
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, OrthoGrad, load_optimizer
5+
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, LookSAM, OrthoGrad, load_optimizer
66
from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES
77
from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss
88

@@ -116,12 +116,13 @@ def test_sparse_supported(sparse_optimizer):
116116
optimizer.step()
117117

118118

119-
def test_sam_no_gradient():
119+
@pytest.mark.parametrize('optimizer', [SAM, LookSAM])
120+
def test_sam_no_gradient(optimizer):
120121
(x_data, y_data), model, loss_fn = build_environment()
121122
model.fc1.weight.requires_grad = False
122123
model.fc1.weight.grad = None
123124

124-
optimizer = SAM(model.parameters(), AdamP)
125+
optimizer = optimizer(model.parameters(), AdamP)
125126
optimizer.zero_grad()
126127

127128
loss = loss_fn(y_data, model(x_data))

tests/test_optimizer_parameters.py

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TRAC,
88
WSAM,
99
Lookahead,
10+
LookSAM,
1011
OrthoGrad,
1112
PCGrad,
1213
Ranger21,
@@ -110,6 +111,12 @@ def test_wsam_methods():
110111
optimizer.load_state_dict(optimizer.state_dict())
111112

112113

114+
def test_looksam_methods():
115+
optimizer = LookSAM([simple_parameter()], load_optimizer('adamp'))
116+
optimizer.reset()
117+
optimizer.load_state_dict(optimizer.state_dict())
118+
119+
113120
def test_safe_fp16_methods():
114121
optimizer = SafeFP16Optimizer(load_optimizer('adamp')([simple_parameter()], lr=5e-1))
115122
optimizer.load_state_dict(optimizer.state_dict())

0 commit comments

Comments
 (0)