Skip to content

Commit a203458

Browse files
committed
coll/HAN: Add support for XHC's "special" Reduce for the low-comm in Allreduce
MPI_Reduce in XHC is not complete; it is implemented as a sub-case of Allreduce, and requires that the rbuf parameter is always present and appropriately sized for all ranks (not only for the root). This implementation is disabled by default and falls back to another coll component, but can be manually enabled for a single operation by invoking it with root=-1, which will do a reduce to rank 0. Inside HAN's Allreduce, the rbuf parameter restriction is satisfied, so it's safe to use this partially implemented Reduce. This patch is temporary (TM) until XHC's Reduce is fully implemented. The reason for its existence is the improved Allreduce performance potential with XHC for the intra-comm. Signed-off-by: George Katevenis <[email protected]>
1 parent 2f279b5 commit a203458

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

ompi/mca/coll/han/coll_han.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct mca_coll_han_allreduce_args_s {
102102
int seg_count;
103103
int root_up_rank;
104104
int root_low_rank;
105+
int root_reduce_low_rank;
105106
int num_segments;
106107
int cur_seg;
107108
int w_rank;

ompi/mca/coll/han/coll_han_allreduce.c

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*
77
* Copyright (c) 2020 Cisco Systems, Inc. All rights reserved.
88
* Copyright (c) 2022 IBM Corporation. All rights reserved
9+
* Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV)
10+
* Laboratory, ICS Forth. All rights reserved.
911
* $COPYRIGHT$
1012
*
1113
* Additional copyrights may follow
@@ -22,6 +24,7 @@
2224

2325
#include "coll_han.h"
2426
#include "ompi/mca/coll/base/coll_base_functions.h"
27+
#include "ompi/mca/coll/base/coll_base_util.h"
2528
#include "ompi/mca/coll/base/coll_tags.h"
2629
#include "ompi/mca/pml/pml.h"
2730
#include "coll_han_trigger.h"
@@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
4346
struct ompi_op_t *op,
4447
int root_up_rank,
4548
int root_low_rank,
49+
int root_reduce_low_rank,
4650
struct ompi_communicator_t *up_comm,
4751
struct ompi_communicator_t *low_comm,
4852
int num_segments,
@@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
5963
args->op = op;
6064
args->root_up_rank = root_up_rank;
6165
args->root_low_rank = root_low_rank;
66+
args->root_reduce_low_rank = root_reduce_low_rank;
6267
args->up_comm = up_comm;
6368
args->low_comm = low_comm;
6469
args->num_segments = num_segments;
@@ -139,15 +144,26 @@ mca_coll_han_allreduce_intra(const void *sbuf,
139144
int low_rank = ompi_comm_rank(low_comm);
140145
int root_up_rank = 0;
141146
int root_low_rank = 0;
147+
int root_reduce_low_rank = 0;
148+
149+
mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *)
150+
opal_list_get_last(low_comm->c_coll->module_list);
151+
152+
// Invoke XHC's "special" Reduce
153+
if(0 == strcmp(low_1st_module->ac_component_name, "xhc")
154+
&& low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) {
155+
root_reduce_low_rank = -1;
156+
}
157+
142158
/* Create t0 task for the first segment */
143159
mca_coll_task_t *t0 = OBJ_NEW(mca_coll_task_t);
144160
/* Setup up t0 task arguments */
145161
int *completed = (int *) malloc(sizeof(int));
146162
completed[0] = 0;
147163
mca_coll_han_allreduce_args_t *t = malloc(sizeof(mca_coll_han_allreduce_args_t));
148164
mca_coll_han_set_allreduce_args(t, t0, (char *) sbuf, (char *) rbuf, seg_count, dtype, op,
149-
root_up_rank, root_low_rank, up_comm, low_comm, num_segments, 0,
150-
w_rank, count - (num_segments - 1) * seg_count,
165+
root_up_rank, root_low_rank, root_reduce_low_rank, up_comm,
166+
low_comm, num_segments, 0, w_rank, count - (num_segments - 1) * seg_count,
151167
low_rank != root_low_rank, NULL, completed);
152168
/* Init t0 task */
153169
init_task(t0, mca_coll_han_allreduce_t0_task, (void *) (t));
@@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args)
215231
if (MPI_IN_PLACE == t->sbuf) {
216232
if (!t->noop) {
217233
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *) t->rbuf, t->seg_count, t->dtype,
218-
t->op, t->root_low_rank, t->low_comm,
234+
t->op, t->root_reduce_low_rank, t->low_comm,
219235
t->low_comm->c_coll->coll_reduce_module);
220236
}
221237
else {
222238
t->low_comm->c_coll->coll_reduce((char *) t->rbuf, NULL, t->seg_count, t->dtype,
223-
t->op, t->root_low_rank, t->low_comm,
239+
t->op, t->root_reduce_low_rank, t->low_comm,
224240
t->low_comm->c_coll->coll_reduce_module);
225241
}
226242
}
227243
else {
228244
t->low_comm->c_coll->coll_reduce((char *) t->sbuf, (char *) t->rbuf, t->seg_count, t->dtype,
229-
t->op, t->root_low_rank, t->low_comm,
245+
t->op, t->root_reduce_low_rank, t->low_comm,
230246
t->low_comm->c_coll->coll_reduce_module);
231247
}
232248
return OMPI_SUCCESS;
@@ -264,7 +280,7 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
264280
}
265281
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + extent * t->seg_count,
266282
(char *) t->rbuf + extent * t->seg_count, tmp_count,
267-
t->dtype, t->op, t->root_low_rank, t->low_comm,
283+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
268284
t->low_comm->c_coll->coll_reduce_module);
269285

270286
}
@@ -323,7 +339,7 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
323339
}
324340
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
325341
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
326-
t->dtype, t->op, t->root_low_rank, t->low_comm,
342+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
327343
t->low_comm->c_coll->coll_reduce_module);
328344
}
329345
if (!t->noop && req_count > 0) {
@@ -387,7 +403,7 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
387403
}
388404
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 3 * extent * t->seg_count,
389405
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
390-
t->dtype, t->op, t->root_low_rank, t->low_comm,
406+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
391407
t->low_comm->c_coll->coll_reduce_module);
392408
}
393409
/* lb of cur_seg */
@@ -421,6 +437,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
421437
ompi_communicator_t *low_comm;
422438
ompi_communicator_t *up_comm;
423439
int root_low_rank = 0;
440+
int root_reduce_low_rank = 0;
424441
int low_rank;
425442
int ret;
426443
mca_coll_han_module_t *han_module = (mca_coll_han_module_t *)module;
@@ -452,22 +469,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
452469
up_comm = han_module->sub_comm[INTER_NODE];
453470
low_rank = ompi_comm_rank(low_comm);
454471

472+
mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *)
473+
opal_list_get_last(low_comm->c_coll->module_list);
474+
475+
// Invoke XHC's "special" Reduce
476+
if(0 == strcmp(low_1st_module->ac_component_name, "xhc")
477+
&& low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) {
478+
root_reduce_low_rank = -1;
479+
}
480+
455481
/* Low_comm reduce */
456482
if (MPI_IN_PLACE == sbuf) {
457483
if (low_rank == root_low_rank) {
458484
ret = low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *)rbuf,
459-
count, dtype, op, root_low_rank,
485+
count, dtype, op, root_reduce_low_rank,
460486
low_comm, low_comm->c_coll->coll_reduce_module);
461487
}
462488
else {
463489
ret = low_comm->c_coll->coll_reduce((char *)rbuf, NULL,
464-
count, dtype, op, root_low_rank,
490+
count, dtype, op, root_reduce_low_rank,
465491
low_comm, low_comm->c_coll->coll_reduce_module);
466492
}
467493
}
468494
else {
469495
ret = low_comm->c_coll->coll_reduce((char *)sbuf, (char *)rbuf,
470-
count, dtype, op, root_low_rank,
496+
count, dtype, op, root_reduce_low_rank,
471497
low_comm, low_comm->c_coll->coll_reduce_module);
472498
}
473499
if (OPAL_UNLIKELY(OMPI_SUCCESS != ret)) {

0 commit comments

Comments
 (0)