source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_intercomm_kernel.cpp @ 1304

Last change on this file since 1304 was 1304, checked in by yushan, 4 years ago

change blocking calls to non-blocking calls for intercomm creation. Tested with test_client and test_complete for upto 100 clients

File size: 27.5 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4
5using namespace std;
6
7
8namespace ep_lib
9{
10  int MPI_Intercomm_create_kernel(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
11  {
12    int ep_rank, ep_rank_loc, mpi_rank;
13    int ep_size, num_ep, mpi_size;
14
15    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
16    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
17    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
18    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
19    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
20    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
21
22    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
23                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
24
25    int rank_in_world;
26    int rank_in_local_parent;
27
28    int rank_in_peer_mpi[2];
29
30    int local_ep_size = ep_size;
31    int remote_ep_size;
32
33
34    ::MPI_Comm local_mpi_comm = static_cast< ::MPI_Comm>(local_comm.mpi_comm);
35
36   
37    ::MPI_Comm_rank(static_cast< ::MPI_Comm>(MPI_COMM_WORLD.mpi_comm), &rank_in_world);
38    ::MPI_Comm_rank(static_cast< ::MPI_Comm>(local_comm.mpi_comm), &rank_in_local_parent);
39   
40
41    bool is_proc_master = false;
42    bool is_local_leader = false;
43    bool is_final_master = false;
44
45
46    if(ep_rank == local_leader) { is_proc_master = true; is_local_leader = true; is_final_master = true;}
47    if(ep_rank_loc == 0 && mpi_rank != local_comm.rank_map->at(local_leader).second) is_proc_master = true;
48
49
50    int size_info[4]; //! used for choose size of rank_info 0-> mpi_size of local_comm, 1-> mpi_size of remote_comm
51
52    int leader_info[4]; //! 0->world rank of local_leader, 1->world rank of remote leader
53
54
55    std::vector<int> ep_info[2]; //! 0-> num_ep in local_comm, 1->num_ep in remote_comm
56
57    std::vector<int> new_rank_info[4];
58    std::vector<int> new_ep_info[2];
59
60    std::vector<int> offset;
61
62    if(is_proc_master)
63    {
64
65      size_info[0] = mpi_size;
66
67      rank_info[0].resize(size_info[0]);
68      rank_info[1].resize(size_info[0]);
69
70
71
72      ep_info[0].resize(size_info[0]);
73
74      vector<int> send_buf(6);
75      vector<int> recv_buf(3*size_info[0]);
76
77      send_buf[0] = rank_in_world;
78      send_buf[1] = rank_in_local_parent;
79      send_buf[2] = num_ep;
80
81      ::MPI_Allgather(send_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), recv_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), local_mpi_comm);
82
83      for(int i=0; i<size_info[0]; i++)
84      {
85        rank_info[0][i] = recv_buf[3*i];
86        rank_info[1][i] = recv_buf[3*i+1];
87        ep_info[0][i]   = recv_buf[3*i+2];
88      }
89
90      if(is_local_leader)
91      {
92        leader_info[0] = rank_in_world;
93        leader_info[1] = remote_leader;
94
95        ::MPI_Comm_rank(static_cast< ::MPI_Comm>(peer_comm.mpi_comm), &rank_in_peer_mpi[0]);
96
97       
98
99        send_buf[0] = size_info[0];
100        send_buf[1] = local_ep_size;
101        send_buf[2] = rank_in_peer_mpi[0];
102
103       
104       
105        MPI_Request requests[2];
106        MPI_Status statuses[2];
107       
108        MPI_Isend(send_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &requests[0]);
109        MPI_Irecv(recv_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &requests[1]);
110
111
112        MPI_Waitall(2, requests, statuses);
113       
114        size_info[1] = recv_buf[0];
115        remote_ep_size = recv_buf[1];
116        rank_in_peer_mpi[1] = recv_buf[2];
117
118      }
119
120
121
122      send_buf[0] = size_info[1];
123      send_buf[1] = leader_info[0];
124      send_buf[2] = leader_info[1];
125      send_buf[3] = rank_in_peer_mpi[0];
126      send_buf[4] = rank_in_peer_mpi[1];
127
128      ::MPI_Bcast(send_buf.data(), 5, static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
129
130      size_info[1] = send_buf[0];
131      leader_info[0] = send_buf[1];
132      leader_info[1] = send_buf[2];
133      rank_in_peer_mpi[0] = send_buf[3];
134      rank_in_peer_mpi[1] = send_buf[4];
135
136
137      rank_info[2].resize(size_info[1]);
138      rank_info[3].resize(size_info[1]);
139
140      ep_info[1].resize(size_info[1]);
141
142      send_buf.resize(3*size_info[0]);
143      recv_buf.resize(3*size_info[1]);
144
145      if(is_local_leader)
146      {
147        MPI_Request requests[2];
148        MPI_Status statuses[2];
149
150        std::copy ( rank_info[0].data(), rank_info[0].data() + size_info[0], send_buf.begin() );
151        std::copy ( rank_info[1].data(), rank_info[1].data() + size_info[0], send_buf.begin() + size_info[0] );
152        std::copy ( ep_info[0].data(),   ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[0] );
153
154        MPI_Isend(send_buf.data(), 3*size_info[0], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+1, peer_comm, &requests[0]);
155        MPI_Irecv(recv_buf.data(), 3*size_info[1], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+1, peer_comm, &requests[1]);
156       
157        MPI_Waitall(2, requests, statuses);
158      }
159
160      ::MPI_Bcast(recv_buf.data(), 3*size_info[1], static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
161
162      std::copy ( recv_buf.data(), recv_buf.data() + size_info[1], rank_info[2].begin() );
163      std::copy ( recv_buf.data() + size_info[1], recv_buf.data() + 2*size_info[1], rank_info[3].begin()  );
164      std::copy ( recv_buf.data() + 2*size_info[1], recv_buf.data() + 3*size_info[1], ep_info[1].begin() );
165
166
167      offset.resize(size_info[0]);
168
169      if(leader_info[0]<leader_info[1]) // erase all ranks doubled with remote_comm, except the local leader
170      {
171
172        bool found = false;
173        int ep_local;
174        int ep_remote;
175        for(int i=0; i<size_info[0]; i++)
176        {
177          int target = rank_info[0][i];
178          found = false;
179          for(int j=0; j<size_info[1]; j++)
180          {
181            if(target == rank_info[2][j])
182            {
183              found = true;
184              ep_local = ep_info[0][j];
185              ep_remote = ep_info[1][j];
186              break;
187            }
188          }
189          if(found)
190          {
191
192            if(target == leader_info[0]) // the leader is doubled in remote
193            {
194              new_rank_info[0].push_back(target);
195              new_rank_info[1].push_back(rank_info[1][i]);
196
197              new_ep_info[0].push_back(ep_local + ep_remote);
198              offset[i] = 0;
199            }
200            else
201            {
202              offset[i] = ep_local;
203            }
204          }
205          else
206          {
207            new_rank_info[0].push_back(target);
208            new_rank_info[1].push_back(rank_info[1][i]);
209
210            new_ep_info[0].push_back(ep_info[0][i]);
211
212            offset[i] = 0;
213          }
214
215        }
216      }
217
218      else // erase rank doubled with remote leader
219      {
220
221        bool found = false;
222        int ep_local;
223        int ep_remote;
224        for(int i=0; i<size_info[0]; i++)
225        {
226          int target = rank_info[0][i];
227          found = false;
228          for(int j=0; j<size_info[1]; j++)
229          {
230
231            if(target == rank_info[2][j])
232            {
233              found = true;
234              ep_local = ep_info[0][j];
235              ep_remote = ep_info[1][j];
236              break;
237            }
238          }
239          if(found)
240          {
241            if(target != leader_info[1])
242            {
243              new_rank_info[0].push_back(target);
244              new_rank_info[1].push_back(rank_info[1][i]);
245
246              new_ep_info[0].push_back(ep_local + ep_remote);
247              offset[i] = 0;
248            }
249            else // found remote leader
250            {
251              offset[i] = ep_remote;
252            }
253          }
254          else
255          {
256            new_rank_info[0].push_back(target);
257            new_rank_info[1].push_back(rank_info[1][i]);
258
259            new_ep_info[0].push_back(ep_info[0][i]);
260            offset[i] = 0;
261          }
262        }
263      }
264
265      if(offset[mpi_rank] == 0)
266      {
267        is_final_master = true;
268      }
269
270
271      //! size_info[4]: 2->size of new_ep_info for local, 3->size of new_ep_info for remote
272
273      if(is_local_leader)
274      {
275        size_info[2] = new_ep_info[0].size();
276        MPI_Request requests[2];
277        MPI_Status statuses[2];
278        MPI_Isend(&size_info[2], 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+2, peer_comm, &requests[0]);
279        MPI_Irecv(&size_info[3], 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+2, peer_comm, &requests[1]);
280         
281        MPI_Waitall(2, requests, statuses);
282      }
283
284      ::MPI_Bcast(&size_info[2], 2, static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
285
286      new_rank_info[2].resize(size_info[3]);
287      new_rank_info[3].resize(size_info[3]);
288      new_ep_info[1].resize(size_info[3]);
289
290      send_buf.resize(size_info[2]);
291      recv_buf.resize(size_info[3]);
292
293      if(is_local_leader)
294      {
295        MPI_Request requests[2];
296        MPI_Status statuses[2];
297
298        std::copy ( new_rank_info[0].data(), new_rank_info[0].data() + size_info[2], send_buf.begin() );
299        std::copy ( new_rank_info[1].data(), new_rank_info[1].data() + size_info[2], send_buf.begin() + size_info[2] );
300        std::copy ( new_ep_info[0].data(),   new_ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[2] );
301
302        MPI_Isend(send_buf.data(), 3*size_info[2], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+3, peer_comm, &requests[0]);
303        MPI_Irecv(recv_buf.data(), 3*size_info[3], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+3, peer_comm, &requests[1]);
304       
305        MPI_Waitall(2, requests, statuses);
306      }
307
308      ::MPI_Bcast(recv_buf.data(),   3*size_info[3], static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
309
310      std::copy ( recv_buf.data(), recv_buf.data() + size_info[3], new_rank_info[2].begin() );
311      std::copy ( recv_buf.data() + size_info[3], recv_buf.data() + 2*size_info[3], new_rank_info[3].begin()  );
312      std::copy ( recv_buf.data() + 2*size_info[3], recv_buf.data() + 3*size_info[3], new_ep_info[1].begin() );
313
314    }
315
316   
317
318    if(is_proc_master)
319    {
320      //! leader_info[4]: 2-> rank of local leader in new_group generated comm;
321                      // 3-> rank of remote leader in new_group generated comm;
322      ::MPI_Group local_group;
323      ::MPI_Group new_group;
324      ::MPI_Comm new_comm;
325      ::MPI_Comm intercomm;
326
327      ::MPI_Comm_group(local_mpi_comm, &local_group);
328
329      ::MPI_Group_incl(local_group, size_info[2], new_rank_info[1].data(), &new_group);
330
331      ::MPI_Comm_create(local_mpi_comm, new_group, &new_comm);
332
333
334
335      if(is_local_leader)
336      {
337        ::MPI_Comm_rank(new_comm, &leader_info[2]);
338      }
339
340      ::MPI_Bcast(&leader_info[2], 1, static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
341
342      if(new_comm != static_cast< ::MPI_Comm>(MPI_COMM_NULL.mpi_comm))
343      {
344
345        ::MPI_Barrier(new_comm);
346
347        ::MPI_Intercomm_create(new_comm, leader_info[2], static_cast< ::MPI_Comm>(peer_comm.mpi_comm), rank_in_peer_mpi[1], tag, &intercomm);
348
349        int id;
350
351        ::MPI_Comm_rank(new_comm, &id);
352        int my_num_ep = new_ep_info[0][id];
353
354        MPI_Comm *ep_intercomm;
355        MPI_Info info;
356        MPI_Comm_create_endpoints(new_comm, my_num_ep, info, ep_intercomm);
357
358
359        for(int i= 0; i<my_num_ep; i++)
360        {
361          ep_intercomm[i].is_intercomm = true;
362
363          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
364          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = intercomm;
365          ep_intercomm[i].ep_comm_ptr->comm_label = leader_info[0];
366        }
367
368
369        #pragma omp critical (write_to_tag_list)
370        tag_list.push_back(make_pair( make_pair(tag, min(leader_info[0], leader_info[1])) , ep_intercomm));
371        //printf("tag_list size = %lu\n", tag_list.size());
372      }
373    }
374
375    vector<int> bcast_buf(8);
376    if(is_local_leader)
377    {
378      std::copy(size_info, size_info+4, bcast_buf.begin());
379      std::copy(leader_info, leader_info+4, bcast_buf.begin()+4);
380    }
381
382    MPI_Bcast(bcast_buf.data(), 8, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
383
384    if(!is_local_leader)
385    {
386      std::copy(bcast_buf.begin(), bcast_buf.begin()+4, size_info);
387      std::copy(bcast_buf.begin()+4, bcast_buf.begin()+8, leader_info);
388    }
389
390    if(!is_local_leader)
391    {
392      new_rank_info[1].resize(size_info[2]);
393      ep_info[1].resize(size_info[1]);
394      offset.resize(size_info[0]);
395    }
396
397    bcast_buf.resize(size_info[2]+size_info[1]+size_info[0]+1);
398
399    if(is_local_leader)
400    {
401      bcast_buf[0] = remote_ep_size;
402      std::copy(new_rank_info[1].data(), new_rank_info[1].data()+size_info[2], bcast_buf.begin()+1);
403      std::copy(ep_info[1].data(), ep_info[1].data()+size_info[1], bcast_buf.begin()+size_info[2]+1);
404      std::copy(offset.data(), offset.data()+size_info[0], bcast_buf.begin()+size_info[2]+size_info[1]+1);
405    }
406
407    MPI_Bcast(bcast_buf.data(), size_info[2]+size_info[1]+size_info[0]+1, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
408
409    if(!is_local_leader)
410    {
411      remote_ep_size = bcast_buf[0];
412      std::copy(bcast_buf.data()+1, bcast_buf.data()+1+size_info[2], new_rank_info[1].begin());
413      std::copy(bcast_buf.data()+1+size_info[2], bcast_buf.data()+1+size_info[2]+size_info[1], ep_info[1].begin());
414      std::copy(bcast_buf.data()+1+size_info[2]+size_info[1], bcast_buf.data()+1+size_info[2]+size_info[1]+size_info[0], offset.begin());
415    }
416
417    int my_position = offset[rank_in_local_parent]+ep_rank_loc;
418   
419    MPI_Barrier_local(local_comm);
420    #pragma omp flush
421
422
423    #pragma omp critical (read_from_tag_list)
424    {
425      bool found = false;
426      while(!found)
427      {
428        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
429        {
430          if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
431          {
432            *newintercomm = iter->second[my_position];
433            found = true;
434            break;
435          }
436        }
437      }
438    }
439
440    MPI_Barrier(local_comm);
441
442    if(is_local_leader)
443    {
444      int local_flag = true;
445      int remote_flag = false;
446      MPI_Request mpi_requests[2];
447      MPI_Status mpi_statuses[2];
448     
449      MPI_Isend(&local_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[0]);
450      MPI_Irecv(&remote_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_requests[1]);
451     
452      MPI_Waitall(2, mpi_requests, mpi_statuses);
453    }
454
455
456    MPI_Barrier(local_comm);
457
458    if(is_proc_master)
459    {
460      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
461      {
462        if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
463        {
464          tag_list.erase(iter);
465          break;
466        }
467      }
468    }
469
470    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
471    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
472
473    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
474    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
475    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
476    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
477    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
478    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
479
480    MPI_Bcast(&remote_ep_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
481
482    int my_rank_map_elem[2];
483
484    my_rank_map_elem[0] = intercomm_ep_rank;
485    my_rank_map_elem[1] = (*newintercomm).ep_comm_ptr->comm_label;
486
487    vector<pair<int, int> > local_rank_map_array;
488    vector<pair<int, int> > remote_rank_map_array;
489
490
491    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map = new RANK_MAP;
492    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_ep_size);
493
494    MPI_Allgather(my_rank_map_elem, 2, static_cast< ::MPI_Datatype> (MPI_INT), 
495      (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), local_comm);
496
497    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
498    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_ep_size);
499
500    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
501    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
502    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
503
504    int local_intercomm_size = intercomm_ep_size;
505    int remote_intercomm_size;
506
507    int new_bcast_root_0 = 0;
508    int new_bcast_root = 0;
509
510
511    if(is_local_leader)
512    {
513      MPI_Request requests[4];
514      MPI_Status statuses[4];
515     
516      MPI_Isend((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_ep_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+4, peer_comm, &requests[0]);
517      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+4, peer_comm, &requests[1]);
518
519      MPI_Isend(&local_intercomm_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+5, peer_comm, &requests[2]);
520      MPI_Irecv(&remote_intercomm_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+5, peer_comm, &requests[3]);
521     
522      MPI_Waitall(4, requests, statuses);
523
524      new_bcast_root_0 = intercomm_ep_rank;
525    }
526
527    MPI_Allreduce(&new_bcast_root_0, &new_bcast_root, 1, static_cast< ::MPI_Datatype> (MPI_INT), static_cast< ::MPI_Op>(MPI_SUM), *newintercomm);
528
529
530    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
531    MPI_Bcast(&remote_intercomm_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), new_bcast_root, *newintercomm);
532
533
534    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map = new RANK_MAP;
535    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->resize(remote_intercomm_size);
536
537
538
539
540    if(is_local_leader)
541    {
542      MPI_Request requests[2];
543      MPI_Status statuses[2];
544     
545      MPI_Isend((*newintercomm).rank_map->data(), 2*local_intercomm_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+6, peer_comm, &requests[0]);
546      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+6, peer_comm, &requests[1]);
547     
548      MPI_Waitall(2, requests, statuses);
549    }
550
551    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, static_cast< ::MPI_Datatype> (MPI_INT), new_bcast_root, *newintercomm);
552
553    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
554    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
555
556/*
557    for(int i=0; i<local_ep_size; i++)
558    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, local_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
559          (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).second);
560
561    for(int i=0; i<remote_ep_size; i++)
562    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, remote_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
563          (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).second);
564
565    for(int i=0; i<remote_intercomm_size; i++)
566    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, intercomm_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
567          (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second);
568*/
569
570//    for(int i=0; i<(*newintercomm).rank_map->size(); i++)
571//    if(local_comm.ep_comm_ptr->comm_label != 99) printf("ep_rank = %d, rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
572//          (*newintercomm).rank_map->at(i).first, (*newintercomm).rank_map->at(i).second);
573
574//    MPI_Comm *test_comm = newintercomm->ep_comm_ptr->intercomm->local_comm;
575//    int test_rank;
576//    MPI_Comm_rank(*test_comm, &test_rank);
577//    printf("=================test_rank = %d\n", test_rank);
578   
579   
580
581    return MPI_SUCCESS;
582
583  }
584
585
586
587
588  int MPI_Intercomm_create_unique_leader(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
589  {
590    //! mpi_size of local comm = 1
591    //! same world rank of leaders
592
593    int ep_rank, ep_rank_loc, mpi_rank;
594    int ep_size, num_ep, mpi_size;
595
596    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
597    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
598    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
599    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
600    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
601    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
602
603
604
605    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
606                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
607
608    int rank_in_world;
609
610    int rank_in_peer_mpi[2];
611
612    ::MPI_Comm_rank(static_cast< ::MPI_Comm >(MPI_COMM_WORLD.mpi_comm), &rank_in_world);
613
614
615    int local_num_ep = num_ep;
616    int remote_num_ep;
617    int total_num_ep;
618
619    int leader_rank_in_peer[2];
620
621    int my_position;
622    int tag_label[2];
623
624    vector<int> send_buf(4);
625    vector<int> recv_buf(4);
626
627
628    if(ep_rank == local_leader)
629    {
630      MPI_Status status;
631
632
633
634      MPI_Comm_rank(peer_comm, &leader_rank_in_peer[0]);
635
636      send_buf[0] = local_num_ep;
637      send_buf[1] = leader_rank_in_peer[0];
638
639      MPI_Request req_s, req_r;
640
641      MPI_Isend(send_buf.data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_s);
642      MPI_Irecv(recv_buf.data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_r);
643
644
645      MPI_Wait(&req_s, &status);
646      MPI_Wait(&req_r, &status);
647
648      recv_buf[2] = leader_rank_in_peer[0];
649
650    }
651
652    MPI_Bcast(recv_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
653
654    remote_num_ep = recv_buf[0];
655    leader_rank_in_peer[1] = recv_buf[1];
656    leader_rank_in_peer[0] = recv_buf[2];
657
658    total_num_ep = local_num_ep + remote_num_ep;
659
660
661    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
662    {
663      my_position = ep_rank_loc;
664      //! LEADER create EP
665      if(ep_rank == local_leader)
666      {
667        ::MPI_Comm mpi_dup;
668       
669        ::MPI_Comm_dup(static_cast< ::MPI_Comm>(local_comm.mpi_comm), &mpi_dup);
670
671        MPI_Comm *ep_intercomm;
672        MPI_Info info;
673        MPI_Comm_create_endpoints(mpi_dup, total_num_ep, info, ep_intercomm);
674
675
676        for(int i=0; i<total_num_ep; i++)
677        {
678          ep_intercomm[i].is_intercomm = true;
679          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
680          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = 0;
681
682          ep_intercomm[i].ep_comm_ptr->comm_label = leader_rank_in_peer[0];
683        }
684
685        tag_label[0] = TAG++;
686        tag_label[1] = rank_in_world;
687
688        #pragma omp critical (write_to_tag_list)
689        tag_list.push_back(make_pair( make_pair(tag_label[0], tag_label[1]) , ep_intercomm));
690
691        MPI_Request req_s;
692        MPI_Status sta_s;
693        MPI_Isend(tag_label, 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_s);
694
695        MPI_Wait(&req_s, &sta_s);
696
697      }
698    }
699    else
700    {
701      //! Wait for EP creation
702      my_position = remote_num_ep + ep_rank_loc;
703      if(ep_rank == local_leader)
704      {
705        MPI_Status status;
706        MPI_Request req_r;
707        MPI_Irecv(tag_label, 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_r);
708        MPI_Wait(&req_r, &status);
709      }
710    }
711
712    MPI_Bcast(tag_label, 2, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
713
714
715
716
717    #pragma omp critical (read_from_tag_list)
718    {
719      bool found = false;
720      while(!found)
721      {
722        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
723        {
724          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
725          {
726            *newintercomm =  iter->second[my_position];
727            found = true;
728            // tag_list.erase(iter);
729            break;
730          }
731        }
732      }
733    }
734
735    MPI_Barrier_local(local_comm);
736
737    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
738    {
739      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
740        {
741          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
742          {
743            tag_list.erase(iter);
744            break;
745          }
746        }
747    }
748
749
750
751    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
752    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
753
754    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
755    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
756    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
757    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
758    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
759    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
760
761
762
763    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map  = new RANK_MAP;
764    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
765    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_num_ep);
766    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_num_ep);
767
768    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
769    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
770    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
771
772
773
774    int local_rank_map_ele[2];
775    local_rank_map_ele[0] = intercomm_ep_rank;
776    local_rank_map_ele[1] = (*newintercomm).ep_comm_ptr->comm_label;
777
778    MPI_Allgather(local_rank_map_ele, 2, static_cast< ::MPI_Datatype> (MPI_INT), 
779      (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), local_comm);
780
781    if(ep_rank == local_leader)
782    {
783      MPI_Status status;
784      MPI_Request req_s, req_r;
785
786      MPI_Isend((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_num_ep, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_s);
787      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_r);
788
789
790      MPI_Wait(&req_s, &status);
791      MPI_Wait(&req_r, &status);
792
793    }
794
795    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
796    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
797    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
798
799
800    return MPI_SUCCESS;
801  }
802
803
804}
Note: See TracBrowser for help on using the repository browser.