Skip to content

Commit 26b8b19

Browse files
authored
Merge pull request #113 from kozistr/feature/lion-optimizer
[Feature] Lion optimizer
2 parents 7bce7c2 + d62d303 commit 26b8b19

File tree

8 files changed

+115
-3
lines changed

8 files changed

+115
-3
lines changed

README.rst

+4
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ Supported Optimizers
122122
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
123123
| NovoGrad | *Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks* | `github <https://github.com/lonePatient/NovoGrad-pytorch>`__ | `https://arxiv.org/abs/1905.11286 <https://arxiv.org/abs/1905.11286>`__ |
124124
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
125+
| Lion | *Symbolic Discovery of Optimization Algorithms* | `github <https://github.com/google/automl/tree/master/lion>`__ | `https://arxiv.org/abs/2302.06675 <https://arxiv.org/abs/2302.06675>`__ |
126+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
125127

126128
Useful Resources
127129
----------------
@@ -323,6 +325,8 @@ Citations
323325

324326
`NovoGrad <https://ui.adsabs.harvard.edu/abs/2019arXiv190511286G/exportcitation>`__
325327

328+
`Lion <https://github.com/google/automl/tree/master/lion#citation>`__
329+
326330
Citation
327331
--------
328332

docs/optimizer_api.rst

+8
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,11 @@ NovoGrad
264264

265265
.. autoclass:: pytorch_optimizer.NovoGrad
266266
:members:
267+
268+
.. _Lion:
269+
270+
Lion
271+
----
272+
273+
.. autoclass:: pytorch_optimizer.Lion
274+
:members:

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.4.2"
3+
version = "2.5.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pytorch_optimizer.optimizer.gsam import GSAM
3333
from pytorch_optimizer.optimizer.lamb import Lamb
3434
from pytorch_optimizer.optimizer.lars import LARS
35+
from pytorch_optimizer.optimizer.lion import Lion
3536
from pytorch_optimizer.optimizer.lookahead import Lookahead
3637
from pytorch_optimizer.optimizer.madgrad import MADGRAD
3738
from pytorch_optimizer.optimizer.nero import Nero
@@ -98,6 +99,7 @@
9899
AdaFactor,
99100
Apollo,
100101
NovoGrad,
102+
Lion,
101103
]
102104
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
103105

pytorch_optimizer/optimizer/lion.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class Lion(Optimizer, BaseOptimizer):
10+
r"""Symbolic Discovery of Optimization Algorithms.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
15+
:param weight_decay: float. weight decay (L2 penalty).
16+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
17+
"""
18+
19+
def __init__(
20+
self,
21+
params: PARAMETERS,
22+
lr: float = 1e-4,
23+
betas: BETAS = (0.9, 0.99),
24+
weight_decay: float = 0.0,
25+
weight_decouple: bool = True,
26+
):
27+
self.lr = lr
28+
self.betas = betas
29+
self.weight_decay = weight_decay
30+
self.weight_decouple = weight_decouple
31+
32+
self.validate_parameters()
33+
34+
defaults: DEFAULTS = {
35+
'lr': lr,
36+
'betas': betas,
37+
'weight_decay': weight_decay,
38+
}
39+
super().__init__(params, defaults)
40+
41+
def validate_parameters(self):
42+
self.validate_learning_rate(self.lr)
43+
self.validate_betas(self.betas)
44+
self.validate_weight_decay(self.weight_decay)
45+
46+
@property
47+
def __str__(self) -> str:
48+
return 'Lion'
49+
50+
@torch.no_grad()
51+
def reset(self):
52+
for group in self.param_groups:
53+
for p in group['params']:
54+
state = self.state[p]
55+
56+
state['exp_avg'] = torch.zeros_like(p)
57+
58+
@torch.no_grad()
59+
def step(self, closure: CLOSURE = None) -> LOSS:
60+
loss: LOSS = None
61+
if closure is not None:
62+
with torch.enable_grad():
63+
loss = closure()
64+
65+
for group in self.param_groups:
66+
beta1, beta2 = group['betas']
67+
weight_decay = group['weight_decay']
68+
for p in group['params']:
69+
if p.grad is None:
70+
continue
71+
72+
grad = p.grad
73+
if grad.is_sparse:
74+
raise NoSparseGradientError(self.__str__)
75+
76+
state = self.state[p]
77+
78+
if len(state) == 0:
79+
state['exp_avg'] = torch.zeros_like(p)
80+
81+
update = exp_avg = state['exp_avg']
82+
83+
if weight_decay > 0.0:
84+
if self.weight_decouple:
85+
p.mul_(1.0 - group['lr'] * weight_decay)
86+
else:
87+
grad.add_(p, alpha=weight_decay)
88+
89+
update.mul_(beta1).add_(grad, alpha=1.0 - beta1)
90+
exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2)
91+
92+
p.add_(update.sign(), alpha=-group['lr'])
93+
94+
return loss

tests/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
DiffGrad,
2222
DiffRGrad,
2323
Lamb,
24+
Lion,
2425
Nero,
2526
NovoGrad,
2627
RAdam,
@@ -70,6 +71,7 @@
7071
'adams',
7172
'adafactor',
7273
'novograd',
74+
'lion',
7375
]
7476

7577
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -161,6 +163,8 @@
161163
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
162164
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),
163165
(NovoGrad, {'lr': 5e-1, 'weight_decay': 1e-3, 'grad_averaging': True}, 50),
166+
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
167+
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 10),
164168
]
165169
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
166170
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),

tests/test_load_optimizers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 27
19+
assert len(get_supported_optimizers()) == 28

tests/test_optimizer_parameters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_learning_rate(optimizer_name):
2222

2323
@pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES)
2424
def test_epsilon(optimizer_name):
25-
if optimizer_name in ('nero', 'shampoo', 'scalableshampoo', 'dadaptsgd', 'adafactor'):
25+
if optimizer_name in ('nero', 'shampoo', 'scalableshampoo', 'dadaptsgd', 'adafactor', 'lion'):
2626
pytest.skip(f'skip {optimizer_name} optimizer')
2727

2828
optimizer = load_optimizer(optimizer_name)

0 commit comments

Comments
 (0)