6
6
*
7
7
* Copyright (c) 2020 Cisco Systems, Inc. All rights reserved.
8
8
* Copyright (c) 2022 IBM Corporation. All rights reserved
9
+ * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV)
10
+ * Laboratory, ICS Forth. All rights reserved.
9
11
* $COPYRIGHT$
10
12
*
11
13
* Additional copyrights may follow
22
24
23
25
#include "coll_han.h"
24
26
#include "ompi/mca/coll/base/coll_base_functions.h"
27
+ #include "ompi/mca/coll/base/coll_base_util.h"
25
28
#include "ompi/mca/coll/base/coll_tags.h"
26
29
#include "ompi/mca/pml/pml.h"
27
30
#include "coll_han_trigger.h"
@@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
43
46
struct ompi_op_t * op ,
44
47
int root_up_rank ,
45
48
int root_low_rank ,
49
+ int root_reduce_low_rank ,
46
50
struct ompi_communicator_t * up_comm ,
47
51
struct ompi_communicator_t * low_comm ,
48
52
int num_segments ,
@@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
59
63
args -> op = op ;
60
64
args -> root_up_rank = root_up_rank ;
61
65
args -> root_low_rank = root_low_rank ;
66
+ args -> root_reduce_low_rank = root_reduce_low_rank ;
62
67
args -> up_comm = up_comm ;
63
68
args -> low_comm = low_comm ;
64
69
args -> num_segments = num_segments ;
@@ -139,15 +144,26 @@ mca_coll_han_allreduce_intra(const void *sbuf,
139
144
int low_rank = ompi_comm_rank (low_comm );
140
145
int root_up_rank = 0 ;
141
146
int root_low_rank = 0 ;
147
+ int root_reduce_low_rank = 0 ;
148
+
149
+ mca_coll_base_avail_coll_t * low_1st_module = (mca_coll_base_avail_coll_t * )
150
+ opal_list_get_last (low_comm -> c_coll -> module_list );
151
+
152
+ // Invoke XHC's "special" Reduce
153
+ if (0 == strcmp (low_1st_module -> ac_component_name , "xhc" )
154
+ && low_comm -> c_coll -> coll_reduce_module == low_1st_module -> ac_module ) {
155
+ root_reduce_low_rank = -1 ;
156
+ }
157
+
142
158
/* Create t0 task for the first segment */
143
159
mca_coll_task_t * t0 = OBJ_NEW (mca_coll_task_t );
144
160
/* Setup up t0 task arguments */
145
161
int * completed = (int * ) malloc (sizeof (int ));
146
162
completed [0 ] = 0 ;
147
163
mca_coll_han_allreduce_args_t * t = malloc (sizeof (mca_coll_han_allreduce_args_t ));
148
164
mca_coll_han_set_allreduce_args (t , t0 , (char * ) sbuf , (char * ) rbuf , seg_count , dtype , op ,
149
- root_up_rank , root_low_rank , up_comm , low_comm , num_segments , 0 ,
150
- w_rank , count - (num_segments - 1 ) * seg_count ,
165
+ root_up_rank , root_low_rank , root_reduce_low_rank , up_comm ,
166
+ low_comm , num_segments , 0 , w_rank , count - (num_segments - 1 ) * seg_count ,
151
167
low_rank != root_low_rank , NULL , completed );
152
168
/* Init t0 task */
153
169
init_task (t0 , mca_coll_han_allreduce_t0_task , (void * ) (t ));
@@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args)
215
231
if (MPI_IN_PLACE == t -> sbuf ) {
216
232
if (!t -> noop ) {
217
233
t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE , (char * ) t -> rbuf , t -> seg_count , t -> dtype ,
218
- t -> op , t -> root_low_rank , t -> low_comm ,
234
+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
219
235
t -> low_comm -> c_coll -> coll_reduce_module );
220
236
}
221
237
else {
222
238
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf , NULL , t -> seg_count , t -> dtype ,
223
- t -> op , t -> root_low_rank , t -> low_comm ,
239
+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
224
240
t -> low_comm -> c_coll -> coll_reduce_module );
225
241
}
226
242
}
227
243
else {
228
244
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf , (char * ) t -> rbuf , t -> seg_count , t -> dtype ,
229
- t -> op , t -> root_low_rank , t -> low_comm ,
245
+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
230
246
t -> low_comm -> c_coll -> coll_reduce_module );
231
247
}
232
248
return OMPI_SUCCESS ;
@@ -267,21 +283,20 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
267
283
if (!t -> noop ) {
268
284
t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
269
285
(char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
270
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
286
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
271
287
t -> low_comm -> c_coll -> coll_reduce_module );
272
288
} else {
273
289
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + extent * t -> seg_count ,
274
290
NULL , tmp_count ,
275
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
291
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
276
292
t -> low_comm -> c_coll -> coll_reduce_module );
277
-
278
293
}
279
294
} else {
280
295
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
281
296
(char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
282
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
297
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
283
298
t -> low_comm -> c_coll -> coll_reduce_module );
284
- }
299
+ }
285
300
}
286
301
if (!t -> noop ) {
287
302
ompi_request_wait (& ireduce_req , MPI_STATUS_IGNORE );
@@ -337,25 +352,25 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
337
352
tmp_count = t -> last_seg_count ;
338
353
}
339
354
340
- if (t -> sbuf == MPI_IN_PLACE ) {
341
- if (!t -> noop ) {
342
- t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
355
+ if (t -> sbuf == MPI_IN_PLACE ) {
356
+ if (!t -> noop ) {
357
+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
358
+ (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
359
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
360
+ t -> low_comm -> c_coll -> coll_reduce_module );
361
+ } else {
362
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 2 * extent * t -> seg_count ,
363
+ NULL , tmp_count ,
364
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
365
+ t -> low_comm -> c_coll -> coll_reduce_module );
366
+
367
+ }
368
+ } else {
369
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
343
370
(char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
344
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
345
- t -> low_comm -> c_coll -> coll_reduce_module );
346
- } else {
347
- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 2 * extent * t -> seg_count ,
348
- NULL , tmp_count ,
349
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
371
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
350
372
t -> low_comm -> c_coll -> coll_reduce_module );
351
-
352
- }
353
- } else {
354
- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
355
- (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
356
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
357
- t -> low_comm -> c_coll -> coll_reduce_module );
358
- }
373
+ }
359
374
}
360
375
if (!t -> noop && req_count > 0 ) {
361
376
ompi_request_wait_all (req_count , reqs , MPI_STATUSES_IGNORE );
@@ -421,18 +436,18 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
421
436
if (!t -> noop ) {
422
437
t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
423
438
(char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
424
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
439
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
425
440
t -> low_comm -> c_coll -> coll_reduce_module );
426
- } else {
441
+ } else {
427
442
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 3 * extent * t -> seg_count ,
428
443
NULL , tmp_count ,
429
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
444
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
430
445
t -> low_comm -> c_coll -> coll_reduce_module );
431
446
}
432
447
} else {
433
448
t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
434
449
(char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
435
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
450
+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
436
451
t -> low_comm -> c_coll -> coll_reduce_module );
437
452
}
438
453
}
@@ -473,6 +488,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
473
488
ompi_communicator_t * low_comm ;
474
489
ompi_communicator_t * up_comm ;
475
490
int root_low_rank = 0 ;
491
+ int root_reduce_low_rank = 0 ;
476
492
int low_rank ;
477
493
int ret ;
478
494
mca_coll_han_module_t * han_module = (mca_coll_han_module_t * )module ;
@@ -504,22 +520,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
504
520
up_comm = han_module -> sub_comm [INTER_NODE ];
505
521
low_rank = ompi_comm_rank (low_comm );
506
522
523
+ mca_coll_base_avail_coll_t * low_1st_module = (mca_coll_base_avail_coll_t * )
524
+ opal_list_get_last (low_comm -> c_coll -> module_list );
525
+
526
+ // Invoke XHC's "special" Reduce
527
+ if (0 == strcmp (low_1st_module -> ac_component_name , "xhc" )
528
+ && low_comm -> c_coll -> coll_reduce_module == low_1st_module -> ac_module ) {
529
+ root_reduce_low_rank = -1 ;
530
+ }
531
+
507
532
/* Low_comm reduce */
508
533
if (MPI_IN_PLACE == sbuf ) {
509
534
if (low_rank == root_low_rank ) {
510
535
ret = low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE , (char * )rbuf ,
511
- count , dtype , op , root_low_rank ,
536
+ count , dtype , op , root_reduce_low_rank ,
512
537
low_comm , low_comm -> c_coll -> coll_reduce_module );
513
538
}
514
539
else {
515
540
ret = low_comm -> c_coll -> coll_reduce ((char * )rbuf , NULL ,
516
- count , dtype , op , root_low_rank ,
541
+ count , dtype , op , root_reduce_low_rank ,
517
542
low_comm , low_comm -> c_coll -> coll_reduce_module );
518
543
}
519
544
}
520
545
else {
521
546
ret = low_comm -> c_coll -> coll_reduce ((char * )sbuf , (char * )rbuf ,
522
- count , dtype , op , root_low_rank ,
547
+ count , dtype , op , root_reduce_low_rank ,
523
548
low_comm , low_comm -> c_coll -> coll_reduce_module );
524
549
}
525
550
if (OPAL_UNLIKELY (OMPI_SUCCESS != ret )) {
0 commit comments