Skip to content

Commit b08f94a

Browse files
committed
coll/HAN: Add support for XHC's "special" Reduce for the low-comm in Allreduce
MPI_Reduce in XHC is not complete; it is implemented as a sub-case of Allreduce, and requires that the rbuf parameter is always present and appropriately sized for all ranks (not only for the root). This implementation is disabled by default and falls back to another coll component, but can be manually enabled for a single operation by invoking it with root=-1, which will do a reduce to rank 0. Inside HAN's Allreduce, the rbuf parameter restriction is satisfied, so it's safe to use this partially implemented Reduce. This patch is temporary (TM) until XHC's Reduce is fully implemented. The reason for its existence is the improved Allreduce performance potential with XHC for the intra-comm. Signed-off-by: George Katevenis <[email protected]>
1 parent 17c4dba commit b08f94a

File tree

2 files changed

+60
-34
lines changed

2 files changed

+60
-34
lines changed

ompi/mca/coll/han/coll_han.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct mca_coll_han_allreduce_args_s {
102102
int seg_count;
103103
int root_up_rank;
104104
int root_low_rank;
105+
int root_reduce_low_rank;
105106
int num_segments;
106107
int cur_seg;
107108
int w_rank;

ompi/mca/coll/han/coll_han_allreduce.c

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*
77
* Copyright (c) 2020 Cisco Systems, Inc. All rights reserved.
88
* Copyright (c) 2022 IBM Corporation. All rights reserved
9+
* Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV)
10+
* Laboratory, ICS Forth. All rights reserved.
911
* $COPYRIGHT$
1012
*
1113
* Additional copyrights may follow
@@ -22,6 +24,7 @@
2224

2325
#include "coll_han.h"
2426
#include "ompi/mca/coll/base/coll_base_functions.h"
27+
#include "ompi/mca/coll/base/coll_base_util.h"
2528
#include "ompi/mca/coll/base/coll_tags.h"
2629
#include "ompi/mca/pml/pml.h"
2730
#include "coll_han_trigger.h"
@@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
4346
struct ompi_op_t *op,
4447
int root_up_rank,
4548
int root_low_rank,
49+
int root_reduce_low_rank,
4650
struct ompi_communicator_t *up_comm,
4751
struct ompi_communicator_t *low_comm,
4852
int num_segments,
@@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
5963
args->op = op;
6064
args->root_up_rank = root_up_rank;
6165
args->root_low_rank = root_low_rank;
66+
args->root_reduce_low_rank = root_reduce_low_rank;
6267
args->up_comm = up_comm;
6368
args->low_comm = low_comm;
6469
args->num_segments = num_segments;
@@ -139,15 +144,26 @@ mca_coll_han_allreduce_intra(const void *sbuf,
139144
int low_rank = ompi_comm_rank(low_comm);
140145
int root_up_rank = 0;
141146
int root_low_rank = 0;
147+
int root_reduce_low_rank = 0;
148+
149+
mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *)
150+
opal_list_get_last(low_comm->c_coll->module_list);
151+
152+
// Invoke XHC's "special" Reduce
153+
if(0 == strcmp(low_1st_module->ac_component_name, "xhc")
154+
&& low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) {
155+
root_reduce_low_rank = -1;
156+
}
157+
142158
/* Create t0 task for the first segment */
143159
mca_coll_task_t *t0 = OBJ_NEW(mca_coll_task_t);
144160
/* Setup up t0 task arguments */
145161
int *completed = (int *) malloc(sizeof(int));
146162
completed[0] = 0;
147163
mca_coll_han_allreduce_args_t *t = malloc(sizeof(mca_coll_han_allreduce_args_t));
148164
mca_coll_han_set_allreduce_args(t, t0, (char *) sbuf, (char *) rbuf, seg_count, dtype, op,
149-
root_up_rank, root_low_rank, up_comm, low_comm, num_segments, 0,
150-
w_rank, count - (num_segments - 1) * seg_count,
165+
root_up_rank, root_low_rank, root_reduce_low_rank, up_comm,
166+
low_comm, num_segments, 0, w_rank, count - (num_segments - 1) * seg_count,
151167
low_rank != root_low_rank, NULL, completed);
152168
/* Init t0 task */
153169
init_task(t0, mca_coll_han_allreduce_t0_task, (void *) (t));
@@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args)
215231
if (MPI_IN_PLACE == t->sbuf) {
216232
if (!t->noop) {
217233
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *) t->rbuf, t->seg_count, t->dtype,
218-
t->op, t->root_low_rank, t->low_comm,
234+
t->op, t->root_reduce_low_rank, t->low_comm,
219235
t->low_comm->c_coll->coll_reduce_module);
220236
}
221237
else {
222238
t->low_comm->c_coll->coll_reduce((char *) t->rbuf, NULL, t->seg_count, t->dtype,
223-
t->op, t->root_low_rank, t->low_comm,
239+
t->op, t->root_reduce_low_rank, t->low_comm,
224240
t->low_comm->c_coll->coll_reduce_module);
225241
}
226242
}
227243
else {
228244
t->low_comm->c_coll->coll_reduce((char *) t->sbuf, (char *) t->rbuf, t->seg_count, t->dtype,
229-
t->op, t->root_low_rank, t->low_comm,
245+
t->op, t->root_reduce_low_rank, t->low_comm,
230246
t->low_comm->c_coll->coll_reduce_module);
231247
}
232248
return OMPI_SUCCESS;
@@ -267,21 +283,20 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
267283
if (!t->noop) {
268284
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
269285
(char *) t->rbuf + extent * t->seg_count, tmp_count,
270-
t->dtype, t->op, t->root_low_rank, t->low_comm,
286+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
271287
t->low_comm->c_coll->coll_reduce_module);
272288
} else {
273289
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + extent * t->seg_count,
274290
NULL, tmp_count,
275-
t->dtype, t->op, t->root_low_rank, t->low_comm,
291+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
276292
t->low_comm->c_coll->coll_reduce_module);
277-
278293
}
279294
} else {
280295
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + extent * t->seg_count,
281296
(char *) t->rbuf + extent * t->seg_count, tmp_count,
282-
t->dtype, t->op, t->root_low_rank, t->low_comm,
297+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
283298
t->low_comm->c_coll->coll_reduce_module);
284-
}
299+
}
285300
}
286301
if (!t->noop) {
287302
ompi_request_wait(&ireduce_req, MPI_STATUS_IGNORE);
@@ -337,25 +352,25 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
337352
tmp_count = t->last_seg_count;
338353
}
339354

340-
if (t->sbuf == MPI_IN_PLACE) {
341-
if (!t->noop) {
342-
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
355+
if (t->sbuf == MPI_IN_PLACE) {
356+
if (!t->noop) {
357+
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
358+
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
359+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
360+
t->low_comm->c_coll->coll_reduce_module);
361+
} else {
362+
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count,
363+
NULL, tmp_count,
364+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
365+
t->low_comm->c_coll->coll_reduce_module);
366+
367+
}
368+
} else {
369+
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
343370
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
344-
t->dtype, t->op, t->root_low_rank, t->low_comm,
345-
t->low_comm->c_coll->coll_reduce_module);
346-
} else {
347-
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count,
348-
NULL, tmp_count,
349-
t->dtype, t->op, t->root_low_rank, t->low_comm,
371+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
350372
t->low_comm->c_coll->coll_reduce_module);
351-
352-
}
353-
} else {
354-
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
355-
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
356-
t->dtype, t->op, t->root_low_rank, t->low_comm,
357-
t->low_comm->c_coll->coll_reduce_module);
358-
}
373+
}
359374
}
360375
if (!t->noop && req_count > 0) {
361376
ompi_request_wait_all(req_count, reqs, MPI_STATUSES_IGNORE);
@@ -421,18 +436,18 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
421436
if (!t->noop) {
422437
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
423438
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
424-
t->dtype, t->op, t->root_low_rank, t->low_comm,
439+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
425440
t->low_comm->c_coll->coll_reduce_module);
426-
} else {
441+
} else {
427442
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 3 * extent * t->seg_count,
428443
NULL, tmp_count,
429-
t->dtype, t->op, t->root_low_rank, t->low_comm,
444+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
430445
t->low_comm->c_coll->coll_reduce_module);
431446
}
432447
} else {
433448
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 3 * extent * t->seg_count,
434449
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
435-
t->dtype, t->op, t->root_low_rank, t->low_comm,
450+
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
436451
t->low_comm->c_coll->coll_reduce_module);
437452
}
438453
}
@@ -473,6 +488,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
473488
ompi_communicator_t *low_comm;
474489
ompi_communicator_t *up_comm;
475490
int root_low_rank = 0;
491+
int root_reduce_low_rank = 0;
476492
int low_rank;
477493
int ret;
478494
mca_coll_han_module_t *han_module = (mca_coll_han_module_t *)module;
@@ -504,22 +520,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
504520
up_comm = han_module->sub_comm[INTER_NODE];
505521
low_rank = ompi_comm_rank(low_comm);
506522

523+
mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *)
524+
opal_list_get_last(low_comm->c_coll->module_list);
525+
526+
// Invoke XHC's "special" Reduce
527+
if(0 == strcmp(low_1st_module->ac_component_name, "xhc")
528+
&& low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) {
529+
root_reduce_low_rank = -1;
530+
}
531+
507532
/* Low_comm reduce */
508533
if (MPI_IN_PLACE == sbuf) {
509534
if (low_rank == root_low_rank) {
510535
ret = low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *)rbuf,
511-
count, dtype, op, root_low_rank,
536+
count, dtype, op, root_reduce_low_rank,
512537
low_comm, low_comm->c_coll->coll_reduce_module);
513538
}
514539
else {
515540
ret = low_comm->c_coll->coll_reduce((char *)rbuf, NULL,
516-
count, dtype, op, root_low_rank,
541+
count, dtype, op, root_reduce_low_rank,
517542
low_comm, low_comm->c_coll->coll_reduce_module);
518543
}
519544
}
520545
else {
521546
ret = low_comm->c_coll->coll_reduce((char *)sbuf, (char *)rbuf,
522-
count, dtype, op, root_low_rank,
547+
count, dtype, op, root_reduce_low_rank,
523548
low_comm, low_comm->c_coll->coll_reduce_module);
524549
}
525550
if (OPAL_UNLIKELY(OMPI_SUCCESS != ret)) {

0 commit comments

Comments
 (0)