Skip to content

Commit 40ec30d

Browse files
committed
update: test cases
1 parent 81eb2d1 commit 40ec30d

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

tests/test_gradients.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from pytorch_optimizer.base.exception import NoSparseGradientError
5-
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, OrthoGrad, load_optimizer
5+
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, LookSAM, OrthoGrad, load_optimizer
66
from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES
77
from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss
88

@@ -116,12 +116,13 @@ def test_sparse_supported(sparse_optimizer):
116116
optimizer.step()
117117

118118

119-
def test_sam_no_gradient():
119+
@pytest.mark.parametrize('optimizer', [SAM, LookSAM])
120+
def test_sam_no_gradient(optimizer):
120121
(x_data, y_data), model, loss_fn = build_environment()
121122
model.fc1.weight.requires_grad = False
122123
model.fc1.weight.grad = None
123124

124-
optimizer = SAM(model.parameters(), AdamP)
125+
optimizer = optimizer(model.parameters(), AdamP)
125126
optimizer.zero_grad()
126127

127128
loss = loss_fn(y_data, model(x_data))

tests/test_optimizers.py

+24
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,30 @@ def test_looksam_optimizer(environment):
232232
assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss)
233233

234234

235+
def test_looksam_optimizer_with_closure(environment):
236+
(x_data, y_data), model, loss_fn = environment
237+
238+
optimizer = LookSAM(model.parameters(), load_optimizer('adamw'), lr=5e-1)
239+
240+
def closure():
241+
first_loss = loss_fn(y_data, model(x_data))
242+
first_loss.backward()
243+
return first_loss
244+
245+
init_loss, loss = np.inf, np.inf
246+
for _ in range(5):
247+
loss = loss_fn(y_data, model(x_data))
248+
loss.backward()
249+
250+
optimizer.step(closure)
251+
optimizer.zero_grad()
252+
253+
if init_loss == np.inf:
254+
init_loss = loss
255+
256+
assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss)
257+
258+
235259
@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS)
236260
@pytest.mark.parametrize('decouple', DECOUPLE_FLAGS)
237261
def test_wsam_optimizer(adaptive, decouple, environment):

0 commit comments

Comments
 (0)