Skip to content

[Fix] singular value in compute_power_svd() #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 6, 2023
Merged

[Fix] singular value in compute_power_svd() #104

merged 6 commits into from
Feb 6, 2023

Conversation

kozistr
Copy link
Owner

@kozistr kozistr commented Feb 6, 2023

Problem (Why?)

Fix diag() API at compute_power_svd().

Solution (What/How?)

  • fix diag() at compute_power_svd()

Benchmark

tested on i7700K + GTX 1060 6GB.

backbone: resmlp_12_distilled_224, bs: 16

x2.5 faster (A -> B)

  • AdamP: 3.73 iter / s
  • (old) Shampoo: over 25s / iter
  • Scalable Shampoo w/ Schur-Newton (block size = 256): 1.68 s / iter
  • Scalable Shampoo w/ Schur-Newton (block size = 512): 1.12 iter / s
  • Scalable Shampoo w/ SVD (block size = 256): 1.60 iter / s
  • Scalable Shampoo w/ SVD (block size = 512): 2.50 iter / s

backbone: mixer_b16_224, bs: 8

x0.5 faster (A -> B)

  • AdamP: 3.15 iter / s
  • (old) Shampoo: over 2 mins / iter
  • Scalable Shampoo w/ Schur-Newton (block size = 256): 5.33 s / iter
  • Scalable Shampoo w/ Schur-Newton (block size = 512): 2.97 s / iter
  • Scalable Shampoo w/ SVD (block size = 256): 11.26 s / iter
  • Scalable Shampoo w/ SVD (block size = 512): 21.15 s / iter

code

    from timm import create_model
    from tqdm import tqdm

    model = create_model(backbone, pretrained=False, num_classes=1)
    model.train()
    model.cuda()

    optimizer = load_optimizer('scalableshampoo')(
        model.parameters(), 
        start_preconditioning_step=1,
        block_size=block_size,
        use_svd=use_svd,
    )

    inp = torch.randn((bs, 3, 224, 224), dtype=torch.float32).cuda()
    y = torch.randn((bs, 1), dtype=torch.float32).cuda()

    for _ in tqdm(range(100)):
        optimizer.zero_grad()

        torch.nn.functional.binary_cross_entropy_with_logits(model(inp), y).backward()

        optimizer.step()

Other changes (bug fixes, small refactors)

nope

Notes

bump version to v2.4.1

@kozistr kozistr added bug Something isn't working refactoring Refactoring labels Feb 6, 2023
@kozistr kozistr self-assigned this Feb 6, 2023
@codecov
Copy link

codecov bot commented Feb 6, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.71%. Comparing base (19c3df6) to head (55dcb36).
Report is 1638 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #104   +/-   ##
=======================================
  Coverage   99.71%   99.71%           
=======================================
  Files          39       39           
  Lines        3125     3125           
=======================================
  Hits         3116     3116           
  Misses          9        9           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kozistr kozistr changed the title [Fix] diag() in compute_power_svd() [Fix] singular value in compute_power_svd() Feb 6, 2023
@kozistr kozistr merged commit 06dce18 into main Feb 6, 2023
@kozistr kozistr deleted the fix/svd branch February 6, 2023 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working refactoring Refactoring size/M
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant