|
12 | 12 |
|
13 | 13 | import albumentations as A
|
14 | 14 | from pandas import DataFrame
|
| 15 | +from torchvision.datasets.folder import IMG_EXTENSIONS |
15 | 16 |
|
16 | 17 | from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
17 | 18 | from anomalib.data.task_type import TaskType
|
|
22 | 23 | ValSplitMode,
|
23 | 24 | get_transforms,
|
24 | 25 | )
|
25 |
| -from anomalib.data.utils.path import _prepare_files_labels, _resolve_path |
| 26 | + |
| 27 | + |
| 28 | +def _check_and_convert_path(path: str | Path) -> Path: |
| 29 | + """Check an input path, and convert to Pathlib object. |
| 30 | +
|
| 31 | + Args: |
| 32 | + path (str | Path): Input path. |
| 33 | +
|
| 34 | + Returns: |
| 35 | + Path: Output path converted to pathlib object. |
| 36 | + """ |
| 37 | + if not isinstance(path, Path): |
| 38 | + path = Path(path) |
| 39 | + return path |
| 40 | + |
| 41 | + |
| 42 | +def _prepare_files_labels( |
| 43 | + path: str | Path, path_type: str, extensions: tuple[str, ...] | None = None |
| 44 | +) -> tuple[list, list]: |
| 45 | + """Return a list of filenames and list corresponding labels. |
| 46 | +
|
| 47 | + Args: |
| 48 | + path (str | Path): Path to the directory containing images. |
| 49 | + path_type (str): Type of images in the provided path ("normal", "abnormal", "normal_test") |
| 50 | + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the |
| 51 | + directory. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + List, List: Filenames of the images provided in the paths, labels of the images provided in the paths |
| 55 | + """ |
| 56 | + path = _check_and_convert_path(path) |
| 57 | + if extensions is None: |
| 58 | + extensions = IMG_EXTENSIONS |
| 59 | + |
| 60 | + if isinstance(extensions, str): |
| 61 | + extensions = (extensions,) |
| 62 | + |
| 63 | + filenames = [f for f in path.glob(r"**/*") if f.suffix in extensions and not f.is_dir()] |
| 64 | + if not filenames: |
| 65 | + raise RuntimeError(f"Found 0 {path_type} images in {path}") |
| 66 | + |
| 67 | + labels = [path_type] * len(filenames) |
| 68 | + |
| 69 | + return filenames, labels |
| 70 | + |
| 71 | + |
| 72 | +def _resolve_path(folder: str | Path, root: str | Path | None = None) -> Path: |
| 73 | + """Combines root and folder and returns the absolute path. |
| 74 | +
|
| 75 | + This allows users to pass either a root directory and relative paths, or absolute paths to each of the |
| 76 | + image sources. This function makes sure that the samples dataframe always contains absolute paths. |
| 77 | +
|
| 78 | + Args: |
| 79 | + folder (str | Path | None): Folder location containing image or mask data. |
| 80 | + root (str | Path | None): Root directory for the dataset. |
| 81 | + """ |
| 82 | + folder = Path(folder) |
| 83 | + if folder.is_absolute(): |
| 84 | + # path is absolute; return unmodified |
| 85 | + path = folder |
| 86 | + # path is relative. |
| 87 | + elif root is None: |
| 88 | + # no root provided; return absolute path |
| 89 | + path = folder.resolve() |
| 90 | + else: |
| 91 | + # root provided; prepend root and return absolute path |
| 92 | + path = (Path(root) / folder).resolve() |
| 93 | + return path |
26 | 94 |
|
27 | 95 |
|
28 | 96 | def make_folder_dataset(
|
@@ -69,42 +137,31 @@ def make_folder_dataset(
|
69 | 137 | if normal_test_dir:
|
70 | 138 | dirs = {**dirs, **{"normal_test": normal_test_dir}}
|
71 | 139 |
|
72 |
| - if mask_dir: |
73 |
| - dirs = {**dirs, **{"mask_dir": mask_dir}} |
74 |
| - |
75 | 140 | for dir_type, path in dirs.items():
|
76 | 141 | filename, label = _prepare_files_labels(path, dir_type, extensions)
|
77 | 142 | filenames += filename
|
78 | 143 | labels += label
|
79 | 144 |
|
80 |
| - samples = DataFrame({"image_path": filenames, "label": labels}) |
81 |
| - samples = samples.sort_values(by="image_path", ignore_index=True) |
| 145 | + samples = DataFrame({"image_path": filenames, "label": labels, "mask_path": ""}) |
82 | 146 |
|
83 | 147 | # Create label index for normal (0) and abnormal (1) images.
|
84 | 148 | samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0
|
85 | 149 | samples.loc[(samples.label == "abnormal"), "label_index"] = 1
|
86 |
| - samples.label_index = samples.label_index.astype("Int64") |
| 150 | + samples.label_index = samples.label_index.astype(int) |
87 | 151 |
|
88 | 152 | # If a path to mask is provided, add it to the sample dataframe.
|
89 | 153 | if mask_dir is not None:
|
90 |
| - samples.loc[samples.label == "abnormal", "mask_path"] = samples.loc[ |
91 |
| - samples.label == "mask_dir" |
92 |
| - ].image_path.values |
93 |
| - samples = samples.astype({"mask_path": "str"}) |
94 |
| - |
95 |
| - # make sure all every rgb image has a corresponding mask image. |
96 |
| - assert ( |
97 |
| - samples.loc[samples.label_index == 1] |
98 |
| - .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) |
99 |
| - .all() |
100 |
| - ), "Mismatch between anomalous images and mask images. Make sure the mask files \ |
101 |
| - folder follow the same naming convention as the anomalous images in the dataset \ |
102 |
| - (e.g. image: '000.png', mask: '000.png')." |
103 |
| - |
104 |
| - # remove all the rows with temporal image samples that have already been assigned |
105 |
| - samples = samples.loc[ |
106 |
| - (samples.label == "normal") | (samples.label == "abnormal") | (samples.label == "normal_test") |
107 |
| - ] |
| 154 | + mask_dir = _check_and_convert_path(mask_dir) |
| 155 | + for index, row in samples.iterrows(): |
| 156 | + if row.label_index == 1: |
| 157 | + rel_image_path = row.image_path.relative_to(abnormal_dir) |
| 158 | + samples.loc[index, "mask_path"] = str(mask_dir / rel_image_path) |
| 159 | + |
| 160 | + # make sure all the files exist |
| 161 | + # samples.image_path does NOT need to be checked because we build the df based on that |
| 162 | + assert samples.mask_path.apply( |
| 163 | + lambda x: Path(x).exists() if x != "" else True |
| 164 | + ).all(), f"missing mask files, mask_dir={mask_dir}" |
108 | 165 |
|
109 | 166 | # Ensure the pathlib objects are converted to str.
|
110 | 167 | # This is because torch dataloader doesn't like pathlib.
|
|
0 commit comments