Skip to content

Commit 972e9c4

Browse files
committed
Update ompi_op_reduce to take size_t count
* If the `count` is greater than `INT_MAX` then we call the operation in chunks that fit into an `int`. * This moves the functionality out of libnbc and into the common reduction operation so that all collectives may pass larger counts than `INT_MAX` into the internal reduction operation. Signed-off-by: Joshua Hursey <[email protected]>
1 parent 6b8e368 commit 972e9c4

File tree

2 files changed

+38
-29
lines changed

2 files changed

+38
-29
lines changed

ompi/mca/coll/libnbc/nbc.c

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -536,30 +536,7 @@ static inline int NBC_Start_round(NBC_Handle *handle) {
536536
buf2=opargs.buf2;
537537
}
538538

539-
/* If the count is > INT_MAX then we need to call ompi_op_reduce()
540-
* in iterations of counts <= INT_MAX since it has an `int count`
541-
* parameter.
542-
*/
543-
if( OPAL_UNLIKELY(opargs.count > INT_MAX) ) {
544-
size_t done_count = 0, shift;
545-
int iter_count;
546-
ptrdiff_t ext, lb;
547-
548-
ompi_datatype_get_extent (opargs.datatype, &lb, &ext);
549-
550-
while(done_count < opargs.count) {
551-
if( done_count + INT_MAX > opargs.count ) {
552-
iter_count = opargs.count - done_count;
553-
} else {
554-
iter_count = INT_MAX;
555-
}
556-
shift = done_count * ext;
557-
ompi_op_reduce(opargs.op, buf1 + shift, buf2 + shift, iter_count, opargs.datatype);
558-
done_count += iter_count;
559-
}
560-
} else {
561-
ompi_op_reduce(opargs.op, buf1, buf2, opargs.count, opargs.datatype);
562-
}
539+
ompi_op_reduce(opargs.op, buf1, buf2, opargs.count, opargs.datatype);
563540
break;
564541
case COPY:
565542
NBC_DEBUG(5, " COPY (offset %li) ", offset);

ompi/op/op.h

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
* and Technology (RIST). All rights reserved.
2121
* Copyright (c) 2018 Triad National Security, LLC. All rights
2222
* reserved.
23+
* Copyright (c) 2021 IBM Corporation. All rights reserved.
2324
* $COPYRIGHT$
2425
*
2526
* Additional copyrights may follow
@@ -510,10 +511,41 @@ static inline bool ompi_op_is_valid(ompi_op_t * op, ompi_datatype_t * ddt,
510511
* is not defined to have that operation, it is likely to seg fault.
511512
*/
512513
static inline void ompi_op_reduce(ompi_op_t * op, void *source,
513-
void *target, int count,
514+
void *target, size_t count,
514515
ompi_datatype_t * dtype)
515516
{
516517
MPI_Fint f_dtype, f_count;
518+
int int_count = count;
519+
520+
/*
521+
* If the count is > INT_MAX then we need to call the reduction op
522+
* in iterations of counts <= INT_MAX since it has an `int *len`
523+
* parameter.
524+
*
525+
* Note: When we add BigCount support then we can distinguish between
526+
* a reduction operation with `int *len` and `MPI_Count *len`. At which
527+
* point we can avoid this loop.
528+
*/
529+
if( OPAL_UNLIKELY(count > INT_MAX) ) {
530+
size_t done_count = 0, shift;
531+
int iter_count;
532+
ptrdiff_t ext, lb;
533+
534+
ompi_datatype_get_extent(dtype, &lb, &ext);
535+
536+
while(done_count < count) {
537+
if(done_count + INT_MAX > count) {
538+
iter_count = count - done_count;
539+
} else {
540+
iter_count = INT_MAX;
541+
}
542+
shift = done_count * ext;
543+
// Recurse one level in iterations of 'int'
544+
ompi_op_reduce(op, (char*)source + shift, (char*)target + shift, iter_count, dtype);
545+
done_count += iter_count;
546+
}
547+
return;
548+
}
517549

518550
/*
519551
* Call the reduction function. Two dimensions: a) if both the op
@@ -548,25 +580,25 @@ static inline void ompi_op_reduce(ompi_op_t * op, void *source,
548580
dtype_id = ompi_op_ddt_map[dtype->id];
549581
}
550582
op->o_func.intrinsic.fns[dtype_id](source, target,
551-
&count, &dtype,
583+
&int_count, &dtype,
552584
op->o_func.intrinsic.modules[dtype_id]);
553585
return;
554586
}
555587

556588
/* User-defined function */
557589
if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) {
558590
f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index);
559-
f_count = OMPI_INT_2_FINT(count);
591+
f_count = OMPI_INT_2_FINT(int_count);
560592
op->o_func.fort_fn(source, target, &f_count, &f_dtype);
561593
return;
562594
} else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) {
563-
op->o_func.java_data.intercept_fn(source, target, &count, &dtype,
595+
op->o_func.java_data.intercept_fn(source, target, &int_count, &dtype,
564596
op->o_func.java_data.baseType,
565597
op->o_func.java_data.jnienv,
566598
op->o_func.java_data.object);
567599
return;
568600
}
569-
op->o_func.c_fn(source, target, &count, &dtype);
601+
op->o_func.c_fn(source, target, &int_count, &dtype);
570602
return;
571603
}
572604

0 commit comments

Comments
 (0)