Skip to content

Commit 4c3cbee

Browse files
committed
Fix device functions to accept pointers instead of values
1 parent eaca2d6 commit 4c3cbee

15 files changed

+153
-73
lines changed

c/parallel/src/kernels/operators.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ constexpr std::string_view binary_op_template = R"XXX(
3434
)XXX";
3535

3636
constexpr std::string_view stateless_binary_op_template = R"XXX(
37-
extern "C" __device__ {0} OP_NAME(LHS_T lhs, RHS_T rhs);
37+
extern "C" __device__ void OP_NAME(LHS_T* lhs, RHS_T* rhs, {0}* out);
3838
struct op_wrapper {{
3939
__device__ {0} operator()(LHS_T lhs, RHS_T rhs) const {{
40-
return OP_NAME(lhs, rhs);
40+
{0} ret;
41+
OP_NAME(&lhs, &rhs, &ret);
42+
return ret;
4143
}}
4244
}};
4345
)XXX";
@@ -105,10 +107,12 @@ std::string make_kernel_user_unary_operator(std::string_view input_t, std::strin
105107
)XXX";
106108

107109
constexpr std::string_view stateless_op = R"XXX(
108-
extern "C" __device__ OUTPUT_T OP_NAME(INPUT_T val);
110+
extern "C" __device__ void OP_NAME(INPUT_T* val, OUTPUT_T* result);
109111
struct op_wrapper {
110112
__device__ OUTPUT_T operator()(INPUT_T val) const {
111-
return OP_NAME(val);
113+
OUTPUT_T out;
114+
OP_NAME(&val, &out);
115+
return out;
112116
}
113117
};
114118
)XXX";
@@ -117,7 +121,9 @@ struct op_wrapper {
117121
struct __align__(OP_ALIGNMENT) op_state {
118122
char data[OP_SIZE];
119123
};
124+
120125
extern "C" __device__ OUPUT_T OP_NAME(op_state *state, INPUT_T val);
126+
121127
struct op_wrapper {
122128
op_state state;
123129
__device__ OUTPUT_T operator()(INPUT_T val) {

c/parallel/test/test_merge_sort.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ TEST_CASE("DeviceMergeSort:SortPairsCopy works with custom types", "[merge_sort]
177177
operation_t op = make_operation(
178178
"op",
179179
"struct key_pair { short a; size_t b; };\n"
180-
"extern \"C\" __device__ bool op(key_pair lhs, key_pair rhs) {\n"
181-
" return lhs.a == rhs.a ? lhs.b < rhs.b : lhs.a < rhs.a;\n"
180+
"extern \"C\" __device__ void op(key_pair* lhs, key_pair* rhs, bool* out) {\n"
181+
" *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;\n"
182182
"}");
183183
const std::vector<short> a = generate<short>(num_items);
184184
const std::vector<size_t> b = generate<size_t>(num_items);

c/parallel/test/test_reduce.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ TEST_CASE("Reduce works with custom types", "[reduce]")
7777
operation_t op = make_operation(
7878
"op",
7979
"struct pair { short a; size_t b; };\n"
80-
"extern \"C\" __device__ pair op(pair lhs, pair rhs) {\n"
81-
" return pair{ lhs.a + rhs.a, lhs.b + rhs.b };\n"
80+
"extern \"C\" __device__ op(pair* lhs, pair* rhs, pair* out) {\n"
81+
" *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n"
8282
"}");
8383
const std::vector<short> a = generate<short>(num_items);
8484
const std::vector<size_t> b = generate<size_t>(num_items);

c/parallel/test/test_scan.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ TEST_CASE("Scan works with custom types", "[scan]")
125125
operation_t op = make_operation(
126126
"op",
127127
"struct pair { short a; size_t b; };\n"
128-
"extern \"C\" __device__ pair op(pair lhs, pair rhs) {\n"
129-
" return pair{ lhs.a + rhs.a, lhs.b + rhs.b };\n"
128+
"extern \"C\" __device__ void op(pair* lhs, pair* rhs, pair* out) {\n"
129+
" *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n"
130130
"}");
131131
const std::vector<short> a = generate<short>(num_items);
132132
const std::vector<size_t> b = generate<size_t>(num_items);

c/parallel/test/test_segmented_reduce.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ struct pair {{
208208
short a;
209209
size_t b;
210210
}};
211-
extern "C" __device__ pair {0}(pair lhs, pair rhs) {{
212-
return pair{{ lhs.a + rhs.a, lhs.b + rhs.b }};
211+
extern "C" __device__ void {0}(pair* lhs, pair* rhs, pair* out) {{
212+
*out = pair{{ lhs->a + rhs->a, lhs->b + rhs->b }};
213213
}}
214214
)XXX";
215215
std::string plus_pair_op_src = std::format(plus_pair_op_template, device_op_name);

c/parallel/test/test_transform.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ TEST_CASE("Transform works with output of different type", "[transform]")
105105
operation_t op = make_operation(
106106
"op",
107107
"struct pair { short a; size_t b; };\n"
108-
"extern \"C\" __device__ pair op(int x) {\n"
109-
" return pair{ short(x), size_t(x) };\n"
108+
"extern \"C\" __device__ void op(int* x, pair* out) {\n"
109+
" *out = pair{ short(*x), size_t(*x) };\n"
110110
"}");
111111
const std::vector<int> input = generate<int>(num_items);
112112
std::vector<pair> expected(num_items);
@@ -132,8 +132,8 @@ TEST_CASE("Transform works with custom types", "[transform]")
132132
operation_t op = make_operation(
133133
"op",
134134
"struct pair { short a; size_t b; };\n"
135-
"extern \"C\" __device__ pair op(pair x) {\n"
136-
" return pair{ x.a * 2, x.b * 2 };\n"
135+
"extern \"C\" __device__ void op(pair* x, pair* out) {\n"
136+
" *out = pair{ x->a * 2, x->b * 2 };\n"
137137
"}");
138138
const std::vector<short> a = generate<short>(num_items);
139139
const std::vector<size_t> b = generate<size_t>(num_items);
@@ -217,8 +217,8 @@ TEST_CASE("Transform with binary operator", "[transform]")
217217

218218
operation_t op = make_operation(
219219
"op",
220-
"extern \"C\" __device__ int op(int x, int y) {\n"
221-
" return (x > y) ? x : y;\n"
220+
"extern \"C\" __device__ void op(int* x, int* y, int* out) {\n"
221+
" *out = (*x > *y) ? *x : *y;\n"
222222
"}");
223223

224224
binary_transform(input1_ptr, input2_ptr, output_ptr, num_items, op);
@@ -248,8 +248,8 @@ TEST_CASE("Binary transform with one iterator", "[transform]")
248248

249249
operation_t op = make_operation(
250250
"op",
251-
"extern \"C\" __device__ int op(int x, int y) {\n"
252-
" return (x > y) ? x : y;\n"
251+
"extern \"C\" __device__ void op(int* x, int* y, int* out) {\n"
252+
" *out = (*x > *y) ? *x : *y;\n"
253253
"}");
254254

255255
binary_transform(input1_ptr, input2_it, output_ptr, num_items, op);

c/parallel/test/test_unique_by_key.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ TEST_CASE("DeviceSelect::UniqueByKey works with custom types", "[device][select_
215215
operation_t op = make_operation(
216216
"op",
217217
"struct key_pair { short a; size_t b; };\n"
218-
"extern \"C\" __device__ bool op(key_pair lhs, key_pair rhs) {\n"
219-
" return lhs.a == rhs.a && lhs.b == rhs.b;\n"
218+
"extern \"C\" __device__ void op(key_pair* lhs, key_pair* rhs, bool* out) {\n"
219+
" *out = (lhs->a == rhs->a && lhs->b == rhs->b);\n"
220220
"}");
221221
const std::vector<short> a = generate<short>(num_items);
222222
const std::vector<size_t> b = generate<size_t>(num_items);

c/parallel/test/test_util.h

+44-39
Original file line numberDiff line numberDiff line change
@@ -207,21 +207,20 @@ static std::string get_reduce_op(cccl_type_enum t)
207207
switch (t)
208208
{
209209
case cccl_type_enum::CCCL_INT8:
210-
return "extern \"C\" __device__ char op(char a, char b) { return a + b; }";
210+
return "extern \"C\" __device__ void op(char* a, char* b, char* out) { *out = *a + *b; }";
211211
case cccl_type_enum::CCCL_INT32:
212-
return "extern \"C\" __device__ int op(int a, int b) { return a + b; }";
212+
return "extern \"C\" __device__ void op(int* a, int* b, int* out) { *out = *a + *b; }";
213213
case cccl_type_enum::CCCL_UINT32:
214-
return "extern \"C\" __device__ unsigned int op(unsigned int a, unsigned int b) { return a + b; }";
214+
return "extern \"C\" __device__ void op(unsigned int* a, unsigned int* b, unsigned int* out) { *out = *a + *b; }";
215215
case cccl_type_enum::CCCL_INT64:
216-
return "extern \"C\" __device__ long long op(long long a, long long b) { return a + b; }";
216+
return "extern \"C\" __device__ void op(long long* a, long long* b, long long* out) { *out = *a + *b; }";
217217
case cccl_type_enum::CCCL_UINT64:
218-
return "extern \"C\" __device__ unsigned long long op(unsigned long long a, unsigned long long b) { "
219-
" return a + b; "
220-
"}";
218+
return "extern \"C\" __device__ void op(unsigned long long* a, unsigned long long* b, unsigned long long* out) { "
219+
"*out = *a + *b; }";
221220
case cccl_type_enum::CCCL_FLOAT32:
222-
return "extern \"C\" __device__ float op(float a, float b) { return a + b; }";
221+
return "extern \"C\" __device__ void op(float* a, float* b, float* out) { *out = *a + *b; }";
223222
case cccl_type_enum::CCCL_FLOAT64:
224-
return "extern \"C\" __device__ double op(double a, double b) { return a + b; }";
223+
return "extern \"C\" __device__ void op(double* a, double* b, double* out) { *out = *a + *b; }";
225224
default:
226225
throw std::runtime_error("Unsupported type");
227226
}
@@ -253,26 +252,29 @@ static std::string get_merge_sort_op(cccl_type_enum t)
253252
switch (t)
254253
{
255254
case cccl_type_enum::CCCL_INT8:
256-
return "extern \"C\" __device__ bool op(char lhs, char rhs) { return lhs < rhs; }";
255+
return "extern \"C\" __device__ void op(char* lhs, char* rhs, bool* result) { *result = *lhs < *rhs; }";
257256
case cccl_type_enum::CCCL_UINT8:
258-
return "extern \"C\" __device__ bool op(unsigned char lhs, unsigned char rhs) { return lhs < rhs; }";
257+
return "extern \"C\" __device__ void op(unsigned char* lhs, unsigned char* rhs, bool* result) { *result = *lhs < "
258+
"*rhs; }";
259259
case cccl_type_enum::CCCL_INT16:
260-
return "extern \"C\" __device__ bool op(short lhs, short rhs) { return lhs < rhs; }";
260+
return "extern \"C\" __device__ void op(short* lhs, short* rhs, bool* result) { *result = *lhs < *rhs; }";
261261
case cccl_type_enum::CCCL_UINT16:
262-
return "extern \"C\" __device__ bool op(unsigned short lhs, unsigned short rhs) { return lhs < rhs; }";
262+
return "extern \"C\" __device__ void op(unsigned short* lhs, unsigned short* rhs, bool* result) { *result = *lhs "
263+
"< *rhs; }";
263264
case cccl_type_enum::CCCL_INT32:
264-
return "extern \"C\" __device__ bool op(int lhs, int rhs) { return lhs < rhs; }";
265+
return "extern \"C\" __device__ void op(int* lhs, int* rhs, bool* result) { *result = *lhs < *rhs; }";
265266
case cccl_type_enum::CCCL_UINT32:
266-
return "extern \"C\" __device__ bool op(unsigned int lhs, unsigned int rhs) { return lhs < rhs; }";
267+
return "extern \"C\" __device__ void op(unsigned int* lhs, unsigned int* rhs, bool* result) { *result = *lhs < "
268+
"*rhs; }";
267269
case cccl_type_enum::CCCL_INT64:
268-
return "extern \"C\" __device__ bool op(long long lhs, long long rhs) { return lhs < rhs; }";
270+
return "extern \"C\" __device__ void op(long long* lhs, long long* rhs, bool* result) { *result = *lhs < *rhs; }";
269271
case cccl_type_enum::CCCL_UINT64:
270-
return "extern \"C\" __device__ bool op(unsigned long long lhs, unsigned long long rhs) { return lhs < rhs; }";
272+
return "extern \"C\" __device__ void op(unsigned long long* lhs, unsigned long long* rhs, bool* result) { "
273+
"*result = *lhs < *rhs; }";
271274
case cccl_type_enum::CCCL_FLOAT32:
272-
return "extern \"C\" __device__ bool op(float lhs, float rhs) { return lhs < rhs; }";
275+
return "extern \"C\" __device__ void op(float* lhs, float* rhs, bool* result) { *result = *lhs < *rhs; }";
273276
case cccl_type_enum::CCCL_FLOAT64:
274-
return "extern \"C\" __device__ bool op(double lhs, double rhs) { return lhs < rhs; }";
275-
277+
return "extern \"C\" __device__ void op(double* lhs, double* rhs, bool* result) { *result = *lhs < *rhs; }";
276278
default:
277279
throw std::runtime_error("Unsupported type");
278280
}
@@ -284,25 +286,30 @@ static std::string get_unique_by_key_op(cccl_type_enum t)
284286
switch (t)
285287
{
286288
case cccl_type_enum::CCCL_INT8:
287-
return "extern \"C\" __device__ bool op(char lhs, char rhs) { return lhs == rhs; }";
289+
return "extern \"C\" __device__ void op(char* lhs, char* rhs, bool* result) { *result = *lhs == *rhs; }";
288290
case cccl_type_enum::CCCL_UINT8:
289-
return "extern \"C\" __device__ bool op(unsigned char lhs, unsigned char rhs) { return lhs == rhs; }";
291+
return "extern \"C\" __device__ void op(unsigned char* lhs, unsigned char* rhs, bool* result) { *result = *lhs "
292+
"== *rhs; }";
290293
case cccl_type_enum::CCCL_INT16:
291-
return "extern \"C\" __device__ bool op(short lhs, short rhs) { return lhs == rhs; }";
294+
return "extern \"C\" __device__ void op(short* lhs, short* rhs, bool* result) { *result = *lhs == *rhs; }";
292295
case cccl_type_enum::CCCL_UINT16:
293-
return "extern \"C\" __device__ bool op(unsigned short lhs, unsigned short rhs) { return lhs == rhs; }";
296+
return "extern \"C\" __device__ void op(unsigned short* lhs, unsigned short* rhs, bool* result) { *result = *lhs "
297+
"== *rhs; }";
294298
case cccl_type_enum::CCCL_INT32:
295-
return "extern \"C\" __device__ bool op(int lhs, int rhs) { return lhs == rhs; }";
299+
return "extern \"C\" __device__ void op(int* lhs, int* rhs, bool* result) { *result = *lhs == *rhs; }";
296300
case cccl_type_enum::CCCL_UINT32:
297-
return "extern \"C\" __device__ bool op(unsigned int lhs, unsigned int rhs) { return lhs == rhs; }";
301+
return "extern \"C\" __device__ void op(unsigned int* lhs, unsigned int* rhs, bool* result) { *result = *lhs == "
302+
"*rhs; }";
298303
case cccl_type_enum::CCCL_INT64:
299-
return "extern \"C\" __device__ bool op(long long lhs, long long rhs) { return lhs == rhs; }";
304+
return "extern \"C\" __device__ void op(long long* lhs, long long* rhs, bool* result) { *result = *lhs == *rhs; "
305+
"}";
300306
case cccl_type_enum::CCCL_UINT64:
301-
return "extern \"C\" __device__ bool op(unsigned long long lhs, unsigned long long rhs) { return lhs == rhs; }";
307+
return "extern \"C\" __device__ void op(unsigned long long* lhs, unsigned long long* rhs, bool* result) { "
308+
"*result = *lhs == *rhs; }";
302309
case cccl_type_enum::CCCL_FLOAT32:
303-
return "extern \"C\" __device__ bool op(float lhs, float rhs) { return lhs == rhs; }";
310+
return "extern \"C\" __device__ void op(float* lhs, float* rhs, bool* result) { *result = *lhs == *rhs; }";
304311
case cccl_type_enum::CCCL_FLOAT64:
305-
return "extern \"C\" __device__ bool op(double lhs, double rhs) { return lhs == rhs; }";
312+
return "extern \"C\" __device__ void op(double* lhs, double* rhs, bool* result) { *result = *lhs == *rhs; }";
306313
default:
307314
throw std::runtime_error("Unsupported type");
308315
}
@@ -314,21 +321,19 @@ static std::string get_unary_op(cccl_type_enum t)
314321
switch (t)
315322
{
316323
case cccl_type_enum::CCCL_INT8:
317-
return "extern \"C\" __device__ char op(char a) { return 2 * a; }";
324+
return "extern \"C\" __device__ void op(char* a, char* result) { *result = 2 * *a; }";
318325
case cccl_type_enum::CCCL_INT32:
319-
return "extern \"C\" __device__ int op(int a) { return 2 * a; }";
326+
return "extern \"C\" __device__ void op(int* a, int* result) { *result = 2 * *a; }";
320327
case cccl_type_enum::CCCL_UINT32:
321-
return "extern \"C\" __device__ unsigned int op(unsigned int a) { return 2 * a; }";
328+
return "extern \"C\" __device__ void op(unsigned int* a, unsigned int* result) { *result = 2 * *a; }";
322329
case cccl_type_enum::CCCL_INT64:
323-
return "extern \"C\" __device__ long long op(long long a) { return 2 * a; }";
330+
return "extern \"C\" __device__ void op(long long* a, long long* result) { *result = 2 * *a; }";
324331
case cccl_type_enum::CCCL_UINT64:
325-
return "extern \"C\" __device__ unsigned long long op(unsigned long long a) { "
326-
" return 2 * a; "
327-
"}";
332+
return "extern \"C\" __device__ void op(unsigned long long* a, unsigned long long* result) { *result = 2 * *a; }";
328333
case cccl_type_enum::CCCL_FLOAT32:
329-
return "extern \"C\" __device__ float op(float a) { return 2 * a; }";
334+
return "extern \"C\" __device__ void op(float* a, float* result) { *result = 2 * *a; }";
330335
case cccl_type_enum::CCCL_FLOAT64:
331-
return "extern \"C\" __device__ double op(double a) { return 2 * a; }";
336+
return "extern \"C\" __device__ void op(double* a, double* result) { *result = 2 * *a; }";
332337
default:
333338
throw std::runtime_error("Unsupported type");
334339
}

python/cuda_parallel/cuda/parallel/experimental/_cccl_interop.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from __future__ import annotations
66

77
import functools
8-
from typing import Callable, List
8+
import textwrap
9+
from typing import TYPE_CHECKING, Callable, List
910

1011
import numba
1112
import numpy as np
@@ -46,6 +47,10 @@
4647
}
4748

4849

50+
if TYPE_CHECKING:
51+
from numba.core.typing import Signature
52+
53+
4954
def _type_to_enum(numba_type: types.Type) -> IntEnumerationMember:
5055
if numba_type in _TYPE_TO_ENUM:
5156
return _TYPE_TO_ENUM[numba_type]
@@ -175,7 +180,9 @@ def to_cccl_value(array_or_struct: np.ndarray | GpuStruct) -> Value:
175180
return to_cccl_value(array_or_struct._data)
176181

177182

178-
def to_cccl_op(op: Callable, sig) -> Op:
183+
def _to_cccl_op(op: Callable, sig: Signature) -> Op:
184+
# Internal helper that doesn't do any wrapping of `op` (see
185+
# `to_cccl_op` below).
179186
ltoir, _ = cuda.compile(op, sig=sig, output="ltoir")
180187
return Op(
181188
operator_type=OpKind.STATELESS,
@@ -186,6 +193,36 @@ def to_cccl_op(op: Callable, sig) -> Op:
186193
)
187194

188195

196+
def to_cccl_op(op: Callable, sig: Signature) -> Op:
197+
# Return an `Op` object corresponding to the given callable `op`.
198+
# Importantly, this wraps the callable in a device function that
199+
# takes pointers to the arguments and a pointer to the return
200+
# value.
201+
202+
op = cuda.jit(op, device=True)
203+
204+
def deref(s: str) -> str:
205+
return f"{s}[0]"
206+
207+
arg_names = [f"arg_{i}" for i in range(len(sig.args))]
208+
209+
wrapped_op_src = textwrap.dedent(f"""
210+
def wrapped_op({', '.join(arg_names)}, ret):
211+
ret[0] = op({', '.join(map(deref, arg_names))})
212+
""")
213+
local_ns = {"op": op}
214+
print(wrapped_op_src)
215+
exec(wrapped_op_src, local_ns)
216+
wrapped_op = local_ns["wrapped_op"]
217+
218+
# Construct the signature: n pointer args + 1 pointer return
219+
cccl_sig = types.void(
220+
*(types.CPointer(arg) for arg in sig.args), types.CPointer(sig.return_type)
221+
)
222+
223+
return _to_cccl_op(wrapped_op, cccl_sig)
224+
225+
189226
def get_value_type(d_in: IteratorBase | DeviceArrayLike):
190227
from .struct import gpu_struct_from_numpy_dtype
191228

python/cuda_parallel/cuda/parallel/experimental/algorithms/_merge_sort.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Callable
77

88
import numba
9+
from numba import types
910

1011
from .. import _bindings
1112
from .. import _cccl_interop as cccl
@@ -84,8 +85,7 @@ def __init__(
8485
else:
8586
value_type = numba.from_dtype(protocols.get_dtype(d_in_keys))
8687

87-
sig = (value_type, value_type)
88-
self.op_wrapper = cccl.to_cccl_op(op, sig)
88+
self.op_wrapper = cccl.to_cccl_op(op, types.uint8(value_type, value_type))
8989

9090
self.build_result = call_build(
9191
_bindings.DeviceMergeSortBuildResult,

python/cuda_parallel/cuda/parallel/experimental/algorithms/_reduce.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def __init__(
4444
value_type = numba.from_dtype(h_init.dtype)
4545
else:
4646
value_type = numba.typeof(h_init)
47-
sig = (value_type, value_type)
48-
self.op_wrapper = cccl.to_cccl_op(op, sig)
47+
self.op_wrapper = cccl.to_cccl_op(op, value_type(value_type, value_type))
4948
self.build_result = call_build(
5049
_bindings.DeviceReduceBuildResult,
5150
self.d_in_cccl,

python/cuda_parallel/cuda/parallel/experimental/algorithms/_scan.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def __init__(
4646
value_type = numba.from_dtype(h_init.dtype)
4747
else:
4848
value_type = numba.typeof(h_init)
49-
sig = (value_type, value_type)
50-
self.op_wrapper = cccl.to_cccl_op(op, sig)
49+
self.op_wrapper = cccl.to_cccl_op(op, value_type(value_type, value_type))
5150
self.build_result = call_build(
5251
_bindings.DeviceScanBuildResult,
5352
self.d_in_cccl,

python/cuda_parallel/cuda/parallel/experimental/algorithms/_segmented_reduce.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def __init__(
4242
value_type = numba.from_dtype(h_init.dtype)
4343
else:
4444
value_type = numba.typeof(h_init)
45-
sig = (value_type, value_type)
46-
self.op_wrapper = cccl.to_cccl_op(op, sig)
45+
self.op_wrapper = cccl.to_cccl_op(op, value_type(value_type, value_type))
4746
self.build_result = call_build(
4847
_bindings.DeviceSegmentedReduceBuildResult,
4948
self.d_in_cccl,

0 commit comments

Comments
 (0)