Skip to content

Commit 984944d

Browse files
authored
Merge pull request #12376 from wenduwan/han_gatherv
Implement hierarchical MPI_Gatherv and MPI_Scatterv
2 parents 0353f7e + 2152b61 commit 984944d

11 files changed

+1184
-7
lines changed

ompi/mca/coll/han/Makefile.am

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ coll_han_barrier.c \
2121
coll_han_bcast.c \
2222
coll_han_reduce.c \
2323
coll_han_scatter.c \
24+
coll_han_scatterv.c \
2425
coll_han_gather.c \
26+
coll_han_gatherv.c \
2527
coll_han_allreduce.c \
2628
coll_han_allgather.c \
2729
coll_han_component.c \
@@ -31,7 +33,8 @@ coll_han_algorithms.c \
3133
coll_han_dynamic.c \
3234
coll_han_dynamic_file.c \
3335
coll_han_topo.c \
34-
coll_han_subcomms.c
36+
coll_han_subcomms.c \
37+
coll_han_utils.c
3538

3639
# Make the output library in this directory, and name it either
3740
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la

ompi/mca/coll/han/coll_han.h

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
* reserved.
55
* Copyright (c) 2022 IBM Corporation. All rights reserved
66
* Copyright (c) 2020-2022 Bull S.A.S. All rights reserved.
7+
* Copyright (c) Amazon.com, Inc. or its affiliates.
8+
* All rights reserved.
79
* $COPYRIGHT$
810
*
911
* Additional copyrights may follow
@@ -189,7 +191,9 @@ typedef struct mca_coll_han_op_module_name_t {
189191
mca_coll_han_op_up_low_module_name_t allreduce;
190192
mca_coll_han_op_up_low_module_name_t allgather;
191193
mca_coll_han_op_up_low_module_name_t gather;
194+
mca_coll_han_op_up_low_module_name_t gatherv;
192195
mca_coll_han_op_up_low_module_name_t scatter;
196+
mca_coll_han_op_up_low_module_name_t scatterv;
193197
} mca_coll_han_op_module_name_t;
194198

195199
/**
@@ -233,10 +237,18 @@ typedef struct mca_coll_han_component_t {
233237
uint32_t han_gather_up_module;
234238
/* low level module for gather */
235239
uint32_t han_gather_low_module;
240+
/* up level module for gatherv */
241+
uint32_t han_gatherv_up_module;
242+
/* low level module for gatherv */
243+
uint32_t han_gatherv_low_module;
236244
/* up level module for scatter */
237245
uint32_t han_scatter_up_module;
238246
/* low level module for scatter */
239247
uint32_t han_scatter_low_module;
248+
/* up level module for scatterv */
249+
uint32_t han_scatterv_up_module;
250+
/* low level module for scatterv */
251+
uint32_t han_scatterv_low_module;
240252
/* name of the modules */
241253
mca_coll_han_op_module_name_t han_op_module_name;
242254
/* whether we need reproducible results
@@ -277,8 +289,10 @@ typedef struct mca_coll_han_single_collective_fallback_s {
277289
mca_coll_base_module_barrier_fn_t barrier;
278290
mca_coll_base_module_bcast_fn_t bcast;
279291
mca_coll_base_module_gather_fn_t gather;
292+
mca_coll_base_module_gatherv_fn_t gatherv;
280293
mca_coll_base_module_reduce_fn_t reduce;
281294
mca_coll_base_module_scatter_fn_t scatter;
295+
mca_coll_base_module_scatterv_fn_t scatterv;
282296
} module_fn;
283297
mca_coll_base_module_t* module;
284298
} mca_coll_han_single_collective_fallback_t;
@@ -296,7 +310,9 @@ typedef struct mca_coll_han_collectives_fallback_s {
296310
mca_coll_han_single_collective_fallback_t bcast;
297311
mca_coll_han_single_collective_fallback_t reduce;
298312
mca_coll_han_single_collective_fallback_t gather;
313+
mca_coll_han_single_collective_fallback_t gatherv;
299314
mca_coll_han_single_collective_fallback_t scatter;
315+
mca_coll_han_single_collective_fallback_t scatterv;
300316
} mca_coll_han_collectives_fallback_t;
301317

302318
/** Coll han module */
@@ -369,9 +385,14 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
369385
#define previous_gather fallback.gather.module_fn.gather
370386
#define previous_gather_module fallback.gather.module
371387

388+
#define previous_gatherv fallback.gatherv.module_fn.gatherv
389+
#define previous_gatherv_module fallback.gatherv.module
390+
372391
#define previous_scatter fallback.scatter.module_fn.scatter
373392
#define previous_scatter_module fallback.scatter.module
374393

394+
#define previous_scatterv fallback.scatterv.module_fn.scatterv
395+
#define previous_scatterv_module fallback.scatterv.module
375396

376397
/* macro to correctly load a fallback collective module */
377398
#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \
@@ -391,7 +412,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
391412
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \
392413
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \
393414
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \
415+
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \
394416
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \
417+
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \
395418
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \
396419
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \
397420
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \
@@ -432,11 +455,16 @@ int *mca_coll_han_topo_init(struct ompi_communicator_t *comm, mca_coll_han_modul
432455

433456
/* Utils */
434457
static inline void
435-
mca_coll_han_get_ranks(int *vranks, int root, int low_size,
436-
int *root_low_rank, int *root_up_rank)
458+
mca_coll_han_get_ranks(int *vranks, int w_rank, int low_size,
459+
int *low_rank, int *up_rank)
437460
{
438-
*root_up_rank = vranks[root] / low_size;
439-
*root_low_rank = vranks[root] % low_size;
461+
if (up_rank) {
462+
*up_rank = vranks[w_rank] / low_size;
463+
}
464+
465+
if (low_rank) {
466+
*low_rank = vranks[w_rank] % low_size;
467+
}
440468
}
441469

442470
const char* mca_coll_han_topo_lvl_to_str(TOPO_LVL_T topo_lvl);
@@ -469,11 +497,17 @@ int
469497
mca_coll_han_gather_intra_dynamic(GATHER_BASE_ARGS,
470498
mca_coll_base_module_t *module);
471499
int
500+
mca_coll_han_gatherv_intra_dynamic(GATHERV_BASE_ARGS,
501+
mca_coll_base_module_t *module);
502+
int
472503
mca_coll_han_reduce_intra_dynamic(REDUCE_BASE_ARGS,
473504
mca_coll_base_module_t *module);
474505
int
475506
mca_coll_han_scatter_intra_dynamic(SCATTER_BASE_ARGS,
476507
mca_coll_base_module_t *module);
508+
int
509+
mca_coll_han_scatterv_intra_dynamic(SCATTERV_BASE_ARGS,
510+
mca_coll_base_module_t *module);
477511

478512
int mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm,
479513
mca_coll_base_module_t *module);
@@ -486,4 +520,10 @@ ompi_coll_han_reorder_gather(const void *sbuf,
486520
struct ompi_communicator_t *comm,
487521
int * topo);
488522

523+
size_t
524+
coll_han_utils_gcd(const size_t *numerators, const size_t size);
525+
526+
int
527+
coll_han_utils_create_contiguous_datatype(size_t count, const ompi_datatype_t *oldType,
528+
ompi_datatype_t **newType);
489529
#endif /* MCA_COLL_HAN_EXPORT_H */

ompi/mca/coll/han/coll_han_algorithms.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,19 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] =
5959
{"simple", (fnptr_t) &mca_coll_han_scatter_intra_simple}, // 2-level
6060
{ 0 }
6161
},
62+
[SCATTERV] = (mca_coll_han_algorithm_value_t[]){
63+
{"intra", (fnptr_t) &mca_coll_han_scatterv_intra}, // 2-level
64+
{ 0 }
65+
},
6266
[GATHER] = (mca_coll_han_algorithm_value_t[]){
6367
{"intra", (fnptr_t) &mca_coll_han_gather_intra}, // 2-level
6468
{"simple", (fnptr_t) &mca_coll_han_gather_intra_simple}, // 2-level
6569
{ 0 }
6670
},
71+
[GATHERV] = (mca_coll_han_algorithm_value_t[]){
72+
{"intra", (fnptr_t) &mca_coll_han_gatherv_intra}, // 2-level
73+
{ 0 }
74+
},
6775
[ALLGATHER] = (mca_coll_han_algorithm_value_t[]){
6876
{"intra", (fnptr_t)&mca_coll_han_allgather_intra}, // 2-level
6977
{"simple", (fnptr_t)&mca_coll_han_allgather_intra_simple}, // 2-level

ompi/mca/coll/han/coll_han_algorithms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount,
159159
struct ompi_communicator_t *comm,
160160
mca_coll_base_module_t * module);
161161

162+
/* Scatterv */
163+
int
164+
mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts,
165+
const int *displs, struct ompi_datatype_t *sdtype,
166+
void *rbuf, int rcount,
167+
struct ompi_datatype_t *rdtype,
168+
int root,
169+
struct ompi_communicator_t *comm,
170+
mca_coll_base_module_t *module);
171+
162172
/* Gather */
163173
int
164174
mca_coll_han_gather_intra(const void *sbuf, int scount,
@@ -176,6 +186,13 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
176186
struct ompi_communicator_t *comm,
177187
mca_coll_base_module_t *module);
178188

189+
/* Gatherv */
190+
int
191+
mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
192+
void *rbuf, const int *rcounts, const int *displs,
193+
struct ompi_datatype_t *rdtype, int root,
194+
struct ompi_communicator_t *comm, mca_coll_base_module_t *module);
195+
179196
/* Allgather */
180197
int
181198
mca_coll_han_allgather_intra(const void *sbuf, int scount,

ompi/mca/coll/han/coll_han_component.c

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,21 @@ static int han_close(void)
146146
free(mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name);
147147
mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name = NULL;
148148

149+
free(mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name);
150+
mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name = NULL;
151+
free(mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name);
152+
mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name = NULL;
153+
149154
free(mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name);
150155
mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name = NULL;
151156
free(mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name);
152157
mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name = NULL;
153158

159+
free(mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name);
160+
mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name = NULL;
161+
free(mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name);
162+
mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name = NULL;
163+
154164
return OMPI_SUCCESS;
155165
}
156166

@@ -344,6 +354,18 @@ static int han_register(void)
344354
OPAL_INFO_LVL_9, &cs->han_gather_low_module,
345355
&cs->han_op_module_name.gather.han_op_low_module_name);
346356

357+
cs->han_gatherv_up_module = 0;
358+
(void) mca_coll_han_query_module_from_mca(c, "gatherv_up_module",
359+
"up level module for gatherv, 0 basic",
360+
OPAL_INFO_LVL_9, &cs->han_gatherv_up_module,
361+
&cs->han_op_module_name.gatherv.han_op_up_module_name);
362+
363+
cs->han_gatherv_low_module = 0;
364+
(void) mca_coll_han_query_module_from_mca(c, "gatherv_low_module",
365+
"low level module for gatherv, 0 basic",
366+
OPAL_INFO_LVL_9, &cs->han_gatherv_low_module,
367+
&cs->han_op_module_name.gatherv.han_op_low_module_name);
368+
347369
cs->han_scatter_up_module = 0;
348370
(void) mca_coll_han_query_module_from_mca(c, "scatter_up_module",
349371
"up level module for scatter, 0 libnbc, 1 adapt",
@@ -356,6 +378,18 @@ static int han_register(void)
356378
OPAL_INFO_LVL_9, &cs->han_scatter_low_module,
357379
&cs->han_op_module_name.scatter.han_op_low_module_name);
358380

381+
cs->han_scatterv_up_module = 0;
382+
(void) mca_coll_han_query_module_from_mca(c, "scatterv_up_module",
383+
"up level module for scatterv, 0 basic",
384+
OPAL_INFO_LVL_9, &cs->han_scatterv_up_module,
385+
&cs->han_op_module_name.scatterv.han_op_up_module_name);
386+
387+
cs->han_scatterv_low_module = 0;
388+
(void) mca_coll_han_query_module_from_mca(c, "scatterv_low_module",
389+
"low level module for scatterv, 0 basic",
390+
OPAL_INFO_LVL_9, &cs->han_scatterv_low_module,
391+
&cs->han_op_module_name.scatterv.han_op_low_module_name);
392+
359393
cs->han_reproducible = 0;
360394
(void) mca_base_component_var_register(c, "reproducible",
361395
"whether we need reproducible results "

0 commit comments

Comments
 (0)