Skip to content

Commit d47c57e

Browse files
committed
fix tests
1 parent 0449865 commit d47c57e

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

.github/workflows/python-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
python -m pip install --upgrade pip
2929
python -m pip install pytest
3030
python -m pip install wheel
31-
python -m pip install torch==2.5.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
31+
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
3232
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
3333
- name: Test with pytest
3434
run: |

vit_pytorch/na_vit_nested_tensor.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
import torch
77
import packaging.version as pkg_version
88

9-
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
10-
print('nested tensor NaViT was tested on pytorch 2.5')
11-
129
from torch import nn, Tensor
1310
import torch.nn.functional as F
1411
from torch.nn import Module, ModuleList
@@ -152,6 +149,11 @@ def __init__(
152149
token_dropout_prob: float | None = None
153150
):
154151
super().__init__()
152+
153+
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
154+
print('nested tensor NaViT was tested on pytorch 2.5')
155+
156+
155157
image_height, image_width = pair(image_size)
156158

157159
# what percent of tokens to dropout

vit_pytorch/na_vit_nested_tensor_3d.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
import torch
77
import packaging.version as pkg_version
88

9-
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
10-
print('nested tensor NaViT was tested on pytorch 2.5')
11-
129
from torch import nn, Tensor
1310
import torch.nn.functional as F
1411
from torch.nn import Module, ModuleList
@@ -169,6 +166,9 @@ def __init__(
169166
super().__init__()
170167
image_height, image_width = pair(image_size)
171168

169+
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
170+
print('nested tensor NaViT was tested on pytorch 2.5')
171+
172172
# what percent of tokens to dropout
173173
# if int or float given, then assume constant dropout prob
174174
# otherwise accept a callback that in turn calculates dropout prob from height and width

0 commit comments

Comments
 (0)