4
4
* reserved.
5
5
* Copyright (c) 2022 IBM Corporation. All rights reserved
6
6
* Copyright (c) 2020-2022 Bull S.A.S. All rights reserved.
7
+ * Copyright (c) Amazon.com, Inc. or its affiliates.
8
+ * All rights reserved.
7
9
* $COPYRIGHT$
8
10
*
9
11
* Additional copyrights may follow
@@ -189,7 +191,9 @@ typedef struct mca_coll_han_op_module_name_t {
189
191
mca_coll_han_op_up_low_module_name_t allreduce ;
190
192
mca_coll_han_op_up_low_module_name_t allgather ;
191
193
mca_coll_han_op_up_low_module_name_t gather ;
194
+ mca_coll_han_op_up_low_module_name_t gatherv ;
192
195
mca_coll_han_op_up_low_module_name_t scatter ;
196
+ mca_coll_han_op_up_low_module_name_t scatterv ;
193
197
} mca_coll_han_op_module_name_t ;
194
198
195
199
/**
@@ -233,10 +237,18 @@ typedef struct mca_coll_han_component_t {
233
237
uint32_t han_gather_up_module ;
234
238
/* low level module for gather */
235
239
uint32_t han_gather_low_module ;
240
+ /* up level module for gatherv */
241
+ uint32_t han_gatherv_up_module ;
242
+ /* low level module for gatherv */
243
+ uint32_t han_gatherv_low_module ;
236
244
/* up level module for scatter */
237
245
uint32_t han_scatter_up_module ;
238
246
/* low level module for scatter */
239
247
uint32_t han_scatter_low_module ;
248
+ /* up level module for scatterv */
249
+ uint32_t han_scatterv_up_module ;
250
+ /* low level module for scatterv */
251
+ uint32_t han_scatterv_low_module ;
240
252
/* name of the modules */
241
253
mca_coll_han_op_module_name_t han_op_module_name ;
242
254
/* whether we need reproducible results
@@ -277,8 +289,10 @@ typedef struct mca_coll_han_single_collective_fallback_s {
277
289
mca_coll_base_module_barrier_fn_t barrier ;
278
290
mca_coll_base_module_bcast_fn_t bcast ;
279
291
mca_coll_base_module_gather_fn_t gather ;
292
+ mca_coll_base_module_gatherv_fn_t gatherv ;
280
293
mca_coll_base_module_reduce_fn_t reduce ;
281
294
mca_coll_base_module_scatter_fn_t scatter ;
295
+ mca_coll_base_module_scatterv_fn_t scatterv ;
282
296
} module_fn ;
283
297
mca_coll_base_module_t * module ;
284
298
} mca_coll_han_single_collective_fallback_t ;
@@ -296,7 +310,9 @@ typedef struct mca_coll_han_collectives_fallback_s {
296
310
mca_coll_han_single_collective_fallback_t bcast ;
297
311
mca_coll_han_single_collective_fallback_t reduce ;
298
312
mca_coll_han_single_collective_fallback_t gather ;
313
+ mca_coll_han_single_collective_fallback_t gatherv ;
299
314
mca_coll_han_single_collective_fallback_t scatter ;
315
+ mca_coll_han_single_collective_fallback_t scatterv ;
300
316
} mca_coll_han_collectives_fallback_t ;
301
317
302
318
/** Coll han module */
@@ -369,9 +385,14 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
369
385
#define previous_gather fallback.gather.module_fn.gather
370
386
#define previous_gather_module fallback.gather.module
371
387
388
+ #define previous_gatherv fallback.gatherv.module_fn.gatherv
389
+ #define previous_gatherv_module fallback.gatherv.module
390
+
372
391
#define previous_scatter fallback.scatter.module_fn.scatter
373
392
#define previous_scatter_module fallback.scatter.module
374
393
394
+ #define previous_scatterv fallback.scatterv.module_fn.scatterv
395
+ #define previous_scatterv_module fallback.scatterv.module
375
396
376
397
/* macro to correctly load a fallback collective module */
377
398
#define HAN_LOAD_FALLBACK_COLLECTIVE (HANM , COMM , COLL ) \
@@ -391,7 +412,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
391
412
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \
392
413
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \
393
414
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \
415
+ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \
394
416
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \
417
+ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \
395
418
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \
396
419
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \
397
420
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \
@@ -432,11 +455,16 @@ int *mca_coll_han_topo_init(struct ompi_communicator_t *comm, mca_coll_han_modul
432
455
433
456
/* Utils */
434
457
static inline void
435
- mca_coll_han_get_ranks (int * vranks , int root , int low_size ,
436
- int * root_low_rank , int * root_up_rank )
458
+ mca_coll_han_get_ranks (int * vranks , int w_rank , int low_size ,
459
+ int * low_rank , int * up_rank )
437
460
{
438
- * root_up_rank = vranks [root ] / low_size ;
439
- * root_low_rank = vranks [root ] % low_size ;
461
+ if (up_rank ) {
462
+ * up_rank = vranks [w_rank ] / low_size ;
463
+ }
464
+
465
+ if (low_rank ) {
466
+ * low_rank = vranks [w_rank ] % low_size ;
467
+ }
440
468
}
441
469
442
470
const char * mca_coll_han_topo_lvl_to_str (TOPO_LVL_T topo_lvl );
@@ -469,11 +497,17 @@ int
469
497
mca_coll_han_gather_intra_dynamic (GATHER_BASE_ARGS ,
470
498
mca_coll_base_module_t * module );
471
499
int
500
+ mca_coll_han_gatherv_intra_dynamic (GATHERV_BASE_ARGS ,
501
+ mca_coll_base_module_t * module );
502
+ int
472
503
mca_coll_han_reduce_intra_dynamic (REDUCE_BASE_ARGS ,
473
504
mca_coll_base_module_t * module );
474
505
int
475
506
mca_coll_han_scatter_intra_dynamic (SCATTER_BASE_ARGS ,
476
507
mca_coll_base_module_t * module );
508
+ int
509
+ mca_coll_han_scatterv_intra_dynamic (SCATTERV_BASE_ARGS ,
510
+ mca_coll_base_module_t * module );
477
511
478
512
int mca_coll_han_barrier_intra_simple (struct ompi_communicator_t * comm ,
479
513
mca_coll_base_module_t * module );
@@ -486,4 +520,10 @@ ompi_coll_han_reorder_gather(const void *sbuf,
486
520
struct ompi_communicator_t * comm ,
487
521
int * topo );
488
522
523
+ size_t
524
+ coll_han_utils_gcd (const size_t * numerators , const size_t size );
525
+
526
+ int
527
+ coll_han_utils_create_contiguous_datatype (size_t count , const ompi_datatype_t * oldType ,
528
+ ompi_datatype_t * * newType );
489
529
#endif /* MCA_COLL_HAN_EXPORT_H */
0 commit comments