Skip to content

Commit c6ab4a0

Browse files
committed
Fix some issues in Python bindings + skip sass verification in failing tests
1 parent 2885624 commit c6ab4a0

File tree

7 files changed

+24
-5
lines changed

7 files changed

+24
-5
lines changed

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyx

+5-1
Original file line numberDiff line numberDiff line change
@@ -1671,13 +1671,17 @@ cdef class DeviceUniqueByKeyBuildResult:
16711671
)
16721672
return storage_sz
16731673

1674+
def _get_cubin(self):
1675+
return self.build_data.cubin[:self.build_data.cubin_size]
1676+
16741677
# -----------------
16751678
# DeviceRadixSort
16761679
# -----------------
16771680

16781681
cdef extern from "cccl/c/radix_sort.h":
16791682
cdef struct cccl_device_radix_sort_build_result_t 'cccl_device_radix_sort_build_result_t':
1680-
pass
1683+
const char* cubin
1684+
size_t cubin_size
16811685

16821686
cdef CUresult cccl_device_radix_sort_build(
16831687
cccl_device_radix_sort_build_result_t *build_ptr,

python/cuda_parallel/cuda/parallel/experimental/algorithms/_merge_sort.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Callable
77

88
import numba
9-
from numba import types
109

1110
from .. import _bindings
1211
from .. import _cccl_interop as cccl

python/cuda_parallel/cuda/parallel/experimental/algorithms/_radix_sort.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,13 @@ def __init__(
114114
self.d_out_values_cccl = cccl.to_cccl_iter(d_out_values_array)
115115

116116
# TODO: decomposer op is not supported for now
117-
self.decomposer_op = cccl.to_cccl_op(None, None)
117+
self.decomposer_op = cccl.Op(
118+
name="",
119+
operator_type=cccl.OpKind.STATELESS,
120+
ltoir=b"",
121+
state_alignment=1,
122+
state=None,
123+
)
118124
decomposer_return_type = "".encode("utf-8")
119125

120126
self.build_result = call_build(

python/cuda_parallel/pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,6 @@ extend = "../../pyproject.toml"
7777

7878
[tool.ruff.lint.isort]
7979
known-first-party = ["cuda.parallel"]
80+
81+
[tool.pytest.ini_options]
82+
markers = ["no_verify_sass: skip SASS verification check"]

python/cuda_parallel/tests/conftest.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,14 @@ def cuda_stream() -> Stream:
7070

7171

7272
@pytest.fixture(scope="function", autouse=True)
73-
def verify_sass(monkeypatch):
73+
def verify_sass(request, monkeypatch):
74+
if request.node.get_closest_marker("no_verify_sass"):
75+
return
76+
7477
import cuda.parallel.experimental._cccl_interop
7578

7679
monkeypatch.setattr(
7780
cuda.parallel.experimental._cccl_interop,
7881
"_check_sass",
79-
False, # todo: change to True
82+
True,
8083
)

python/cuda_parallel/tests/test_reduce_api.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
import pytest
6+
57

68
def test_device_reduce():
79
# example-begin reduce-min
@@ -215,6 +217,7 @@ def max_g_value(x, y):
215217
# example-end reduce-struct
216218

217219

220+
@pytest.mark.no_verify_sass(reason="LDL/STL instructions emitted for this test.")
218221
def test_reduce_struct_type_minmax():
219222
# example-begin reduce-minmax
220223
import cupy as cp

python/cuda_parallel/tests/test_scan.py

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def op(a, b):
120120
"force_inclusive",
121121
[True, False],
122122
)
123+
@pytest.mark.no_verify_sass(reason="LDL/STL instructions emitted for this test.")
123124
def test_scan_struct_type(force_inclusive):
124125
@gpu_struct
125126
class XY:

0 commit comments

Comments
 (0)