Skip to content

Commit 4bf4970

Browse files
ashwinvaidya17Ashwin Vaidya
and
Ashwin Vaidya
authored
🐞 Fix tensor detach and gpu count issues in benchmarking script (#100)
Co-authored-by: Ashwin Vaidya <[email protected]>
1 parent c611e43 commit 4bf4970

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

anomalib/utils/sweep/helpers/inference.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Dict, Iterable, List, Tuple, Union
2020

2121
import numpy as np
22+
import torch
2223
from omegaconf import DictConfig, ListConfig
2324
from torch.utils.data import DataLoader
2425

@@ -106,6 +107,7 @@ def get_torch_throughput(
106107
Returns:
107108
float: Inference throughput
108109
"""
110+
torch.set_grad_enabled(False)
109111
model.eval()
110112
inferencer = TorchInferencer(config, model)
111113
torch_dataloader = MockImageLoader(config.dataset.image_size, len(test_dataset))
@@ -118,6 +120,7 @@ def get_torch_throughput(
118120
inference_time = time.time() - start_time
119121
throughput = len(test_dataset) / inference_time
120122

123+
torch.set_grad_enabled(True)
121124
return throughput
122125

123126

tools/benchmarking/benchmark.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
import logging
19+
import math
1920
import multiprocessing
2021
import time
2122
from concurrent.futures import ProcessPoolExecutor, as_completed
@@ -136,12 +137,12 @@ def distribute_over_gpus():
136137
run_configs = list(get_run_config(sweep_config.grid_search))
137138
jobs = []
138139
for device_id, run_split in enumerate(
139-
range(0, len(run_configs), len(run_configs) // torch.cuda.device_count())
140+
range(0, len(run_configs), math.ceil(len(run_configs) / torch.cuda.device_count()))
140141
):
141142
jobs.append(
142143
executor.submit(
143144
compute_on_gpu,
144-
run_configs[run_split : run_split + len(run_configs) // torch.cuda.device_count()],
145+
run_configs[run_split : run_split + math.ceil(len(run_configs) / torch.cuda.device_count())],
145146
device_id + 1,
146147
sweep_config.seed,
147148
sweep_config.writer,

0 commit comments

Comments
 (0)