source: XIOS/dev/branch_yushan/extern/src_ep_dev/ep_merge.cpp @ 1037

Last change on this file since 1037 was 1037, checked in by yushan, 7 years ago

initialize the branch

File size: 9.9 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4
5using namespace std;
6
7
8namespace ep_lib {
9
10  int MPI_Intercomm_merge_unique_leader(MPI_Comm inter_comm, bool high, MPI_Comm *newintracomm)
11  {
12    Debug("intercomm_merge with unique leader\n");
13
14
15
16    int ep_rank, ep_rank_loc, mpi_rank;
17    int ep_size, num_ep, mpi_size;
18
19    ep_rank = inter_comm.ep_comm_ptr->size_rank_info[0].first;
20    ep_rank_loc = inter_comm.ep_comm_ptr->size_rank_info[1].first;
21    mpi_rank = inter_comm.ep_comm_ptr->size_rank_info[2].first;
22    ep_size = inter_comm.ep_comm_ptr->size_rank_info[0].second;
23    num_ep = inter_comm.ep_comm_ptr->size_rank_info[1].second;
24    mpi_size = inter_comm.ep_comm_ptr->size_rank_info[2].second;
25
26    int local_high = high;
27    int remote_high;
28
29    int remote_ep_size = inter_comm.ep_comm_ptr->intercomm->remote_rank_map->size();
30
31    int local_ep_rank, local_ep_rank_loc, local_mpi_rank;
32    int local_ep_size, local_num_ep, local_mpi_size;
33
34    local_ep_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].first;
35    local_ep_rank_loc = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].first;
36    local_mpi_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].first;
37    local_ep_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].second;
38    local_num_ep = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].second;
39    local_mpi_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].second;
40
41
42//    if(local_ep_rank == 0 && high == false)
43//    {
44//      MPI_Status status;
45//      MPI_Send(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm);
46//      MPI_Recv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &status);
47//    }
48//
49//    if(local_ep_rank == 0 && high == true)
50//    {
51//      MPI_Status status;
52//      MPI_Recv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &status);
53//      MPI_Send(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm);
54//    }
55
56    if(local_ep_rank == 0)
57    {
58      MPI_Status status;
59      MPI_Request req_s, req_r;
60      MPI_Isend(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_s);
61      MPI_Irecv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_r);
62
63      MPI_Wait(&req_s, &status);
64      MPI_Wait(&req_r, &status);
65    }
66
67
68    MPI_Bcast(&remote_high, 1, MPI_INT, 0, *(inter_comm.ep_comm_ptr->intercomm->local_comm));
69
70//    printf("%d, %d, %d, %d\n", local_ep_size, remote_ep_size, local_high, remote_high);
71
72
73    MPI_Comm_dup(inter_comm, newintracomm);
74
75    int my_ep_rank = local_high<remote_high? local_ep_rank: local_ep_rank+remote_ep_size;
76
77
78    int intra_ep_rank, intra_ep_rank_loc, intra_mpi_rank;
79    int intra_ep_size, intra_num_ep, intra_mpi_size;
80
81    intra_ep_rank = newintracomm->ep_comm_ptr->size_rank_info[0].first;
82    intra_ep_rank_loc = newintracomm->ep_comm_ptr->size_rank_info[1].first;
83    intra_mpi_rank = newintracomm->ep_comm_ptr->size_rank_info[2].first;
84    intra_ep_size = newintracomm->ep_comm_ptr->size_rank_info[0].second;
85    intra_num_ep = newintracomm->ep_comm_ptr->size_rank_info[1].second;
86    intra_mpi_size = newintracomm->ep_comm_ptr->size_rank_info[2].second;
87
88
89    MPI_Barrier_local(*newintracomm);
90
91
92    int *reorder;
93    if(intra_ep_rank_loc == 0)
94    {
95      reorder = new int[intra_ep_size];
96    }
97
98
99    MPI_Gather(&my_ep_rank, 1, MPI_INT, reorder, 1, MPI_INT, 0, *newintracomm);
100    if(intra_ep_rank_loc == 0)
101    {
102      #ifdef _serialized
103      #pragma omp critical (_mpi_call)
104      #endif // _serialized
105      ::MPI_Bcast(reorder, intra_ep_size, MPI_INT_STD, 0, static_cast< ::MPI_Comm>(newintracomm->mpi_comm));
106
107      vector< pair<int, int> > tmp_rank_map(intra_ep_size);
108
109
110      for(int i=0; i<intra_ep_size; i++)
111      {
112        tmp_rank_map[reorder[i]] = newintracomm->rank_map->at(i) ;
113      }
114
115      newintracomm->rank_map->swap(tmp_rank_map);
116
117      tmp_rank_map.clear();
118    }
119
120    MPI_Barrier_local(*newintracomm);
121
122    (*newintracomm).ep_comm_ptr->size_rank_info[0].first = my_ep_rank;
123
124    if(intra_ep_rank_loc == 0)
125    {
126      delete[] reorder;
127    }
128
129    return MPI_SUCCESS;
130  }
131
132
133
134
135
136  int MPI_Intercomm_merge(MPI_Comm inter_comm, bool high, MPI_Comm *newintracomm)
137  {
138
139    assert(inter_comm.is_intercomm);
140
141    if(inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->comm_label == -99)
142    {
143        return MPI_Intercomm_merge_unique_leader(inter_comm, high, newintracomm);
144    }
145
146
147    Debug("intercomm_merge kernel\n");
148
149    int ep_rank, ep_rank_loc, mpi_rank;
150    int ep_size, num_ep, mpi_size;
151
152    ep_rank = inter_comm.ep_comm_ptr->size_rank_info[0].first;
153    ep_rank_loc = inter_comm.ep_comm_ptr->size_rank_info[1].first;
154    mpi_rank = inter_comm.ep_comm_ptr->size_rank_info[2].first;
155    ep_size = inter_comm.ep_comm_ptr->size_rank_info[0].second;
156    num_ep = inter_comm.ep_comm_ptr->size_rank_info[1].second;
157    mpi_size = inter_comm.ep_comm_ptr->size_rank_info[2].second;
158
159
160    int local_ep_rank, local_ep_rank_loc, local_mpi_rank;
161    int local_ep_size, local_num_ep, local_mpi_size;
162
163
164    local_ep_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].first;
165    local_ep_rank_loc = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].first;
166    local_mpi_rank = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].first;
167    local_ep_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[0].second;
168    local_num_ep = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[1].second;
169    local_mpi_size = inter_comm.ep_comm_ptr->intercomm->local_comm->ep_comm_ptr->size_rank_info[2].second;
170
171    int remote_ep_size = inter_comm.ep_comm_ptr->intercomm->remote_rank_map->size();
172
173    int local_high = high;
174    int remote_high;
175
176    MPI_Barrier(inter_comm);
177
178//    if(local_ep_rank == 0 && high == false)
179//    {
180//      MPI_Status status;
181//      MPI_Send(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm);
182//      MPI_Recv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &status);
183//    }
184//
185//    if(local_ep_rank == 0 && high == true)
186//    {
187//      MPI_Status status;
188//      MPI_Recv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &status);
189//      MPI_Send(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm);
190//    }
191
192    if(local_ep_rank == 0)
193    {
194      MPI_Status status;
195      MPI_Request req_s, req_r;
196      MPI_Isend(&local_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_s);
197      MPI_Irecv(&remote_high, 1, MPI_INT, 0, inter_comm.ep_comm_ptr->intercomm->intercomm_tag, inter_comm, &req_r);
198
199      MPI_Wait(&req_s, &status);
200      MPI_Wait(&req_r, &status);
201    }
202
203    MPI_Bcast(&remote_high, 1, MPI_INT, 0, *(inter_comm.ep_comm_ptr->intercomm->local_comm));
204
205    int intercomm_high;
206    if(ep_rank == 0) intercomm_high = local_high;
207    MPI_Bcast(&intercomm_high, 1, MPI_INT, 0, inter_comm);
208
209    //printf("remote_ep_size = %d, local_high = %d, remote_high = %d, intercomm_high = %d\n", remote_ep_size, local_high, remote_high, intercomm_high);
210
211
212    ::MPI_Comm mpi_intracomm;
213    MPI_Comm *ep_intracomm;
214
215    if(ep_rank_loc == 0)
216    {
217
218      ::MPI_Comm mpi_comm = static_cast< ::MPI_Comm>(inter_comm.ep_comm_ptr->intercomm->mpi_inter_comm);
219      #ifdef _serialized
220      #pragma omp critical (_mpi_call)
221      #endif // _serialized
222      ::MPI_Intercomm_merge(mpi_comm, intercomm_high, &mpi_intracomm);
223      MPI_Info info;
224      MPI_Comm_create_endpoints(mpi_intracomm, num_ep, info, ep_intracomm);
225
226      inter_comm.ep_comm_ptr->comm_list->mem_bridge = ep_intracomm;
227
228    }
229
230
231
232    MPI_Barrier_local(inter_comm);
233
234    *newintracomm = inter_comm.ep_comm_ptr->comm_list->mem_bridge[ep_rank_loc];
235
236    int my_ep_rank = local_high<remote_high? local_ep_rank: local_ep_rank+remote_ep_size;
237
238    int intra_ep_rank, intra_ep_rank_loc, intra_mpi_rank;
239    int intra_ep_size, intra_num_ep, intra_mpi_size;
240
241    intra_ep_rank = newintracomm->ep_comm_ptr->size_rank_info[0].first;
242    intra_ep_rank_loc = newintracomm->ep_comm_ptr->size_rank_info[1].first;
243    intra_mpi_rank = newintracomm->ep_comm_ptr->size_rank_info[2].first;
244    intra_ep_size = newintracomm->ep_comm_ptr->size_rank_info[0].second;
245    intra_num_ep = newintracomm->ep_comm_ptr->size_rank_info[1].second;
246    intra_mpi_size = newintracomm->ep_comm_ptr->size_rank_info[2].second;
247
248
249
250    MPI_Barrier_local(*newintracomm);
251
252
253    int *reorder;
254    if(intra_ep_rank_loc == 0)
255    {
256      reorder = new int[intra_ep_size];
257    }
258
259
260
261    MPI_Gather(&my_ep_rank, 1, MPI_INT, reorder, 1, MPI_INT, 0, *newintracomm);
262    if(intra_ep_rank_loc == 0)
263    {
264      #ifdef _serialized
265      #pragma omp critical (_mpi_call)
266      #endif // _serialized
267      ::MPI_Bcast(reorder, intra_ep_size, MPI_INT_STD, 0, static_cast< ::MPI_Comm>(newintracomm->mpi_comm));
268
269      vector< pair<int, int> > tmp_rank_map(intra_ep_size);
270
271
272      for(int i=0; i<intra_ep_size; i++)
273      {
274        tmp_rank_map[reorder[i]] = newintracomm->rank_map->at(i) ;
275      }
276
277      newintracomm->rank_map->swap(tmp_rank_map);
278
279      tmp_rank_map.clear();
280    }
281
282    MPI_Barrier_local(*newintracomm);
283
284    (*newintracomm).ep_comm_ptr->size_rank_info[0].first = my_ep_rank;
285
286
287    if(intra_ep_rank_loc == 0)
288    {
289      delete[] reorder;
290
291    }
292
293    /*
294    if(intra_ep_rank == 0)
295    {
296      for(int i=0; i<intra_ep_size; i++)
297      {
298        printf("intra rank_map[%d] = (%d, %d)\n", i, newintracomm->rank_map->at(i).first, newintracomm->rank_map->at(i).second);
299      }
300    }
301*/
302    return MPI_SUCCESS;
303
304  }
305
306
307}
Note: See TracBrowser for help on using the repository browser.