Skip to content

Commit 550ac58

Browse files
amd-nithyavswenduwan
authored andcommitted
Fix mpi4py failures
Corner cases are handled to fix mpi4py failures. Signed-off-by: Nithya V S <[email protected]>
1 parent bccc940 commit 550ac58

File tree

7 files changed

+76
-29
lines changed

7 files changed

+76
-29
lines changed

ompi/mca/coll/acoll/coll_acoll_allgather.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ static inline int mca_coll_acoll_allgather_intra(const void *sbuf, int scount,
268268
data_blk_size[0] = bcount * (num_sgs - 2) + last_subgrp_rcnt;
269269
blk_ofst[0] = bcount;
270270
} else if (sg_id == num_sgs - 1) {
271+
if (last_subgrp_size < 2) {
272+
return err;
273+
}
271274
num_data_blks = 1;
272275
data_blk_size[0] = bcount * (num_sgs - 1);
273276
blk_ofst[0] = 0;
@@ -329,8 +332,7 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_
329332
int i;
330333
int err;
331334
int size;
332-
int rank, adj_rank;
333-
int num_sgs;
335+
int rank;
334336
int sg_size, log2_sg_size;
335337
int num_nodes, node_start, node_end, node_id;
336338
int node_size, last_node_size;
@@ -388,7 +390,9 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_
388390
if (size <= 2) {
389391
intra_comm = comm;
390392
} else {
391-
assert(subc->local_r_comm != NULL);
393+
if (num_nodes > 1) {
394+
assert(subc->local_r_comm != NULL);
395+
}
392396
intra_comm = num_nodes == 1 ? comm : subc->local_r_comm;
393397
}
394398
err = mca_coll_acoll_allgather_intra(sbuf, scount, sdtype, local_rbuf, rcount, rdtype,
@@ -454,12 +458,14 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_
454458
} /* End of if inter leader */
455459

456460
/* Do intra node broadcast */
457-
num_sgs = (node_size + sg_size - 1) >> log2_sg_size;
458461
if (node_id == 0) {
459462
num_data_blks = 1;
460463
data_blk_size[0] = bcount * (num_nodes - 2) + last_subgrp_rcnt;
461464
blk_ofst[0] = bcount;
462465
} else if (node_id == num_nodes - 1) {
466+
if (last_node_size < 2) {
467+
return err;
468+
}
463469
num_data_blks = 1;
464470
data_blk_size[0] = bcount * (num_nodes - 1);
465471
blk_ofst[0] = 0;

ompi/mca/coll/acoll/coll_acoll_barrier.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ static int mca_coll_acoll_barrier_send_subc(struct ompi_communicator_t *comm,
125125
int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module)
126126
{
127127
int size, ssize, bsize;
128-
int srank;
129128
int err = MPI_SUCCESS;
130129
int nreqs = 0;
131130
ompi_request_t **reqs;
@@ -141,6 +140,9 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base
141140

142141
subc = &acoll_module->subc[cid];
143142
size = ompi_comm_size(comm);
143+
if (size == 1) {
144+
return err;
145+
}
144146
if (!subc->initialized && size > 1) {
145147
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
146148
if (MPI_SUCCESS != err) {

ompi/mca/coll/acoll/coll_acoll_bcast.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static int bcast_binomial(void *buff, int count, struct ompi_datatype_t *datatyp
3737
struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs,
3838
int world_rank)
3939
{
40-
int msb_pos, sub_rank, peer, err;
40+
int msb_pos, sub_rank, peer, err = MPI_SUCCESS;
4141
int size, rank, dim;
4242
int i, mask;
4343

@@ -83,7 +83,7 @@ static int bcast_flat_tree(void *buff, int count, struct ompi_datatype_t *dataty
8383
int world_rank)
8484
{
8585
int peer;
86-
int err;
86+
int err = MPI_SUCCESS;
8787
int rank = ompi_comm_rank(comm);
8888
int size = ompi_comm_size(comm);
8989

ompi/mca/coll/acoll/coll_acoll_gather.c

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,16 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
4343
int i, err, rank, size;
4444
char *wkg = NULL, *workbuf = NULL;
4545
MPI_Status status;
46-
MPI_Aint incr, extent, lb;
4746
MPI_Aint sextent, sgap = 0, ssize;
48-
MPI_Aint rextent, rgap = 0, rsize;
47+
MPI_Aint rextent;
4948
int total_recv = 0;
5049
int sg_cnt, node_cnt;
5150
int cur_sg, root_sg;
5251
int cur_node, root_node;
5352
int is_base, is_local_root;
5453
int startr, endr, inc;
55-
int startn, endn, incn;
56-
int num_nodes, node_id;
54+
int startn, endn;
55+
int num_nodes;
5756
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
5857
coll_acoll_reserve_mem_t *reserve_mem_gather = &(acoll_module->reserve_mem_s);
5958

@@ -70,17 +69,13 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
7069
num_nodes = 1;
7170
}
7271

73-
ompi_datatype_get_extent(rdtype, &lb, &extent);
74-
incr = extent * (ptrdiff_t) rcount;
75-
76-
/* Setup root for reveive */
72+
/* Setup root for receive */
7773
if (rank == root) {
7874
ompi_datatype_type_extent(rdtype, &rextent);
79-
rsize = opal_datatype_span(&rdtype->super, (int64_t) rcount * size, &rgap);
8075
/* Just use the recv buffer */
8176
wkg = (char *) rbuf;
8277
if (sbuf != MPI_IN_PLACE) {
83-
MPI_Aint root_ofst = extent * (ptrdiff_t) (rcount * root);
78+
MPI_Aint root_ofst = rextent * (ptrdiff_t) (rcount * root);
8479
err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, wkg + (ptrdiff_t) root_ofst,
8580
rcount, rdtype);
8681
if (MPI_SUCCESS != err) {
@@ -100,7 +95,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
10095
is_local_root = (rank % node_cnt == 0) && (cur_node != root_node);
10196
startn = (rank / node_cnt) * node_cnt;
10297

103-
if (is_base || (rank == root)) {
98+
if (is_base) {
10499
int64_t buf_size = is_local_root ? (int64_t) scount * node_cnt : (int64_t) scount * sg_cnt;
105100
ompi_datatype_type_extent(sdtype, &sextent);
106101
ssize = opal_datatype_span(&sdtype->super, buf_size, &sgap);
@@ -111,7 +106,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
111106
return OMPI_ERR_OUT_OF_RESOURCE;
112107
}
113108
wkg = workbuf - sgap;
114-
tmprecv = wkg + extent * (ptrdiff_t) (rcount * (rank - startr));
109+
tmprecv = wkg + sextent * (ptrdiff_t) (rcount * (rank - startr));
115110
/* local copy to workbuf */
116111
err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, tmprecv, scount, sdtype);
117112
if (MPI_SUCCESS != err) {
@@ -123,7 +118,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
123118
rcount = scount;
124119
rextent = sextent;
125120
total_recv = rcount;
126-
} else {
121+
} else if (rank != root) {
127122
wkg = (char *) sbuf;
128123
total_recv = scount;
129124
}
@@ -141,9 +136,9 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
141136
continue;
142137
}
143138
if (rank == root) {
144-
tmprecv = wkg + extent * (ptrdiff_t) (rcount * i);
139+
tmprecv = wkg + rextent * (ptrdiff_t) (rcount * i);
145140
} else {
146-
tmprecv = wkg + extent * (ptrdiff_t) (rcount * (i - startr));
141+
tmprecv = wkg + rextent * (ptrdiff_t) (rcount * (i - startr));
147142
}
148143
err = MCA_PML_CALL(
149144
recv(tmprecv, rcount, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status));
@@ -161,10 +156,9 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
161156
if (endn > size) {
162157
endn = size;
163158
}
164-
incn = (rank == root) ? ((root != startn) ? 0 : sg_cnt) : sg_cnt;
165159
if (sg_cnt < size) {
166160
int local_root = (root_node == cur_node) ? root : startn;
167-
for (i = startn + incn; i < endn; i += sg_cnt) {
161+
for (i = startn; i < endn; i += sg_cnt) {
168162
int i_sg = i / sg_cnt;
169163
if ((rank != local_root) && (rank == i) && is_base) {
170164
err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, local_root,
@@ -173,7 +167,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
173167
}
174168
if ((rank == local_root) && (rank != i) && (i_sg != root_sg)) {
175169
int recv_amt = (i + sg_cnt > size) ? rcount * (size - i) : rcount * sg_cnt;
176-
MPI_Aint rcv_ofst = extent * (ptrdiff_t) (rcount * (i - startn));
170+
MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * (i - startn));
177171

178172
err = MCA_PML_CALL(recv(wkg + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i,
179173
MCA_COLL_BASE_TAG_GATHER, comm, &status));
@@ -189,7 +183,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
189183
}
190184

191185
/* All local roots ranks send to root */
192-
if (node_cnt < size) {
186+
if (node_cnt < size && num_nodes > 1) {
193187
for (i = 0; i < size; i += node_cnt) {
194188
int i_node = i / node_cnt;
195189
if ((rank != root) && (rank == i) && is_base) {
@@ -199,7 +193,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
199193
}
200194
if ((rank == root) && (rank != i) && (i_node != root_node)) {
201195
int recv_amt = (i + node_cnt > size) ? rcount * (size - i) : rcount * node_cnt;
202-
MPI_Aint rcv_ofst = extent * (ptrdiff_t) (rcount * i);
196+
MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * i);
203197

204198
err = MCA_PML_CALL(recv((char *) rbuf + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i,
205199
MCA_COLL_BASE_TAG_GATHER, comm, &status));

ompi/mca/coll/acoll/coll_acoll_module.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *co
4141
return NULL;
4242
}
4343

44+
if (OMPI_COMM_IS_INTER(comm)) {
45+
*priority = 0;
46+
return NULL;
47+
}
48+
if (OMPI_COMM_IS_INTRA(comm) && ompi_comm_size(comm) < 2) {
49+
*priority = 0;
50+
return NULL;
51+
}
52+
4453
*priority = mca_coll_acoll_priority;
4554

4655
/* Set topology params */

ompi/mca/coll/acoll/coll_acoll_reduce.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,11 @@ int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, int count,
382382
module);
383383
} else {
384384
return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op,
385-
root, comm, module, 0, 0);
385+
root, comm, module, 0, 0);
386386
}
387387
#else
388388
return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root,
389-
comm, module, 0, 0);
389+
comm, module, 0, 0);
390390
#endif
391391
}
392392
} else {

ompi/mca/coll/acoll/coll_acoll_utils.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm,
262262
mca_coll_base_module_allreduce_fn_t coll_allreduce_org = (comm)->c_coll->coll_allreduce;
263263
mca_coll_base_module_allgather_fn_t coll_allgather_org = (comm)->c_coll->coll_allgather;
264264
mca_coll_base_module_bcast_fn_t coll_bcast_org = (comm)->c_coll->coll_bcast;
265+
mca_coll_base_module_allreduce_fn_t coll_allreduce_loc, coll_allreduce_soc;
266+
mca_coll_base_module_allgather_fn_t coll_allgather_loc, coll_allgather_soc;
267+
mca_coll_base_module_bcast_fn_t coll_bcast_loc, coll_bcast_soc;
265268
coll_acoll_subcomms_t *subc;
266269
int err;
267270
int size = ompi_comm_size(comm);
@@ -362,6 +365,21 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm,
362365
subc->base_root[MCA_COLL_ACOLL_L3CACHE][i] = -1;
363366
subc->base_root[MCA_COLL_ACOLL_NUMA][i] = -1;
364367
}
368+
/* Store original collectives for local and socket comms */
369+
coll_allreduce_loc = (subc->local_comm)->c_coll->coll_allreduce;
370+
coll_allgather_loc = (subc->local_comm)->c_coll->coll_allgather;
371+
coll_bcast_loc = (subc->local_comm)->c_coll->coll_bcast;
372+
(subc->local_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring;
373+
(subc->local_comm)->c_coll->coll_allreduce
374+
= ompi_coll_base_allreduce_intra_recursivedoubling;
375+
(subc->local_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear;
376+
coll_allreduce_soc = (subc->socket_comm)->c_coll->coll_allreduce;
377+
coll_allgather_soc = (subc->socket_comm)->c_coll->coll_allgather;
378+
coll_bcast_soc = (subc->socket_comm)->c_coll->coll_bcast;
379+
(subc->socket_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring;
380+
(subc->socket_comm)->c_coll->coll_allreduce
381+
= ompi_coll_base_allreduce_intra_recursivedoubling;
382+
(subc->socket_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear;
365383
}
366384

367385
/* Further subcommunicators based on root */
@@ -519,6 +537,14 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm,
519537
}
520538
}
521539

540+
/* Restore originals for local and socket comms */
541+
(subc->local_comm)->c_coll->coll_allreduce = coll_allreduce_loc;
542+
(subc->local_comm)->c_coll->coll_allgather = coll_allgather_loc;
543+
(subc->local_comm)->c_coll->coll_bcast = coll_bcast_loc;
544+
(subc->socket_comm)->c_coll->coll_allreduce = coll_allreduce_soc;
545+
(subc->socket_comm)->c_coll->coll_allgather = coll_allgather_soc;
546+
(subc->socket_comm)->c_coll->coll_bcast = coll_bcast_soc;
547+
522548
/* For collectives where order is important (like gather, allgather),
523549
* split based on ranks. This is optimal for global communicators with
524550
* equal split among nodes, but suboptimal for other cases.
@@ -590,6 +616,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica
590616
data = (coll_acoll_data_t *) malloc(sizeof(coll_acoll_data_t));
591617
if (NULL == data) {
592618
line = __LINE__;
619+
ret = OMPI_ERR_OUT_OF_RESOURCE;
593620
goto error_hndl;
594621
}
595622
size = ompi_comm_size(comm);
@@ -601,6 +628,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica
601628
data->scratch = (char *) malloc(subc->xpmem_buf_size);
602629
if (NULL == data->scratch) {
603630
line = __LINE__;
631+
ret = OMPI_ERR_OUT_OF_RESOURCE;
604632
goto error_hndl;
605633
}
606634
} else {
@@ -611,41 +639,49 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica
611639
data->allseg_id = (xpmem_segid_t *) malloc(sizeof(xpmem_segid_t) * size);
612640
if (NULL == data->allseg_id) {
613641
line = __LINE__;
642+
ret = OMPI_ERR_OUT_OF_RESOURCE;
614643
goto error_hndl;
615644
}
616645
data->all_apid = (xpmem_apid_t *) malloc(sizeof(xpmem_apid_t) * size);
617646
if (NULL == data->all_apid) {
618647
line = __LINE__;
648+
ret = OMPI_ERR_OUT_OF_RESOURCE;
619649
goto error_hndl;
620650
}
621651
data->allshm_sbuf = (void **) malloc(sizeof(void *) * size);
622652
if (NULL == data->allshm_sbuf) {
623653
line = __LINE__;
654+
ret = OMPI_ERR_OUT_OF_RESOURCE;
624655
goto error_hndl;
625656
}
626657
data->allshm_rbuf = (void **) malloc(sizeof(void *) * size);
627658
if (NULL == data->allshm_rbuf) {
628659
line = __LINE__;
660+
ret = OMPI_ERR_OUT_OF_RESOURCE;
629661
goto error_hndl;
630662
}
631663
data->xpmem_saddr = (void **) malloc(sizeof(void *) * size);
632664
if (NULL == data->xpmem_saddr) {
633665
line = __LINE__;
666+
ret = OMPI_ERR_OUT_OF_RESOURCE;
634667
goto error_hndl;
635668
}
636669
data->xpmem_raddr = (void **) malloc(sizeof(void *) * size);
637670
if (NULL == data->xpmem_raddr) {
638671
line = __LINE__;
672+
ret = OMPI_ERR_OUT_OF_RESOURCE;
639673
goto error_hndl;
640674
}
641675
data->rcache = (mca_rcache_base_module_t **) malloc(sizeof(mca_rcache_base_module_t *) * size);
642676
if (NULL == data->rcache) {
643677
line = __LINE__;
678+
ret = OMPI_ERR_OUT_OF_RESOURCE;
644679
goto error_hndl;
645680
}
646681
seg_id = xpmem_make(0, XPMEM_MAXADDR_SIZE, XPMEM_PERMIT_MODE, (void *) 0666);
647682
if (seg_id == -1) {
648683
line = __LINE__;
684+
ret = -1;
649685
goto error_hndl;
650686
}
651687

0 commit comments

Comments
 (0)