source: XIOS/dev/branch_openmp/extern/ep_dev/ep_intercomm_kernel.cpp @ 1501

Last change on this file since 1501 was 1500, checked in by yushan, 6 years ago

save dev

File size: 26.4 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4#include "ep_mpi.hpp"
5
6using namespace std;
7
8
9namespace ep_lib
10{
11  int MPI_Intercomm_create_kernel(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
12  {
13    int ep_rank, ep_rank_loc, mpi_rank;
14    int ep_size, num_ep, mpi_size;
15
16    ep_rank = local_comm->ep_comm_ptr->size_rank_info[0].first;
17    ep_rank_loc = local_comm->ep_comm_ptr->size_rank_info[1].first;
18    mpi_rank = local_comm->ep_comm_ptr->size_rank_info[2].first;
19    ep_size = local_comm->ep_comm_ptr->size_rank_info[0].second;
20    num_ep = local_comm->ep_comm_ptr->size_rank_info[1].second;
21    mpi_size = local_comm->ep_comm_ptr->size_rank_info[2].second;
22
23    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
24                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
25
26    int rank_in_world;
27    int rank_in_local_parent;
28
29    int rank_in_peer_mpi[2];
30
31    int local_ep_size = ep_size;
32    int remote_ep_size;
33
34
35    ::MPI_Comm local_mpi_comm = to_mpi_comm(local_comm->mpi_comm);
36
37   
38    ::MPI_Comm_rank(to_mpi_comm(MPI_COMM_WORLD->mpi_comm), &rank_in_world);
39    ::MPI_Comm_rank(local_mpi_comm, &rank_in_local_parent);
40   
41
42    bool is_proc_master = false;
43    bool is_local_leader = false;
44    bool is_final_master = false;
45
46
47    if(ep_rank == local_leader) { is_proc_master = true; is_local_leader = true; is_final_master = true;}
48    if(ep_rank_loc == 0 && mpi_rank != local_comm->rank_map->at(local_leader).second) is_proc_master = true;
49
50
51    int size_info[4]; //! used for choose size of rank_info 0-> mpi_size of local_comm, 1-> mpi_size of remote_comm
52
53    int leader_info[4]; //! 0->world rank of local_leader, 1->world rank of remote leader
54
55
56    std::vector<int> ep_info[2]; //! 0-> num_ep in local_comm, 1->num_ep in remote_comm
57
58    std::vector<int> new_rank_info[4];
59    std::vector<int> new_ep_info[2];
60
61    std::vector<int> offset;
62
63    if(is_proc_master)
64    {
65
66      size_info[0] = mpi_size;
67
68      rank_info[0].resize(size_info[0]);
69      rank_info[1].resize(size_info[0]);
70
71
72
73      ep_info[0].resize(size_info[0]);
74
75      vector<int> send_buf(6);
76      vector<int> recv_buf(3*size_info[0]);
77
78      send_buf[0] = rank_in_world;
79      send_buf[1] = rank_in_local_parent;
80      send_buf[2] = num_ep;
81
82      ::MPI_Allgather(send_buf.data(), 3, to_mpi_type(MPI_INT), recv_buf.data(), 3, to_mpi_type(MPI_INT), local_mpi_comm);
83
84      for(int i=0; i<size_info[0]; i++)
85      {
86        rank_info[0][i] = recv_buf[3*i];
87        rank_info[1][i] = recv_buf[3*i+1];
88        ep_info[0][i]   = recv_buf[3*i+2];
89      }
90
91      if(is_local_leader)
92      {
93        leader_info[0] = rank_in_world;
94        leader_info[1] = remote_leader;
95
96        ::MPI_Comm_rank(to_mpi_comm(peer_comm->mpi_comm), &rank_in_peer_mpi[0]);
97
98        send_buf[0] = size_info[0];
99        send_buf[1] = local_ep_size;
100        send_buf[2] = rank_in_peer_mpi[0];
101
102       
103       
104        MPI_Request requests[2];
105        MPI_Status statuses[2];
106       
107        MPI_Isend(send_buf.data(), 3, MPI_INT, remote_leader, tag, peer_comm, &requests[0]);
108        MPI_Irecv(recv_buf.data(), 3, MPI_INT, remote_leader, tag, peer_comm, &requests[1]);
109
110
111        MPI_Waitall(2, requests, statuses);
112       
113        size_info[1] = recv_buf[0];
114        remote_ep_size = recv_buf[1];
115        rank_in_peer_mpi[1] = recv_buf[2];
116
117      }
118
119
120
121      send_buf[0] = size_info[1];
122      send_buf[1] = leader_info[0];
123      send_buf[2] = leader_info[1];
124      send_buf[3] = rank_in_peer_mpi[0];
125      send_buf[4] = rank_in_peer_mpi[1];
126
127      ::MPI_Bcast(send_buf.data(), 5, to_mpi_type(MPI_INT), local_comm->rank_map->at(local_leader).second, local_mpi_comm);
128
129      size_info[1] = send_buf[0];
130      leader_info[0] = send_buf[1];
131      leader_info[1] = send_buf[2];
132      rank_in_peer_mpi[0] = send_buf[3];
133      rank_in_peer_mpi[1] = send_buf[4];
134
135
136      rank_info[2].resize(size_info[1]);
137      rank_info[3].resize(size_info[1]);
138
139      ep_info[1].resize(size_info[1]);
140
141      send_buf.resize(3*size_info[0]);
142      recv_buf.resize(3*size_info[1]);
143
144      if(is_local_leader)
145      {
146        MPI_Request requests[2];
147        MPI_Status statuses[2];
148
149        std::copy ( rank_info[0].data(), rank_info[0].data() + size_info[0], send_buf.begin() );
150        std::copy ( rank_info[1].data(), rank_info[1].data() + size_info[0], send_buf.begin() + size_info[0] );
151        std::copy ( ep_info[0].data(),   ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[0] );
152
153        MPI_Isend(send_buf.data(), 3*size_info[0], MPI_INT, remote_leader, tag+1, peer_comm, &requests[0]);
154        MPI_Irecv(recv_buf.data(), 3*size_info[1], MPI_INT, remote_leader, tag+1, peer_comm, &requests[1]);
155       
156        MPI_Waitall(2, requests, statuses);
157      }
158
159      ::MPI_Bcast(recv_buf.data(), 3*size_info[1], to_mpi_type(MPI_INT), local_comm->rank_map->at(local_leader).second, local_mpi_comm);
160
161      std::copy ( recv_buf.data(), recv_buf.data() + size_info[1], rank_info[2].begin() );
162      std::copy ( recv_buf.data() + size_info[1], recv_buf.data() + 2*size_info[1], rank_info[3].begin()  );
163      std::copy ( recv_buf.data() + 2*size_info[1], recv_buf.data() + 3*size_info[1], ep_info[1].begin() );
164
165
166      offset.resize(size_info[0]);
167
168      if(leader_info[0]<leader_info[1]) // erase all ranks doubled with remote_comm, except the local leader
169      {
170
171        bool found = false;
172        int ep_local;
173        int ep_remote;
174        for(int i=0; i<size_info[0]; i++)
175        {
176          int target = rank_info[0][i];
177          found = false;
178          for(int j=0; j<size_info[1]; j++)
179          {
180            if(target == rank_info[2][j])
181            {
182              found = true;
183              ep_local = ep_info[0][j];
184              ep_remote = ep_info[1][j];
185              break;
186            }
187          }
188          if(found)
189          {
190
191            if(target == leader_info[0]) // the leader is doubled in remote
192            {
193              new_rank_info[0].push_back(target);
194              new_rank_info[1].push_back(rank_info[1][i]);
195
196              new_ep_info[0].push_back(ep_local + ep_remote);
197              offset[i] = 0;
198            }
199            else
200            {
201              offset[i] = ep_local;
202            }
203          }
204          else
205          {
206            new_rank_info[0].push_back(target);
207            new_rank_info[1].push_back(rank_info[1][i]);
208
209            new_ep_info[0].push_back(ep_info[0][i]);
210
211            offset[i] = 0;
212          }
213
214        }
215      }
216
217      else // erase rank doubled with remote leader
218      {
219
220        bool found = false;
221        int ep_local;
222        int ep_remote;
223        for(int i=0; i<size_info[0]; i++)
224        {
225          int target = rank_info[0][i];
226          found = false;
227          for(int j=0; j<size_info[1]; j++)
228          {
229
230            if(target == rank_info[2][j])
231            {
232              found = true;
233              ep_local = ep_info[0][j];
234              ep_remote = ep_info[1][j];
235              break;
236            }
237          }
238          if(found)
239          {
240            if(target != leader_info[1])
241            {
242              new_rank_info[0].push_back(target);
243              new_rank_info[1].push_back(rank_info[1][i]);
244
245              new_ep_info[0].push_back(ep_local + ep_remote);
246              offset[i] = 0;
247            }
248            else // found remote leader
249            {
250              offset[i] = ep_remote;
251            }
252          }
253          else
254          {
255            new_rank_info[0].push_back(target);
256            new_rank_info[1].push_back(rank_info[1][i]);
257
258            new_ep_info[0].push_back(ep_info[0][i]);
259            offset[i] = 0;
260          }
261        }
262      }
263
264      if(offset[mpi_rank] == 0)
265      {
266        is_final_master = true;
267      }
268
269
270      //! size_info[4]: 2->size of new_ep_info for local, 3->size of new_ep_info for remote
271
272      if(is_local_leader)
273      {
274        size_info[2] = new_ep_info[0].size();
275        MPI_Request requests[2];
276        MPI_Status statuses[2];
277        MPI_Isend(&size_info[2], 1, MPI_INT, remote_leader, tag+2, peer_comm, &requests[0]);
278        MPI_Irecv(&size_info[3], 1, MPI_INT, remote_leader, tag+2, peer_comm, &requests[1]);
279         
280        MPI_Waitall(2, requests, statuses);
281      }
282
283      ::MPI_Bcast(&size_info[2], 2, to_mpi_type(MPI_INT), local_comm->rank_map->at(local_leader).second, local_mpi_comm);
284
285      new_rank_info[2].resize(size_info[3]);
286      new_rank_info[3].resize(size_info[3]);
287      new_ep_info[1].resize(size_info[3]);
288
289      send_buf.resize(size_info[2]);
290      recv_buf.resize(size_info[3]);
291
292      if(is_local_leader)
293      {
294        MPI_Request requests[2];
295        MPI_Status statuses[2];
296
297        std::copy ( new_rank_info[0].data(), new_rank_info[0].data() + size_info[2], send_buf.begin() );
298        std::copy ( new_rank_info[1].data(), new_rank_info[1].data() + size_info[2], send_buf.begin() + size_info[2] );
299        std::copy ( new_ep_info[0].data(),   new_ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[2] );
300
301        MPI_Isend(send_buf.data(), 3*size_info[2], MPI_INT, remote_leader, tag+3, peer_comm, &requests[0]);
302        MPI_Irecv(recv_buf.data(), 3*size_info[3], MPI_INT, remote_leader, tag+3, peer_comm, &requests[1]);
303       
304        MPI_Waitall(2, requests, statuses);
305      }
306
307      ::MPI_Bcast(recv_buf.data(),   3*size_info[3], to_mpi_type(MPI_INT), local_comm->rank_map->at(local_leader).second, local_mpi_comm);
308
309      std::copy ( recv_buf.data(), recv_buf.data() + size_info[3], new_rank_info[2].begin() );
310      std::copy ( recv_buf.data() + size_info[3], recv_buf.data() + 2*size_info[3], new_rank_info[3].begin()  );
311      std::copy ( recv_buf.data() + 2*size_info[3], recv_buf.data() + 3*size_info[3], new_ep_info[1].begin() );
312
313    }
314
315   
316
317    if(is_proc_master)
318    {
319      //! leader_info[4]: 2-> rank of local leader in new_group generated comm;
320                      // 3-> rank of remote leader in new_group generated comm;
321      ::MPI_Group local_group;
322      ::MPI_Group new_group;
323      ::MPI_Comm *new_comm = new ::MPI_Comm;
324      ::MPI_Comm *intercomm = new ::MPI_Comm;
325
326      ::MPI_Comm_group(local_mpi_comm, &local_group);
327
328      ::MPI_Group_incl(local_group, size_info[2], new_rank_info[1].data(), &new_group);
329
330      ::MPI_Comm_create(local_mpi_comm, new_group, new_comm);
331
332
333
334      if(is_local_leader)
335      {
336        ::MPI_Comm_rank(*new_comm, &leader_info[2]);
337      }
338
339      ::MPI_Bcast(&leader_info[2], 1, to_mpi_type(MPI_INT), local_comm->rank_map->at(local_leader).second, local_mpi_comm);
340
341      if(new_comm != static_cast< ::MPI_Comm*>(MPI_COMM_NULL->mpi_comm))
342      {
343
344        ::MPI_Barrier(*new_comm);
345
346        ::MPI_Intercomm_create(*new_comm, leader_info[2], to_mpi_comm(peer_comm->mpi_comm), rank_in_peer_mpi[1], tag, intercomm);
347
348        int id;
349
350        ::MPI_Comm_rank(*new_comm, &id);
351        int my_num_ep = new_ep_info[0][id];
352
353        MPI_Comm *ep_intercomm;
354        MPI_Info info;
355        MPI_Comm_create_endpoints(new_comm, my_num_ep, info, ep_intercomm);
356
357
358        for(int i= 0; i<my_num_ep; i++)
359        {
360          ep_intercomm[i]->is_intercomm = true;
361
362          ep_intercomm[i]->ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
363          ep_intercomm[i]->ep_comm_ptr->intercomm->mpi_inter_comm = intercomm;
364          ep_intercomm[i]->ep_comm_ptr->comm_label = leader_info[0];
365        }
366
367
368        #pragma omp critical (write_to_tag_list)
369        tag_list.push_back(make_pair( make_pair(tag, min(leader_info[0], leader_info[1])) , ep_intercomm));
370        //printf("tag_list size = %lu\n", tag_list.size());
371      }
372    }
373
374    vector<int> bcast_buf(8);
375    if(is_local_leader)
376    {
377      std::copy(size_info, size_info+4, bcast_buf.begin());
378      std::copy(leader_info, leader_info+4, bcast_buf.begin()+4);
379    }
380
381    MPI_Bcast(bcast_buf.data(), 8, MPI_INT, local_leader, local_comm);
382
383    if(!is_local_leader)
384    {
385      std::copy(bcast_buf.begin(), bcast_buf.begin()+4, size_info);
386      std::copy(bcast_buf.begin()+4, bcast_buf.begin()+8, leader_info);
387    }
388
389    if(!is_local_leader)
390    {
391      new_rank_info[1].resize(size_info[2]);
392      ep_info[1].resize(size_info[1]);
393      offset.resize(size_info[0]);
394    }
395
396    bcast_buf.resize(size_info[2]+size_info[1]+size_info[0]+1);
397
398    if(is_local_leader)
399    {
400      bcast_buf[0] = remote_ep_size;
401      std::copy(new_rank_info[1].data(), new_rank_info[1].data()+size_info[2], bcast_buf.begin()+1);
402      std::copy(ep_info[1].data(), ep_info[1].data()+size_info[1], bcast_buf.begin()+size_info[2]+1);
403      std::copy(offset.data(), offset.data()+size_info[0], bcast_buf.begin()+size_info[2]+size_info[1]+1);
404    }
405
406    MPI_Bcast(bcast_buf.data(), size_info[2]+size_info[1]+size_info[0]+1, MPI_INT, local_leader, local_comm);
407
408    if(!is_local_leader)
409    {
410      remote_ep_size = bcast_buf[0];
411      std::copy(bcast_buf.data()+1, bcast_buf.data()+1+size_info[2], new_rank_info[1].begin());
412      std::copy(bcast_buf.data()+1+size_info[2], bcast_buf.data()+1+size_info[2]+size_info[1], ep_info[1].begin());
413      std::copy(bcast_buf.data()+1+size_info[2]+size_info[1], bcast_buf.data()+1+size_info[2]+size_info[1]+size_info[0], offset.begin());
414    }
415
416    int my_position = offset[rank_in_local_parent]+ep_rank_loc;
417   
418    MPI_Barrier_local(local_comm);
419    #pragma omp flush
420
421
422    #pragma omp critical (read_from_tag_list)
423    {
424      bool found = false;
425      while(!found)
426      {
427        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
428        {
429          if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
430          {
431            *newintercomm = iter->second[my_position];
432            found = true;
433            break;
434          }
435        }
436      }
437    }
438
439    MPI_Barrier(local_comm);
440
441    if(is_local_leader)
442    {
443      int local_flag = true;
444      int remote_flag = false;
445      MPI_Request mpi_requests[2];
446      MPI_Status mpi_statuses[2];
447     
448      MPI_Isend(&local_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[0]);
449      MPI_Irecv(&remote_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[1]);
450     
451      MPI_Waitall(2, mpi_requests, mpi_statuses);
452    }
453
454
455    MPI_Barrier(local_comm);
456
457    if(is_proc_master)
458    {
459      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
460      {
461        if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
462        {
463          tag_list.erase(iter);
464          break;
465        }
466      }
467    }
468
469    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
470    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
471
472    intercomm_ep_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[0].first;
473    intercomm_ep_rank_loc = (*newintercomm)->ep_comm_ptr->size_rank_info[1].first;
474    intercomm_mpi_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[2].first;
475    intercomm_ep_size = (*newintercomm)->ep_comm_ptr->size_rank_info[0].second;
476    intercomm_num_ep = (*newintercomm)->ep_comm_ptr->size_rank_info[1].second;
477    intercomm_mpi_size = (*newintercomm)->ep_comm_ptr->size_rank_info[2].second;
478
479    MPI_Bcast(&remote_ep_size, 1, MPI_INT, local_leader, local_comm);
480
481    int my_rank_map_elem[2];
482
483    my_rank_map_elem[0] = intercomm_ep_rank;
484    my_rank_map_elem[1] = (*newintercomm)->ep_comm_ptr->comm_label;
485
486    vector<pair<int, int> > local_rank_map_array;
487    vector<pair<int, int> > remote_rank_map_array;
488
489
490    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map = new RANK_MAP;
491    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->resize(local_ep_size);
492
493    MPI_Allgather(my_rank_map_elem, 2, MPI_INT, 
494      (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT, local_comm);
495
496    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
497    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->resize(remote_ep_size);
498
499    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[0] = local_comm->ep_comm_ptr->size_rank_info[0];
500    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[1] = local_comm->ep_comm_ptr->size_rank_info[1];
501    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[2] = local_comm->ep_comm_ptr->size_rank_info[2];
502
503    int local_intercomm_size = intercomm_ep_size;
504    int remote_intercomm_size;
505
506    int new_bcast_root_0 = 0;
507    int new_bcast_root = 0;
508
509
510    if(is_local_leader)
511    {
512      MPI_Request requests[4];
513      MPI_Status statuses[4];
514     
515      MPI_Isend((*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_ep_size, MPI_INT, remote_leader, tag+4, peer_comm, &requests[0]);
516      MPI_Irecv((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT, remote_leader, tag+4, peer_comm, &requests[1]);
517
518      MPI_Isend(&local_intercomm_size, 1, MPI_INT, remote_leader, tag+5, peer_comm, &requests[2]);
519      MPI_Irecv(&remote_intercomm_size, 1, MPI_INT, remote_leader, tag+5, peer_comm, &requests[3]);
520     
521      MPI_Waitall(4, requests, statuses);
522
523      new_bcast_root_0 = intercomm_ep_rank;
524    }
525
526    MPI_Allreduce(&new_bcast_root_0, &new_bcast_root, 1, MPI_INT, MPI_SUM, *newintercomm);
527
528
529    MPI_Bcast((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT, local_leader, local_comm);
530    MPI_Bcast(&remote_intercomm_size, 1, MPI_INT, new_bcast_root, *newintercomm);
531
532
533    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map = new RANK_MAP;
534    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->resize(remote_intercomm_size);
535
536
537
538
539    if(is_local_leader)
540    {
541      MPI_Request requests[2];
542      MPI_Status statuses[2];
543     
544      MPI_Isend((*newintercomm)->rank_map->data(), 2*local_intercomm_size, MPI_INT, remote_leader, tag+6, peer_comm, &requests[0]);
545      MPI_Irecv((*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT, remote_leader, tag+6, peer_comm, &requests[1]);
546     
547      MPI_Waitall(2, requests, statuses);
548    }
549
550    MPI_Bcast((*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT, new_bcast_root, *newintercomm);
551
552    (*newintercomm)->ep_comm_ptr->intercomm->local_comm = (local_comm->ep_comm_ptr->comm_list[ep_rank_loc]);
553    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_tag = tag;
554
555/*
556    for(int i=0; i<local_ep_size; i++)
557    if(local_comm->ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, local_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
558          (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->at(i).first, (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->at(i).second);
559
560    for(int i=0; i<remote_ep_size; i++)
561    if(local_comm->ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, remote_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
562          (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->at(i).first, (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->at(i).second);
563
564    for(int i=0; i<remote_intercomm_size; i++)
565    if(local_comm->ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, intercomm_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
566          (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first, (*newintercomm)->ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second);
567*/
568
569//    for(int i=0; i<(*newintercomm)->rank_map->size(); i++)
570//    if(local_comm->ep_comm_ptr->comm_label != 99) printf("ep_rank = %d, rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
571//          (*newintercomm)->rank_map->at(i).first, (*newintercomm)->rank_map->at(i).second);
572
573//    MPI_Comm *test_comm = newintercomm->ep_comm_ptr->intercomm->local_comm;
574//    int test_rank;
575//    MPI_Comm_rank(*test_comm, &test_rank);
576//    printf("=================test_rank = %d\n", test_rank);
577   
578   
579
580    return MPI_SUCCESS;
581
582  }
583
584
585
586
587  int MPI_Intercomm_create_unique_leader(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
588  {
589    //! mpi_size of local comm = 1
590    //! same world rank of leaders
591
592    int ep_rank, ep_rank_loc, mpi_rank;
593    int ep_size, num_ep, mpi_size;
594
595    ep_rank = local_comm->ep_comm_ptr->size_rank_info[0].first;
596    ep_rank_loc = local_comm->ep_comm_ptr->size_rank_info[1].first;
597    mpi_rank = local_comm->ep_comm_ptr->size_rank_info[2].first;
598    ep_size = local_comm->ep_comm_ptr->size_rank_info[0].second;
599    num_ep = local_comm->ep_comm_ptr->size_rank_info[1].second;
600    mpi_size = local_comm->ep_comm_ptr->size_rank_info[2].second;
601
602
603
604    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
605                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
606
607    int rank_in_world;
608
609    int rank_in_peer_mpi[2];
610
611    ::MPI_Comm_rank(to_mpi_comm(MPI_COMM_WORLD->mpi_comm), &rank_in_world);
612
613
614    int local_num_ep = num_ep;
615    int remote_num_ep;
616    int total_num_ep;
617
618    int leader_rank_in_peer[2];
619
620    int my_position;
621    int tag_label[2];
622
623    vector<int> send_buf(4);
624    vector<int> recv_buf(4);
625
626
627    if(ep_rank == local_leader)
628    {
629      MPI_Status status;
630
631
632
633      MPI_Comm_rank(peer_comm, &leader_rank_in_peer[0]);
634
635      send_buf[0] = local_num_ep;
636      send_buf[1] = leader_rank_in_peer[0];
637
638      MPI_Request req_s, req_r;
639
640      MPI_Isend(send_buf.data(), 2, MPI_INT, remote_leader, tag, peer_comm, &req_s);
641      MPI_Irecv(recv_buf.data(), 2, MPI_INT, remote_leader, tag, peer_comm, &req_r);
642
643
644      MPI_Wait(&req_s, &status);
645      MPI_Wait(&req_r, &status);
646
647      recv_buf[2] = leader_rank_in_peer[0];
648
649    }
650
651    MPI_Bcast(recv_buf.data(), 3, MPI_INT, local_leader, local_comm);
652
653    remote_num_ep = recv_buf[0];
654    leader_rank_in_peer[1] = recv_buf[1];
655    leader_rank_in_peer[0] = recv_buf[2];
656
657    total_num_ep = local_num_ep + remote_num_ep;
658
659
660    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
661    {
662      my_position = ep_rank_loc;
663      //! LEADER create EP
664      if(ep_rank == local_leader)
665      {
666        ::MPI_Comm *mpi_dup = new ::MPI_Comm;
667       
668        ::MPI_Comm_dup(to_mpi_comm(local_comm->mpi_comm), mpi_dup);
669
670        MPI_Comm *ep_intercomm;
671        MPI_Info info;
672        MPI_Comm_create_endpoints(mpi_dup, total_num_ep, info, ep_intercomm);
673
674
675        for(int i=0; i<total_num_ep; i++)
676        {
677          ep_intercomm[i]->is_intercomm = true;
678          ep_intercomm[i]->ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
679          ep_intercomm[i]->ep_comm_ptr->intercomm->mpi_inter_comm = 0;
680
681          ep_intercomm[i]->ep_comm_ptr->comm_label = leader_rank_in_peer[0];
682        }
683
684        tag_label[0] = TAG++;
685        tag_label[1] = rank_in_world;
686
687        #pragma omp critical (write_to_tag_list)
688        tag_list.push_back(make_pair( make_pair(tag_label[0], tag_label[1]) , ep_intercomm));
689
690        MPI_Request req_s;
691        MPI_Status sta_s;
692        MPI_Isend(tag_label, 2, MPI_INT, remote_leader, tag, peer_comm, &req_s);
693
694        MPI_Wait(&req_s, &sta_s);
695
696      }
697    }
698    else
699    {
700      //! Wait for EP creation
701      my_position = remote_num_ep + ep_rank_loc;
702      if(ep_rank == local_leader)
703      {
704        MPI_Status status;
705        MPI_Request req_r;
706        MPI_Irecv(tag_label, 2, MPI_INT, remote_leader, tag, peer_comm, &req_r);
707        MPI_Wait(&req_r, &status);
708      }
709    }
710
711    MPI_Bcast(tag_label, 2, MPI_INT, local_leader, local_comm);
712
713
714
715
716    #pragma omp critical (read_from_tag_list)
717    {
718      bool found = false;
719      while(!found)
720      {
721        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
722        {
723          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
724          {
725            *newintercomm =  iter->second[my_position];
726            found = true;
727            // tag_list.erase(iter);
728            break;
729          }
730        }
731      }
732    }
733
734    MPI_Barrier_local(local_comm);
735
736    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
737    {
738      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
739        {
740          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
741          {
742            tag_list.erase(iter);
743            break;
744          }
745        }
746    }
747
748
749
750    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
751    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
752
753    intercomm_ep_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[0].first;
754    intercomm_ep_rank_loc = (*newintercomm)->ep_comm_ptr->size_rank_info[1].first;
755    intercomm_mpi_rank = (*newintercomm)->ep_comm_ptr->size_rank_info[2].first;
756    intercomm_ep_size = (*newintercomm)->ep_comm_ptr->size_rank_info[0].second;
757    intercomm_num_ep = (*newintercomm)->ep_comm_ptr->size_rank_info[1].second;
758    intercomm_mpi_size = (*newintercomm)->ep_comm_ptr->size_rank_info[2].second;
759
760
761
762    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map  = new RANK_MAP;
763    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
764    (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->resize(local_num_ep);
765    (*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->resize(remote_num_ep);
766
767    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[0] = local_comm->ep_comm_ptr->size_rank_info[0];
768    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[1] = local_comm->ep_comm_ptr->size_rank_info[1];
769    (*newintercomm)->ep_comm_ptr->intercomm->size_rank_info[2] = local_comm->ep_comm_ptr->size_rank_info[2];
770
771
772
773    int local_rank_map_ele[2];
774    local_rank_map_ele[0] = intercomm_ep_rank;
775    local_rank_map_ele[1] = (*newintercomm)->ep_comm_ptr->comm_label;
776
777    MPI_Allgather(local_rank_map_ele, 2, MPI_INT, 
778      (*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT, local_comm);
779
780    if(ep_rank == local_leader)
781    {
782      MPI_Status status;
783      MPI_Request req_s, req_r;
784
785      MPI_Isend((*newintercomm)->ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_num_ep, MPI_INT, remote_leader, tag, peer_comm, &req_s);
786      MPI_Irecv((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT, remote_leader, tag, peer_comm, &req_r);
787
788
789      MPI_Wait(&req_s, &status);
790      MPI_Wait(&req_r, &status);
791
792    }
793
794    MPI_Bcast((*newintercomm)->ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT, local_leader, local_comm);
795    (*newintercomm)->ep_comm_ptr->intercomm->local_comm = (local_comm->ep_comm_ptr->comm_list[ep_rank_loc]);
796    (*newintercomm)->ep_comm_ptr->intercomm->intercomm_tag = tag;
797
798
799    return MPI_SUCCESS;
800  }
801
802
803}
Note: See TracBrowser for help on using the repository browser.