|
24 | 24 | from nose_utils import SkipTest, attr, nottest
|
25 | 25 | from nvidia.dali.pipeline import Pipeline, pipeline_def
|
26 | 26 | from nvidia.dali.pipeline.experimental import pipeline_def as experimental_pipeline_def
|
27 |
| -from nvidia.dali.plugin.numba.fn.experimental import numba_function |
28 | 27 |
|
29 | 28 | import test_utils
|
30 | 29 | from segmentation_test_utils import make_batch_select_masks
|
@@ -375,41 +374,30 @@ def numba_setup_out_shape(out_shape, in_shape):
|
375 | 374 | (fn.ones_like, {"devices": ["cpu"]}),
|
376 | 375 | ]
|
377 | 376 |
|
| 377 | +numba_compatible_devices = [] |
378 | 378 | if check_numba_compatibility_gpu(False):
|
379 |
| - ops_image_custom_args.append( |
380 |
| - ( |
381 |
| - numba_function, |
382 |
| - { |
383 |
| - "batch_processing": True, |
384 |
| - "devices": ["cpu"], |
385 |
| - "in_types": [types.UINT8], |
386 |
| - "ins_ndim": [3], |
387 |
| - "out_types": [types.UINT8], |
388 |
| - "outs_ndim": [3], |
389 |
| - "run_fn": numba_set_all_values_to_255_batch, |
390 |
| - "setup_fn": numba_setup_out_shape, |
391 |
| - }, |
392 |
| - ) |
393 |
| - ) |
394 |
| - |
395 |
| - |
| 379 | + numba_compatible_devices.append("gpu") |
396 | 380 | if check_numba_compatibility_cpu(False):
|
397 |
| - ops_image_custom_args.append( |
398 |
| - ( |
399 |
| - numba_function, |
400 |
| - { |
401 |
| - "batch_processing": False, |
402 |
| - "devices": ["cpu"], |
403 |
| - "in_types": [types.UINT8], |
404 |
| - "ins_ndim": [3], |
405 |
| - "out_types": [types.UINT8], |
406 |
| - "outs_ndim": [3], |
407 |
| - "run_fn": numba_set_all_values_to_255_batch, |
408 |
| - "setup_fn": numba_setup_out_shape, |
409 |
| - }, |
| 381 | + numba_compatible_devices.append("cpu") |
| 382 | + |
| 383 | +if len(numba_compatible_devices) > 0: |
| 384 | + from nvidia.dali.plugin.numba.fn.experimental import numba_function |
| 385 | + for device in numba_compatible_devices: |
| 386 | + ops_image_custom_args.append( |
| 387 | + ( |
| 388 | + numba_function, |
| 389 | + { |
| 390 | + "batch_processing": True if device == 'gpu' else False, |
| 391 | + "devices": [device], |
| 392 | + "in_types": [types.UINT8], |
| 393 | + "ins_ndim": [3], |
| 394 | + "out_types": [types.UINT8], |
| 395 | + "outs_ndim": [3], |
| 396 | + "run_fn": numba_set_all_values_to_255_batch, |
| 397 | + "setup_fn": numba_setup_out_shape, |
| 398 | + }, |
| 399 | + ) |
410 | 400 | )
|
411 |
| - ) |
412 |
| - |
413 | 401 |
|
414 | 402 | def test_ops_image_custom_args():
|
415 | 403 | for op, args in ops_image_custom_args:
|
|
0 commit comments