Skip to content

Commit b8bd3f9

Browse files
phcarvalAshwin Vaidya
authored and
Ashwin Vaidya
committed
* fixed DSR squeeze bug * added comment
1 parent ee1f659 commit b8bd3f9

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/anomalib/models/dsr/torch_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ def forward(self, batch: Tensor, anomaly_map_to_generate: Tensor | None = None)
154154
).detach()
155155
image_score = torch.amax(out_mask_averaged, dim=(2, 3)).squeeze()
156156

157+
# prevent crash when image_score is a single value (batch size of 1)
158+
if image_score.size() == torch.Size([]):
159+
image_score = image_score.unsqueeze(0)
160+
157161
out_mask_cv = out_mask_sm_up[:, 1, :, :]
158162

159163
outputs = {"anomaly_map": out_mask_cv, "pred_score": image_score}

0 commit comments

Comments
 (0)