Skip to content

Upgrade torchmetrics #2017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 74 commits into from
Mar 1, 2023
Merged

Conversation

nik-mosaic
Copy link
Contributor

@nik-mosaic nik-mosaic commented Mar 1, 2023

What does this PR do?

Upgrades torchmetrics to support 0.11.x. This is not straightforward because torchmetrics.Accuracy() now requires a task argument (either 'binary', 'multiclass', or 'multilabel'). This PR replaces all instances of torchmetrics.Accuracy() in Composer with multiclass Accuracy, with the correct number of classes provided. torchmetrics.MatthewsCorrCoef() is also updated similarly.

Because torchmetrics.Accuracy(task='multiclass') requires a num_classes argument, we add a new, optional 'num_classes' argument to ComposerClassifier. Users must either
(1) Pass in num_classes to the ComposerClassifier, or
(2) Specify a num_classes parameter in the PyTorch network submodule.

This is because if we are to use Accuracy as our default training and validation metric, we have to know how many classes the user wants. Alternatively, if users do not want to use our default training/validation metric, they can
(3) Pass in both train_metrics and val_metrics to Composer Classifier
and we will not require num_classes since we no longer need to create a default metric.

If none of these three options are satisfied, ComposerClassifier will now raise an error.

The rest of the changes in this PR are fixing the various models, tests, notebooks, and documentation snippets that depend on ComposerClassifier and properties of its Accuracy metric, such as the class name or shape.

What issue(s) does this change relate to?

CO-1836

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, approving to unblock. Please paste in proof that 0.10 and 0.11 work

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Mar 1, 2023

V11 Pass
image
image

V10. Added pip install torchmetrics==0.10.0 to MCP script
image
image

@mvpatel2000 mvpatel2000 self-requested a review March 1, 2023 22:41
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@nik-mosaic nik-mosaic merged commit 28bf919 into mosaicml:dev Mar 1, 2023
@nik-mosaic nik-mosaic deleted the upgrade-torchmetrics branch March 1, 2023 23:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants