-
Notifications
You must be signed in to change notification settings - Fork 444
Upgrade torchmetrics #2017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade torchmetrics #2017
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
… into upgrade-torchmetrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, approving to unblock. Please paste in proof that 0.10 and 0.11 work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
What does this PR do?
Upgrades torchmetrics to support 0.11.x. This is not straightforward because torchmetrics.Accuracy() now requires a task argument (either 'binary', 'multiclass', or 'multilabel'). This PR replaces all instances of torchmetrics.Accuracy() in Composer with multiclass Accuracy, with the correct number of classes provided. torchmetrics.MatthewsCorrCoef() is also updated similarly.
Because torchmetrics.Accuracy(task='multiclass') requires a num_classes argument, we add a new, optional 'num_classes' argument to ComposerClassifier. Users must either
(1) Pass in num_classes to the ComposerClassifier, or
(2) Specify a num_classes parameter in the PyTorch network submodule.
This is because if we are to use Accuracy as our default training and validation metric, we have to know how many classes the user wants. Alternatively, if users do not want to use our default training/validation metric, they can
(3) Pass in both train_metrics and val_metrics to Composer Classifier
and we will not require
num_classes
since we no longer need to create a default metric.If none of these three options are satisfied, ComposerClassifier will now raise an error.
The rest of the changes in this PR are fixing the various models, tests, notebooks, and documentation snippets that depend on ComposerClassifier and properties of its Accuracy metric, such as the class name or shape.
What issue(s) does this change relate to?
CO-1836