Skip to content

Commit bce2a80

Browse files
committed
Fixes
Signed-off-by: Joaquin Anton Guirao <[email protected]>
1 parent 3f1f2b8 commit bce2a80

File tree

2 files changed

+22
-34
lines changed

2 files changed

+22
-34
lines changed

dali/test/python/test_dali_variable_batch_size.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from nose_utils import SkipTest, attr, nottest
2525
from nvidia.dali.pipeline import Pipeline, pipeline_def
2626
from nvidia.dali.pipeline.experimental import pipeline_def as experimental_pipeline_def
27-
from nvidia.dali.plugin.numba.fn.experimental import numba_function
2827

2928
import test_utils
3029
from segmentation_test_utils import make_batch_select_masks
@@ -375,41 +374,30 @@ def numba_setup_out_shape(out_shape, in_shape):
375374
(fn.ones_like, {"devices": ["cpu"]}),
376375
]
377376

377+
numba_compatible_devices = []
378378
if check_numba_compatibility_gpu(False):
379-
ops_image_custom_args.append(
380-
(
381-
numba_function,
382-
{
383-
"batch_processing": True,
384-
"devices": ["cpu"],
385-
"in_types": [types.UINT8],
386-
"ins_ndim": [3],
387-
"out_types": [types.UINT8],
388-
"outs_ndim": [3],
389-
"run_fn": numba_set_all_values_to_255_batch,
390-
"setup_fn": numba_setup_out_shape,
391-
},
392-
)
393-
)
394-
395-
379+
numba_compatible_devices.append("gpu")
396380
if check_numba_compatibility_cpu(False):
397-
ops_image_custom_args.append(
398-
(
399-
numba_function,
400-
{
401-
"batch_processing": False,
402-
"devices": ["cpu"],
403-
"in_types": [types.UINT8],
404-
"ins_ndim": [3],
405-
"out_types": [types.UINT8],
406-
"outs_ndim": [3],
407-
"run_fn": numba_set_all_values_to_255_batch,
408-
"setup_fn": numba_setup_out_shape,
409-
},
381+
numba_compatible_devices.append("cpu")
382+
383+
if len(numba_compatible_devices) > 0:
384+
from nvidia.dali.plugin.numba.fn.experimental import numba_function
385+
for device in numba_compatible_devices:
386+
ops_image_custom_args.append(
387+
(
388+
numba_function,
389+
{
390+
"batch_processing": True if device == 'gpu' else False,
391+
"devices": [device],
392+
"in_types": [types.UINT8],
393+
"ins_ndim": [3],
394+
"out_types": [types.UINT8],
395+
"outs_ndim": [3],
396+
"run_fn": numba_set_all_values_to_255_batch,
397+
"setup_fn": numba_setup_out_shape,
398+
},
399+
)
410400
)
411-
)
412-
413401

414402
def test_ops_image_custom_args():
415403
for op, args in ops_image_custom_args:

qa/TL0_self_test_Ampere/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash -ex
22

3-
pip_packages='${python_test_runner_package} numpy opencv-python nvidia-ml-py==11.450.51'
3+
pip_packages='${python_test_runner_package} numpy opencv-python nvidia-ml-py==11.450.51 numba'
44

55
target_dir=./dali/test/python
66

0 commit comments

Comments
 (0)