Skip to content

Commit b26a78e

Browse files
authored
Merge pull request #378 from kozistr/feature/vsgd-optimizer
[Feature] Implement VSGD optimizer
2 parents 87ab0e6 + c59c95d commit b26a78e

17 files changed

+204
-29
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **105 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **106 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -213,6 +213,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
213213
| Simplified-Ademamix | *Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants* | [github](https://github.com/DepenM/Simplified-AdEMAMix/) | <https://arxiv.org/abs/2502.02431> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250202431M/exportcitation) |
214214
| Fira | *Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?* | [github](https://github.com/xichen-fy/Fira) | <https://arxiv.org/abs/2410.01623> | [cite](https://github.com/xichen-fy/Fira/tree/main?tab=readme-ov-file#citation) |
215215
| RACS & Alice | *Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension* | | <https://arxiv.org/pdf/2502.07752> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207752G/exportcitation) |
216+
| VSGD | *Variational Stochastic Gradient Descent for Deep Neural Networks* | [github](https://github.com/generativeai-tue/vsgd) | <https://openreview.net/forum?id=xu4ATNjcdy> | [cite](https://github.com/generativeai-tue/vsgd/tree/main?tab=readme-ov-file#cite) |
216217

217218
## Supported LR Scheduler
218219

docs/changelogs/v3.5.2.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* [Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)
77
* Implement `RACS` and `Alice optimizer. (#376)
88
* [Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension](https://arxiv.org/abs/2502.07752)
9+
* Implement `VSGD` optimizer. (#377, #378)
10+
* [Variational Stochastic Gradient Descent for Deep Neural Networks](https://openreview.net/forum?id=xu4ATNjcdy)
911

1012
### Fix
1113

docs/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **105 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **106 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -213,6 +213,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
213213
| Simplified-Ademamix | *Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants* | [github](https://github.com/DepenM/Simplified-AdEMAMix/) | <https://arxiv.org/abs/2502.02431> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250202431M/exportcitation) |
214214
| Fira | *Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?* | [github](https://github.com/xichen-fy/Fira) | <https://arxiv.org/abs/2410.01623> | [cite](https://github.com/xichen-fy/Fira/tree/main?tab=readme-ov-file#citation) |
215215
| RACS & Alice | *Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension* | | <https://arxiv.org/pdf/2502.07752> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207752G/exportcitation) |
216+
| VSGD | *Variational Stochastic Gradient Descent for Deep Neural Networks* | [github](https://github.com/generativeai-tue/vsgd) | <https://openreview.net/forum?id=xu4ATNjcdy> | [cite](https://github.com/generativeai-tue/vsgd/tree/main?tab=readme-ov-file#cite) |
216217

217218
## Supported LR Scheduler
218219

docs/optimizer.md

+4
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,10 @@
436436
:docstring:
437437
:members:
438438

439+
::: pytorch_optimizer.VSGD
440+
:docstring:
441+
:members:
442+
439443
::: pytorch_optimizer.WSAM
440444
:docstring:
441445
:members:

docs/visualization.md

+16
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@
274274

275275
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_QHM.png)
276276

277+
### RACS
278+
279+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_RACS.png)
280+
277281
### RAdam
278282

279283
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_RAdam.png)
@@ -382,6 +386,10 @@
382386

383387
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Tiger.png)
384388

389+
### VSGD
390+
391+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_VSGD.png)
392+
385393
### Yogi
386394

387395
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Yogi.png)
@@ -660,6 +668,10 @@
660668

661669
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_QHM.png)
662670

671+
### RACS
672+
673+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_RACS.png)
674+
663675
### RAdam
664676

665677
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_RAdam.png)
@@ -768,6 +780,10 @@
768780

769781
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Tiger.png)
770782

783+
### VSGD
784+
785+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_VSGD.png)
786+
771787
### Yogi
772788

773789
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Yogi.png)
631 KB
Loading
633 KB
Loading
144 KB
Loading
132 KB
Loading

examples/visualize_optimizers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
filterwarnings('ignore', category=UserWarning)
1818

19-
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'muon')
19+
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'muon', 'alice')
2020
OPTIMIZERS_MODEL_INPUT_NEEDED = ('lomo', 'adalomo', 'adammini')
2121
OPTIMIZERS_GRAPH_NEEDED = ('adahessian', 'sophiah')
2222
OPTIMIZERS_CLOSURE_NEEDED = ('alig', 'bsam')

poetry.lock

+42-22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ keywords = [
2020
"PNM", "Prodigy", "PSGD", "QHAdam", "QHM", "RACS", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM",
2121
"LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo",
2222
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "StableSPAM", "SRMM", "StableAdamW",
23-
"SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
24-
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
23+
"SWATS", "TAM", "Tiger", "TRAC", "VSGD", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
24+
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
25+
"QGaLore",
2526
]
2627
classifiers = [
2728
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
SWATS,
7070
TAM,
7171
TRAC,
72+
VSGD,
7273
WSAM,
7374
A2Grad,
7475
AccSGD,

pytorch_optimizer/optimizer/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
ScheduleFreeWrapper,
9191
)
9292
from pytorch_optimizer.optimizer.scion import SCION, SCIONLight
93-
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
93+
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, VSGD, AccSGD, SGDSaI, SignSGD
9494
from pytorch_optimizer.optimizer.sgdp import SGDP
9595
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
9696
from pytorch_optimizer.optimizer.sm3 import SM3
@@ -318,6 +318,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
318318
Fira,
319319
RACS,
320320
Alice,
321+
VSGD,
321322
]
322323
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
323324

0 commit comments

Comments
 (0)