@@ -92,6 +92,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, mvtec_data_module):
92
92
assert sorted (["image_path" , "mask_path" , "image" , "label" , "mask" ]) == sorted (val_data .keys ())
93
93
assert sorted (["image_path" , "mask_path" , "image" , "label" , "mask" ]) == sorted (test_data .keys ())
94
94
95
+ def test_non_overlapping_splits (self , mvtec_data_module ):
96
+ """This test ensures that the train and test splits generated are non-overlapping."""
97
+ assert (
98
+ len (
99
+ set (mvtec_data_module .test_data .samples ["image_path" ].values ).intersection (
100
+ set (mvtec_data_module .train_data .samples ["image_path" ].values )
101
+ )
102
+ )
103
+ == 0
104
+ ), "Found train and test split contamination"
105
+
95
106
96
107
class TestBTechDataModule :
97
108
"""Test BTech Data Module."""
@@ -111,6 +122,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, btech_data_module):
111
122
assert sorted (["image_path" , "mask_path" , "image" , "label" , "mask" ]) == sorted (val_data .keys ())
112
123
assert sorted (["image_path" , "mask_path" , "image" , "label" , "mask" ]) == sorted (test_data .keys ())
113
124
125
+ def test_non_overlapping_splits (self , btech_data_module ):
126
+ """This test ensures that the train and test splits generated are non-overlapping."""
127
+ assert (
128
+ len (
129
+ set (btech_data_module .test_data .samples ["image_path" ].values ).intersection (
130
+ set (btech_data_module .train_data .samples ["image_path" ].values )
131
+ )
132
+ )
133
+ == 0
134
+ ), "Found train and test split contamination"
135
+
114
136
115
137
class TestFolderDataModule :
116
138
"""Test Folder Data Module."""
@@ -130,6 +152,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, folder_data_module):
130
152
assert sorted (["image_path" , "mask_path" , "image" , "label" , "mask" ]) == sorted (val_data .keys ())
131
153
assert sorted (["image_path" , "mask_path" , "image" , "label" , "mask" ]) == sorted (test_data .keys ())
132
154
155
+ def test_non_overlapping_splits (self , folder_data_module ):
156
+ """This test ensures that the train and test splits generated are non-overlapping."""
157
+ assert (
158
+ len (
159
+ set (folder_data_module .test_data .samples ["image_path" ].values ).intersection (
160
+ set (folder_data_module .train_data .samples ["image_path" ].values )
161
+ )
162
+ )
163
+ == 0
164
+ ), "Found train and test split contamination"
165
+
133
166
134
167
class TestDenormalize :
135
168
"""Test Denormalize Util."""
0 commit comments