6
6
*
7
7
* Copyright (c) 2020 Cisco Systems, Inc. All rights reserved.
8
8
* Copyright (c) 2022 IBM Corporation. All rights reserved
9
+ * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV)
10
+ * Laboratory, ICS Forth. All rights reserved.
9
11
* $COPYRIGHT$
10
12
*
11
13
* Additional copyrights may follow
22
24
23
25
#include "coll_han.h"
24
26
#include "ompi/mca/coll/base/coll_base_functions.h"
27
+ #include "ompi/mca/coll/base/coll_base_util.h"
25
28
#include "ompi/mca/coll/base/coll_tags.h"
26
29
#include "ompi/mca/pml/pml.h"
27
30
#include "coll_han_trigger.h"
@@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
43
46
struct ompi_op_t * op ,
44
47
int root_up_rank ,
45
48
int root_low_rank ,
49
+ int root_reduce_low_rank ,
46
50
struct ompi_communicator_t * up_comm ,
47
51
struct ompi_communicator_t * low_comm ,
48
52
int num_segments ,
@@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
59
63
args -> op = op ;
60
64
args -> root_up_rank = root_up_rank ;
61
65
args -> root_low_rank = root_low_rank ;
66
+ args -> root_reduce_low_rank = root_reduce_low_rank ;
62
67
args -> up_comm = up_comm ;
63
68
args -> low_comm = low_comm ;
64
69
args -> num_segments = num_segments ;
@@ -139,15 +144,26 @@ mca_coll_han_allreduce_intra(const void *sbuf,
139
144
int low_rank = ompi_comm_rank (low_comm );
140
145
int root_up_rank = 0 ;
141
146
int root_low_rank = 0 ;
147
+ int root_reduce_low_rank = 0 ;
148
+
149
+ mca_coll_base_avail_coll_t * low_1st_module = (mca_coll_base_avail_coll_t * )
150
+ opal_list_get_last (low_comm -> c_coll -> module_list );
151
+
152
+ // Invoke XHC's "special" Reduce
153
+ if (0 == strcmp (low_1st_module -> ac_component_name , "xhc" )
154
+ && low_comm -> c_coll -> coll_reduce_module == low_1st_module -> ac_module ) {
155
+ root_reduce_low_rank = -1 ;
156
+ }
157
+
142
158
/* Create t0 task for the first segment */
143
159
mca_coll_task_t * t0 = OBJ_NEW (mca_coll_task_t );
144
160
/* Setup up t0 task arguments */
145
161
int * completed = (int * ) malloc (sizeof (int ));
146
162
completed [0 ] = 0 ;
147
163
mca_coll_han_allreduce_args_t * t = malloc (sizeof (mca_coll_han_allreduce_args_t ));
148
164
mca_coll_han_set_allreduce_args (t , t0 , (char * ) sbuf , (char * ) rbuf , seg_count , dtype , op ,
149
- root_up_rank , root_low_rank , up_comm , low_comm , num_segments , 0 ,
150
- w_rank , count - (num_segments - 1 ) * seg_count ,
165
+ root_up_rank , root_low_rank , root_reduce_low_rank , up_comm ,
166
+ low_comm , num_segments , 0 , w_rank , count - (num_segments - 1 ) * seg_count ,
151
167
low_rank != root_low_rank , NULL , completed );
152
168
/* Init t0 task */
153
169
init_task (t0 , mca_coll_han_allreduce_t0_task , (void * ) (t ));
@@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args)
215
231
if (MPI_IN_PLACE == t -> sbuf ) {
216
232
if (!t -> noop ) {
217
233
t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE , (char * ) t -> rbuf , t -> seg_count , t -> dtype ,
218
- t -> op , t -> root_low_rank , t -> low_comm ,
234
+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
219
235
t -> low_comm -> c_coll -> coll_reduce_module );
220
236
}
221
237
else {
222
238
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf , NULL , t -> seg_count , t -> dtype ,
223
- t -> op , t -> root_low_rank , t -> low_comm ,
239
+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
224
240
t -> low_comm -> c_coll -> coll_reduce_module );
225
241
}
226
242
}
227
243
else {
228
244
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf , (char * ) t -> rbuf , t -> seg_count , t -> dtype ,
229
- t -> op , t -> root_low_rank , t -> low_comm ,
245
+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
230
246
t -> low_comm -> c_coll -> coll_reduce_module );
231
247
}
232
248
return OMPI_SUCCESS ;
@@ -264,7 +280,7 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
264
280
}
265
281
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
266
282
(char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
267
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
283
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
268
284
t -> low_comm -> c_coll -> coll_reduce_module );
269
285
270
286
}
@@ -323,7 +339,7 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
323
339
}
324
340
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
325
341
(char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
326
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
342
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
327
343
t -> low_comm -> c_coll -> coll_reduce_module );
328
344
}
329
345
if (!t -> noop && req_count > 0 ) {
@@ -387,7 +403,7 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
387
403
}
388
404
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
389
405
(char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
390
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
406
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
391
407
t -> low_comm -> c_coll -> coll_reduce_module );
392
408
}
393
409
/* lb of cur_seg */
@@ -421,6 +437,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
421
437
ompi_communicator_t * low_comm ;
422
438
ompi_communicator_t * up_comm ;
423
439
int root_low_rank = 0 ;
440
+ int root_reduce_low_rank = 0 ;
424
441
int low_rank ;
425
442
int ret ;
426
443
mca_coll_han_module_t * han_module = (mca_coll_han_module_t * )module ;
@@ -452,22 +469,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
452
469
up_comm = han_module -> sub_comm [INTER_NODE ];
453
470
low_rank = ompi_comm_rank (low_comm );
454
471
472
+ mca_coll_base_avail_coll_t * low_1st_module = (mca_coll_base_avail_coll_t * )
473
+ opal_list_get_last (low_comm -> c_coll -> module_list );
474
+
475
+ // Invoke XHC's "special" Reduce
476
+ if (0 == strcmp (low_1st_module -> ac_component_name , "xhc" )
477
+ && low_comm -> c_coll -> coll_reduce_module == low_1st_module -> ac_module ) {
478
+ root_reduce_low_rank = -1 ;
479
+ }
480
+
455
481
/* Low_comm reduce */
456
482
if (MPI_IN_PLACE == sbuf ) {
457
483
if (low_rank == root_low_rank ) {
458
484
ret = low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE , (char * )rbuf ,
459
- count , dtype , op , root_low_rank ,
485
+ count , dtype , op , root_reduce_low_rank ,
460
486
low_comm , low_comm -> c_coll -> coll_reduce_module );
461
487
}
462
488
else {
463
489
ret = low_comm -> c_coll -> coll_reduce ((char * )rbuf , NULL ,
464
- count , dtype , op , root_low_rank ,
490
+ count , dtype , op , root_reduce_low_rank ,
465
491
low_comm , low_comm -> c_coll -> coll_reduce_module );
466
492
}
467
493
}
468
494
else {
469
495
ret = low_comm -> c_coll -> coll_reduce ((char * )sbuf , (char * )rbuf ,
470
- count , dtype , op , root_low_rank ,
496
+ count , dtype , op , root_reduce_low_rank ,
471
497
low_comm , low_comm -> c_coll -> coll_reduce_module );
472
498
}
473
499
if (OPAL_UNLIKELY (OMPI_SUCCESS != ret )) {
0 commit comments