Skip to content

Commit aca2ef2

Browse files
authored
Merge pull request #366 from kozistr/feature/optimizers
[Feature] AdaGC and SimplifiedAdEMAMix optimizers
2 parents 487200d + 771541a commit aca2ef2

16 files changed

+537
-223
lines changed

README.md

+107-105
Large diffs are not rendered by default.

docs/changelogs/v3.4.3.md renamed to docs/changelogs/v3.5.0.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
* Support `StableSPAM` optimizer. (#358, #359)
66
* [How to Train in 4-Bit More Stably than 16-Bit Adam](https://arxiv.org/abs/2502.17055?)
77
* Support `ScheduleFreeWrapper`. (#334, #360)
8+
* Implement `AdaGC` optimizer. (#364, #366)
9+
* [Improving Training Stability for Large Language Model Pretraining](https://arxiv.org/abs/2502.11034)
10+
* Implement `Simplified-Ademamix` optimizer. (#364, #366)
11+
* [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
812

913
### Update
1014

docs/index.md

+107-105
Large diffs are not rendered by default.

docs/optimizer.md

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
:docstring:
2929
:members:
3030

31+
::: pytorch_optimizer.AdaGC
32+
:docstring:
33+
:members:
34+
3135
::: pytorch_optimizer.AdaHessian
3236
:docstring:
3337
:members:
@@ -92,6 +96,10 @@
9296
:docstring:
9397
:members:
9498

99+
::: pytorch_optimizer.SimplifiedAdEMAMix
100+
:docstring:
101+
:members:
102+
95103
::: pytorch_optimizer.ADOPT
96104
:docstring:
97105
:members:

docs/visualization.md

+16
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaFactor.png)
2424

25+
### AdaGC
26+
27+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaGC.png)
28+
2529
### AdaHessian
2630

2731
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaHessian.png)
@@ -326,6 +330,10 @@
326330

327331
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SignSGD.png)
328332

333+
### SimplifiedAdEMAMix
334+
335+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SimplifiedAdEMAMix.png)
336+
329337
### SM3
330338

331339
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SM3.png)
@@ -392,6 +400,10 @@
392400

393401
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaFactor.png)
394402

403+
### AdaGC
404+
405+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaGC.png)
406+
395407
### AdaHessian
396408

397409
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaHessian.png)
@@ -696,6 +708,10 @@
696708

697709
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SignSGD.png)
698710

711+
### SimplifiedAdEMAMix
712+
713+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SimplifiedAdEMAMix.png)
714+
699715
### SM3
700716

701717
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SM3.png)
634 KB
Loading
Loading
141 KB
Loading
Loading

pyproject.toml

+11-11
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
14-
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
15-
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
16-
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
17-
"DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity",
18-
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
19-
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
20-
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
21-
"ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH",
22-
"SPAM", "StableSPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal",
23-
"Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
24-
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
14+
"AdaDelta", "AdaFactor", "AdaGC", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix",
15+
"Simplified-AdEMAMix", "ADOPT", "AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan",
16+
"AggMo", "Aida", "AliG", "Amos", "Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
17+
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL",
18+
"GaLore", "Grams", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead",
19+
"MADGRAD", "MARS", "MSVAG", "Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy",
20+
"PSGD", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD",
21+
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
22+
"SOAP", "SopihaH", "SPAM", "StableSPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi",
23+
"BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky",
24+
"FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2525
]
2626
classifiers = [
2727
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
AdaBound,
7676
AdaDelta,
7777
AdaFactor,
78+
AdaGC,
7879
AdaHessian,
7980
Adai,
8081
Adalite,
@@ -143,6 +144,7 @@
143144
SGDSaI,
144145
Shampoo,
145146
SignSGD,
147+
SimplifiedAdEMAMix,
146148
SophiaH,
147149
StableAdamW,
148150
StableSPAM,

pytorch_optimizer/optimizer/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytorch_optimizer.optimizer.adabound import AdaBound
1414
from pytorch_optimizer.optimizer.adadelta import AdaDelta
1515
from pytorch_optimizer.optimizer.adafactor import AdaFactor
16+
from pytorch_optimizer.optimizer.adagc import AdaGC
1617
from pytorch_optimizer.optimizer.adahessian import AdaHessian
1718
from pytorch_optimizer.optimizer.adai import Adai
1819
from pytorch_optimizer.optimizer.adalite import Adalite
@@ -28,7 +29,7 @@
2829
from pytorch_optimizer.optimizer.adapnm import AdaPNM
2930
from pytorch_optimizer.optimizer.adashift import AdaShift
3031
from pytorch_optimizer.optimizer.adasmooth import AdaSmooth
31-
from pytorch_optimizer.optimizer.ademamix import AdEMAMix
32+
from pytorch_optimizer.optimizer.ademamix import AdEMAMix, SimplifiedAdEMAMix
3233
from pytorch_optimizer.optimizer.adopt import ADOPT
3334
from pytorch_optimizer.optimizer.agc import agc
3435
from pytorch_optimizer.optimizer.aggmo import AggMo
@@ -292,6 +293,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
292293
AdaLOMO,
293294
AdamG,
294295
AdEMAMix,
296+
SimplifiedAdEMAMix,
295297
SOAP,
296298
ADOPT,
297299
FTRL,
@@ -308,6 +310,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
308310
EXAdam,
309311
SCION,
310312
StableSPAM,
313+
AdaGC,
311314
Ranger25,
312315
]
313316
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

pytorch_optimizer/optimizer/adagc.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import math
2+
3+
import torch
4+
5+
from pytorch_optimizer.base.exception import NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
from pytorch_optimizer.optimizer.utils import get_global_gradient_norm
9+
10+
11+
class AdaGC(BaseOptimizer):
12+
r"""Improving Training Stability for Large Language Model Pretraining.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param lr: float. learning rate.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param beta: float. smoothing coefficient for EMA.
18+
:param lambda_abs: float. absolute clipping threshold to prevent unstable updates from gradient explosions.
19+
:param lambda_rel: float. relative clipping threshold to prevent unstable updates from gradient explosions.
20+
:param warmup_steps: int. warmup steps.
21+
:param weight_decay: float. weight decay (L2 penalty).
22+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
23+
:param fixed_decay: bool. fix weight decay.
24+
:param eps: float. term added to the denominator to improve numerical stability.
25+
"""
26+
27+
def __init__(
28+
self,
29+
params: PARAMETERS,
30+
lr: float = 1e-3,
31+
betas: BETAS = (0.9, 0.999),
32+
beta: float = 0.98,
33+
lambda_abs: float = 1.0,
34+
lambda_rel: float = 1.05,
35+
warmup_steps: int = 100,
36+
weight_decay: float = 1e-1,
37+
weight_decouple: bool = True,
38+
fixed_decay: bool = False,
39+
eps: float = 1e-8,
40+
**kwargs,
41+
):
42+
self.validate_learning_rate(lr)
43+
self.validate_betas(betas)
44+
self.validate_range(beta, 'beta', 0.0, 1.0, '[)')
45+
self.validate_positive(lambda_abs, 'lambda_abs')
46+
self.validate_positive(lambda_rel, 'lambda_rel')
47+
self.validate_non_negative(warmup_steps, 'warmup_steps')
48+
self.validate_non_negative(weight_decay, 'weight_decay')
49+
self.validate_non_negative(eps, 'eps')
50+
51+
defaults: DEFAULTS = {
52+
'lr': lr,
53+
'betas': betas,
54+
'beta': beta,
55+
'lambda_abs': lambda_abs,
56+
'lambda_rel': lambda_rel,
57+
'warmup_steps': warmup_steps,
58+
'weight_decay': weight_decay,
59+
'weight_decouple': weight_decouple,
60+
'fixed_decay': fixed_decay,
61+
'eps': eps,
62+
}
63+
super().__init__(params, defaults)
64+
65+
def __str__(self) -> str:
66+
return 'AdaGC'
67+
68+
@torch.no_grad()
69+
def reset(self):
70+
pass
71+
72+
@torch.no_grad()
73+
def step(self, closure: CLOSURE = None) -> LOSS:
74+
loss: LOSS = None
75+
if closure is not None:
76+
with torch.enable_grad():
77+
loss = closure()
78+
79+
for group in self.param_groups:
80+
if 'step' in group:
81+
group['step'] += 1
82+
else:
83+
group['step'] = 1
84+
85+
beta1, beta2 = group['betas']
86+
87+
bias_correction1: float = self.debias(beta1, group['step'])
88+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
89+
90+
for p in group['params']:
91+
if p.grad is None:
92+
continue
93+
94+
grad = p.grad
95+
if grad.is_sparse:
96+
raise NoSparseGradientError(str(self))
97+
98+
state = self.state[p]
99+
100+
if 'exp_avg' not in state:
101+
state['exp_avg'] = torch.zeros_like(grad)
102+
state['exp_avg_sq'] = torch.zeros_like(grad)
103+
state['gamma'] = torch.empty((1,), device=grad.device, dtype=grad.dtype)
104+
105+
self.apply_weight_decay(
106+
p=p,
107+
grad=grad,
108+
lr=group['lr'],
109+
weight_decay=group['weight_decay'],
110+
weight_decouple=group['weight_decouple'],
111+
fixed_decay=group['fixed_decay'],
112+
)
113+
114+
gamma = state['gamma']
115+
116+
if group['step'] < group['warmup_steps']:
117+
grad_norm = get_global_gradient_norm(self.param_groups).add_(group['eps'])
118+
119+
h_t = min(group['lambda_abs'] / grad_norm, 1.0)
120+
g_hat = grad.mul(h_t)
121+
122+
g_hat_norm = g_hat.norm()
123+
124+
gamma.copy_(g_hat_norm if group['step'] == 1 else min(gamma, g_hat_norm))
125+
else:
126+
h_t = min(group['lambda_rel'] * gamma / grad.norm(), 1.0)
127+
g_hat = grad.mul(h_t)
128+
129+
gamma.mul_(group['beta']).add_(g_hat.norm(), alpha=1.0 - group['beta'])
130+
131+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
132+
exp_avg.mul_(beta1).add_(g_hat, alpha=1.0 - beta1)
133+
exp_avg_sq.mul_(beta2).addcmul_(g_hat, g_hat, value=1.0 - beta2)
134+
135+
update = (exp_avg / bias_correction1) / exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
136+
137+
p.add_(update, alpha=-group['lr'])
138+
139+
return loss

0 commit comments

Comments
 (0)