Skip to content

KeyError: 'exp_avg' when using SPAM #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
liveck opened this issue Feb 7, 2025 · 7 comments
Closed

KeyError: 'exp_avg' when using SPAM #342

liveck opened this issue Feb 7, 2025 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@liveck
Copy link

liveck commented Feb 7, 2025

Describe the bug

I'm using SPAM to train my new model.
config as blow

Lookahead(  
            optimizer=create_optimizer(
                    self.hf_model, 'SPAM',self.lr, weight_decay=0.01,
                    wd_ban_list=[],
                ),  
            k=5,  
            alpha=0.5  
        )

Image

Env

  • GPU: 3090x8
  • docker image: nvcr.io/nvidia/pytorch:24.05-py3
  • pytorch-optimizer version : 3.4.0
@liveck liveck added the bug Something isn't working label Feb 7, 2025
@kozistr
Copy link
Owner

kozistr commented Feb 9, 2025

thanks for reporting! i'll look into it later

@kozistr
Copy link
Owner

kozistr commented Feb 26, 2025

@liveck sorry for the late reply. I've still failed to reproduce that error on my toy example, but, I think I know what the problem is. in the meantime, how about using the StableSPAM optimizer, which is a stabilized version of the SPAM optimizer if you're still interested in the SPAM optimizer? here's an implementation code

@liveck
Copy link
Author

liveck commented Feb 26, 2025

@liveck sorry for the late reply. I've still failed to reproduce that error on my toy example, but, I think I know what the problem is. in the meantime, how about using the StableSPAM optimizer, which is a stabilized version of the SPAM optimizer if you're still interested in the SPAM optimizer? here's an implementation code

thanks

@liveck
Copy link
Author

liveck commented Mar 6, 2025

@liveck sorry for the late reply. I've still failed to reproduce that error on my toy example, but, I think I know what the problem is. in the meantime, how about using the StableSPAM optimizer, which is a stabilized version of the SPAM optimizer if you're still interested in the SPAM optimizer? here's an implementation code

is there any fix now ?

StableSPAM could not reach the same accuracy as SPAM for my task.

@kozistr
Copy link
Owner

kozistr commented Mar 6, 2025

@liveck sorry for the late reply. I've still failed to reproduce that error on my toy example, but, I think I know what the problem is. in the meantime, how about using the StableSPAM optimizer, which is a stabilized version of the SPAM optimizer if you're still interested in the SPAM optimizer? here's an implementation code

is there any fix now ?

StableSPAM could not reach the same accuracy as SPAM for my task.

sorry. I've got lots on my plate recently.

It looks like it occurs when updating the mask (curr step % update proj step == 0), and it doesn't have exp_avg state, which is a little weird. I've just made a quick fix that potentially resolves this issue in the fix/spam-optimizer branch, dea159a. Would you mind testing with this one and letting me know if it resolves the problem?

for your reference, here's my toy example to reproduce, but it worked well in this scenario.

from transformers import AutoModelForSequenceClassification
from tqdm import tqdm

import torch
from torch.nn import functional as F

from pytorch_optimizer import create_optimizer

bs = 8

model = AutoModelForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base', trust_remote_code=True, num_labels=1)
model.train()
model.cuda()

optimizer = create_optimizer(
    model,
    'SPAM',
    lr=1e-3,
    weight_decay=1e-2,
    use_lookahead=True,
    k=5,
    alpha=0.5,
    warmup_epoch=1,
    grad_accu_steps=2,
    update_proj_gap=3,  # arbitrarily set lower value to trigger
)

input_ids = torch.arange(16, dtype=torch.long)[None, :].repeat_interleave(bs, dim=0)
attention_mask = torch.ones_like(input_ids)
y = torch.randn((bs, 1), dtype=torch.float32)

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
y = y.cuda()

for _ in tqdm(range(25)):
    optimizer.zero_grad()

    out = model(input_ids, attention_mask).logits
    F.binary_cross_entropy_with_logits(out, y).backward()

    optimizer.step()

@kozistr
Copy link
Owner

kozistr commented Mar 31, 2025

@liveck hi! I've released a new version v3.5.0, which includes the fix. Could you please test it with the latest version when you have a chance?

@liveck
Copy link
Author

liveck commented Mar 31, 2025

Although I haven't had a chance to test this yet, I plan to close this issue for now.

If I encountered the problem again in the future, I'll be sure to reach out.

Thank you for your attention to this issue.

@liveck liveck closed this as completed Mar 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants