source: XIOS/dev/branch_openmp/extern/ep_dev/ep_intercomm_kernel.cpp @ 1503

Last change on this file since 1503 was 1503, checked in by yushan, 6 years ago

rank_map is passed from vector to map, in order to have more flexibility in comm_split

File size: 27.1 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4#include "ep_mpi.hpp"
5
6using namespace std;
7
8
9namespace ep_lib
10{
11  int MPI_Intercomm_create_kernel(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
12  {
13    int ep_rank, ep_rank_loc, mpi_rank;
14    int ep_size, num_ep, mpi_size;
15
16    ep_rank = local_comm->ep_comm_ptr->size_rank_info[0].first;
17    ep_rank_loc = local_comm->ep_comm_ptr->size_rank_info[1].first;
18    mpi_rank = local_comm->ep_comm_ptr->size_rank_info[2].first;
19    ep_size = local_comm->ep_comm_ptr->size_rank_info[0].second;
20    num_ep = local_comm->ep_comm_ptr->size_rank_info[1].second;
21    mpi_size = local_comm->ep_comm_ptr->size_rank_info[2].second;
22
23    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
24                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
25
26    int rank_in_world;
27    int rank_in_local_parent;
28
29    int rank_in_peer_mpi[2];
30
31    int local_ep_size = ep_size;
32    int remote_ep_size;
33
34
35    ::MPI_Comm local_mpi_comm = to_mpi_comm(local_comm->mpi_comm);
36
37   
38    ::MPI_Comm_rank(to_mpi_comm(MPI_COMM_WORLD->mpi_comm), &rank_in_world);
39    ::MPI_Comm_rank(local_mpi_comm, &rank_in_local_parent);
40   
41
42    bool is_proc_master = false;
43    bool is_local_leader = false;
44    bool is_final_master = false;
45
46
47    if(ep_rank == local_leader) { is_proc_master = true; is_local_leader = true; is_final_master = true;}
48    if(ep_rank_loc == 0 && mpi_rank != local_comm->ep_rank_map->at(local_leader).second) is_proc_master = true;
49
50
51    int size_info[4]; //! used for choose size of rank_info 0-> mpi_size of local_comm, 1-> mpi_size of remote_comm
52
53    int leader_info[4]; //! 0->world rank of local_leader, 1->world rank of remote leader
54
55
56    std::vector<int> ep_info[2]; //! 0-> num_ep in local_comm, 1->num_ep in remote_comm
57
58    std::vector<int> new_rank_info[4];
59    std::vector<int> new_ep_info[2];
60
61    std::vector<int> offset;
62
63    if(is_proc_master)
64    {
65
66      size_info[0] = mpi_size;
67
68      rank_info[0].resize(size_info[0]);
69      rank_info[1].resize(size_info[0]);
70
71
72
73      ep_info[0].resize(size_info[0]);
74
75      vector<int> send_buf(6);
76      vector<int> recv_buf(3*size_info[0]);
77
78      send_buf[0] = rank_in_world;
79      send_buf[1] = rank_in_local_parent;
80      send_buf[2] = num_ep;
81
82      ::MPI_Allgather(send_buf.data(), 3, to_mpi_type(MPI_INT), recv_buf.data(), 3, to_mpi_type(MPI_INT), local_mpi_comm);
83
84      for(int i=0; i<size_info[0]; i++)
85      {
86        rank_info[0][i] = recv_buf[3*i];
87        rank_info[1][i] = recv_buf[3*i+1];
88        ep_info[0][i]   = recv_buf[3*i+2];
89      }
90
91      if(is_local_leader)
92      {
93        leader_info[0] = rank_in_world;
94        leader_info[1] = remote_leader;
95
96        ::MPI_Comm_rank(to_mpi_comm(peer_comm->mpi_comm), &rank_in_peer_mpi[0]);
97
98        send_buf[0] = size_info[0];
99        send_buf[1] = local_ep_size;
100        send_buf[2] = rank_in_peer_mpi[0];
101
102       
103       
104        MPI_Request requests[2];
105        MPI_Status statuses[2];
106       
107        MPI_Isend(send_buf.data(), 3, MPI_INT, remote_leader, tag, peer_comm, &requests[0]);
108        MPI_Irecv(recv_buf.data(), 3, MPI_INT, remote_leader, tag, peer_comm, &requests[1]);
109
110
111        MPI_Waitall(2, requests, statuses);
112       
113        size_info[1] = recv_buf[0];
114        remote_ep_size = recv_buf[1];
115        rank_in_peer_mpi[1] = recv_buf[2];
116
117      }
118
119
120
121      send_buf[0] = size_info[1];
122      send_buf[1] = leader_info[0];
123      send_buf[2] = leader_info[1];
124      send_buf[3] = rank_in_peer_mpi[0];
125      send_buf[4] = rank_in_peer_mpi[1];
126
127      ::MPI_Bcast(send_buf.data(), 5, to_mpi_type(MPI_INT), local_comm->ep_rank_map->at(local_leader).second, local_mpi_comm);
128
129      size_info[1] = send_buf[0];
130      leader_info[0] = send_buf[1];
131      leader_info[1] = send_buf[2];
132      rank_in_peer_mpi[0] = send_buf[3];
133      rank_in_peer_mpi[1] = send_buf[4];
134
135
136      rank_info[2].resize(size_info[1]);
137      rank_info[3].resize(size_info[1]);
138
139      ep_info[1].resize(size_info[1]);
140
141      send_buf.resize(3*size_info[0]);
142      recv_buf.resize(3*size_info[1]);
143
144      if(is_local_leader)
145      {
146        MPI_Request requests[2];
147        MPI_Status statuses[2];
148
149        std::copy ( rank_info[0].data(), rank_info[0].data() + size_info[0], send_buf.begin() );
150        std::copy ( rank_info[1].data(), rank_info[1].data() + size_info[0], send_buf.begin() + size_info[0] );
151        std::copy ( ep_info[0].data(),   ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[0] );
152
153        MPI_Isend(send_buf.data(), 3*size_info[0], MPI_INT, remote_leader, tag+1, peer_comm, &requests[0]);
154        MPI_Irecv(recv_buf.data(), 3*size_info[1], MPI_INT, remote_leader, tag+1, peer_comm, &requests[1]);
155       
156        MPI_Waitall(2, requests, statuses);
157      }
158
159      ::MPI_Bcast(recv_buf.data(), 3*size_info[1], to_mpi_type(MPI_INT), local_comm->ep_rank_map->at(local_leader).second, local_mpi_comm);
160
161      std::copy ( recv_buf.data(), recv_buf.data() + size_info[1], rank_info[2].begin() );
162      std::copy ( recv_buf.data() + size_info[1], recv_buf.data() + 2*size_info[1], rank_info[3].begin()  );
163      std::copy ( recv_buf.data() + 2*size_info[1], recv_buf.data() + 3*size_info[1], ep_info[1].begin() );
164
165
166      offset.resize(size_info[0]);
167
168      if(leader_info[0]<leader_info[1]) // erase all ranks doubled with remote_comm, except the local leader
169      {
170
171        bool found = false;
172        int ep_local;
173        int ep_remote;
174        for(int i=0; i<size_info[0]; i++)
175        {
176          int target = rank_info[0][i];
177          found = false;
178          for(int j=0; j<size_info[1]; j++)
179          {
180            if(target == rank_info[2][j])
181            {
182              found = true;
183              ep_local = ep_info[0][j];
184              ep_remote = ep_info[1][j];
185              break;
186            }
187          }
188          if(found)
189          {
190
191            if(target == leader_info[0]) // the leader is doubled in remote
192            {
193              new_rank_info[0].push_back(target);
194              new_rank_info[1].push_back(rank_info[1][i]);
195
196              new_ep_info[0].push_back(ep_local + ep_remote);
197              offset[i] = 0;
198            }
199            else
200            {
201              offset[i] = ep_local;
202            }
203          }
204          else
205          {
206            new_rank_info[0].push_back(target);
207            new_rank_info[1].push_back(rank_info[1][i]);
208
209            new_ep_info[0].push_back(ep_info[0][i]);
210
211            offset[i] = 0;
212          }
213
214        }
215      }
216
217      else // erase rank doubled with remote leader
218      {
219
220        bool found = false;
221        int ep_local;
222        int ep_remote;
223        for(int i=0; i<size_info[0]; i++)
224        {
225          int target = rank_info[0][i];
226          found = false;
227          for(int j=0; j<size_info[1]; j++)
228          {
229
230            if(target == rank_info[2][j])
231            {
232              found = true;
233              ep_local = ep_info[0][j];
234              ep_remote = ep_info[1][j];
235              break;
236            }
237          }
238          if(found)
239          {
240            if(target != leader_info[1])
241            {
242              new_rank_info[0].push_back(target);
243              new_rank_info[1].push_back(rank_info[1][i]);
244
245              new_ep_info[0].push_back(ep_local + ep_remote);
246              offset[i] = 0;
247            }
248            else // found remote leader
249            {
250              offset[i] = ep_remote;
251            }
252          }
253          else
254          {
255            new_rank_info[0].push_back(target);
256            new_rank_info[1].push_back(rank_info[1][i]);
257
258            new_ep_info[0].push_back(ep_info[0][i]);
259            offset[i] = 0;
260          }
261        }
262      }
263
264      if(offset[mpi_rank] == 0)
265      {
266        is_final_master = true;
267      }
268
269
270      //! size_info[4]: 2->size of new_ep_info for local, 3->size of new_ep_info for remote
271
272      if(is_local_leader)
273      {
274        size_info[2] = new_ep_info[0].size();
275        MPI_Request requests[2];
276        MPI_Status statuses[2];
277        MPI_Isend(&size_info[2], 1, MPI_INT, remote_leader, tag+2, peer_comm, &requests[0]);
278        MPI_Irecv(&size_info[3], 1, MPI_INT, remote_leader, tag+2, peer_comm, &requests[1]);
279         
280        MPI_Waitall(2, requests, statuses);
281      }
282
283      ::MPI_Bcast(&size_info[2], 2, to_mpi_type(MPI_INT), local_comm->ep_rank_map->at(local_leader).second, local_mpi_comm);
284
285      new_rank_info[2].resize(size_info[3]);
286      new_rank_info[3].resize(size_info[3]);
287      new_ep_info[1].resize(size_info[3]);
288
289      send_buf.resize(size_info[2]);
290      recv_buf.resize(size_info[3]);
291
292      if(is_local_leader)
293      {
294        MPI_Request requests[2];
295        MPI_Status statuses[2];
296
297        std::copy ( new_rank_info[0].data(), new_rank_info[0].data() + size_info[2], send_buf.begin() );
298        std::copy ( new_rank_info[1].data(), new_rank_info[1].data() + size_info[2], send_buf.begin() + size_info[2] );
299        std::copy ( new_ep_info[0].data(),   new_ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[2] );
300
301        MPI_Isend(send_buf.data(), 3*size_info[2], MPI_INT, remote_leader, tag+3, peer_comm, &requests[0]);
302        MPI_Irecv(recv_buf.data(), 3*size_info[3], MPI_INT, remote_leader, tag+3, peer_comm, &requests[1]);
303       
304        MPI_Waitall(2, requests, statuses);
305      }
306
307      ::MPI_Bcast(recv_buf.data(),   3*size_info[3], to_mpi_type(MPI_INT), local_comm->ep_rank_map->at(local_leader).second, local_mpi_comm);
308
309      std::copy ( recv_buf.data(), recv_buf.data() + size_info[3], new_rank_info[2].begin() );
310      std::copy ( recv_buf.data() + size_info[3], recv_buf.data() + 2*size_info[3], new_rank_info[3].begin()  );
311      std::copy ( recv_buf.data() + 2*size_info[3], recv_buf.data() + 3*size_info[3], new_ep_info[1].begin() );
312
313    }
314
315   
316
317    if(is_proc_master)
318    {
319      //! leader_info[4]: 2-> rank of local leader in new_group generated comm;
320                      // 3-> rank of remote leader in new_group generated comm;
321      ::MPI_Group local_group;
322      ::MPI_Group new_group;
323      ::MPI_Comm *new_comm = new ::MPI_Comm;
324      ::MPI_Comm *intercomm = new ::MPI_Comm;
325
326      ::MPI_Comm_group(local_mpi_comm, &local_group);
327
328      ::MPI_Group_incl(local_group, size_info[2], new_rank_info[1].data(), &new_group);
329
330      ::MPI_Comm_create(local_mpi_comm, new_group, new_comm);
331
332
333
334      if(is_local_leader)
335      {
336        ::MPI_Comm_rank(*new_comm, &leader_info[2]);
337      }
338
339      ::MPI_Bcast(&leader_info[2], 1, to_mpi_type(MPI_INT), local_comm->ep_rank_map->at(local_leader).second, local_mpi_comm);
340
341      if(new_comm != static_cast< ::MPI_Comm*>(MPI_COMM_NULL->mpi_comm))
342      {
343
344        ::MPI_Barrier(*new_comm);
345
346        ::MPI_Intercomm_create(*new_comm, leader_info[2], to_mpi_comm(peer_comm->mpi_comm), rank_in_peer_mpi[1], tag, intercomm);
347
348        int id;
349
350        ::MPI_Comm_rank(*new_comm, &id);
351        int my_num_ep = new_ep_info[0][id];
352
353        MPI_Comm *ep_intercomm;
354        MPI_Info info;
355        MPI_Comm_create_endpoints(new_comm, my_num_ep, info, ep_intercomm);
356
357
358        for(int i= 0; i<my_num_ep; i++)
359        {
360          ep_intercomm[i]->is_intercomm = true;
361
362          ep_intercomm[i]->ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
363          ep_intercomm[i]->ep_comm_ptr->intercomm->mpi_inter_comm = intercomm;
364          ep_intercomm[i]->ep_comm_ptr->comm_label = leader_info[0];
365        }
366
367
368        #pragma omp critical (write_to_tag_list)
369        tag_list.push_back(make_pair( make_pair(tag, min(leader_info[0], leader_info[1])) , ep_intercomm));
370        //printf("tag_list size = %lu\n", tag_list.size());
371      }
372    }
373
374    vector<int> bcast_buf(8);
375    if(is_local_leader)
376    {
377      std::copy(size_info, size_info+4, bcast_buf.begin());
378      std::copy(leader_info, leader_info+4, bcast_buf.begin()+4);
379    }
380
381    MPI_Bcast(bcast_buf.data(), 8, MPI_INT, local_leader, local_comm);
382
383    if(!is_local_leader)
384    {
385      std::copy(bcast_buf.begin(), bcast_buf.begin()+4, size_info);
386      std::copy(bcast_buf.begin()+4, bcast_buf.begin()+8, leader_info);
387    }
388
389    if(!is_local_leader)
390    {
391      new_rank_info[1].resize(size_info[2]);
392      ep_info[1].resize(size_info[1]);
393      offset.resize(size_info[0]);
394    }
395
396    bcast_buf.resize(size_info[2]+size_info[1]+size_info[0]+1);
397
398    if(is_local_leader)
399    {
400      bcast_buf[0] = remote_ep_size;
401      std::copy(new_rank_info[1].data(), new_rank_info[1].data()+size_info[2], bcast_buf.begin()+1);
402      std::copy(ep_info[1].data(), ep_info[1].data()+size_info[1], bcast_buf.begin()+size_info[2]+1);
403      std::copy(offset.data(), offset.data()+size_info[0], bcast_buf.begin()+size_info[2]+size_info[1]+1);
404    }
405
406    MPI_Bcast(bcast_buf.data(), size_info[2]+size_info[1]+size_info[0]+1, MPI_INT, local_leader, local_comm);
407
408    if(!is_local_leader)
409    {
410      remote_ep_size = bcast_buf[0];
411      std::copy(bcast_buf.data()+1, bcast_buf.data()+1+size_info[2], new_rank_info[1].begin());
412      std::copy(bcast_buf.data()+1+size_info[2], bcast_buf.data()+1+size_info[2]+size_info[1], ep_info[1].begin());
413      std::copy(bcast_buf.data()+1+size_info[2]+size_info[1], bcast_buf.data()+1+size_info[2]+size_info[1]+size_info[0], offset.begin());
414    }
415
416    int my_position = offset[rank_in_local_parent]+ep_rank_loc;
417   
418    MPI_Barrier_local(local_comm);
419    #pragma omp flush
420
421
422    #pragma omp critical (read_from_tag_list)
423    {
424      bool found = false;
425      while(!found)
426      {
427        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
428        {
429          if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
430          {
431            *newintercomm = iter->second[my_position];
432            found = true;
433            break;
434          }
435        }
436      }
437    }
438
439    MPI_Barrier(local_comm);
440
441    if(is_local_leader)
442    {
443      int local_flag = true;
444      int remote_flag = false;
445      MPI_Request mpi_requests[2];
446      MPI_Status mpi_statuses[2];
447     
448      MPI_Isend(&local_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[0]);
449      MPI_Irecv(&remote_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[1]);
450     
451      MPI_Waitall(2, mpi_requests, mpi_statuses);
452    }
453
454
455    MPI_Barrier(local_comm);
456
457    if(is_proc_master)
458    {
459      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
460      {
461        if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
462        {
463          tag_list.erase(iter);
464          break;
465        }
466      }
467    }
468
469    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
470    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
471
472    intercomm_ep_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[0].first;
473    intercomm_ep_rank_loc = (*newintercomm)->ep_comm_ptr->size_rank_info[1].first;
474    intercomm_mpi_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[2].first;
475    intercomm_ep_size = (*newintercomm)->ep_comm_ptr->size_rank_info[0].second;
476    intercomm_num_ep = (*newintercomm)->ep_comm_ptr->size_rank_info[1].second;
477    intercomm_mpi_size = (*newintercomm)->ep_comm_ptr->size_rank_info[2].second;
478
479    MPI_Bcast(&remote_ep_size, 1, MPI_INT, local_leader, local_comm);
480
481    int my_rank_map_elem[2];
482
483    my_rank_map_elem[0] = intercomm_ep_rank;
484    my_rank_map_elem[1] = (*newintercomm)->ep_comm_ptr->comm_label;
485
486    vector<pair<int, int> > local_rank_map_array;
487    vector<pair<int, int> > remote_rank_map_array;
488
489
490    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map = new RANK_MAP;
491    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->resize(local_ep_size);
492
493    MPI_Allgather(my_rank_map_elem, 2, MPI_INT, 
494      (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT, local_comm);
495
496    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
497    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->resize(remote_ep_size);
498
499    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[0] = local_comm->ep_comm_ptr->size_rank_info[0];
500    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[1] = local_comm->ep_comm_ptr->size_rank_info[1];
501    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[2] = local_comm->ep_comm_ptr->size_rank_info[2];
502
503    int local_intercomm_size = intercomm_ep_size;
504    int remote_intercomm_size;
505
506    int new_bcast_root_0 = 0;
507    int new_bcast_root = 0;
508
509
510    if(is_local_leader)
511    {
512      MPI_Request requests[4];
513      MPI_Status statuses[4];
514     
515      MPI_Isend((*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_ep_size, MPI_INT, remote_leader, tag+4, peer_comm, &requests[0]);
516      MPI_Irecv((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT, remote_leader, tag+4, peer_comm, &requests[1]);
517
518      MPI_Isend(&local_intercomm_size, 1, MPI_INT, remote_leader, tag+5, peer_comm, &requests[2]);
519      MPI_Irecv(&remote_intercomm_size, 1, MPI_INT, remote_leader, tag+5, peer_comm, &requests[3]);
520     
521      MPI_Waitall(4, requests, statuses);
522
523      new_bcast_root_0 = intercomm_ep_rank;
524    }
525
526    MPI_Allreduce(&new_bcast_root_0, &new_bcast_root, 1, MPI_INT, MPI_SUM, *newintercomm);
527
528
529    MPI_Bcast((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT, local_leader, local_comm);
530    MPI_Bcast(&remote_intercomm_size, 1, MPI_INT, new_bcast_root, *newintercomm);
531
532
533    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map = new RANK_MAP;
534    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->resize(remote_intercomm_size);
535
536
537
538
539    if(is_local_leader)
540    {
541      MPI_Request requests[2];
542      MPI_Status statuses[2];
543     
544      std::vector<std::pair<int, std::pair<int, int> > > map2vec((*newintercomm)->ep_rank_map->size());
545      std::vector<std::pair<int, std::pair<int, int> > > vec2map((*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->size());
546     
547      int ii=0;
548      for(std::map<int, std::pair<int, int> >::iterator it = (*newintercomm)->ep_rank_map->begin(); it != (*newintercomm)->ep_rank_map->end(); it++)
549      {
550        map2vec[ii++] = make_pair(it->first, make_pair(it->second.first, it->second.second));
551      }
552     
553     
554      MPI_Isend(map2vec.data(), 3*local_intercomm_size, MPI_INT, remote_leader, tag+6, peer_comm, &requests[0]);
555      MPI_Irecv(vec2map.data(), 3*remote_intercomm_size, MPI_INT, remote_leader, tag+6, peer_comm, &requests[1]);
556     
557     
558      for(ii=0; ii<vec2map.size(); ii++)
559      {
560        (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->at(vec2map[ii].first) = make_pair(vec2map[ii].second.first, vec2map[ii].second.second);
561      }
562     
563      MPI_Waitall(2, requests, statuses);
564    }
565
566    MPI_Bcast((*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT, new_bcast_root, *newintercomm);
567
568    (*newintercomm)->ep_comm_ptr->intercomm->local_comm = (local_comm->ep_comm_ptr->comm_list[ep_rank_loc]);
569    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_tag = tag;
570
571/*
572    for(int i=0; i<local_ep_size; i++)
573    if(local_comm->ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, local_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
574          (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->at(i).first, (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->at(i).second);
575
576    for(int i=0; i<remote_ep_size; i++)
577    if(local_comm->ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, remote_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
578          (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->at(i).first, (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->at(i).second);
579
580    for(int i=0; i<remote_intercomm_size; i++)
581    if(local_comm->ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, intercomm_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
582          (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first, (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second);
583*/
584
585//    for(int i=0; i<(*newintercomm)->rank_map->size(); i++)
586//    if(local_comm->ep_comm_ptr->comm_label != 99) printf("ep_rank = %d, rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
587//          (*newintercomm)->rank_map->at(i).first, (*newintercomm)->rank_map->at(i).second);
588
589//    MPI_Comm *test_comm = newintercomm->ep_comm_ptr->intercomm->local_comm;
590//    int test_rank;
591//    MPI_Comm_rank(*test_comm, &test_rank);
592//    printf("=================test_rank = %d\n", test_rank);
593   
594   
595
596    return MPI_SUCCESS;
597
598  }
599
600
601
602
603  int MPI_Intercomm_create_unique_leader(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
604  {
605    //! mpi_size of local comm = 1
606    //! same world rank of leaders
607
608    int ep_rank, ep_rank_loc, mpi_rank;
609    int ep_size, num_ep, mpi_size;
610
611    ep_rank = local_comm->ep_comm_ptr->size_rank_info[0].first;
612    ep_rank_loc = local_comm->ep_comm_ptr->size_rank_info[1].first;
613    mpi_rank = local_comm->ep_comm_ptr->size_rank_info[2].first;
614    ep_size = local_comm->ep_comm_ptr->size_rank_info[0].second;
615    num_ep = local_comm->ep_comm_ptr->size_rank_info[1].second;
616    mpi_size = local_comm->ep_comm_ptr->size_rank_info[2].second;
617
618
619
620    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
621                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
622
623    int rank_in_world;
624
625    int rank_in_peer_mpi[2];
626
627    ::MPI_Comm_rank(to_mpi_comm(MPI_COMM_WORLD->mpi_comm), &rank_in_world);
628
629
630    int local_num_ep = num_ep;
631    int remote_num_ep;
632    int total_num_ep;
633
634    int leader_rank_in_peer[2];
635
636    int my_position;
637    int tag_label[2];
638
639    vector<int> send_buf(4);
640    vector<int> recv_buf(4);
641
642
643    if(ep_rank == local_leader)
644    {
645      MPI_Status status;
646
647
648
649      MPI_Comm_rank(peer_comm, &leader_rank_in_peer[0]);
650
651      send_buf[0] = local_num_ep;
652      send_buf[1] = leader_rank_in_peer[0];
653
654      MPI_Request req_s, req_r;
655
656      MPI_Isend(send_buf.data(), 2, MPI_INT, remote_leader, tag, peer_comm, &req_s);
657      MPI_Irecv(recv_buf.data(), 2, MPI_INT, remote_leader, tag, peer_comm, &req_r);
658
659
660      MPI_Wait(&req_s, &status);
661      MPI_Wait(&req_r, &status);
662
663      recv_buf[2] = leader_rank_in_peer[0];
664
665    }
666
667    MPI_Bcast(recv_buf.data(), 3, MPI_INT, local_leader, local_comm);
668
669    remote_num_ep = recv_buf[0];
670    leader_rank_in_peer[1] = recv_buf[1];
671    leader_rank_in_peer[0] = recv_buf[2];
672
673    total_num_ep = local_num_ep + remote_num_ep;
674
675
676    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
677    {
678      my_position = ep_rank_loc;
679      //! LEADER create EP
680      if(ep_rank == local_leader)
681      {
682        ::MPI_Comm *mpi_dup = new ::MPI_Comm;
683       
684        ::MPI_Comm_dup(to_mpi_comm(local_comm->mpi_comm), mpi_dup);
685
686        MPI_Comm *ep_intercomm;
687        MPI_Info info;
688        MPI_Comm_create_endpoints(mpi_dup, total_num_ep, info, ep_intercomm);
689
690
691        for(int i=0; i<total_num_ep; i++)
692        {
693          ep_intercomm[i]->is_intercomm = true;
694          ep_intercomm[i]->ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
695          ep_intercomm[i]->ep_comm_ptr->intercomm->mpi_inter_comm = 0;
696
697          ep_intercomm[i]->ep_comm_ptr->comm_label = leader_rank_in_peer[0];
698        }
699
700        tag_label[0] = TAG++;
701        tag_label[1] = rank_in_world;
702
703        #pragma omp critical (write_to_tag_list)
704        tag_list.push_back(make_pair( make_pair(tag_label[0], tag_label[1]) , ep_intercomm));
705
706        MPI_Request req_s;
707        MPI_Status sta_s;
708        MPI_Isend(tag_label, 2, MPI_INT, remote_leader, tag, peer_comm, &req_s);
709
710        MPI_Wait(&req_s, &sta_s);
711
712      }
713    }
714    else
715    {
716      //! Wait for EP creation
717      my_position = remote_num_ep + ep_rank_loc;
718      if(ep_rank == local_leader)
719      {
720        MPI_Status status;
721        MPI_Request req_r;
722        MPI_Irecv(tag_label, 2, MPI_INT, remote_leader, tag, peer_comm, &req_r);
723        MPI_Wait(&req_r, &status);
724      }
725    }
726
727    MPI_Bcast(tag_label, 2, MPI_INT, local_leader, local_comm);
728
729
730
731
732    #pragma omp critical (read_from_tag_list)
733    {
734      bool found = false;
735      while(!found)
736      {
737        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
738        {
739          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
740          {
741            *newintercomm =  iter->second[my_position];
742            found = true;
743            // tag_list.erase(iter);
744            break;
745          }
746        }
747      }
748    }
749
750    MPI_Barrier_local(local_comm);
751
752    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
753    {
754      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
755        {
756          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
757          {
758            tag_list.erase(iter);
759            break;
760          }
761        }
762    }
763
764
765
766    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
767    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
768
769    intercomm_ep_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[0].first;
770    intercomm_ep_rank_loc = (*newintercomm)->ep_comm_ptr->size_rank_info[1].first;
771    intercomm_mpi_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[2].first;
772    intercomm_ep_size = (*newintercomm)->ep_comm_ptr->size_rank_info[0].second;
773    intercomm_num_ep = (*newintercomm)->ep_comm_ptr->size_rank_info[1].second;
774    intercomm_mpi_size = (*newintercomm)->ep_comm_ptr->size_rank_info[2].second;
775
776
777
778    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map  = new RANK_MAP;
779    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
780    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->resize(local_num_ep);
781    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->resize(remote_num_ep);
782
783    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[0] = local_comm->ep_comm_ptr->size_rank_info[0];
784    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[1] = local_comm->ep_comm_ptr->size_rank_info[1];
785    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[2] = local_comm->ep_comm_ptr->size_rank_info[2];
786
787
788
789    int local_rank_map_ele[2];
790    local_rank_map_ele[0] = intercomm_ep_rank;
791    local_rank_map_ele[1] = (*newintercomm)->ep_comm_ptr->comm_label;
792
793    MPI_Allgather(local_rank_map_ele, 2, MPI_INT, 
794      (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT, local_comm);
795
796    if(ep_rank == local_leader)
797    {
798      MPI_Status status;
799      MPI_Request req_s, req_r;
800
801      MPI_Isend((*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_num_ep, MPI_INT, remote_leader, tag, peer_comm, &req_s);
802      MPI_Irecv((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT, remote_leader, tag, peer_comm, &req_r);
803
804
805      MPI_Wait(&req_s, &status);
806      MPI_Wait(&req_r, &status);
807
808    }
809
810    MPI_Bcast((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT, local_leader, local_comm);
811    (*newintercomm)->ep_comm_ptr->intercomm->local_comm = (local_comm->ep_comm_ptr->comm_list[ep_rank_loc]);
812    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_tag = tag;
813
814
815    return MPI_SUCCESS;
816  }
817
818
819}
Note: See TracBrowser for help on using the repository browser.