source: XIOS/dev/branch_openmp/extern/ep_dev/ep_intercomm_kernel.cpp @ 1381

Last change on this file since 1381 was 1381, checked in by yushan, 3 years ago

add folder for MPI EP-RMA development. Current: MPI_Win, MPI_win_create, MPI_win_fence, MPI_win_free

File size: 26.0 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4#include "ep_mpi.hpp"
5
6using namespace std;
7
8
9namespace ep_lib
10{
11  int MPI_Intercomm_create_kernel(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
12  {
13    int ep_rank, ep_rank_loc, mpi_rank;
14    int ep_size, num_ep, mpi_size;
15
16    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
17    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
18    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
19    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
20    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
21    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
22
23    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
24                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
25
26    int rank_in_world;
27    int rank_in_local_parent;
28
29    int rank_in_peer_mpi[2];
30
31    int local_ep_size = ep_size;
32    int remote_ep_size;
33
34
35    ::MPI_Comm local_mpi_comm = to_mpi_comm(local_comm.mpi_comm);
36
37   
38    ::MPI_Comm_rank(to_mpi_comm(MPI_COMM_WORLD.mpi_comm), &rank_in_world);
39    ::MPI_Comm_rank(local_mpi_comm, &rank_in_local_parent);
40   
41
42    bool is_proc_master = false;
43    bool is_local_leader = false;
44    bool is_final_master = false;
45
46
47    if(ep_rank == local_leader) { is_proc_master = true; is_local_leader = true; is_final_master = true;}
48    if(ep_rank_loc == 0 && mpi_rank != local_comm.rank_map->at(local_leader).second) is_proc_master = true;
49
50
51    int size_info[4]; //! used for choose size of rank_info 0-> mpi_size of local_comm, 1-> mpi_size of remote_comm
52
53    int leader_info[4]; //! 0->world rank of local_leader, 1->world rank of remote leader
54
55
56    std::vector<int> ep_info[2]; //! 0-> num_ep in local_comm, 1->num_ep in remote_comm
57
58    std::vector<int> new_rank_info[4];
59    std::vector<int> new_ep_info[2];
60
61    std::vector<int> offset;
62
63    if(is_proc_master)
64    {
65
66      size_info[0] = mpi_size;
67
68      rank_info[0].resize(size_info[0]);
69      rank_info[1].resize(size_info[0]);
70
71
72
73      ep_info[0].resize(size_info[0]);
74
75      vector<int> send_buf(6);
76      vector<int> recv_buf(3*size_info[0]);
77
78      send_buf[0] = rank_in_world;
79      send_buf[1] = rank_in_local_parent;
80      send_buf[2] = num_ep;
81
82      ::MPI_Allgather(send_buf.data(), 3, to_mpi_type(MPI_INT), recv_buf.data(), 3, to_mpi_type(MPI_INT), local_mpi_comm);
83
84      for(int i=0; i<size_info[0]; i++)
85      {
86        rank_info[0][i] = recv_buf[3*i];
87        rank_info[1][i] = recv_buf[3*i+1];
88        ep_info[0][i]   = recv_buf[3*i+2];
89      }
90
91      if(is_local_leader)
92      {
93        leader_info[0] = rank_in_world;
94        leader_info[1] = remote_leader;
95
96        ::MPI_Comm_rank(to_mpi_comm(peer_comm.mpi_comm), &rank_in_peer_mpi[0]);
97
98        send_buf[0] = size_info[0];
99        send_buf[1] = local_ep_size;
100        send_buf[2] = rank_in_peer_mpi[0];
101
102       
103       
104        MPI_Request requests[2];
105        MPI_Status statuses[2];
106       
107        MPI_Isend(send_buf.data(), 3, MPI_INT, remote_leader, tag, peer_comm, &requests[0]);
108        MPI_Irecv(recv_buf.data(), 3, MPI_INT, remote_leader, tag, peer_comm, &requests[1]);
109
110
111        MPI_Waitall(2, requests, statuses);
112       
113        size_info[1] = recv_buf[0];
114        remote_ep_size = recv_buf[1];
115        rank_in_peer_mpi[1] = recv_buf[2];
116
117      }
118
119      send_buf[0] = size_info[1];
120      send_buf[1] = leader_info[0];
121      send_buf[2] = leader_info[1];
122      send_buf[3] = rank_in_peer_mpi[0];
123      send_buf[4] = rank_in_peer_mpi[1];
124
125      ::MPI_Bcast(send_buf.data(), 5, to_mpi_type(MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
126
127      size_info[1] = send_buf[0];
128      leader_info[0] = send_buf[1];
129      leader_info[1] = send_buf[2];
130      rank_in_peer_mpi[0] = send_buf[3];
131      rank_in_peer_mpi[1] = send_buf[4];
132
133
134      rank_info[2].resize(size_info[1]);
135      rank_info[3].resize(size_info[1]);
136
137      ep_info[1].resize(size_info[1]);
138
139      send_buf.resize(3*size_info[0]);
140      recv_buf.resize(3*size_info[1]);
141
142      if(is_local_leader)
143      {
144        MPI_Request requests[2];
145        MPI_Status statuses[2];
146
147        std::copy ( rank_info[0].data(), rank_info[0].data() + size_info[0], send_buf.begin() );
148        std::copy ( rank_info[1].data(), rank_info[1].data() + size_info[0], send_buf.begin() + size_info[0] );
149        std::copy ( ep_info[0].data(),   ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[0] );
150
151        MPI_Isend(send_buf.data(), 3*size_info[0], MPI_INT, remote_leader, tag+1, peer_comm, &requests[0]);
152        MPI_Irecv(recv_buf.data(), 3*size_info[1], MPI_INT, remote_leader, tag+1, peer_comm, &requests[1]);
153       
154        MPI_Waitall(2, requests, statuses);
155      }
156
157      ::MPI_Bcast(recv_buf.data(), 3*size_info[1], to_mpi_type(MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
158
159      std::copy ( recv_buf.data(), recv_buf.data() + size_info[1], rank_info[2].begin() );
160      std::copy ( recv_buf.data() + size_info[1], recv_buf.data() + 2*size_info[1], rank_info[3].begin()  );
161      std::copy ( recv_buf.data() + 2*size_info[1], recv_buf.data() + 3*size_info[1], ep_info[1].begin() );
162
163
164      offset.resize(size_info[0]);
165
166      if(leader_info[0]<leader_info[1]) // erase all ranks doubled with remote_comm, except the local leader
167      {
168
169        bool found = false;
170        int ep_local;
171        int ep_remote;
172        for(int i=0; i<size_info[0]; i++)
173        {
174          int target = rank_info[0][i];
175          found = false;
176          for(int j=0; j<size_info[1]; j++)
177          {
178            if(target == rank_info[2][j])
179            {
180              found = true;
181              ep_local = ep_info[0][j];
182              ep_remote = ep_info[1][j];
183              break;
184            }
185          }
186          if(found)
187          {
188
189            if(target == leader_info[0]) // the leader is doubled in remote
190            {
191              new_rank_info[0].push_back(target);
192              new_rank_info[1].push_back(rank_info[1][i]);
193
194              new_ep_info[0].push_back(ep_local + ep_remote);
195              offset[i] = 0;
196            }
197            else
198            {
199              offset[i] = ep_local;
200            }
201          }
202          else
203          {
204            new_rank_info[0].push_back(target);
205            new_rank_info[1].push_back(rank_info[1][i]);
206
207            new_ep_info[0].push_back(ep_info[0][i]);
208
209            offset[i] = 0;
210          }
211
212        }
213      }
214
215      else // erase rank doubled with remote leader
216      {
217
218        bool found = false;
219        int ep_local;
220        int ep_remote;
221        for(int i=0; i<size_info[0]; i++)
222        {
223          int target = rank_info[0][i];
224          found = false;
225          for(int j=0; j<size_info[1]; j++)
226          {
227
228            if(target == rank_info[2][j])
229            {
230              found = true;
231              ep_local = ep_info[0][j];
232              ep_remote = ep_info[1][j];
233              break;
234            }
235          }
236          if(found)
237          {
238            if(target != leader_info[1])
239            {
240              new_rank_info[0].push_back(target);
241              new_rank_info[1].push_back(rank_info[1][i]);
242
243              new_ep_info[0].push_back(ep_local + ep_remote);
244              offset[i] = 0;
245            }
246            else // found remote leader
247            {
248              offset[i] = ep_remote;
249            }
250          }
251          else
252          {
253            new_rank_info[0].push_back(target);
254            new_rank_info[1].push_back(rank_info[1][i]);
255
256            new_ep_info[0].push_back(ep_info[0][i]);
257            offset[i] = 0;
258          }
259        }
260      }
261
262      if(offset[mpi_rank] == 0)
263      {
264        is_final_master = true;
265      }
266
267
268      //! size_info[4]: 2->size of new_ep_info for local, 3->size of new_ep_info for remote
269
270      if(is_local_leader)
271      {
272        size_info[2] = new_ep_info[0].size();
273        MPI_Request requests[2];
274        MPI_Status statuses[2];
275        MPI_Isend(&size_info[2], 1, MPI_INT, remote_leader, tag+2, peer_comm, &requests[0]);
276        MPI_Irecv(&size_info[3], 1, MPI_INT, remote_leader, tag+2, peer_comm, &requests[1]);
277         
278        MPI_Waitall(2, requests, statuses);
279      }
280
281      ::MPI_Bcast(&size_info[2], 2, to_mpi_type(MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
282
283      new_rank_info[2].resize(size_info[3]);
284      new_rank_info[3].resize(size_info[3]);
285      new_ep_info[1].resize(size_info[3]);
286
287      send_buf.resize(size_info[2]);
288      recv_buf.resize(size_info[3]);
289
290      if(is_local_leader)
291      {
292        MPI_Request requests[2];
293        MPI_Status statuses[2];
294
295        std::copy ( new_rank_info[0].data(), new_rank_info[0].data() + size_info[2], send_buf.begin() );
296        std::copy ( new_rank_info[1].data(), new_rank_info[1].data() + size_info[2], send_buf.begin() + size_info[2] );
297        std::copy ( new_ep_info[0].data(),   new_ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[2] );
298
299        MPI_Isend(send_buf.data(), 3*size_info[2], MPI_INT, remote_leader, tag+3, peer_comm, &requests[0]);
300        MPI_Irecv(recv_buf.data(), 3*size_info[3], MPI_INT, remote_leader, tag+3, peer_comm, &requests[1]);
301       
302        MPI_Waitall(2, requests, statuses);
303      }
304
305      ::MPI_Bcast(recv_buf.data(),   3*size_info[3], to_mpi_type(MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
306
307      std::copy ( recv_buf.data(), recv_buf.data() + size_info[3], new_rank_info[2].begin() );
308      std::copy ( recv_buf.data() + size_info[3], recv_buf.data() + 2*size_info[3], new_rank_info[3].begin()  );
309      std::copy ( recv_buf.data() + 2*size_info[3], recv_buf.data() + 3*size_info[3], new_ep_info[1].begin() );
310
311    }
312
313   
314
315    if(is_proc_master)
316    {
317      //! leader_info[4]: 2-> rank of local leader in new_group generated comm;
318                      // 3-> rank of remote leader in new_group generated comm;
319      ::MPI_Group local_group;
320      ::MPI_Group new_group;
321      ::MPI_Comm *new_comm = new ::MPI_Comm;
322      ::MPI_Comm *intercomm = new ::MPI_Comm;
323
324      ::MPI_Comm_group(local_mpi_comm, &local_group);
325
326      ::MPI_Group_incl(local_group, size_info[2], new_rank_info[1].data(), &new_group);
327
328      ::MPI_Comm_create(local_mpi_comm, new_group, new_comm);
329
330
331
332      if(is_local_leader)
333      {
334        ::MPI_Comm_rank(*new_comm, &leader_info[2]);
335      }
336
337      ::MPI_Bcast(&leader_info[2], 1, to_mpi_type(MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
338
339      if(new_comm != static_cast< ::MPI_Comm*>(MPI_COMM_NULL.mpi_comm))
340      {
341
342        ::MPI_Barrier(*new_comm);
343
344        ::MPI_Intercomm_create(*new_comm, leader_info[2], to_mpi_comm(peer_comm.mpi_comm), rank_in_peer_mpi[1], tag, intercomm);
345
346        int id;
347
348        ::MPI_Comm_rank(*new_comm, &id);
349        int my_num_ep = new_ep_info[0][id];
350
351        MPI_Comm *ep_intercomm;
352        MPI_Info info;
353        MPI_Comm_create_endpoints(new_comm, my_num_ep, info, ep_intercomm);
354
355
356        for(int i= 0; i<my_num_ep; i++)
357        {
358          ep_intercomm[i].is_intercomm = true;
359
360          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
361          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = intercomm;
362          ep_intercomm[i].ep_comm_ptr->comm_label = leader_info[0];
363        }
364
365
366        #pragma omp critical (write_to_tag_list)
367        tag_list.push_back(make_pair( make_pair(tag, min(leader_info[0], leader_info[1])) , ep_intercomm));
368      }
369    }
370
371    vector<int> bcast_buf(8);
372    if(is_local_leader)
373    {
374      std::copy(size_info, size_info+4, bcast_buf.begin());
375      std::copy(leader_info, leader_info+4, bcast_buf.begin()+4);
376    }
377
378    MPI_Bcast(bcast_buf.data(), 8, MPI_INT, local_leader, local_comm);
379
380    if(!is_local_leader)
381    {
382      std::copy(bcast_buf.begin(), bcast_buf.begin()+4, size_info);
383      std::copy(bcast_buf.begin()+4, bcast_buf.begin()+8, leader_info);
384    }
385
386    if(!is_local_leader)
387    {
388      new_rank_info[1].resize(size_info[2]);
389      ep_info[1].resize(size_info[1]);
390      offset.resize(size_info[0]);
391    }
392
393    bcast_buf.resize(size_info[2]+size_info[1]+size_info[0]+1);
394
395    if(is_local_leader)
396    {
397      bcast_buf[0] = remote_ep_size;
398      std::copy(new_rank_info[1].data(), new_rank_info[1].data()+size_info[2], bcast_buf.begin()+1);
399      std::copy(ep_info[1].data(), ep_info[1].data()+size_info[1], bcast_buf.begin()+size_info[2]+1);
400      std::copy(offset.data(), offset.data()+size_info[0], bcast_buf.begin()+size_info[2]+size_info[1]+1);
401    }
402
403    MPI_Bcast(bcast_buf.data(), size_info[2]+size_info[1]+size_info[0]+1, MPI_INT, local_leader, local_comm);
404
405    if(!is_local_leader)
406    {
407      remote_ep_size = bcast_buf[0];
408      std::copy(bcast_buf.data()+1, bcast_buf.data()+1+size_info[2], new_rank_info[1].begin());
409      std::copy(bcast_buf.data()+1+size_info[2], bcast_buf.data()+1+size_info[2]+size_info[1], ep_info[1].begin());
410      std::copy(bcast_buf.data()+1+size_info[2]+size_info[1], bcast_buf.data()+1+size_info[2]+size_info[1]+size_info[0], offset.begin());
411    }
412
413    int my_position = offset[rank_in_local_parent]+ep_rank_loc;
414   
415    MPI_Barrier_local(local_comm);
416    #pragma omp flush
417
418
419    #pragma omp critical (read_from_tag_list)
420    {
421      bool found = false;
422      while(!found)
423      {
424        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
425        {
426          if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
427          {
428            *newintercomm = iter->second[my_position];
429            found = true;
430            break;
431          }
432        }
433      }
434    }
435
436    MPI_Barrier(local_comm);
437
438    if(is_local_leader)
439    {
440      int local_flag = true;
441      int remote_flag = false;
442      MPI_Request mpi_requests[2];
443      MPI_Status mpi_statuses[2];
444     
445      MPI_Isend(&local_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[0]);
446      MPI_Irecv(&remote_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[1]);
447     
448      MPI_Waitall(2, mpi_requests, mpi_statuses);
449    }
450
451
452    MPI_Barrier(local_comm);
453
454    if(is_proc_master)
455    {
456      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
457      {
458        if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
459        {
460          tag_list.erase(iter);
461          break;
462        }
463      }
464    }
465
466    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
467    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
468
469    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
470    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
471    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
472    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
473    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
474    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
475
476    MPI_Bcast(&remote_ep_size, 1, MPI_INT, local_leader, local_comm);
477
478    int my_rank_map_elem[2];
479
480    my_rank_map_elem[0] = intercomm_ep_rank;
481    my_rank_map_elem[1] = (*newintercomm).ep_comm_ptr->comm_label;
482
483    vector<pair<int, int> > local_rank_map_array;
484    vector<pair<int, int> > remote_rank_map_array;
485
486
487    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map = new RANK_MAP;
488    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_ep_size);
489
490    MPI_Allgather(my_rank_map_elem, 2, MPI_INT, 
491      (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT, local_comm);
492
493    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
494    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_ep_size);
495
496    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
497    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
498    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
499
500    int local_intercomm_size = intercomm_ep_size;
501    int remote_intercomm_size;
502
503    int new_bcast_root_0 = 0;
504    int new_bcast_root = 0;
505
506
507    if(is_local_leader)
508    {
509      MPI_Request requests[4];
510      MPI_Status statuses[4];
511     
512      MPI_Isend((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_ep_size, MPI_INT, remote_leader, tag+4, peer_comm, &requests[0]);
513      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT, remote_leader, tag+4, peer_comm, &requests[1]);
514
515      MPI_Isend(&local_intercomm_size, 1, MPI_INT, remote_leader, tag+5, peer_comm, &requests[2]);
516      MPI_Irecv(&remote_intercomm_size, 1, MPI_INT, remote_leader, tag+5, peer_comm, &requests[3]);
517     
518      MPI_Waitall(4, requests, statuses);
519
520      new_bcast_root_0 = intercomm_ep_rank;
521    }
522
523    MPI_Allreduce(&new_bcast_root_0, &new_bcast_root, 1, MPI_INT, MPI_SUM, *newintercomm);
524
525
526    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT, local_leader, local_comm);
527    MPI_Bcast(&remote_intercomm_size, 1, MPI_INT, new_bcast_root, *newintercomm);
528
529
530    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map = new RANK_MAP;
531    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->resize(remote_intercomm_size);
532
533
534
535
536    if(is_local_leader)
537    {
538      MPI_Request requests[2];
539      MPI_Status statuses[2];
540     
541      MPI_Isend((*newintercomm).rank_map->data(), 2*local_intercomm_size, MPI_INT, remote_leader, tag+6, peer_comm, &requests[0]);
542      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT, remote_leader, tag+6, peer_comm, &requests[1]);
543     
544      MPI_Waitall(2, requests, statuses);
545    }
546
547    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT, new_bcast_root, *newintercomm);
548
549    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
550    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
551
552/*
553    for(int i=0; i<local_ep_size; i++)
554    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, local_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
555          (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).second);
556
557    for(int i=0; i<remote_ep_size; i++)
558    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, remote_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
559          (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).second);
560
561    for(int i=0; i<remote_intercomm_size; i++)
562    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, intercomm_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
563          (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second);
564*/
565
566//    for(int i=0; i<(*newintercomm).rank_map->size(); i++)
567//    if(local_comm.ep_comm_ptr->comm_label != 99) printf("ep_rank = %d, rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
568//          (*newintercomm).rank_map->at(i).first, (*newintercomm).rank_map->at(i).second);
569
570
571    return MPI_SUCCESS;
572
573  }
574
575
576
577
578  int MPI_Intercomm_create_unique_leader(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
579  {
580    //! mpi_size of local comm = 1
581    //! same world rank of leaders
582
583    int ep_rank, ep_rank_loc, mpi_rank;
584    int ep_size, num_ep, mpi_size;
585
586    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
587    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
588    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
589    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
590    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
591    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
592
593
594
595    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
596                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
597
598    int rank_in_world;
599
600    int rank_in_peer_mpi[2];
601
602    ::MPI_Comm_rank(to_mpi_comm(MPI_COMM_WORLD.mpi_comm), &rank_in_world);
603
604
605    int local_num_ep = num_ep;
606    int remote_num_ep;
607    int total_num_ep;
608
609    int leader_rank_in_peer[2];
610
611    int my_position;
612    int tag_label[2];
613
614    vector<int> send_buf(4);
615    vector<int> recv_buf(4);
616
617
618    if(ep_rank == local_leader)
619    {
620      MPI_Status status;
621
622      MPI_Comm_rank(peer_comm, &leader_rank_in_peer[0]);
623
624      send_buf[0] = local_num_ep;
625      send_buf[1] = leader_rank_in_peer[0];
626
627      MPI_Request req_s, req_r;
628
629      MPI_Isend(send_buf.data(), 2, MPI_INT, remote_leader, tag, peer_comm, &req_s);
630      MPI_Irecv(recv_buf.data(), 2, MPI_INT, remote_leader, tag, peer_comm, &req_r);
631
632
633      MPI_Wait(&req_s, &status);
634      MPI_Wait(&req_r, &status);
635
636      recv_buf[2] = leader_rank_in_peer[0];
637
638    }
639
640    MPI_Bcast(recv_buf.data(), 3, MPI_INT, local_leader, local_comm);
641
642    remote_num_ep = recv_buf[0];
643    leader_rank_in_peer[1] = recv_buf[1];
644    leader_rank_in_peer[0] = recv_buf[2];
645
646    total_num_ep = local_num_ep + remote_num_ep;
647
648
649    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
650    {
651      my_position = ep_rank_loc;
652      //! LEADER create EP
653      if(ep_rank == local_leader)
654      {
655        ::MPI_Comm *mpi_dup = new ::MPI_Comm;
656       
657        ::MPI_Comm_dup(to_mpi_comm(local_comm.mpi_comm), mpi_dup);
658
659        MPI_Comm *ep_intercomm;
660        MPI_Info info;
661        MPI_Comm_create_endpoints(mpi_dup, total_num_ep, info, ep_intercomm);
662
663
664        for(int i=0; i<total_num_ep; i++)
665        {
666          ep_intercomm[i].is_intercomm = true;
667          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
668          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = 0;
669
670          ep_intercomm[i].ep_comm_ptr->comm_label = leader_rank_in_peer[0];
671        }
672
673        tag_label[0] = TAG++;
674        tag_label[1] = rank_in_world;
675
676        #pragma omp critical (write_to_tag_list)
677        tag_list.push_back(make_pair( make_pair(tag_label[0], tag_label[1]) , ep_intercomm));
678
679        MPI_Request req_s;
680        MPI_Status sta_s;
681        MPI_Isend(tag_label, 2, MPI_INT, remote_leader, tag, peer_comm, &req_s);
682
683        MPI_Wait(&req_s, &sta_s);
684
685      }
686    }
687    else
688    {
689      //! Wait for EP creation
690      my_position = remote_num_ep + ep_rank_loc;
691      if(ep_rank == local_leader)
692      {
693        MPI_Status status;
694        MPI_Request req_r;
695        MPI_Irecv(tag_label, 2, MPI_INT, remote_leader, tag, peer_comm, &req_r);
696        MPI_Wait(&req_r, &status);
697      }
698    }
699
700    MPI_Bcast(tag_label, 2, MPI_INT, local_leader, local_comm);
701
702
703
704
705    #pragma omp critical (read_from_tag_list)
706    {
707      bool found = false;
708      while(!found)
709      {
710        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
711        {
712          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
713          {
714            *newintercomm =  iter->second[my_position];
715            found = true;
716            break;
717          }
718        }
719      }
720    }
721
722    MPI_Barrier_local(local_comm);
723
724    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
725    {
726      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
727        {
728          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
729          {
730            tag_list.erase(iter);
731            break;
732          }
733        }
734    }
735
736
737
738    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
739    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
740
741    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
742    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
743    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
744    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
745    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
746    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
747
748
749
750    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map  = new RANK_MAP;
751    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
752    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_num_ep);
753    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_num_ep);
754
755    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
756    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
757    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
758
759
760
761    int local_rank_map_ele[2];
762    local_rank_map_ele[0] = intercomm_ep_rank;
763    local_rank_map_ele[1] = (*newintercomm).ep_comm_ptr->comm_label;
764
765    MPI_Allgather(local_rank_map_ele, 2, MPI_INT, 
766      (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT, local_comm);
767
768    if(ep_rank == local_leader)
769    {
770      MPI_Status status;
771      MPI_Request req_s, req_r;
772
773      MPI_Isend((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_num_ep, MPI_INT, remote_leader, tag, peer_comm, &req_s);
774      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT, remote_leader, tag, peer_comm, &req_r);
775
776
777      MPI_Wait(&req_s, &status);
778      MPI_Wait(&req_r, &status);
779
780    }
781
782    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT, local_leader, local_comm);
783    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
784    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
785
786
787    return MPI_SUCCESS;
788  }
789
790
791}
Note: See TracBrowser for help on using the repository browser.