Skip to content

Only load a full task config when load_task_extras is passed #1812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 10, 2024

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Sep 5, 2024

This switches the way we load task configuration and "head weights" to better accommodate upcoming vision models.

For many vision models, like resnet trained on imagenet, or deeplabv3, we have head weights that some users may want but others will not. We need to add an option for loading head weights.

With this change, we will be able to do the following...

# Load a random head.
classifier = ImageClassifier.from_preset("resnet50", num_classes=2)
# Load the imagenet head with 1000 output classes.
classifier = ImageClassifier.from_preset("resnet50", load_task_extras=True)

We could do this other ways as well, or flip the default, but I think we need to add an option to control wether to load just the backbone with random weights, or loading the full task.

This switches the way we load task configuration and "head weights"
to better accommodate upcoming vision models.

For many vision models, like resnet trained on imagenet, or
deeplabv3, we have head weights that some users may want but others
will not. We need to add an option for loading head weights.

With this change, we will be able to do the following...

```python
classifier = ImageClassifier.from_preset("resnet50", num_classes=2)
classifier = ImageClassifier.from_preset("resnet50", load_task=True)
```

We could do this other ways as well, or flip the default, but I think
we need to add an option to control wether to load just the backbone
with random weights, or loading the full task.
@SamanehSaadat
Copy link
Member

SamanehSaadat commented Sep 5, 2024

Users can do this for NLP models without the need for the load_task flag, right?
load_task=False is like loading the backbone only:

backbone = BertBackbone.from_preset("bert_preset")

and then they can create their own classifier with the loaded backbone:

classifier = BertClassifier(backbone=backbone, num_classes=4)

I was wondering if it's possible to use a similar design for the vision models!

@@ -704,35 +705,54 @@ def load_tokenizer(self, cls, **kwargs):
tokenizer.load_preset_assets(self.preset)
return tokenizer

def load_task(self, cls, load_weights, **kwargs):
def load_task(self, cls, load_weights, load_task, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a little bit confusing that the function is called load_task and the new flag is also called load_task?
If the load_task flag is false, it only loads the backbone so it's not loading the task anymore so the function name becomes confusing 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good call! This function isn't exposed at least but I'll think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe load_task_extras would be a better name for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying this out.

@mattdangerw
Copy link
Member Author

mattdangerw commented Sep 6, 2024

@SamanehSaadat the only way before this PR to load a bert model that has been saved with classifier weights without the extra classifier stuff would be...

backbone = BertBackbone.from_preset("bert_preset")
tokenizer = BertTokenizer.from_preset("bert_preset")
preprocessor = BertTextClassifierPreprocessor(tokenizer)
classifier = BertTextClassifier(backbone, preprocessor, num_classes=5)

I'm think that's too clunky to be the main way we suggest loading vision classifiers. The difference here is that the main presets we provide for vision models will have associated classifier weights, that some people will want but most will not. For nlp there is usually not associated classifier weights with a base pretrained model. So suddenly our awkward usage would become a lot more prominent.

Another option instead of this is to upload separate presets "with classifier head" and "without classifier head" for most basic vision backbones. Keras applications essentially does this (separate "no top" weights and "top" weights). But that seems very redundant given that we split the backbone and task weights in our format

@mattdangerw
Copy link
Member Author

mattdangerw commented Sep 6, 2024

Maybe a more concise way to say this. Loading a model that has been pretrained as a classifier as a different classifier is very common in computer vision. It is uncommon in NLP. We need to make it easier.

@mattdangerw mattdangerw changed the title Only load a full task config when load_task is passed Only load a full task config when load_task_extras is passed Sep 6, 2024
@SamanehSaadat
Copy link
Member

@mattdangerw Got it! Thanks for the explanation!

"top" might also be a good name then! load_with_top or load_top!

@@ -146,6 +146,7 @@ def from_preset(
cls,
preset,
load_weights=True,
load_task_extras=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering what's the reason behind setting the default to False here! 🤔 Since it's loading a task, loading the head weights by default might not be a bad idea!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest reason for wanting to flip the default is to have parallel quickstarts for vision and text models:

classifier = TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
classifier.fit(text_dataset)

classifier = ImageClassifier.from_preset(
    "res_net_50",
    num_classes=2,
)
classifier.fit(image_dataset)

I think those snippets are important, they will be front and center. We should avoid needing to introduce more concepts there.

This also flips the arg to be more explicit (explicit is good I think). Passing true will now error if a task.json does not exist, or if it is for the wrong class. But we definitely cannot keep this strict behavior if we flip the default. It would break our current quickstart!

@mattdangerw mattdangerw merged commit 84a6b66 into keras-team:master Sep 10, 2024
10 checks passed
mattdangerw added a commit to mattdangerw/keras-hub that referenced this pull request Sep 12, 2024
Let's get rid of `load_task_extras`, which is a bad and confusing name.
Instead, we will adopt some behavior that is specific to classifiers,
but a lot simpler.

```python
# Random head.
classifier = ImageClassifier.from_preset("resnet50", num_classes=2)
# Pretrained head.
classifier = ImageClassifier.from_preset("resnet50")
# Error, must provide num_classes.
classifier = TextClassifier.from_preset("bert_base_en")
```
mattdangerw added a commit to mattdangerw/keras-hub that referenced this pull request Sep 12, 2024
Let's get rid of `load_task_extras`, which is a bad and confusing name.
Instead, we will adopt some behavior that is specific to classifiers,
but a lot simpler.

```python
# Random head.
classifier = ImageClassifier.from_preset("resnet50", num_classes=2)
# Pretrained head.
classifier = ImageClassifier.from_preset("resnet50")
# Error, must provide num_classes.
classifier = TextClassifier.from_preset("bert_base_en")
```
mattdangerw added a commit that referenced this pull request Sep 12, 2024
* Take two of #1812, simpler classifier head loading

Let's get rid of `load_task_extras`, which is a bad and confusing name.
Instead, we will adopt some behavior that is specific to classifiers,
but a lot simpler.

```python
# Random head.
classifier = ImageClassifier.from_preset("resnet50", num_classes=2)
# Pretrained head.
classifier = ImageClassifier.from_preset("resnet50")
# Error, must provide num_classes.
classifier = TextClassifier.from_preset("bert_base_en")
```

* address review comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants