Skip to content

Commit 433729d

Browse files
authored
Revert "Mvtec 3d (#907)"
This reverts commit 052a3ad.
1 parent b789448 commit 433729d

File tree

15 files changed

+260
-1530
lines changed

15 files changed

+260
-1530
lines changed

anomalib/data/__init__.py

-47
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
from .base import AnomalibDataModule, AnomalibDataset
1414
from .btech import BTech
1515
from .folder import Folder
16-
from .folder_3d import Folder3D
1716
from .inference import InferenceDataset
1817
from .mvtec import MVTec
19-
from .mvtec_3d import MVTec3D
2018
from .shanghaitech import ShanghaiTech
2119
from .task_type import TaskType
2220
from .ucsd_ped import UCSDped
@@ -61,24 +59,6 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
6159
val_split_mode=config.dataset.val_split_mode,
6260
val_split_ratio=config.dataset.val_split_ratio,
6361
)
64-
elif config.dataset.format.lower() == "mvtec_3d":
65-
datamodule = MVTec3D(
66-
root=config.dataset.path,
67-
category=config.dataset.category,
68-
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
69-
center_crop=center_crop,
70-
normalization=config.dataset.normalization,
71-
train_batch_size=config.dataset.train_batch_size,
72-
eval_batch_size=config.dataset.eval_batch_size,
73-
num_workers=config.dataset.num_workers,
74-
task=config.dataset.task,
75-
transform_config_train=config.dataset.transform_config.train,
76-
transform_config_eval=config.dataset.transform_config.eval,
77-
test_split_mode=config.dataset.test_split_mode,
78-
test_split_ratio=config.dataset.test_split_ratio,
79-
val_split_mode=config.dataset.val_split_mode,
80-
val_split_ratio=config.dataset.val_split_ratio,
81-
)
8262
elif config.dataset.format.lower() == "btech":
8363
datamodule = BTech(
8464
root=config.dataset.path,
@@ -119,31 +99,6 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
11999
val_split_mode=config.dataset.val_split_mode,
120100
val_split_ratio=config.dataset.val_split_ratio,
121101
)
122-
elif config.dataset.format.lower() == "folder_3d":
123-
datamodule = Folder3D(
124-
root=config.dataset.root,
125-
normal_dir=config.dataset.normal_dir,
126-
normal_depth_dir=config.dataset.normal_depth_dir,
127-
abnormal_dir=config.dataset.abnormal_dir,
128-
abnormal_depth_dir=config.dataset.abnormal_depth_dir,
129-
task=config.dataset.task,
130-
normal_test_dir=config.dataset.normal_test_dir,
131-
normal_test_depth_dir=config.dataset.normal_test_depth_dir,
132-
mask_dir=config.dataset.mask_dir,
133-
extensions=config.dataset.extensions,
134-
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
135-
center_crop=center_crop,
136-
normalization=config.dataset.normalization,
137-
train_batch_size=config.dataset.train_batch_size,
138-
eval_batch_size=config.dataset.eval_batch_size,
139-
num_workers=config.dataset.num_workers,
140-
transform_config_train=config.dataset.transform_config.train,
141-
transform_config_eval=config.dataset.transform_config.eval,
142-
test_split_mode=config.dataset.test_split_mode,
143-
test_split_ratio=config.dataset.test_split_ratio,
144-
val_split_mode=config.dataset.val_split_mode,
145-
val_split_ratio=config.dataset.val_split_ratio,
146-
)
147102
elif config.dataset.format.lower() == "ucsdped":
148103
datamodule = UCSDped(
149104
root=config.dataset.path,
@@ -232,10 +187,8 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
232187
"get_datamodule",
233188
"BTech",
234189
"Folder",
235-
"Folder3D",
236190
"InferenceDataset",
237191
"MVTec",
238-
"MVTec3D",
239192
"Avenue",
240193
"UCSDped",
241194
"TaskType",

anomalib/data/base/__init__.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@
66

77
from .datamodule import AnomalibDataModule
88
from .dataset import AnomalibDataset
9-
from .depth import AnomalibDepthDataset
109
from .video import AnomalibVideoDataModule, AnomalibVideoDataset
1110

12-
__all__ = [
13-
"AnomalibDataset",
14-
"AnomalibDataModule",
15-
"AnomalibVideoDataset",
16-
"AnomalibVideoDataModule",
17-
"AnomalibDepthDataset",
18-
]
11+
__all__ = ["AnomalibDataset", "AnomalibDataModule", "AnomalibVideoDataset", "AnomalibVideoDataModule"]

anomalib/data/base/datamodule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pandas import DataFrame
1313
from pytorch_lightning import LightningDataModule
1414
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
15-
from torch.utils.data.dataloader import DataLoader, default_collate
15+
from torch.utils.data import DataLoader, default_collate
1616

1717
from anomalib.data.base.dataset import AnomalibDataset
1818
from anomalib.data.synthetic import SyntheticAnomalyDataset

anomalib/data/base/dataset.py

-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
126126
elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
127127
# Only Anomalous (1) images have masks in anomaly datasets
128128
# Therefore, create empty mask for Normal (0) images.
129-
130129
if label_index == 0:
131130
mask = np.zeros(shape=image.shape[:2])
132131
else:

anomalib/data/base/depth.py

-68
This file was deleted.

anomalib/data/folder.py

+82-25
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import albumentations as A
1414
from pandas import DataFrame
15+
from torchvision.datasets.folder import IMG_EXTENSIONS
1516

1617
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
1718
from anomalib.data.task_type import TaskType
@@ -22,7 +23,74 @@
2223
ValSplitMode,
2324
get_transforms,
2425
)
25-
from anomalib.data.utils.path import _prepare_files_labels, _resolve_path
26+
27+
28+
def _check_and_convert_path(path: str | Path) -> Path:
29+
"""Check an input path, and convert to Pathlib object.
30+
31+
Args:
32+
path (str | Path): Input path.
33+
34+
Returns:
35+
Path: Output path converted to pathlib object.
36+
"""
37+
if not isinstance(path, Path):
38+
path = Path(path)
39+
return path
40+
41+
42+
def _prepare_files_labels(
43+
path: str | Path, path_type: str, extensions: tuple[str, ...] | None = None
44+
) -> tuple[list, list]:
45+
"""Return a list of filenames and list corresponding labels.
46+
47+
Args:
48+
path (str | Path): Path to the directory containing images.
49+
path_type (str): Type of images in the provided path ("normal", "abnormal", "normal_test")
50+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the
51+
directory.
52+
53+
Returns:
54+
List, List: Filenames of the images provided in the paths, labels of the images provided in the paths
55+
"""
56+
path = _check_and_convert_path(path)
57+
if extensions is None:
58+
extensions = IMG_EXTENSIONS
59+
60+
if isinstance(extensions, str):
61+
extensions = (extensions,)
62+
63+
filenames = [f for f in path.glob(r"**/*") if f.suffix in extensions and not f.is_dir()]
64+
if not filenames:
65+
raise RuntimeError(f"Found 0 {path_type} images in {path}")
66+
67+
labels = [path_type] * len(filenames)
68+
69+
return filenames, labels
70+
71+
72+
def _resolve_path(folder: str | Path, root: str | Path | None = None) -> Path:
73+
"""Combines root and folder and returns the absolute path.
74+
75+
This allows users to pass either a root directory and relative paths, or absolute paths to each of the
76+
image sources. This function makes sure that the samples dataframe always contains absolute paths.
77+
78+
Args:
79+
folder (str | Path | None): Folder location containing image or mask data.
80+
root (str | Path | None): Root directory for the dataset.
81+
"""
82+
folder = Path(folder)
83+
if folder.is_absolute():
84+
# path is absolute; return unmodified
85+
path = folder
86+
# path is relative.
87+
elif root is None:
88+
# no root provided; return absolute path
89+
path = folder.resolve()
90+
else:
91+
# root provided; prepend root and return absolute path
92+
path = (Path(root) / folder).resolve()
93+
return path
2694

2795

2896
def make_folder_dataset(
@@ -69,42 +137,31 @@ def make_folder_dataset(
69137
if normal_test_dir:
70138
dirs = {**dirs, **{"normal_test": normal_test_dir}}
71139

72-
if mask_dir:
73-
dirs = {**dirs, **{"mask_dir": mask_dir}}
74-
75140
for dir_type, path in dirs.items():
76141
filename, label = _prepare_files_labels(path, dir_type, extensions)
77142
filenames += filename
78143
labels += label
79144

80-
samples = DataFrame({"image_path": filenames, "label": labels})
81-
samples = samples.sort_values(by="image_path", ignore_index=True)
145+
samples = DataFrame({"image_path": filenames, "label": labels, "mask_path": ""})
82146

83147
# Create label index for normal (0) and abnormal (1) images.
84148
samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0
85149
samples.loc[(samples.label == "abnormal"), "label_index"] = 1
86-
samples.label_index = samples.label_index.astype("Int64")
150+
samples.label_index = samples.label_index.astype(int)
87151

88152
# If a path to mask is provided, add it to the sample dataframe.
89153
if mask_dir is not None:
90-
samples.loc[samples.label == "abnormal", "mask_path"] = samples.loc[
91-
samples.label == "mask_dir"
92-
].image_path.values
93-
samples = samples.astype({"mask_path": "str"})
94-
95-
# make sure all every rgb image has a corresponding mask image.
96-
assert (
97-
samples.loc[samples.label_index == 1]
98-
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
99-
.all()
100-
), "Mismatch between anomalous images and mask images. Make sure the mask files \
101-
folder follow the same naming convention as the anomalous images in the dataset \
102-
(e.g. image: '000.png', mask: '000.png')."
103-
104-
# remove all the rows with temporal image samples that have already been assigned
105-
samples = samples.loc[
106-
(samples.label == "normal") | (samples.label == "abnormal") | (samples.label == "normal_test")
107-
]
154+
mask_dir = _check_and_convert_path(mask_dir)
155+
for index, row in samples.iterrows():
156+
if row.label_index == 1:
157+
rel_image_path = row.image_path.relative_to(abnormal_dir)
158+
samples.loc[index, "mask_path"] = str(mask_dir / rel_image_path)
159+
160+
# make sure all the files exist
161+
# samples.image_path does NOT need to be checked because we build the df based on that
162+
assert samples.mask_path.apply(
163+
lambda x: Path(x).exists() if x != "" else True
164+
).all(), f"missing mask files, mask_dir={mask_dir}"
108165

109166
# Ensure the pathlib objects are converted to str.
110167
# This is because torch dataloader doesn't like pathlib.

0 commit comments

Comments
 (0)