source: XIOS/dev/branch_openmp/extern/ep_dev/ep_intercomm.cpp @ 1527

Last change on this file since 1527 was 1527, checked in by yushan, 6 years ago

save dev

File size: 12.9 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4#include "ep_mpi.hpp"
5
6using namespace std;
7
8extern std::map<std::pair<int, int>, MPI_Group* > * tag_group_map;
9
10extern std::map<int, std::pair<ep_lib::MPI_Comm*, std::pair<int, int> > > * tag_comm_map;
11
12extern MPI_Group MPI_GROUP_WORLD;
13
14namespace ep_lib
15{
16  int MPI_Intercomm_create(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
17  {
18    assert(local_comm->is_ep);
19
20    int ep_rank, ep_rank_loc, mpi_rank;
21    int ep_size, num_ep, mpi_size;
22
23    ep_rank = local_comm->ep_comm_ptr->size_rank_info[0].first;
24    ep_rank_loc = local_comm->ep_comm_ptr->size_rank_info[1].first;
25    mpi_rank = local_comm->ep_comm_ptr->size_rank_info[2].first;
26    ep_size = local_comm->ep_comm_ptr->size_rank_info[0].second;
27    num_ep = local_comm->ep_comm_ptr->size_rank_info[1].second;
28    mpi_size = local_comm->ep_comm_ptr->size_rank_info[2].second;
29
30    int world_rank_and_num_ep[2];
31    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank_and_num_ep[0]);
32    world_rank_and_num_ep[1] = num_ep;
33
34    int remote_mpi_size;
35    int remote_ep_size;
36
37    int *local_world_rank_and_num_ep;
38    int *remote_world_rank_and_num_ep;
39    int *summed_world_rank_and_num_ep;
40
41
42    bool is_leader = ep_rank==local_leader? true : false;
43    bool is_local_leader = is_leader? true: (ep_rank_loc==0 && mpi_rank!=local_comm->ep_rank_map->at(local_leader).second ? true : false);
44    bool priority;
45
46    if(is_leader)
47    {
48      int leader_mpi_rank_in_peer;
49      MPI_Comm_rank(peer_comm, &leader_mpi_rank_in_peer);
50      if(leader_mpi_rank_in_peer == remote_leader) 
51      {
52        printf("same leader in peer_comm\n");
53        exit(1);
54      }
55      priority = leader_mpi_rank_in_peer<remote_leader? true : false;
56    }
57
58
59    MPI_Bcast(&priority, 1, MPI_INT, local_leader, local_comm);
60
61    if(is_local_leader)
62    {
63      local_world_rank_and_num_ep = new int[2*mpi_size];
64      ::MPI_Allgather(world_rank_and_num_ep, 2, to_mpi_type(MPI_INT), local_world_rank_and_num_ep, 2, to_mpi_type(MPI_INT), to_mpi_comm(local_comm->mpi_comm));
65    }
66
67   
68   
69    if(is_leader)
70    {
71      MPI_Request request;
72      MPI_Status status;
73
74      if(priority)
75      {
76        MPI_Isend(&mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
77        MPI_Wait(&request, &status);
78       
79        MPI_Irecv(&remote_mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
80        MPI_Wait(&request, &status);
81
82        MPI_Isend(&ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
83        MPI_Wait(&request, &status);
84       
85        MPI_Irecv(&remote_ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
86        MPI_Wait(&request, &status);
87      }
88      else
89      {
90        MPI_Irecv(&remote_mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
91        MPI_Wait(&request, &status);
92         
93        MPI_Isend(&mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
94        MPI_Wait(&request, &status);
95
96        MPI_Irecv(&remote_ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
97        MPI_Wait(&request, &status);
98         
99        MPI_Isend(&ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
100        MPI_Wait(&request, &status);
101      }
102    }
103
104    MPI_Bcast(&remote_mpi_size, 1, MPI_INT, local_leader, local_comm);
105    MPI_Bcast(&remote_ep_size, 1, MPI_INT, local_leader, local_comm);
106
107    remote_world_rank_and_num_ep = new int[2*remote_mpi_size];
108
109
110    if(is_leader)
111    {
112      MPI_Request request;
113      MPI_Status status;
114
115      if(priority)
116      {
117        MPI_Isend(local_world_rank_and_num_ep, 2*mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request);
118        MPI_Wait(&request, &status);
119       
120        MPI_Irecv(remote_world_rank_and_num_ep, 2*remote_mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request);
121        MPI_Wait(&request, &status);
122      }
123      else
124      {
125        MPI_Irecv(remote_world_rank_and_num_ep, 2*remote_mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request);
126        MPI_Wait(&request, &status);
127         
128        MPI_Isend(local_world_rank_and_num_ep, 2*mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request);
129        MPI_Wait(&request, &status);
130      }
131    }
132
133   
134    MPI_Bcast(remote_world_rank_and_num_ep, 2*remote_mpi_size, MPI_INT, local_leader, local_comm);
135   
136
137
138    bool is_new_leader  = is_local_leader;
139
140    if(is_local_leader && !priority)
141    {
142      for(int i=0; i<remote_mpi_size; i++)
143      {
144        if(world_rank_and_num_ep[0] == remote_world_rank_and_num_ep[2*i])
145        {
146          is_new_leader = false;
147          break;
148        }
149      }
150    } 
151   
152
153    ::MPI_Group *empty_group;
154    ::MPI_Group *local_group;
155    ::MPI_Group union_group;
156
157    if(is_local_leader)
158    {
159
160      int *ranks = new int[mpi_size];
161      for(int i=0; i<mpi_size; i++)
162      {
163        ranks[i] = local_world_rank_and_num_ep[2*i];
164      }
165
166      local_group = new ::MPI_Group;
167      ::MPI_Group_incl(MPI_GROUP_WORLD, mpi_size, ranks, local_group);
168
169      delete[] ranks;
170
171     
172      #pragma omp flush
173      #pragma omp critical (write_to_tag_group_map)
174      {
175        if(tag_group_map == 0) tag_group_map = new map< std::pair<int, int>, ::MPI_Group * >;
176       
177        tag_group_map->insert(std::make_pair(std::make_pair(tag, priority? 1 : 2), local_group));       
178      }
179    }
180
181    MPI_Barrier(local_comm);
182
183    if(is_leader)
184    {
185      MPI_Request request;
186      MPI_Status status;
187
188      int send_signal=0;
189      int recv_signal;
190
191      if(priority)
192      {
193        MPI_Isend(&send_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
194        MPI_Wait(&request, &status);
195       
196        MPI_Irecv(&recv_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
197        MPI_Wait(&request, &status);
198      }
199      else
200      {
201        MPI_Irecv(&recv_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
202        MPI_Wait(&request, &status);
203         
204        MPI_Isend(&send_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request);
205        MPI_Wait(&request, &status);
206      }
207    }
208
209
210    MPI_Barrier(local_comm);
211
212    if(is_new_leader)
213    {
214      ::MPI_Group *group1;
215      ::MPI_Group *group2;
216
217      empty_group = new ::MPI_Group;
218      *empty_group = MPI_GROUP_EMPTY;
219
220      #pragma omp flush
221      #pragma omp critical (read_from_tag_group_map)
222      {
223        group1 = tag_group_map->find(make_pair(tag, 1)) != tag_group_map->end()? tag_group_map->at(std::make_pair(tag, 1)) : empty_group;
224        group2 = tag_group_map->find(make_pair(tag, 2)) != tag_group_map->end()? tag_group_map->at(std::make_pair(tag, 2)) : empty_group;
225      }
226
227     
228#ifdef _showinfo
229
230      int group1_rank, group1_size;
231      int group2_rank, group2_size;
232      ::MPI_Group_rank(*group1, &group1_rank);
233      ::MPI_Group_size(*group1, &group1_size);
234      ::MPI_Group_rank(*group2, &group2_rank);
235      ::MPI_Group_size(*group2, &group2_size);
236
237#endif
238
239      ::MPI_Group_union(*group1, *group2, &union_group);
240
241
242      #pragma omp critical (read_from_tag_group_map)
243      {
244        tag_group_map->erase(make_pair(tag, 1));
245        tag_group_map->erase(make_pair(tag, 2));
246      }
247
248#ifdef _showinfo
249
250      int group_rank, group_size;
251      ::MPI_Group_rank(union_group, &group_rank);
252      ::MPI_Group_size(union_group, &group_size);
253      printf("rank = %d : map = %p, group1_rank/size = %d/%d, group2_rank/size = %d/%d, union_rank/size = %d/%d\n", ep_rank, tag_group_map, group1_rank, group1_size, group2_rank, group2_size, group_rank, group_size);
254#endif
255
256    }
257
258    int summed_world_rank_and_num_ep_size=mpi_size;
259    summed_world_rank_and_num_ep = new int[2*(mpi_size+remote_mpi_size)];
260
261
262    if(is_leader)
263    {
264     
265      for(int i=0; i<mpi_size; i++)
266      {
267        summed_world_rank_and_num_ep[2*i] = local_world_rank_and_num_ep[2*i];
268        summed_world_rank_and_num_ep[2*i+1] = local_world_rank_and_num_ep[2*i+1]; 
269      }
270
271      for(int i=0; i<remote_mpi_size; i++)
272      {
273        bool found=false;
274        for(int j=0; j<mpi_size; j++)
275        {
276          if(remote_world_rank_and_num_ep[2*i] == local_world_rank_and_num_ep[2*j])
277          {
278            found=true;
279            summed_world_rank_and_num_ep[2*j+1] += remote_world_rank_and_num_ep[2*i+1]; 
280          }
281        }
282        if(!found)
283        {
284          summed_world_rank_and_num_ep[2*summed_world_rank_and_num_ep_size] = remote_world_rank_and_num_ep[2*i];
285          summed_world_rank_and_num_ep[2*summed_world_rank_and_num_ep_size+1] = remote_world_rank_and_num_ep[2*i+1];
286          summed_world_rank_and_num_ep_size++;
287        }
288
289      }
290    }
291
292    MPI_Bcast(&summed_world_rank_and_num_ep_size, 1, MPI_INT, local_leader, local_comm);
293
294    MPI_Bcast(summed_world_rank_and_num_ep, 2*summed_world_rank_and_num_ep_size, MPI_INT, local_leader, local_comm);
295   
296   
297
298    int remote_num_ep;
299    for(int i=0; i<remote_mpi_size; i++)
300    {
301      if(remote_world_rank_and_num_ep[2*i] == world_rank_and_num_ep[0])
302      {
303        remote_num_ep = remote_world_rank_and_num_ep[2*i+1];
304        break;
305      }
306    }
307
308    int new_ep_rank_loc = priority? ep_rank_loc : ep_rank_loc+remote_num_ep;
309
310#ifdef _showinfo
311    printf("rank = %d, priority = %d, remote_num_ep = %d, new_ep_rank_loc = %d\n", ep_rank, priority, remote_num_ep, new_ep_rank_loc);
312#endif
313   
314    if(is_new_leader)
315    {
316      int new_num_ep;
317      for(int i=0; i<summed_world_rank_and_num_ep_size; i++)
318      {
319        if(summed_world_rank_and_num_ep[2*i] == world_rank_and_num_ep[0])
320        {
321          new_num_ep = summed_world_rank_and_num_ep[2*i+1];
322          break;
323        }
324      }
325
326      ::MPI_Comm mpi_comm;
327      ::MPI_Comm_create_group(to_mpi_comm(MPI_COMM_WORLD->mpi_comm), union_group, tag, &mpi_comm);
328
329
330      MPI_Comm *ep_comm;
331      MPI_Info info;
332      MPI_Comm_create_endpoints(&mpi_comm, new_num_ep, info, ep_comm);
333
334      #pragma omp critical (write_to_tag_comm_map)
335      {
336        if(tag_comm_map == 0) tag_comm_map = new std::map<int, std::pair<ep_lib::MPI_Comm*, std::pair<int, int> > >;
337        tag_comm_map->insert(std::make_pair(tag, std::make_pair(ep_comm, std::make_pair(new_num_ep, 0))));
338      }
339      #pragma omp flush
340    }
341
342
343    bool found=false;
344    while(!found)
345    {
346      #pragma omp flush
347      #pragma omp critical (read_from_tag_comm_map)
348      {
349        if(tag_comm_map!=0)
350        {
351          if(tag_comm_map->find(tag) != tag_comm_map->end())
352          {             
353            *newintercomm = tag_comm_map->at(tag).first[new_ep_rank_loc];
354           
355            tag_comm_map->at(tag).second.second++;
356            if(tag_comm_map->at(tag).second.second == tag_comm_map->at(tag).second.first)
357            {
358              tag_comm_map->erase(tag_comm_map->find(tag));
359            }
360
361            found=true;
362          }
363        }
364      } 
365    }
366
367    (*newintercomm)->is_intercomm = true;
368
369   
370
371
372    (*newintercomm)->inter_rank_map = new INTER_RANK_MAP;
373 
374
375    int rank_info[2];
376    rank_info[0] = ep_rank;
377    rank_info[1] = (*newintercomm)->ep_comm_ptr->size_rank_info[0].first;
378
379#ifdef _showinfo
380    printf("priority = %d, ep_rank = %d, new_ep_rank = %d\n", priority, rank_info[0], rank_info[1]);
381#endif
382
383    int *local_rank_info = new int[2*ep_size];
384    int *remote_rank_info = new int[2*remote_ep_size];
385
386    MPI_Allgather(rank_info, 2, MPI_INT, local_rank_info, 2, MPI_INT, local_comm);
387
388    if(is_leader)
389    {
390      MPI_Request request;
391      MPI_Status status;
392
393      if(priority)
394      {
395        MPI_Isend(local_rank_info, 2*ep_size, MPI_INT, remote_leader, tag, peer_comm, &request);
396        MPI_Wait(&request, &status);
397       
398        MPI_Irecv(remote_rank_info, 2*remote_ep_size, MPI_INT, remote_leader, tag, peer_comm, &request);
399        MPI_Wait(&request, &status);
400      }
401      else
402      {
403        MPI_Irecv(remote_rank_info, 2*remote_ep_size, MPI_INT, remote_leader, tag, peer_comm, &request);
404        MPI_Wait(&request, &status);
405         
406        MPI_Isend(local_rank_info, 2*ep_size, MPI_INT, remote_leader, tag, peer_comm, &request);
407        MPI_Wait(&request, &status);
408      }
409    }
410
411    MPI_Bcast(remote_rank_info, 2*remote_ep_size, MPI_INT, local_leader, local_comm);
412
413    for(int i=0; i<remote_ep_size; i++)
414    {
415      (*newintercomm)->inter_rank_map->insert(make_pair(remote_rank_info[2*i], remote_rank_info[2*i+1]));
416    }
417
418#ifdef _showinfo
419    if(ep_rank==4 && !priority)
420    {
421      for(std::map<int, int > :: iterator it=(*newintercomm)->inter_rank_map->begin(); it != (*newintercomm)->inter_rank_map->end(); it++)
422      {
423        printf("inter_rank_map[%d] = %d\n", it->first, it->second);
424      }
425    }
426#endif
427
428    (*newintercomm)->ep_comm_ptr->size_rank_info[0] = local_comm->ep_comm_ptr->size_rank_info[0];
429
430    if(is_local_leader)
431    {
432      delete[] local_world_rank_and_num_ep; 
433   
434      MPI_Group_free(local_group);
435      delete local_group;
436    }
437
438    if(is_new_leader)
439    {
440      MPI_Group_free(&union_group);
441      delete empty_group;
442    }
443
444    delete[] remote_world_rank_and_num_ep;
445    delete[] summed_world_rank_and_num_ep;
446    delete[] local_rank_info;
447    delete[] remote_rank_info;
448
449
450  }
451
452}
Note: See TracBrowser for help on using the repository browser.