@@ -207,21 +207,20 @@ static std::string get_reduce_op(cccl_type_enum t)
207
207
switch (t)
208
208
{
209
209
case cccl_type_enum::CCCL_INT8:
210
- return " extern \" C\" __device__ char op(char a, char b ) { return a + b; }" ;
210
+ return " extern \" C\" __device__ void op(char* a, char* b, char* out ) { *out = * a + * b; }" ;
211
211
case cccl_type_enum::CCCL_INT32:
212
- return " extern \" C\" __device__ int op(int a, int b ) { return a + b; }" ;
212
+ return " extern \" C\" __device__ void op(int* a, int* b, int* out ) { *out = * a + * b; }" ;
213
213
case cccl_type_enum::CCCL_UINT32:
214
- return " extern \" C\" __device__ unsigned int op(unsigned int a, unsigned int b ) { return a + b; }" ;
214
+ return " extern \" C\" __device__ void op(unsigned int* a, unsigned int* b, unsigned int* out ) { *out = * a + * b; }" ;
215
215
case cccl_type_enum::CCCL_INT64:
216
- return " extern \" C\" __device__ long long op(long long a, long long b ) { return a + b; }" ;
216
+ return " extern \" C\" __device__ void op(long long* a, long long* b, long long* out ) { *out = * a + * b; }" ;
217
217
case cccl_type_enum::CCCL_UINT64:
218
- return " extern \" C\" __device__ unsigned long long op(unsigned long long a, unsigned long long b) { "
219
- " return a + b; "
220
- " }" ;
218
+ return " extern \" C\" __device__ void op(unsigned long long* a, unsigned long long* b, unsigned long long* out) { "
219
+ " *out = *a + *b; }" ;
221
220
case cccl_type_enum::CCCL_FLOAT32:
222
- return " extern \" C\" __device__ float op(float a, float b ) { return a + b; }" ;
221
+ return " extern \" C\" __device__ void op(float* a, float* b, float* out ) { *out = * a + * b; }" ;
223
222
case cccl_type_enum::CCCL_FLOAT64:
224
- return " extern \" C\" __device__ double op(double a, double b ) { return a + b; }" ;
223
+ return " extern \" C\" __device__ void op(double* a, double* b, double* out ) { *out = * a + * b; }" ;
225
224
default :
226
225
throw std::runtime_error (" Unsupported type" );
227
226
}
@@ -253,26 +252,29 @@ static std::string get_merge_sort_op(cccl_type_enum t)
253
252
switch (t)
254
253
{
255
254
case cccl_type_enum::CCCL_INT8:
256
- return " extern \" C\" __device__ bool op(char lhs, char rhs) { return lhs < rhs; }" ;
255
+ return " extern \" C\" __device__ void op(char* lhs, char* rhs, bool* result ) { *result = * lhs < * rhs; }" ;
257
256
case cccl_type_enum::CCCL_UINT8:
258
- return " extern \" C\" __device__ bool op(unsigned char lhs, unsigned char rhs) { return lhs < rhs; }" ;
257
+ return " extern \" C\" __device__ void op(unsigned char* lhs, unsigned char* rhs, bool* result) { *result = *lhs < "
258
+ " *rhs; }" ;
259
259
case cccl_type_enum::CCCL_INT16:
260
- return " extern \" C\" __device__ bool op(short lhs, short rhs) { return lhs < rhs; }" ;
260
+ return " extern \" C\" __device__ void op(short* lhs, short* rhs, bool* result ) { *result = * lhs < * rhs; }" ;
261
261
case cccl_type_enum::CCCL_UINT16:
262
- return " extern \" C\" __device__ bool op(unsigned short lhs, unsigned short rhs) { return lhs < rhs; }" ;
262
+ return " extern \" C\" __device__ void op(unsigned short* lhs, unsigned short* rhs, bool* result) { *result = *lhs "
263
+ " < *rhs; }" ;
263
264
case cccl_type_enum::CCCL_INT32:
264
- return " extern \" C\" __device__ bool op(int lhs, int rhs) { return lhs < rhs; }" ;
265
+ return " extern \" C\" __device__ void op(int* lhs, int* rhs, bool* result ) { *result = * lhs < * rhs; }" ;
265
266
case cccl_type_enum::CCCL_UINT32:
266
- return " extern \" C\" __device__ bool op(unsigned int lhs, unsigned int rhs) { return lhs < rhs; }" ;
267
+ return " extern \" C\" __device__ void op(unsigned int* lhs, unsigned int* rhs, bool* result) { *result = *lhs < "
268
+ " *rhs; }" ;
267
269
case cccl_type_enum::CCCL_INT64:
268
- return " extern \" C\" __device__ bool op(long long lhs, long long rhs) { return lhs < rhs; }" ;
270
+ return " extern \" C\" __device__ void op(long long* lhs, long long* rhs, bool* result ) { *result = * lhs < * rhs; }" ;
269
271
case cccl_type_enum::CCCL_UINT64:
270
- return " extern \" C\" __device__ bool op(unsigned long long lhs, unsigned long long rhs) { return lhs < rhs; }" ;
272
+ return " extern \" C\" __device__ void op(unsigned long long* lhs, unsigned long long* rhs, bool* result) { "
273
+ " *result = *lhs < *rhs; }" ;
271
274
case cccl_type_enum::CCCL_FLOAT32:
272
- return " extern \" C\" __device__ bool op(float lhs, float rhs) { return lhs < rhs; }" ;
275
+ return " extern \" C\" __device__ void op(float* lhs, float* rhs, bool* result ) { *result = * lhs < * rhs; }" ;
273
276
case cccl_type_enum::CCCL_FLOAT64:
274
- return " extern \" C\" __device__ bool op(double lhs, double rhs) { return lhs < rhs; }" ;
275
-
277
+ return " extern \" C\" __device__ void op(double* lhs, double* rhs, bool* result) { *result = *lhs < *rhs; }" ;
276
278
default :
277
279
throw std::runtime_error (" Unsupported type" );
278
280
}
@@ -284,25 +286,30 @@ static std::string get_unique_by_key_op(cccl_type_enum t)
284
286
switch (t)
285
287
{
286
288
case cccl_type_enum::CCCL_INT8:
287
- return " extern \" C\" __device__ bool op(char lhs, char rhs) { return lhs == rhs; }" ;
289
+ return " extern \" C\" __device__ void op(char* lhs, char* rhs, bool* result ) { *result = * lhs == * rhs; }" ;
288
290
case cccl_type_enum::CCCL_UINT8:
289
- return " extern \" C\" __device__ bool op(unsigned char lhs, unsigned char rhs) { return lhs == rhs; }" ;
291
+ return " extern \" C\" __device__ void op(unsigned char* lhs, unsigned char* rhs, bool* result) { *result = *lhs "
292
+ " == *rhs; }" ;
290
293
case cccl_type_enum::CCCL_INT16:
291
- return " extern \" C\" __device__ bool op(short lhs, short rhs) { return lhs == rhs; }" ;
294
+ return " extern \" C\" __device__ void op(short* lhs, short* rhs, bool* result ) { *result = * lhs == * rhs; }" ;
292
295
case cccl_type_enum::CCCL_UINT16:
293
- return " extern \" C\" __device__ bool op(unsigned short lhs, unsigned short rhs) { return lhs == rhs; }" ;
296
+ return " extern \" C\" __device__ void op(unsigned short* lhs, unsigned short* rhs, bool* result) { *result = *lhs "
297
+ " == *rhs; }" ;
294
298
case cccl_type_enum::CCCL_INT32:
295
- return " extern \" C\" __device__ bool op(int lhs, int rhs) { return lhs == rhs; }" ;
299
+ return " extern \" C\" __device__ void op(int* lhs, int* rhs, bool* result ) { *result = * lhs == * rhs; }" ;
296
300
case cccl_type_enum::CCCL_UINT32:
297
- return " extern \" C\" __device__ bool op(unsigned int lhs, unsigned int rhs) { return lhs == rhs; }" ;
301
+ return " extern \" C\" __device__ void op(unsigned int* lhs, unsigned int* rhs, bool* result) { *result = *lhs == "
302
+ " *rhs; }" ;
298
303
case cccl_type_enum::CCCL_INT64:
299
- return " extern \" C\" __device__ bool op(long long lhs, long long rhs) { return lhs == rhs; }" ;
304
+ return " extern \" C\" __device__ void op(long long* lhs, long long* rhs, bool* result) { *result = *lhs == *rhs; "
305
+ " }" ;
300
306
case cccl_type_enum::CCCL_UINT64:
301
- return " extern \" C\" __device__ bool op(unsigned long long lhs, unsigned long long rhs) { return lhs == rhs; }" ;
307
+ return " extern \" C\" __device__ void op(unsigned long long* lhs, unsigned long long* rhs, bool* result) { "
308
+ " *result = *lhs == *rhs; }" ;
302
309
case cccl_type_enum::CCCL_FLOAT32:
303
- return " extern \" C\" __device__ bool op(float lhs, float rhs) { return lhs == rhs; }" ;
310
+ return " extern \" C\" __device__ void op(float* lhs, float* rhs, bool* result ) { *result = * lhs == * rhs; }" ;
304
311
case cccl_type_enum::CCCL_FLOAT64:
305
- return " extern \" C\" __device__ bool op(double lhs, double rhs) { return lhs == rhs; }" ;
312
+ return " extern \" C\" __device__ void op(double* lhs, double* rhs, bool* result ) { *result = * lhs == * rhs; }" ;
306
313
default :
307
314
throw std::runtime_error (" Unsupported type" );
308
315
}
@@ -314,21 +321,19 @@ static std::string get_unary_op(cccl_type_enum t)
314
321
switch (t)
315
322
{
316
323
case cccl_type_enum::CCCL_INT8:
317
- return " extern \" C\" __device__ char op(char a ) { return 2 * a; }" ;
324
+ return " extern \" C\" __device__ void op(char* a, char* result ) { *result = 2 * * a; }" ;
318
325
case cccl_type_enum::CCCL_INT32:
319
- return " extern \" C\" __device__ int op(int a ) { return 2 * a; }" ;
326
+ return " extern \" C\" __device__ void op(int* a, int* result ) { *result = 2 * * a; }" ;
320
327
case cccl_type_enum::CCCL_UINT32:
321
- return " extern \" C\" __device__ unsigned int op(unsigned int a ) { return 2 * a; }" ;
328
+ return " extern \" C\" __device__ void op(unsigned int* a, unsigned int* result ) { *result = 2 * * a; }" ;
322
329
case cccl_type_enum::CCCL_INT64:
323
- return " extern \" C\" __device__ long long op(long long a ) { return 2 * a; }" ;
330
+ return " extern \" C\" __device__ void op(long long* a, long long* result ) { *result = 2 * * a; }" ;
324
331
case cccl_type_enum::CCCL_UINT64:
325
- return " extern \" C\" __device__ unsigned long long op(unsigned long long a) { "
326
- " return 2 * a; "
327
- " }" ;
332
+ return " extern \" C\" __device__ void op(unsigned long long* a, unsigned long long* result) { *result = 2 * *a; }" ;
328
333
case cccl_type_enum::CCCL_FLOAT32:
329
- return " extern \" C\" __device__ float op(float a ) { return 2 * a; }" ;
334
+ return " extern \" C\" __device__ void op(float* a, float* result ) { *result = 2 * * a; }" ;
330
335
case cccl_type_enum::CCCL_FLOAT64:
331
- return " extern \" C\" __device__ double op(double a ) { return 2 * a; }" ;
336
+ return " extern \" C\" __device__ void op(double* a, double* result ) { *result = 2 * * a; }" ;
332
337
default :
333
338
throw std::runtime_error (" Unsupported type" );
334
339
}
0 commit comments