Skip to content

Commit 05decc3

Browse files
feats: add loader in classification task datasets. (#8939)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 7b2addf commit 05decc3

17 files changed

+167
-71
lines changed

test/test_datasets.py

+37
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn.functional as F
2525
from common_utils import combinations_grid
2626
from torchvision import datasets
27+
from torchvision.io import decode_image
2728
from torchvision.transforms import v2
2829

2930

@@ -1175,6 +1176,8 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase):
11751176
DATASET_CLASS = datasets.SBU
11761177
FEATURE_TYPES = (PIL.Image.Image, str)
11771178

1179+
SUPPORT_TV_IMAGE_DECODE = True
1180+
11781181
def inject_fake_data(self, tmpdir, config):
11791182
num_images = 3
11801183

@@ -1413,6 +1416,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
14131416
_IMAGES_FOLDER = "images"
14141417
_ANNOTATIONS_FILE = "captions.html"
14151418

1419+
SUPPORT_TV_IMAGE_DECODE = True
1420+
14161421
def dataset_args(self, tmpdir, config):
14171422
tmpdir = pathlib.Path(tmpdir)
14181423
root = tmpdir / self._IMAGES_FOLDER
@@ -1482,6 +1487,8 @@ class Flickr30kTestCase(Flickr8kTestCase):
14821487

14831488
_ANNOTATIONS_FILE = "captions.token"
14841489

1490+
SUPPORT_TV_IMAGE_DECODE = True
1491+
14851492
def _image_file_name(self, idx):
14861493
return f"{idx}.jpg"
14871494

@@ -1942,6 +1949,8 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
19421949
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
19431950
_file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
19441951

1952+
SUPPORT_TV_IMAGE_DECODE = True
1953+
19451954
def inject_fake_data(self, tmpdir, config):
19461955
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
19471956
os.makedirs(tmpdir, exist_ok=True)
@@ -1978,6 +1987,18 @@ def _create_random_id(self):
19781987
part2 = datasets_utils.create_random_string(random.randint(4, 7))
19791988
return f"{part1}_{part2}"
19801989

1990+
def test_tv_decode_image_support(self):
1991+
if not self.SUPPORT_TV_IMAGE_DECODE:
1992+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
1993+
1994+
with self.create_dataset(
1995+
config=dict(
1996+
loader=decode_image,
1997+
)
1998+
) as (dataset, _):
1999+
image = dataset[0][0]
2000+
assert isinstance(image, torch.Tensor)
2001+
19812002

19822003
class LFWPairsTestCase(LFWPeopleTestCase):
19832004
DATASET_CLASS = datasets.LFWPairs
@@ -2335,6 +2356,8 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
23352356

23362357
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
23372358

2359+
SUPPORT_TV_IMAGE_DECODE = True
2360+
23382361
def inject_fake_data(self, tmpdir: str, config):
23392362
root_folder = pathlib.Path(tmpdir) / "food-101"
23402363
image_folder = root_folder / "images"
@@ -2371,6 +2394,7 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
23712394
ADDITIONAL_CONFIGS = combinations_grid(
23722395
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
23732396
)
2397+
SUPPORT_TV_IMAGE_DECODE = True
23742398

23752399
def inject_fake_data(self, tmpdir: str, config):
23762400
split = config["split"]
@@ -2420,6 +2444,8 @@ def inject_fake_data(self, tmpdir: str, config):
24202444
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
24212445
DATASET_CLASS = datasets.SUN397
24222446

2447+
SUPPORT_TV_IMAGE_DECODE = True
2448+
24232449
def inject_fake_data(self, tmpdir: str, config):
24242450
data_dir = pathlib.Path(tmpdir) / "SUN397"
24252451
data_dir.mkdir()
@@ -2451,6 +2477,8 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
24512477
DATASET_CLASS = datasets.DTD
24522478
FEATURE_TYPES = (PIL.Image.Image, int)
24532479

2480+
SUPPORT_TV_IMAGE_DECODE = True
2481+
24542482
ADDITIONAL_CONFIGS = combinations_grid(
24552483
split=("train", "test", "val"),
24562484
# There is no need to test the whole matrix here, since each fold is treated exactly the same
@@ -2611,6 +2639,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
26112639
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
26122640

26132641
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
2642+
SUPPORT_TV_IMAGE_DECODE = True
26142643

26152644
def inject_fake_data(self, tmpdir, config):
26162645
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
@@ -2708,6 +2737,8 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
27082737
REQUIRED_PACKAGES = ("scipy",)
27092738
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
27102739

2740+
SUPPORT_TV_IMAGE_DECODE = True
2741+
27112742
def inject_fake_data(self, tmpdir, config):
27122743
import scipy.io as io
27132744
from numpy.core.records import fromarrays
@@ -2782,6 +2813,8 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
27822813
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
27832814
REQUIRED_PACKAGES = ("scipy",)
27842815

2816+
SUPPORT_TV_IMAGE_DECODE = True
2817+
27852818
def inject_fake_data(self, tmpdir: str, config):
27862819
base_folder = pathlib.Path(tmpdir) / "flowers-102"
27872820

@@ -2840,6 +2873,8 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
28402873
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
28412874
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
28422875

2876+
SUPPORT_TV_IMAGE_DECODE = True
2877+
28432878
def inject_fake_data(self, tmpdir: str, config):
28442879
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
28452880
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
@@ -3500,6 +3535,8 @@ class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
35003535
DATASET_CLASS = datasets.Imagenette
35013536
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])
35023537

3538+
SUPPORT_TV_IMAGE_DECODE = True
3539+
35033540
_WNIDS = [
35043541
"n01440764",
35053542
"n02102040",

torchvision/datasets/clevr.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, List, Optional, Tuple, Union
44
from urllib.parse import urlparse
55

6-
from PIL import Image
6+
from .folder import default_loader
77

88
from .utils import download_and_extract_archive, verify_str_arg
99
from .vision import VisionDataset
@@ -18,11 +18,14 @@ class CLEVRClassification(VisionDataset):
1818
root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
1919
set to True.
2020
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
21-
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
22-
version. E.g, ``transforms.RandomCrop``
21+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
22+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2323
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
2424
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
2525
dataset is already downloaded, it is not downloaded again.
26+
loader (callable, optional): A function to load an image given its path.
27+
By default, it uses PIL as its image loader, but users could also pass in
28+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2629
"""
2730

2831
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
@@ -35,9 +38,11 @@ def __init__(
3538
transform: Optional[Callable] = None,
3639
target_transform: Optional[Callable] = None,
3740
download: bool = False,
41+
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
3842
) -> None:
3943
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4044
super().__init__(root, transform=transform, target_transform=target_transform)
45+
self.loader = loader
4146
self._base_folder = pathlib.Path(self.root) / "clevr"
4247
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
4348

@@ -65,7 +70,7 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
6570
image_file = self._image_files[idx]
6671
label = self._labels[idx]
6772

68-
image = Image.open(image_file).convert("RGB")
73+
image = self.loader(image_file)
6974

7075
if self.transform:
7176
image = self.transform(image)

torchvision/datasets/country211.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class Country211(ImageFolder):
1616
Args:
1717
root (str or ``pathlib.Path``): Root directory of the dataset.
1818
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
19-
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
20-
version. E.g, ``transforms.RandomCrop``.
19+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
20+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2121
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222
download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323
``root/country211/``. If dataset is already downloaded, it is not downloaded again.

torchvision/datasets/dtd.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pathlib
33
from typing import Any, Callable, Optional, Tuple, Union
44

5-
import PIL.Image
5+
from .folder import default_loader
66

77
from .utils import download_and_extract_archive, verify_str_arg
88
from .vision import VisionDataset
@@ -21,12 +21,15 @@ class DTD(VisionDataset):
2121
The partition only changes which split each image belongs to. Thus, regardless of the selected
2222
partition, combining all splits will result in all images.
2323
24-
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
25-
version. E.g, ``transforms.RandomCrop``.
24+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
25+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2626
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2727
download (bool, optional): If True, downloads the dataset from the internet and
2828
puts it in root directory. If dataset is already downloaded, it is not
2929
downloaded again. Default is False.
30+
loader (callable, optional): A function to load an image given its path.
31+
By default, it uses PIL as its image loader, but users could also pass in
32+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
3033
"""
3134

3235
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
@@ -40,6 +43,7 @@ def __init__(
4043
transform: Optional[Callable] = None,
4144
target_transform: Optional[Callable] = None,
4245
download: bool = False,
46+
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
4347
) -> None:
4448
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4549
if not isinstance(partition, int) and not (1 <= partition <= 10):
@@ -72,13 +76,14 @@ def __init__(
7276
self.classes = sorted(set(classes))
7377
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
7478
self._labels = [self.class_to_idx[cls] for cls in classes]
79+
self.loader = loader
7580

7681
def __len__(self) -> int:
7782
return len(self._image_files)
7883

7984
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
8085
image_file, label = self._image_files[idx], self._labels[idx]
81-
image = PIL.Image.open(image_file).convert("RGB")
86+
image = self.loader(image_file)
8287

8388
if self.transform:
8489
image = self.transform(image)

torchvision/datasets/eurosat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class EuroSAT(ImageFolder):
1414
1515
Args:
1616
root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
17-
transform (callable, optional): A function/transform that takes in a PIL image
17+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
1818
and returns a transformed version. E.g, ``transforms.RandomCrop``
1919
target_transform (callable, optional): A function/transform that takes in the
2020
target and transforms it.

torchvision/datasets/fgvc_aircraft.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
from typing import Any, Callable, Optional, Tuple, Union
66

7-
import PIL.Image
7+
from .folder import default_loader
88

99
from .utils import download_and_extract_archive, verify_str_arg
1010
from .vision import VisionDataset
@@ -29,13 +29,16 @@ class FGVCAircraft(VisionDataset):
2929
``trainval`` and ``test``.
3030
annotation_level (str, optional): The annotation level, supports ``variant``,
3131
``family`` and ``manufacturer``.
32-
transform (callable, optional): A function/transform that takes in a PIL image
32+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
3333
and returns a transformed version. E.g, ``transforms.RandomCrop``
3434
target_transform (callable, optional): A function/transform that takes in the
3535
target and transforms it.
3636
download (bool, optional): If True, downloads the dataset from the internet and
3737
puts it in root directory. If dataset is already downloaded, it is not
3838
downloaded again.
39+
loader (callable, optional): A function to load an image given its path.
40+
By default, it uses PIL as its image loader, but users could also pass in
41+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
3942
"""
4043

4144
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
@@ -48,6 +51,7 @@ def __init__(
4851
transform: Optional[Callable] = None,
4952
target_transform: Optional[Callable] = None,
5053
download: bool = False,
54+
loader: Callable[[str], Any] = default_loader,
5155
) -> None:
5256
super().__init__(root, transform=transform, target_transform=target_transform)
5357
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
@@ -87,13 +91,14 @@ def __init__(
8791
image_name, label_name = line.strip().split(" ", 1)
8892
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
8993
self._labels.append(self.class_to_idx[label_name])
94+
self.loader = loader
9095

9196
def __len__(self) -> int:
9297
return len(self._image_files)
9398

9499
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
95100
image_file, label = self._image_files[idx], self._labels[idx]
96-
image = PIL.Image.open(image_file).convert("RGB")
101+
image = self.loader(image_file)
97102

98103
if self.transform:
99104
image = self.transform(image)

0 commit comments

Comments
 (0)