source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_merge.cpp @ 1196

Last change on this file since 1196 was 1134, checked in by yushan, 7 years ago

branch merged with trunk r1130

File size: 9.0 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4
5using namespace std;
6
7
8namespace ep_lib {
9
10  int MPI_Intercomm_merge_unique_leader(MPI_Comm inter_comm, bool high, MPI_Comm *newintracomm)
11  {
12    Debug("intercomm_merge with unique leader\n");
13
14
15
16    int ep_rank, ep_rank_loc, mpi_rank;
17    int ep_size, num_ep, mpi_size;
18
19    ep_rank = inter_comm.ep_comm_ptr->size_rank_info[0].first;
20    ep_rank_loc = inter_comm.ep_comm_ptr->size_rank_info[1].first;
21    mpi_rank = inter_comm.ep_comm_ptr->size_rank_info[2].first;
22    ep_size = inter_comm.ep_comm_ptr->size_rank_info[0].second;
23    num_ep = inter_comm.ep_comm_ptr->size_rank_info[1].second;
24    mpi_size = inter_comm.ep_comm_ptr->size_rank_info[2].second;
25
26    int local_high = high;
27    int remote_high;
28
29    int remote_ep_size = inter_comm.ep_comm_ptr->intercomm->remote_rank_map->size();
30
31    int local_ep_rank, local_ep_rank_loc, local_mpi_rank;
32    int local_ep_size, local_num_ep, local_mpi_size;
33
34    local_ep_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].first;
35    local_ep_rank_loc = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].first;
36    local_mpi_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].first;
37    local_ep_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].second;
38    local_num_ep = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].second;
39    local_mpi_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].second;
40
41
42    if(local_ep_rank == 0)
43    {
44      MPI_Status status;
45      MPI_Request req_s, req_r;
46      MPI_Isend(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_s);
47      MPI_Irecv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_r);
48
49      MPI_Wait(&req_s, &status);
50      MPI_Wait(&req_r, &status);
51    }
52
53
54    MPI_Bcast(&remote_high, 1, MPI_INT, 0, *(inter_comm.ep_comm_ptr->intercomm->local_comm));
55
56//    printf("%d, %d, %d, %d\n", local_ep_size, remote_ep_size, local_high, remote_high);
57
58
59    MPI_Comm_dup(inter_comm, newintracomm);
60
61    int my_ep_rank = local_high<remote_high? local_ep_rank: local_ep_rank+remote_ep_size;
62
63
64    int intra_ep_rank, intra_ep_rank_loc, intra_mpi_rank;
65    int intra_ep_size, intra_num_ep, intra_mpi_size;
66
67    intra_ep_rank = newintracomm->ep_comm_ptr->size_rank_info[0].first;
68    intra_ep_rank_loc = newintracomm->ep_comm_ptr->size_rank_info[1].first;
69    intra_mpi_rank = newintracomm->ep_comm_ptr->size_rank_info[2].first;
70    intra_ep_size = newintracomm->ep_comm_ptr->size_rank_info[0].second;
71    intra_num_ep = newintracomm->ep_comm_ptr->size_rank_info[1].second;
72    intra_mpi_size = newintracomm->ep_comm_ptr->size_rank_info[2].second;
73
74
75    MPI_Barrier_local(*newintracomm);
76
77
78    int *reorder;
79    if(intra_ep_rank_loc == 0)
80    {
81      reorder = new int[intra_ep_size];
82    }
83
84
85    MPI_Gather(&my_ep_rank, 1, MPI_INT, reorder, 1, MPI_INT, 0, *newintracomm);
86    if(intra_ep_rank_loc == 0)
87    {
88      ::MPI_Bcast(reorder, intra_ep_size, MPI_INT_STD, 0, static_cast< ::MPI_Comm>(newintracomm->mpi_comm));
89
90      vector< pair<int, int> > tmp_rank_map(intra_ep_size);
91
92
93      for(int i=0; i<intra_ep_size; i++)
94      {
95        tmp_rank_map[reorder[i]] = newintracomm->rank_map->at(i) ;
96      }
97
98      newintracomm->rank_map->swap(tmp_rank_map);
99
100      tmp_rank_map.clear();
101    }
102
103    MPI_Barrier_local(*newintracomm);
104
105    (*newintracomm).ep_comm_ptr->size_rank_info[0].first = my_ep_rank;
106
107    if(intra_ep_rank_loc == 0)
108    {
109      delete[] reorder;
110    }
111
112    return MPI_SUCCESS;
113  }
114
115
116
117
118
119  int MPI_Intercomm_merge(MPI_Comm inter_comm, bool high, MPI_Comm *newintracomm)
120  {
121
122    assert(inter_comm.is_intercomm);
123
124    if(inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->comm_label == -99)
125    {
126        return MPI_Intercomm_merge_unique_leader(inter_comm, high, newintracomm);
127    }
128
129
130    Debug("intercomm_merge kernel\n");
131
132    int ep_rank, ep_rank_loc, mpi_rank;
133    int ep_size, num_ep, mpi_size;
134
135    ep_rank = inter_comm.ep_comm_ptr->size_rank_info[0].first;
136    ep_rank_loc = inter_comm.ep_comm_ptr->size_rank_info[1].first;
137    mpi_rank = inter_comm.ep_comm_ptr->size_rank_info[2].first;
138    ep_size = inter_comm.ep_comm_ptr->size_rank_info[0].second;
139    num_ep = inter_comm.ep_comm_ptr->size_rank_info[1].second;
140    mpi_size = inter_comm.ep_comm_ptr->size_rank_info[2].second;
141
142
143    int local_ep_rank, local_ep_rank_loc, local_mpi_rank;
144    int local_ep_size, local_num_ep, local_mpi_size;
145
146
147    local_ep_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].first;
148    local_ep_rank_loc = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].first;
149    local_mpi_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].first;
150    local_ep_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].second;
151    local_num_ep = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].second;
152    local_mpi_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].second;
153
154    int remote_ep_size = inter_comm.ep_comm_ptr->intercomm->remote_rank_map->size();
155
156    int local_high = high;
157    int remote_high;
158
159    MPI_Barrier(inter_comm);
160
161//    if(local_ep_rank == 0 && high == false)
162//    {
163//      MPI_Status status;
164//      MPI_Send(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm);
165//      MPI_Recv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &status);
166//    }
167//
168//    if(local_ep_rank == 0 && high == true)
169//    {
170//      MPI_Status status;
171//      MPI_Recv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &status);
172//      MPI_Send(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm);
173//    }
174
175    if(local_ep_rank == 0)
176    {
177      MPI_Status status;
178      MPI_Request req_s, req_r;
179      MPI_Isend(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_s);
180      MPI_Irecv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_r);
181
182      MPI_Wait(&req_s, &status);
183      MPI_Wait(&req_r, &status);
184    }
185
186    MPI_Bcast(&remote_high, 1, MPI_INT, 0, *(inter_comm.ep_comm_ptr->intercomm->local_comm));
187
188    int intercomm_high;
189    if(ep_rank == 0) intercomm_high = local_high;
190    MPI_Bcast(&intercomm_high, 1, MPI_INT, 0, inter_comm);
191
192    //printf("remote_ep_size = %d, local_high = %d, remote_high = %d, intercomm_high = %d\n", remote_ep_size, local_high, remote_high, intercomm_high);
193
194
195    ::MPI_Comm mpi_intracomm;
196    MPI_Comm *ep_intracomm;
197
198    if(ep_rank_loc == 0)
199    {
200
201      ::MPI_Comm mpi_comm = static_cast< ::MPI_Comm>(inter_comm.ep_comm_ptr->intercomm->mpi_inter_comm);
202
203      ::MPI_Intercomm_merge(mpi_comm, intercomm_high, &mpi_intracomm);
204      MPI_Info info;
205      MPI_Comm_create_endpoints(mpi_intracomm, num_ep, info, ep_intracomm);
206
207      inter_comm.ep_comm_ptr->comm_list->mem_bridge = ep_intracomm;
208
209    }
210
211
212
213    MPI_Barrier_local(inter_comm);
214
215    *newintracomm = inter_comm.ep_comm_ptr->comm_list->mem_bridge[ep_rank_loc];
216
217    int my_ep_rank = local_high<remote_high? local_ep_rank: local_ep_rank+remote_ep_size;
218
219    int intra_ep_rank, intra_ep_rank_loc, intra_mpi_rank;
220    int intra_ep_size, intra_num_ep, intra_mpi_size;
221
222    intra_ep_rank = newintracomm->ep_comm_ptr->size_rank_info[0].first;
223    intra_ep_rank_loc = newintracomm->ep_comm_ptr->size_rank_info[1].first;
224    intra_mpi_rank = newintracomm->ep_comm_ptr->size_rank_info[2].first;
225    intra_ep_size = newintracomm->ep_comm_ptr->size_rank_info[0].second;
226    intra_num_ep = newintracomm->ep_comm_ptr->size_rank_info[1].second;
227    intra_mpi_size = newintracomm->ep_comm_ptr->size_rank_info[2].second;
228
229
230
231    MPI_Barrier_local(*newintracomm);
232
233
234    int *reorder;
235    if(intra_ep_rank_loc == 0)
236    {
237      reorder = new int[intra_ep_size];
238    }
239
240
241
242    MPI_Gather(&my_ep_rank, 1, MPI_INT, reorder, 1, MPI_INT, 0, *newintracomm);
243    if(intra_ep_rank_loc == 0)
244    {
245
246      ::MPI_Bcast(reorder, intra_ep_size, MPI_INT_STD, 0, static_cast< ::MPI_Comm>(newintracomm->mpi_comm));
247
248      vector< pair<int, int> > tmp_rank_map(intra_ep_size);
249
250
251      for(int i=0; i<intra_ep_size; i++)
252      {
253        tmp_rank_map[reorder[i]] = newintracomm->rank_map->at(i) ;
254      }
255
256      newintracomm->rank_map->swap(tmp_rank_map);
257
258      tmp_rank_map.clear();
259    }
260
261    MPI_Barrier_local(*newintracomm);
262
263    (*newintracomm).ep_comm_ptr->size_rank_info[0].first = my_ep_rank;
264
265
266    if(intra_ep_rank_loc == 0)
267    {
268      delete[] reorder;
269
270    }
271
272    /*
273    if(intra_ep_rank == 0)
274    {
275      for(int i=0; i<intra_ep_size; i++)
276      {
277        printf("intra rank_map[%d] = (%d, %d)\n", i, newintracomm->rank_map->at(i).first, newintracomm->rank_map->at(i).second);
278      }
279    }
280*/
281    return MPI_SUCCESS;
282
283  }
284
285
286}
Note: See TracBrowser for help on using the repository browser.