Skip to content

Commit ba59533

Browse files
committed
op/aarch64: refactor SVE functions
Refactor SVE functions and incidentally make NVIDIA compilers a happy panda again. Signed-off-by: Gilles Gouaillardet <[email protected]>
1 parent 1b95379 commit ba59533

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

ompi/mca/op/aarch64/op_aarch64_functions.c

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
* reserved.
55
* Copyright (c) 2019 Arm Ltd. All rights reserved.
66
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
7+
* Copyright (c) 2024 Research Organization for Information Science
8+
* and Technology (RIST). All rights reserved.
79
*
810
* $COPYRIGHT$
911
*
@@ -140,20 +142,18 @@ _Generic((*(out)), \
140142
struct ompi_datatype_t **dtype, \
141143
struct ompi_op_base_module_1_0_0_t *module) \
142144
{ \
143-
int types_per_step = svcnt(*((type##type_size##_t *) _in)); \
144-
size_t idx = 0, left_over = *count; \
145+
const int types_per_step = svcnt(*((type##type_size##_t *) _in)); \
146+
const int cnt = *count; \
145147
type##type_size##_t *in = (type##type_size##_t *) _in, \
146148
*out = (type##type_size##_t *) _out; \
147149
OP_CONCAT(OMPI_OP_TYPE_PREPEND, type##type_size##_t) vsrc, vdst; \
148-
svbool_t pred = svwhilelt_b##type_size(idx, left_over); \
149-
do { \
150+
for (int idx=0; idx < cnt; idx += types_per_step) { \
151+
svbool_t pred = svwhilelt_b##type_size(idx, cnt); \
150152
vsrc = svld1(pred, &in[idx]); \
151153
vdst = svld1(pred, &out[idx]); \
152154
vdst = OP_CONCAT(OMPI_OP_OP_PREPEND, op##_x)(pred, vdst, vsrc); \
153155
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
154-
idx += types_per_step; \
155-
pred = svwhilelt_b##type_size(idx, left_over); \
156-
} while (svptest_any(svptrue_b##type_size(), pred)); \
156+
} \
157157
}
158158
#endif
159159

@@ -308,21 +308,19 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
308308
struct ompi_datatype_t **dtype, \
309309
struct ompi_op_base_module_1_0_0_t *module) \
310310
{ \
311-
int types_per_step = svcnt(*((type##type_size##_t *) _in1)); \
311+
const int types_per_step = svcnt(*((type##type_size##_t *) _in1)); \
312312
type##type_size##_t *in1 = (type##type_size##_t *) _in1, \
313313
*in2 = (type##type_size##_t *) _in2, \
314314
*out = (type##type_size##_t *) _out; \
315-
size_t idx = 0, left_over = *count; \
315+
const int cnt = *count; \
316316
OP_CONCAT(OMPI_OP_TYPE_PREPEND, type##type_size##_t) vsrc, vdst; \
317-
svbool_t pred = svwhilelt_b##type_size(idx, left_over); \
318-
do { \
317+
for (int idx=0; idx < cnt; idx += types_per_step) { \
318+
svbool_t pred = svwhilelt_b##type_size(idx, cnt); \
319319
vsrc = svld1(pred, &in1[idx]); \
320320
vdst = svld1(pred, &in2[idx]); \
321321
vdst = OP_CONCAT(OMPI_OP_OP_PREPEND, op##_x)(pred, vdst, vsrc); \
322322
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
323-
idx += types_per_step; \
324-
pred = svwhilelt_b##type_size(idx, left_over); \
325-
} while (svptest_any(svptrue_b##type_size(), pred)); \
323+
} \
326324
}
327325
#endif /* defined(GENERATE_SVE_CODE) */
328326

0 commit comments

Comments
 (0)