24
24
import torchvision
25
25
from kornia import gaussian_blur2d
26
26
from omegaconf import ListConfig
27
- from sklearn import random_projection
28
27
from torch import Tensor , nn
29
28
30
29
from anomalib .core .model import AnomalyModule
31
30
from anomalib .core .model .dynamic_module import DynamicBufferModule
32
31
from anomalib .core .model .feature_extractor import FeatureExtractor
33
32
from anomalib .data .tiler import Tiler
33
+ from anomalib .models .patchcore .utils .sampling import (
34
+ KCenterGreedy ,
35
+ NearestNeighbors ,
36
+ SparseRandomProjection ,
37
+ )
34
38
35
39
36
40
class AnomalyMapGenerator :
@@ -123,6 +127,7 @@ def __init__(
123
127
124
128
self .feature_extractor = FeatureExtractor (backbone = self .backbone (pretrained = True ), layers = self .layers )
125
129
self .feature_pooler = torch .nn .AvgPool2d (3 , 1 , 1 )
130
+ self .nn_search = NearestNeighbors (n_neighbors = 9 )
126
131
self .anomaly_map_generator = AnomalyMapGenerator (input_size = input_size )
127
132
128
133
if apply_tiling :
@@ -165,8 +170,7 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
165
170
if self .training :
166
171
output = embedding
167
172
else :
168
- distances = torch .cdist (embedding , self .memory_bank , p = 2.0 ) # euclidean norm
169
- patch_scores , _ = distances .topk (k = 9 , largest = False , dim = 1 )
173
+ patch_scores , _ = self .nn_search .kneighbors (embedding )
170
174
171
175
anomaly_map , anomaly_score = self .anomaly_map_generator (patch_scores = patch_scores )
172
176
output = (anomaly_map , anomaly_score )
@@ -209,48 +213,25 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
209
213
embedding = embedding .permute (0 , 2 , 3 , 1 ).reshape (- 1 , embedding_size )
210
214
return embedding
211
215
212
- def create_coreset (
213
- self ,
214
- embedding : Tensor ,
215
- sample_count : int = 500 ,
216
- eps : float = 0.90 ,
217
- ):
218
- """Creates n subsampled coreset for given sample_set.
216
+ @staticmethod
217
+ def subsample_embedding (embedding : torch .Tensor , sampling_ratio : float ) -> torch .Tensor :
218
+ """Subsample embedding based on coreset sampling.
219
219
220
220
Args:
221
- embedding (Tensor): (sample_count, d) tensor of patches.
222
- sample_count (int): Number of patches to select.
223
- eps (float): Parameter for spare projection aggression.
221
+ embedding (np.ndarray): Embedding tensor from the CNN
222
+ sampling_ratio (float): Coreset sampling ratio
223
+
224
+ Returns:
225
+ np.ndarray: Subsampled embedding whose dimensionality is reduced.
224
226
"""
225
- # TODO: https://github.com/openvinotoolkit/anomalib/issues/54
226
- # Replace print statement with logger.
227
- print ("Fitting random projections..." )
228
- try :
229
- transformer = random_projection .SparseRandomProjection (eps = eps )
230
- sample_set = torch .tensor (transformer .fit_transform (embedding .cpu ())).to ( # pylint: disable=not-callable
231
- embedding .device
232
- )
233
- except ValueError :
234
- # TODO: https://github.com/openvinotoolkit/anomalib/issues/54
235
- # Replace print statement with logger.
236
- print (" Error: could not project vectors. Please increase `eps` value." )
237
-
238
- select_idx = 0
239
- last_item = sample_set [select_idx : select_idx + 1 ]
240
- coreset_idx = [torch .tensor (select_idx ).to (embedding .device )] # pylint: disable=not-callable
241
- min_distances = torch .linalg .norm (sample_set - last_item , dim = 1 , keepdims = True )
242
-
243
- for _ in range (sample_count - 1 ):
244
- distances = torch .linalg .norm (sample_set - last_item , dim = 1 , keepdims = True ) # broadcast
245
- min_distances = torch .minimum (distances , min_distances ) # iterate
246
- select_idx = torch .argmax (min_distances ) # select
247
-
248
- last_item = sample_set [select_idx : select_idx + 1 ]
249
- min_distances [select_idx ] = 0
250
- coreset_idx .append (select_idx )
251
-
252
- coreset_idx = torch .stack (coreset_idx )
253
- self .memory_bank = embedding [coreset_idx ]
227
+ # Random projection
228
+ random_projector = SparseRandomProjection (eps = 0.9 )
229
+ random_projector .fit (embedding )
230
+
231
+ # Coreset Subsampling
232
+ sampler = KCenterGreedy (model = random_projector , embedding = embedding , sampling_ratio = sampling_ratio )
233
+ coreset = sampler .sample_coreset ()
234
+ return coreset
254
235
255
236
256
237
class PatchcoreLightning (AnomalyModule ):
@@ -311,10 +292,12 @@ def training_epoch_end(self, outputs):
311
292
outputs (List[Dict[str, np.ndarray]]): List of embedding vectors
312
293
"""
313
294
embedding = torch .vstack ([output ["embedding" ] for output in outputs ])
314
-
315
295
sampling_ratio = self .hparams .model .coreset_sampling_ratio
316
296
317
- self .model .create_coreset (embedding = embedding , sample_count = int (sampling_ratio * embedding .shape [0 ]), eps = 0.9 )
297
+ embedding = self .model .subsample_embedding (embedding , sampling_ratio )
298
+
299
+ self .model .nn_search .fit (embedding )
300
+ self .model .memory_bank = embedding
318
301
319
302
def validation_step (self , batch , _ ): # pylint: disable=arguments-differ
320
303
"""Get batch of anomaly maps from input image batch.
@@ -328,8 +311,7 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
328
311
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
329
312
"""
330
313
331
- anomaly_maps , anomaly_score = self .model (batch ["image" ])
314
+ anomaly_maps , _ = self .model (batch ["image" ])
332
315
batch ["anomaly_maps" ] = anomaly_maps
333
- batch ["pred_scores" ] = anomaly_score .unsqueeze (0 )
334
316
335
317
return batch
0 commit comments