Skip to content

Amulet #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 211 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
211 commits
Select commit Hold shift + click to select a range
c36d7de
Update README.md
AnanyaKumar Feb 18, 2021
a8de5ae
Update README.md
AnanyaKumar Feb 18, 2021
185149d
Update README.md
AnanyaKumar Feb 18, 2021
743a617
Merge branch 'main' of github.com:AnanyaKumar/cifar_training into main
AnanyaKumar Feb 18, 2021
e341c80
Add pretrained model, config, add assert to utils to check for class.
AnanyaKumar Feb 20, 2021
cf15b05
Add support for multiple test datasets, support for imagenet c
AnanyaKumar Feb 21, 2021
eeabce5
Test on more cifar10c datasets, add stl and cinic wrappers.
AnanyaKumar Feb 23, 2021
97d2b03
Support linear probe and max samples in testing for CINIC.
AnanyaKumar Feb 23, 2021
c001d39
Linear probe, log more often to wandb.
AnanyaKumar Feb 24, 2021
e44f4d8
Simplify baseline config.
AnanyaKumar Feb 24, 2021
902df7b
Add more configs.
AnanyaKumar Mar 1, 2021
832f60d
Add configs, and use Karan's tool.
AnanyaKumar Mar 31, 2021
884a120
Add Imnet intersect cifar 10 classes.
AnanyaKumar Mar 31, 2021
2f794c1
Add breeds.
AnanyaKumar Mar 31, 2021
1252963
Use quinine and update best stats.
AnanyaKumar Mar 31, 2021
3a267a3
Add breeds example.
AnanyaKumar Mar 31, 2021
cd02aae
Update script?
AnanyaKumar Mar 31, 2021
5d3dd0d
Reorganize configs.
AnanyaKumar Apr 1, 2021
79c1098
Fix breeds dataset.
AnanyaKumar Apr 1, 2021
b35a288
Add support for moco initialization.
AnanyaKumar Apr 1, 2021
b209fb9
Test loading moco
AnanyaKumar Apr 1, 2021
2d7d329
Add more configs for breeds.
AnanyaKumar Apr 4, 2021
f320de6
ImageNet can have less than 1000 images in some classes.
AnanyaKumar Apr 4, 2021
d4300b8
Add support for swav.
AnanyaKumar Apr 4, 2021
d303c8e
Add script for moco loading
AnanyaKumar Apr 4, 2021
1e2151f
Add uncertainty ensembling script.
AnanyaKumar Apr 4, 2021
840d6b8
Refactor baseline_train, add support for best checkpoints.
AnanyaKumar Apr 4, 2021
1824a6c
Save tsv files.
AnanyaKumar Apr 4, 2021
c2b2d7c
Move functions
AnanyaKumar Apr 4, 2021
1cf7418
Fix indent.
AnanyaKumar Apr 4, 2021
27dfc69
Fix naming bug.
AnanyaKumar Apr 4, 2021
057a895
Save best checkpoints for living17.
AnanyaKumar Apr 4, 2021
b54a27f
Add setup file
AnanyaKumar Apr 5, 2021
86f8459
Release (#2)
AnanyaKumar Apr 4, 2022
dad75b1
Update README
AnanyaKumar Apr 4, 2022
cea5bb2
Update
AnanyaKumar Apr 4, 2022
f6fb976
Update readme
AnanyaKumar Apr 4, 2022
c0f5122
Training -> fine-tuning
AnanyaKumar Apr 4, 2022
e1adf22
Update README.md
AnanyaKumar Jun 30, 2022
9543919
Add mixup, tuning middle layers, waterbirds, worst group acc
AnanyaKumar Jul 10, 2022
e62a38f
Add support for unlocking all layers and full fine-tuning after a cer…
AnanyaKumar Jul 14, 2022
e85c42d
Add print command option
AnanyaKumar Aug 9, 2022
775b4f7
Merge branch 'main' of github.com:AnanyaKumar/transfer_learning into …
AnanyaKumar Aug 9, 2022
c172641
Add no norm waterbirds.
AnanyaKumar Aug 9, 2022
2479af5
Merge branch 'main' of github.com:AnanyaKumar/transfer_learning into …
AnanyaKumar Aug 9, 2022
8aac7ee
Updates.
AnanyaKumar Aug 9, 2022
308d87a
Get waterbirds working on sandbox.
AnanyaKumar Aug 9, 2022
df51eea
Get waterbirds working on sandbox.
AnanyaKumar Aug 9, 2022
2873b32
Add waterbirds background prediction task.
AnanyaKumar Aug 10, 2022
04b4c8d
Add waterbirds background prediction task.
AnanyaKumar Aug 10, 2022
549b532
Support waterbirds background sweeps.
AnanyaKumar Aug 10, 2022
8857485
Support waterbirds background sweeps.
AnanyaKumar Aug 10, 2022
44c1685
Add get_layers to clip model.
AnanyaKumar Aug 10, 2022
64f71b6
Add get_layers to clip model.
AnanyaKumar Aug 10, 2022
d698832
Fix typo with get layer, only include classifier layer if it exists.
AnanyaKumar Aug 10, 2022
d8ab1d7
Fix typo with get layer, only include classifier layer if it exists.
AnanyaKumar Aug 10, 2022
95dc790
Add dino layers change.
AnanyaKumar Aug 11, 2022
21234e3
Add dino layers change.
AnanyaKumar Aug 11, 2022
2973555
Add bit resnet model.
AnanyaKumar Aug 11, 2022
a1f2d2a
Add bit resnet model.
AnanyaKumar Aug 11, 2022
497c1e7
bit resnet fixes.
AnanyaKumar Aug 11, 2022
53825d8
bit resnet fixes.
AnanyaKumar Aug 11, 2022
0f9cce8
Use checkpoint_path to streamline with other models.
AnanyaKumar Aug 11, 2022
8be6fb8
Use checkpoint_path to streamline with other models.
AnanyaKumar Aug 11, 2022
ffedf7e
Add bit model to sweep
AnanyaKumar Aug 11, 2022
7966d6e
Add bit model to sweep
AnanyaKumar Aug 11, 2022
3a6c053
Bit fixes.
AnanyaKumar Aug 12, 2022
1ae1067
Bit fixes.
AnanyaKumar Aug 12, 2022
9a0a529
Add support for subsampled wilds.
AnanyaKumar Aug 16, 2022
2175f69
Add support for subsampled wilds.
AnanyaKumar Aug 16, 2022
cc558f8
Add support for more optimizers.
AnanyaKumar Aug 16, 2022
fcfd209
Add support for more optimizers.
AnanyaKumar Aug 16, 2022
413fd2d
Reduce batch size for larger models.
AnanyaKumar Aug 16, 2022
6e77472
Reduce batch size for larger models.
AnanyaKumar Aug 16, 2022
1ccdb67
Do normalization inside vit.
AnanyaKumar Aug 16, 2022
3ba5517
Do normalization inside vit.
AnanyaKumar Aug 16, 2022
5ae9794
update
AnanyaKumar Aug 16, 2022
756ee70
update
AnanyaKumar Aug 16, 2022
5515b81
Add per step scheduling and gradient clipping.
AnanyaKumar Aug 19, 2022
64980e9
Add per step scheduling and gradient clipping.
AnanyaKumar Aug 19, 2022
1a53342
Make it easy to summarize results.
AnanyaKumar Aug 20, 2022
c72c172
Make it easy to summarize results.
AnanyaKumar Aug 20, 2022
79d3efb
Add linear warmup and gradient clipping.
AnanyaKumar Aug 20, 2022
b2f432d
Add linear warmup and gradient clipping.
AnanyaKumar Aug 20, 2022
a913346
Implement gradient accumulation.
AnanyaKumar Aug 22, 2022
143abfe
Implement gradient accumulation.
AnanyaKumar Aug 22, 2022
52187d2
fix loss computation issue, add warmup_frac option.
AnanyaKumar Aug 23, 2022
c322dd7
fix loss computation issue, add warmup_frac option.
AnanyaKumar Aug 23, 2022
bb6bce7
log warmup steps
AnanyaKumar Aug 23, 2022
216a0d3
log warmup steps
AnanyaKumar Aug 23, 2022
e16e29a
Fine-grained layers
AnanyaKumar Aug 24, 2022
e216a35
Fine-grained layers
AnanyaKumar Aug 24, 2022
537256e
Log warmup steps.
AnanyaKumar Aug 24, 2022
b5a61df
Log warmup steps.
AnanyaKumar Aug 24, 2022
f6ae4a3
Tune bottom k
AnanyaKumar Aug 25, 2022
792ead2
Tune bottom k
AnanyaKumar Aug 25, 2022
b7e33de
Add no train option, useful to e.g., get gradient stats across train …
AnanyaKumar Aug 25, 2022
3f97134
Add no train option, useful to e.g., get gradient stats across train …
AnanyaKumar Aug 25, 2022
908cfbc
Add option to not trian.
AnanyaKumar Aug 26, 2022
b0f127d
Add option to not trian.
AnanyaKumar Aug 26, 2022
4ceaa67
Add amlt template files.
AnanyaKumar Aug 27, 2022
7662a44
Add amlt template files.
AnanyaKumar Aug 27, 2022
cb4452b
Update amulet output dir
AnanyaKumar Aug 29, 2022
bd4912d
Update amulet output dir
AnanyaKumar Aug 29, 2022
9e030f8
Add wilds downloader, fix model layers
AnanyaKumar Sep 1, 2022
bb0ebca
Add wilds downloader, fix model layers
AnanyaKumar Sep 1, 2022
3b18056
Add wilds downloader, fix model layers
AnanyaKumar Sep 1, 2022
1565eee
Add amulet sweep capability.
AnanyaKumar Sep 2, 2022
3b5504d
Add amulet sweep capability.
AnanyaKumar Sep 2, 2022
879effa
Add amulet sweep capability.
AnanyaKumar Sep 2, 2022
0862ea5
Fix edge case with batch splitting, and add timm model
AnanyaKumar Sep 6, 2022
fd42695
Fix edge case with batch splitting, and add timm model
AnanyaKumar Sep 6, 2022
4e811d5
Fix edge case with batch splitting, and add timm model
AnanyaKumar Sep 6, 2022
33de4e4
Add option to run all hyper sweeps in one job
AnanyaKumar Sep 7, 2022
acf6136
Add option to run all hyper sweeps in one job
AnanyaKumar Sep 7, 2022
6d21738
Add option to run all hyper sweeps in one job
AnanyaKumar Sep 7, 2022
85b1194
Add amlt_data_cmd to run copy data bash script.
AnanyaKumar Sep 7, 2022
d92c08d
Add amlt_data_cmd to run copy data bash script.
AnanyaKumar Sep 7, 2022
1b037f0
Add amlt_data_cmd to run copy data bash script.
AnanyaKumar Sep 7, 2022
e04c4a3
Copy imagenet and domainnet on amulet.
AnanyaKumar Sep 7, 2022
c13e543
Copy imagenet and domainnet on amulet.
AnanyaKumar Sep 7, 2022
2876b60
Copy imagenet and domainnet on amulet.
AnanyaKumar Sep 7, 2022
dcf3805
Add modified CLIP repo inside
AnanyaKumar Sep 7, 2022
015fd34
Add modified CLIP repo inside
AnanyaKumar Sep 7, 2022
a2223bc
Add modified CLIP repo inside
AnanyaKumar Sep 7, 2022
7c66464
Add CLIP
AnanyaKumar Sep 7, 2022
ac6e038
Add CLIP
AnanyaKumar Sep 7, 2022
67d69df
Add CLIP
AnanyaKumar Sep 7, 2022
4e6b23a
Add big table script, use v100
AnanyaKumar Sep 7, 2022
433a593
Add big table script, use v100
AnanyaKumar Sep 7, 2022
698bd43
Add big table script, use v100
AnanyaKumar Sep 7, 2022
f005701
Automate results collection a bit more.
AnanyaKumar Sep 13, 2022
18b3a62
Automate results collection a bit more.
AnanyaKumar Sep 13, 2022
9f14e30
Automate results collection a bit more.
AnanyaKumar Sep 13, 2022
062253b
Log parameter norm, so we can compute gradient normalized by paramete…
AnanyaKumar Sep 13, 2022
75a7a76
Log parameter norm, so we can compute gradient normalized by paramete…
AnanyaKumar Sep 13, 2022
ae09256
Log parameter norm, so we can compute gradient normalized by paramete…
AnanyaKumar Sep 13, 2022
f23b88a
Add more wilds support
AnanyaKumar Sep 15, 2022
87201c5
Add more wilds support
AnanyaKumar Sep 15, 2022
2f833db
Add more wilds support
AnanyaKumar Sep 15, 2022
b07ff74
Add layer wise tuning.
AnanyaKumar Sep 15, 2022
d594473
Add layer wise tuning.
AnanyaKumar Sep 15, 2022
72c60d6
Add layer wise tuning.
AnanyaKumar Sep 15, 2022
bbd3306
Layerwise tuning
AnanyaKumar Sep 21, 2022
9c00b9a
Layerwise tuning
AnanyaKumar Sep 21, 2022
627836f
Layerwise tuning
AnanyaKumar Sep 21, 2022
4d884ae
Improve layer wise tuner, add arbitrary options to run script.
AnanyaKumar Sep 21, 2022
ce411ce
Improve layer wise tuner, add arbitrary options to run script.
AnanyaKumar Sep 21, 2022
46ce62f
Improve layer wise tuner, add arbitrary options to run script.
AnanyaKumar Sep 21, 2022
1c317c8
Final paper updates
AnanyaKumar Oct 7, 2022
55026d2
Final paper updates
AnanyaKumar Oct 7, 2022
312fc85
Final paper updates
AnanyaKumar Oct 7, 2022
0d84a28
Remove files
AnanyaKumar Oct 7, 2022
3ba355c
Update gitignore
AnanyaKumar Oct 7, 2022
c466dd8
Merge conflicts.
AnanyaKumar Oct 10, 2022
62fd8a8
Merge conflicts.
AnanyaKumar Oct 10, 2022
3ddc9c5
Start prep for wilds leaderboard.
AnanyaKumar Oct 10, 2022
9a095b3
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Oct 10, 2022
630d766
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Oct 10, 2022
233b155
Add best to avoid name conflict.
AnanyaKumar Oct 10, 2022
09b35a9
Add best to avoid name conflict.
AnanyaKumar Oct 10, 2022
558037c
Fix arg name.
AnanyaKumar Oct 10, 2022
f32a864
Fix arg name.
AnanyaKumar Oct 10, 2022
efc99a5
Debug
AnanyaKumar Oct 10, 2022
5ec49e9
Debug
AnanyaKumar Oct 10, 2022
866072e
Add high rest fmow
AnanyaKumar Oct 12, 2022
dc28590
Add high rest fmow
AnanyaKumar Oct 12, 2022
0f4b6da
address reviewer comments
AnanyaKumar Nov 6, 2022
0315a9b
address reviewer comments
AnanyaKumar Nov 6, 2022
c55423e
Updates
AnanyaKumar Nov 9, 2022
e6db406
Updates
AnanyaKumar Nov 9, 2022
a112530
Add big table data.
AnanyaKumar Nov 9, 2022
d3bd0b1
Add big table data.
AnanyaKumar Nov 9, 2022
6583268
big table format change
sgunasekar Nov 11, 2022
4668c0c
big table format change
sgunasekar Nov 11, 2022
dca4ca0
made non significant gain/loss lighter
sgunasekar Nov 11, 2022
115dcbe
made non significant gain/loss lighter
sgunasekar Nov 11, 2022
84566ae
Change name.
AnanyaKumar Nov 11, 2022
0380e11
Change name.
AnanyaKumar Nov 11, 2022
248f439
Clear outputs.
AnanyaKumar Nov 11, 2022
c502625
Clear outputs.
AnanyaKumar Nov 11, 2022
9629490
Merge branch 'main' of github.com:AnanyaKumar/transfer_learning into …
AnanyaKumar Nov 11, 2022
8981329
Update, and keep big table source.
AnanyaKumar Nov 11, 2022
2cbf388
Update, and keep big table source.
AnanyaKumar Nov 11, 2022
f3ee8b4
Add more pickle files for results.
AnanyaKumar Nov 14, 2022
873ef11
Add more pickle files for results.
AnanyaKumar Nov 14, 2022
5ab5c47
Fix typo.
AnanyaKumar Nov 14, 2022
6a7dc5c
Fix typo.
AnanyaKumar Nov 14, 2022
bb02830
Add group name, add models.
AnanyaKumar Dec 13, 2022
0b8944e
Add group name, add models.
AnanyaKumar Dec 13, 2022
29acb9b
Codalab
AnanyaKumar Dec 26, 2022
40d6bfc
Codalab
AnanyaKumar Dec 26, 2022
30dd39d
More codalab updates.
AnanyaKumar Dec 27, 2022
80bc68b
More codalab updates.
AnanyaKumar Dec 27, 2022
c5f51a0
Add cluster group.
AnanyaKumar Dec 28, 2022
3b16a3e
Restore hyper for fmow.
AnanyaKumar Dec 28, 2022
8dab1c4
Restore hyper for fmow.
AnanyaKumar Dec 28, 2022
b60d17f
Updates for surgical fine tuning checkpoints.
AnanyaKumar Dec 28, 2022
406687a
Updates for surgical fine tuning checkpoints.
AnanyaKumar Dec 28, 2022
382f167
Merge branch 'main' into amulet
AnanyaKumar Dec 28, 2022
1317bff
Merge branch 'main' into amulet
AnanyaKumar Dec 28, 2022
925100a
Use scr-sync
AnanyaKumar Dec 29, 2022
a5724e5
Use scr-sync
AnanyaKumar Dec 29, 2022
fe42a5c
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Dec 29, 2022
cc2d7f9
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Dec 29, 2022
3a35b16
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Dec 30, 2022
46b21ad
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Dec 30, 2022
e295e78
Updates
AnanyaKumar Dec 30, 2022
22bfc57
Updates
AnanyaKumar Dec 30, 2022
6c4b27d
Merge branch 'amulet' of github.com:AnanyaKumar/transfer_learning int…
AnanyaKumar Dec 30, 2022
d343311
Remove printing in summarize results.
AnanyaKumar Dec 30, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .amltignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
**wandb**
**.tsv**
**egg-info**
**slurm_outputs**
**logs**
**amlt_results**
**amlt_configs**
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
**.swn
**.swo
**amlt_configs/**
.amltconfig
**amlt_scripts**

**tmp.tsv**
**slurm_outputs**
**outdated**
Expand Down
10 changes: 10 additions & 0 deletions CLIP/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info
.pytest_cache
.ipynb_checkpoints

thumbs.db
.DS_Store
.idea
Binary file added CLIP/CLIP.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions CLIP/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
MIT License

Copyright (c) 2021 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

1 change: 1 addition & 0 deletions CLIP/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include clip/bpe_simple_vocab_16e6.txt.gz
193 changes: 193 additions & 0 deletions CLIP/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# CLIP

[[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb)

CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.



## Approach

![CLIP](CLIP.png)



## Usage

First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) (or later) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:

```bash
$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git
```

Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU.

```python
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)

logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
```


## API

The CLIP module `clip` provides the following methods:

#### `clip.available_models()`

Returns the names of the available CLIP models.

#### `clip.load(name, device=..., jit=False)`

Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint.

The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded.

#### `clip.tokenize(text: Union[str, List[str]], context_length=77)`

Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model

---

The model returned by `clip.load()` supports the following methods:

#### `model.encode_image(image: Tensor)`

Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.

#### `model.encode_text(text: Tensor)`

Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.

#### `model(image: Tensor, text: Tensor)`

Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.



## More Examples

### Zero-Shot Prediction

The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.

```python
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
```

The output will look like the following (the exact numbers may be slightly different depending on the compute device):

```
Top predictions:

snake: 65.31%
turtle: 12.29%
sweet_pepper: 3.83%
lizard: 1.88%
crocodile: 1.75%
```

Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.


### Linear-probe evaluation

The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features.

```python
import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
all_features = []
all_labels = []

with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))

all_features.append(features)
all_labels.append(labels)

return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
```

Note that the `C` value should be determined via a hyperparameter sweep using a validation split.
1 change: 1 addition & 0 deletions CLIP/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .clip import *
Binary file added CLIP/clip/bpe_simple_vocab_16e6.txt.gz
Binary file not shown.
Loading