Skip to content

Commit f6baa63

Browse files
authored
Merge pull request #94 from kozistr/refactor/lr_scheduler
[Feature] Implement GSAM optimizer
2 parents 8a31b1e + a29fd3d commit f6baa63

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+720
-38
lines changed

README.rst

+4
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ Supported Optimizers
112112
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
113113
| Adai | *Disentangling the Effects of Adaptive Learning Rate and Momentum* | `github <https://github.com/zeke-xie/adaptive-inertia-adai>`__ | `https://arxiv.org/abs/2006.15815 <https://arxiv.org/abs/2006.15815>`__ |
114114
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
115+
| GSAM | *Surrogate Gap Guided Sharpness-Aware Minimization* | `github <https://github.com/juntang-zhuang/GSAM>`__ | `https://openreview.net/pdf?id=edONMAnhLu- <https://openreview.net/pdf?id=edONMAnhLu->`__ |
116+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
115117

116118
Useful Resources
117119
----------------
@@ -303,6 +305,8 @@ Citations
303305

304306
`Adai <https://github.com/zeke-xie/adaptive-inertia-adai#citing>`__
305307

308+
`GSAM <https://github.com/juntang-zhuang/GSAM#citation>`__
309+
306310
Citation
307311
--------
308312

docs/optimizer_api.rst

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Implemented Optimizers
2-
====================
1+
Optimizers
2+
==========
33

44
.. _AdaBelief:
55

@@ -192,3 +192,11 @@ Shampoo
192192

193193
.. autoclass:: pytorch_optimizer.Shampoo
194194
:members:
195+
196+
.. _GSAM:
197+
198+
GSAM
199+
----
200+
201+
.. autoclass:: pytorch_optimizer.GSAM
202+
:members:

docs/scheduler_api.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Implemented LR Schedulers
2-
=========================
1+
LR Schedulers
2+
=============
33

44
.. _get_chebyshev_schedule:
55

docs/util_api.rst

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Implemented utilizations
2-
========================
1+
Utilizations
2+
============
33

44
.. _clip_grad_norm:
55

@@ -56,3 +56,20 @@ SafeFP16Optimizer
5656

5757
.. autoclass:: pytorch_optimizer.SafeFP16Optimizer
5858
:members:
59+
60+
.. _enable_running_stats:
61+
62+
enable_running_stats
63+
--------------------
64+
65+
.. autoclass:: pytorch_optimizer.enable_running_stats
66+
:members:
67+
68+
69+
.. _disable_running_stats:
70+
71+
disable_running_stats
72+
---------------------
73+
74+
.. autoclass:: pytorch_optimizer.disable_running_stats
75+
:members:

pyproject.toml

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.1.1"
4-
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
3+
version = "2.2.0"
4+
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
77
maintainers = ["kozistr <[email protected]>"]
@@ -51,6 +51,11 @@ name = "torch"
5151
url = "https://download.pytorch.org/whl/cpu"
5252
secondary = true
5353

54+
[tool.coverage.run]
55+
omit = [
56+
"./pytorch_optimizer/optimizer/gsam.py",
57+
]
58+
5459
[build-system]
5560
requires = ["poetry-core>=1.0.0"]
5661
build-backend = "poetry.core.masonry.api"

pytorch_optimizer/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
)
1212
from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_schedule
1313
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
14+
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
15+
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
1416
from pytorch_optimizer.optimizer.adabelief import AdaBelief
1517
from pytorch_optimizer.optimizer.adabound import AdaBound
1618
from pytorch_optimizer.optimizer.adai import Adai
@@ -22,6 +24,7 @@
2224
from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad
2325
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
2426
from pytorch_optimizer.optimizer.gc import centralize_gradient
27+
from pytorch_optimizer.optimizer.gsam import GSAM
2528
from pytorch_optimizer.optimizer.lamb import Lamb
2629
from pytorch_optimizer.optimizer.lars import LARS
2730
from pytorch_optimizer.optimizer.lookahead import Lookahead
@@ -38,6 +41,8 @@
3841
from pytorch_optimizer.optimizer.shampoo import Shampoo
3942
from pytorch_optimizer.optimizer.utils import (
4043
clip_grad_norm,
44+
disable_running_stats,
45+
enable_running_stats,
4146
get_optimizer_parameters,
4247
matrix_power,
4348
normalize_gradient,
@@ -74,6 +79,10 @@
7479
CosineAnnealingWarmRestarts,
7580
CyclicLR,
7681
OneCycleLR,
82+
CosineScheduler,
83+
PolyScheduler,
84+
LinearScheduler,
85+
ProportionScheduler,
7786
]
7887
LR_SCHEDULERS: Dict[str, SCHEDULER] = {
7988
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST

pytorch_optimizer/base/exception.py

+18
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,21 @@ class NoClosureError(Exception):
2525
def __init__(self, optimizer_name: str):
2626
self.message: str = f'[-] {optimizer_name} requires closure.'
2727
super().__init__(self.message)
28+
29+
30+
class NegativeLRError(Exception):
31+
"""Raised when learning rate is negative"""
32+
33+
def __init__(self, lr: float, lr_type: str = ''):
34+
self.note: str = 'learning rate' if lr_type == '' else lr_type
35+
self.message: str = f'[-] {self.note} must be positive. ({lr} > 0)'
36+
super().__init__(self.message)
37+
38+
39+
class NegativeStepError(Exception):
40+
"""Raised when step is negative"""
41+
42+
def __init__(self, num_steps: int, step_type: str = ''):
43+
self.note: str = 'step' if step_type == '' else step_type
44+
self.message: str = f'[-] {self.note} must be positive. ({num_steps} > 0)'
45+
super().__init__(self.message)

pytorch_optimizer/base/base_optimizer.py renamed to pytorch_optimizer/base/optimizer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import torch
44

5+
from pytorch_optimizer.base.exception import NegativeLRError
56
from pytorch_optimizer.base.types import BETAS
67

78

89
class BaseOptimizer(ABC):
910
@staticmethod
1011
def validate_learning_rate(learning_rate: float):
1112
if learning_rate < 0.0:
12-
raise ValueError(f'[-] learning rate {learning_rate} must be positive')
13+
raise NegativeLRError(learning_rate)
1314

1415
@staticmethod
1516
def validate_beta(beta: float):

pytorch_optimizer/base/scheduler.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List
3+
4+
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
5+
from pytorch_optimizer.base.types import OPTIMIZER
6+
7+
8+
class BaseLinearWarmupScheduler(ABC):
9+
r"""BaseLinearWarmupScheduler class. The LR Scheduler class based on this class has linear warmup strategy.
10+
11+
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer.
12+
:param t_max: int. total steps to train.
13+
:param max_lr: float. maximum lr.
14+
:param min_lr: float. minimum lr.
15+
:param init_lr: float. initial lr.
16+
:param warmup_steps: int. steps to warm-up.
17+
"""
18+
19+
def __init__(
20+
self,
21+
optimizer: OPTIMIZER,
22+
t_max: int,
23+
max_lr: float,
24+
min_lr: float = 0.0,
25+
init_lr: float = 0.0,
26+
warmup_steps: int = 0,
27+
):
28+
self.optimizer = optimizer
29+
self.total_steps = t_max
30+
self.max_lr = max_lr
31+
self.min_lr = min_lr
32+
self.init_lr = init_lr
33+
self.warmup_steps = warmup_steps
34+
35+
self.step_t: int = 0
36+
self.base_lrs: List[float] = []
37+
38+
# record current value in self._last_lr to match API from torch.optim.lr_scheduler
39+
self.last_lr: List[float] = [init_lr]
40+
41+
self.validate_parameters()
42+
43+
self._init_lr()
44+
45+
def validate_parameters(self):
46+
if self.min_lr < 0:
47+
raise NegativeLRError(self.min_lr, 'min_lr')
48+
49+
if self.max_lr < 0:
50+
raise NegativeLRError(self.max_lr, 'max_lr')
51+
52+
if self.init_lr < 0:
53+
raise NegativeLRError(self.init_lr, 'init_lr')
54+
55+
if self.total_steps < 0:
56+
raise NegativeStepError(self.total_steps, 't_max')
57+
58+
if self.warmup_steps < 0:
59+
raise NegativeStepError(self.warmup_steps, 'warmup_steps')
60+
61+
def _init_lr(self):
62+
self.base_lrs = []
63+
for param_group in self.optimizer.param_groups:
64+
param_group['lr'] = self.min_lr
65+
self.base_lrs.append(self.min_lr)
66+
67+
def step(self):
68+
if self.step_t < self.warmup_steps:
69+
value = self.init_lr + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps
70+
elif self.step_t == self.warmup_steps:
71+
value = self.max_lr
72+
else:
73+
value = self._step()
74+
75+
self.step_t += 1
76+
77+
# apply the lr to optimizer if it's provided
78+
if self.optimizer is not None:
79+
for param_group in self.optimizer.param_groups:
80+
param_group['lr'] = value
81+
82+
self.last_lr = [value]
83+
84+
return value
85+
86+
@abstractmethod
87+
def _step(self) -> float:
88+
raise NotImplementedError
89+
90+
def get_lr(self) -> float:
91+
return self.last_lr[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import math
2+
3+
import numpy as np
4+
5+
from pytorch_optimizer.base.scheduler import BaseLinearWarmupScheduler
6+
7+
8+
class LinearScheduler(BaseLinearWarmupScheduler):
9+
def _step(self) -> float:
10+
return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / (
11+
self.total_steps - self.warmup_steps
12+
)
13+
14+
15+
class CosineScheduler(BaseLinearWarmupScheduler):
16+
def _step(self) -> float:
17+
phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi
18+
return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0
19+
20+
21+
class PolyScheduler(BaseLinearWarmupScheduler):
22+
r"""Poly LR Scheduler
23+
24+
:param: poly_order: float. lr scheduler decreases with steps.
25+
"""
26+
27+
def __init__(self, poly_order: float = 0.5, **kwargs):
28+
self.poly_order = poly_order
29+
30+
if poly_order <= 0:
31+
raise ValueError(f'[-] poly_order must be positive. {poly_order}')
32+
33+
super().__init__(**kwargs)
34+
35+
def _step(self) -> float:
36+
return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import List
2+
3+
4+
class ProportionScheduler:
5+
r"""ProportionScheduler (Rho Scheduler of GSAM)
6+
This scheduler outputs a value that evolves proportional to lr_scheduler.
7+
8+
:param lr_scheduler: learning rate scheduler.
9+
:param max_lr: float. maximum lr.
10+
:param min_lr: float. minimum lr.
11+
:param max_value: float. maximum of rho.
12+
:param min_value: float. minimum of rho.
13+
"""
14+
15+
def __init__(
16+
self, lr_scheduler, max_lr: float, min_lr: float = 0.0, max_value: float = 2.0, min_value: float = 2.0
17+
):
18+
self.lr_scheduler = lr_scheduler
19+
self.max_lr = max_lr
20+
self.min_lr = min_lr
21+
self.max_value = max_value
22+
self.min_value = min_value
23+
24+
self.step_t: int = 0
25+
self.last_lr: List[float] = []
26+
27+
self.step()
28+
29+
def get_lr(self) -> float:
30+
return self.last_lr[0]
31+
32+
def step(self) -> float:
33+
self.step_t += 1
34+
35+
if hasattr(self.lr_scheduler, 'last_lr'):
36+
lr = self.lr_scheduler.last_lr[0]
37+
else:
38+
lr = self.lr_scheduler.optimizer.param_groups[0]['lr']
39+
40+
if self.max_lr > self.min_lr:
41+
value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / (
42+
self.max_lr - self.min_lr
43+
)
44+
else:
45+
value = self.max_value
46+
47+
self.last_lr = [value]
48+
49+
return value

pytorch_optimizer/optimizer/adabelief.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6-
from pytorch_optimizer.base.base_optimizer import BaseOptimizer
76
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
88
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
99

1010

pytorch_optimizer/optimizer/adabound.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
from torch.optim.optimizer import Optimizer
66

7-
from pytorch_optimizer.base.base_optimizer import BaseOptimizer
87
from pytorch_optimizer.base.exception import NoSparseGradientError
8+
from pytorch_optimizer.base.optimizer import BaseOptimizer
99
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
1010

1111

pytorch_optimizer/optimizer/adai.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6-
from pytorch_optimizer.base.base_optimizer import BaseOptimizer
76
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
88
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
99
from pytorch_optimizer.optimizer.gc import centralize_gradient
1010

pytorch_optimizer/optimizer/adamp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6-
from pytorch_optimizer.base.base_optimizer import BaseOptimizer
76
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
88
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
99
from pytorch_optimizer.optimizer.gc import centralize_gradient
1010
from pytorch_optimizer.optimizer.utils import projection

0 commit comments

Comments
 (0)