Skip to content

Commit d0775cf

Browse files
committed
Code review fixes
Signed-off-by: Joaquin Anton Guirao <[email protected]>
1 parent c32b46c commit d0775cf

File tree

5 files changed

+30
-5
lines changed

5 files changed

+30
-5
lines changed

dali/test/python/checkpointing/test_dali_stateless_operators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
check_numba_compatibility_cpu,
2525
has_operator,
2626
restrict_platform,
27+
is_of_supported
2728
)
2829
from nose2.tools import params, cartesian_params
2930
from nose_utils import assert_raises, SkipTest, attr
@@ -575,8 +576,6 @@ def test_preemphasis_filter_stateless(device):
575576

576577
@stateless_signed_off("optical_flow")
577578
def test_optical_flow_stateless():
578-
from test_optical_flow import is_of_supported
579-
580579
if not is_of_supported():
581580
raise SkipTest("Optical Flow is not supported on this platform")
582581
check_single_sequence_input(fn.optical_flow, "gpu")

dali/test/python/test_dali_variable_batch_size.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929
import test_utils
3030
from segmentation_test_utils import make_batch_select_masks
3131
from test_detection_pipeline import coco_anchors
32-
from test_optical_flow import load_frames, is_of_supported
3332
from test_utils import (
3433
module_functions,
3534
has_operator,
3635
restrict_platform,
3736
check_numba_compatibility_cpu,
3837
check_numba_compatibility_gpu,
38+
is_of_supported
3939
)
4040

4141
"""

dali/test/python/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
from nose_utils import SkipTest
3131

3232

33+
is_of_supported_var = None
34+
35+
3336
def get_arch(device_id=0):
3437
compute_cap = 0
3538
try:
@@ -989,3 +992,25 @@ def load_test_operator_plugin():
989992
except RuntimeError:
990993
# in conda "libtestoperatorplugin" lands inside lib/ dir
991994
plugin_manager.load_library("libtestoperatorplugin.so")
995+
996+
997+
def is_of_supported(device_id=0):
998+
global is_of_supported_var
999+
if is_of_supported_var is not None:
1000+
return is_of_supported_var
1001+
1002+
driver_version_major = 0
1003+
try:
1004+
import pynvml
1005+
1006+
pynvml.nvmlInit()
1007+
driver_version = pynvml.nvmlSystemGetDriverVersion().decode("utf-8")
1008+
driver_version_major = int(driver_version.split(".")[0])
1009+
except ModuleNotFoundError:
1010+
print("NVML not found")
1011+
1012+
# there is an issue with OpticalFlow driver in R495 and newer on aarch64 platform
1013+
is_of_supported_var = get_arch(device_id) >= 7.5 and (
1014+
platform.machine() == "x86_64" or driver_version_major < 495
1015+
)
1016+
return is_of_supported_var

qa/TL0_self_test_Ampere/test.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash -ex
22

3-
pip_packages='${python_test_runner_package} numpy'
3+
pip_packages='${python_test_runner_package} numpy opencv-python nvidia-ml-py==11.450.51'
44

55
target_dir=./dali/test/python
66

@@ -33,8 +33,9 @@ test_body() {
3333
${python_new_invoke_test} -s decoder test_image
3434

3535
# test Optical Flow
36-
${python_invoke_test} test_optical_flow.py
36+
${python_new_invoke_test} -s operator_1 test_optical_flow
3737
${python_invoke_test} test_dali_variable_batch_size.py:test_optical_flow
38+
${python_invoke_test} test_dali_stateless_operators.py:test_optical_flow_stateless
3839
}
3940

4041
pushd ../..

0 commit comments

Comments
 (0)