source: XIOS/dev/branch_yushan/extern/src_ep_dev/ep_intercomm_kernel.cpp @ 1037

Last change on this file since 1037 was 1037, checked in by yushan, 7 years ago

initialize the branch

File size: 25.9 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4
5using namespace std;
6
7
8namespace ep_lib
9{
10  int MPI_Intercomm_create_kernel(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
11  {
12
13    int ep_rank, ep_rank_loc, mpi_rank;
14    int ep_size, num_ep, mpi_size;
15
16    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
17    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
18    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
19    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
20    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
21    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
22
23    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
24                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
25
26    int rank_in_world;
27    int rank_in_local_parent;
28
29    int rank_in_peer_mpi[2];
30
31    int local_ep_size = ep_size;
32    int remote_ep_size;
33
34
35    ::MPI_Comm local_mpi_comm = static_cast< ::MPI_Comm>(local_comm.mpi_comm);
36
37    #ifdef _serialized
38    #pragma omp critical (_mpi_call)
39    #endif // _serialized
40    {
41      ::MPI_Comm_rank(MPI_COMM_WORLD_STD, &rank_in_world);
42      ::MPI_Comm_rank(static_cast< ::MPI_Comm>(local_comm.mpi_comm), &rank_in_local_parent);
43    }
44
45    bool is_proc_master = false;
46    bool is_local_leader = false;
47    bool is_final_master = false;
48
49
50    if(ep_rank == local_leader) { is_proc_master = true; is_local_leader = true; is_final_master = true;}
51    if(ep_rank_loc == 0 && mpi_rank != local_comm.rank_map->at(local_leader).second) is_proc_master = true;
52
53
54    int size_info[4]; //! used for choose size of rank_info 0-> mpi_size of local_comm, 1-> mpi_size of remote_comm
55
56    int leader_info[4]; //! 0->world rank of local_leader, 1->world rank of remote leader
57
58
59    std::vector<int> ep_info[2]; //! 0-> num_ep in local_comm, 1->num_ep in remote_comm
60
61    std::vector<int> new_rank_info[4];
62    std::vector<int> new_ep_info[2];
63
64    std::vector<int> offset;
65
66    if(is_proc_master)
67    {
68
69      size_info[0] = mpi_size;
70
71      rank_info[0].resize(size_info[0]);
72      rank_info[1].resize(size_info[0]);
73
74
75
76      ep_info[0].resize(size_info[0]);
77
78      vector<int> send_buf(6);
79      vector<int> recv_buf(3*size_info[0]);
80
81      send_buf[0] = rank_in_world;
82      send_buf[1] = rank_in_local_parent;
83      send_buf[2] = num_ep;
84
85      #ifdef _serialized
86      #pragma omp critical (_mpi_call)
87      #endif // _serialized
88      ::MPI_Allgather(send_buf.data(), 3, MPI_INT_STD, recv_buf.data(), 3, MPI_INT_STD, local_mpi_comm);
89
90      for(int i=0; i<size_info[0]; i++)
91      {
92        rank_info[0][i] = recv_buf[3*i];
93        rank_info[1][i] = recv_buf[3*i+1];
94        ep_info[0][i]   = recv_buf[3*i+2];
95      }
96
97      if(is_local_leader)
98      {
99        leader_info[0] = rank_in_world;
100        leader_info[1] = remote_leader;
101
102        #ifdef _serialized
103        #pragma omp critical (_mpi_call)
104        #endif // _serialized
105        ::MPI_Comm_rank(static_cast< ::MPI_Comm>(peer_comm.mpi_comm), &rank_in_peer_mpi[0]);
106
107        MPI_Status status;
108
109        send_buf[0] = size_info[0];
110        send_buf[1] = local_ep_size;
111        send_buf[2] = rank_in_peer_mpi[0];
112
113        MPI_Send(send_buf.data(), 3, MPI_INT_STD, remote_leader, tag, peer_comm);
114        MPI_Recv(recv_buf.data(), 3, MPI_INT_STD, remote_leader, tag, peer_comm, &status);
115
116        size_info[1] = recv_buf[0];
117        remote_ep_size = recv_buf[1];
118        rank_in_peer_mpi[1] = recv_buf[2];
119
120      }
121
122      send_buf[0] = size_info[1];
123      send_buf[1] = leader_info[0];
124      send_buf[2] = leader_info[1];
125      send_buf[3] = rank_in_peer_mpi[0];
126      send_buf[4] = rank_in_peer_mpi[1];
127
128      #ifdef _serialized
129      #pragma omp critical (_mpi_call)
130      #endif // _serialized
131      ::MPI_Bcast(send_buf.data(), 5, MPI_INT_STD, local_comm.rank_map->at(local_leader).second, local_mpi_comm);
132
133      size_info[1] = send_buf[0];
134      leader_info[0] = send_buf[1];
135      leader_info[1] = send_buf[2];
136      rank_in_peer_mpi[0] = send_buf[3];
137      rank_in_peer_mpi[1] = send_buf[4];
138
139
140      rank_info[2].resize(size_info[1]);
141      rank_info[3].resize(size_info[1]);
142
143      ep_info[1].resize(size_info[1]);
144
145      send_buf.resize(3*size_info[0]);
146      recv_buf.resize(3*size_info[1]);
147
148      if(is_local_leader)
149      {
150        MPI_Status status;
151
152        std::copy ( rank_info[0].data(), rank_info[0].data() + size_info[0], send_buf.begin() );
153        std::copy ( rank_info[1].data(), rank_info[1].data() + size_info[0], send_buf.begin() + size_info[0] );
154        std::copy ( ep_info[0].data(),   ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[0] );
155
156        MPI_Send(send_buf.data(), 3*size_info[0], MPI_INT_STD, remote_leader, tag, peer_comm);
157        MPI_Recv(recv_buf.data(), 3*size_info[1], MPI_INT_STD, remote_leader, tag, peer_comm, &status);
158
159      }
160
161      #ifdef _serialized
162      #pragma omp critical (_mpi_call)
163      #endif // _serialized
164      ::MPI_Bcast(recv_buf.data(), 3*size_info[1], MPI_INT_STD, local_comm.rank_map->at(local_leader).second, local_mpi_comm);
165
166      std::copy ( recv_buf.data(), recv_buf.data() + size_info[1], rank_info[2].begin() );
167      std::copy ( recv_buf.data() + size_info[1], recv_buf.data() + 2*size_info[1], rank_info[3].begin()  );
168      std::copy ( recv_buf.data() + 2*size_info[1], recv_buf.data() + 3*size_info[1], ep_info[1].begin() );
169
170
171      offset.resize(size_info[0]);
172
173      if(leader_info[0]<leader_info[1]) // erase all ranks doubled with remote_comm, except the local leader
174      {
175
176        bool found = false;
177        int ep_local;
178        int ep_remote;
179        for(int i=0; i<size_info[0]; i++)
180        {
181          int target = rank_info[0][i];
182          found = false;
183          for(int j=0; j<size_info[1]; j++)
184          {
185            if(target == rank_info[2][j])
186            {
187              found = true;
188              ep_local = ep_info[0][j];
189              ep_remote = ep_info[1][j];
190              break;
191            }
192          }
193          if(found)
194          {
195
196            if(target == leader_info[0]) // the leader is doubled in remote
197            {
198              new_rank_info[0].push_back(target);
199              new_rank_info[1].push_back(rank_info[1][i]);
200
201              new_ep_info[0].push_back(ep_local + ep_remote);
202              offset[i] = 0;
203            }
204            else
205            {
206              offset[i] = ep_local;
207            }
208          }
209          else
210          {
211            new_rank_info[0].push_back(target);
212            new_rank_info[1].push_back(rank_info[1][i]);
213
214            new_ep_info[0].push_back(ep_info[0][i]);
215
216            offset[i] = 0;
217          }
218
219        }
220      }
221
222      else // erase rank doubled with remote leader
223      {
224
225        bool found = false;
226        int ep_local;
227        int ep_remote;
228        for(int i=0; i<size_info[0]; i++)
229        {
230          int target = rank_info[0][i];
231          found = false;
232          for(int j=0; j<size_info[1]; j++)
233          {
234
235            if(target == rank_info[2][j])
236            {
237              found = true;
238              ep_local = ep_info[0][j];
239              ep_remote = ep_info[1][j];
240              break;
241            }
242          }
243          if(found)
244          {
245            if(target != leader_info[1])
246            {
247              new_rank_info[0].push_back(target);
248              new_rank_info[1].push_back(rank_info[1][i]);
249
250              new_ep_info[0].push_back(ep_local + ep_remote);
251              offset[i] = 0;
252            }
253            else // found remote leader
254            {
255              offset[i] = ep_remote;
256            }
257          }
258          else
259          {
260            new_rank_info[0].push_back(target);
261            new_rank_info[1].push_back(rank_info[1][i]);
262
263            new_ep_info[0].push_back(ep_info[0][i]);
264            offset[i] = 0;
265          }
266        }
267      }
268
269      if(offset[mpi_rank] == 0)
270      {
271        is_final_master = true;
272      }
273
274
275      //! size_info[4]: 2->size of new_ep_info for local, 3->size of new_ep_info for remote
276
277      if(is_local_leader)
278      {
279        size_info[2] = new_ep_info[0].size();
280        MPI_Status status;
281        MPI_Send(&size_info[2], 1, MPI_INT_STD, remote_leader, tag, peer_comm);
282        MPI_Recv(&size_info[3], 1, MPI_INT_STD, remote_leader, tag, peer_comm, &status);
283      }
284
285      #ifdef _serialized
286      #pragma omp critical (_mpi_call)
287      #endif // _serialized
288      ::MPI_Bcast(&size_info[2], 2, MPI_INT_STD, local_comm.rank_map->at(local_leader).second, local_mpi_comm);
289
290      new_rank_info[2].resize(size_info[3]);
291      new_rank_info[3].resize(size_info[3]);
292      new_ep_info[1].resize(size_info[3]);
293
294      send_buf.resize(size_info[2]);
295      recv_buf.resize(size_info[3]);
296
297      if(is_local_leader)
298      {
299        MPI_Status status;
300
301        std::copy ( new_rank_info[0].data(), new_rank_info[0].data() + size_info[2], send_buf.begin() );
302        std::copy ( new_rank_info[1].data(), new_rank_info[1].data() + size_info[2], send_buf.begin() + size_info[2] );
303        std::copy ( new_ep_info[0].data(),   new_ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[2] );
304
305        MPI_Send(send_buf.data(), 3*size_info[2], MPI_INT_STD, remote_leader, tag, peer_comm);
306        MPI_Recv(recv_buf.data(), 3*size_info[3], MPI_INT_STD, remote_leader, tag, peer_comm, &status);
307      }
308
309
310      #ifdef _serialized
311      #pragma omp critical (_mpi_call)
312      #endif // _serialized
313      ::MPI_Bcast(recv_buf.data(),   3*size_info[3], MPI_INT_STD, local_comm.rank_map->at(local_leader).second, local_mpi_comm);
314
315      std::copy ( recv_buf.data(), recv_buf.data() + size_info[3], new_rank_info[2].begin() );
316      std::copy ( recv_buf.data() + size_info[3], recv_buf.data() + 2*size_info[3], new_rank_info[3].begin()  );
317      std::copy ( recv_buf.data() + 2*size_info[3], recv_buf.data() + 3*size_info[3], new_ep_info[1].begin() );
318
319    }
320
321
322
323    if(is_proc_master)
324    {
325      //! leader_info[4]: 2-> rank of local leader in new_group generated comm;
326                      // 3-> rank of remote leader in new_group generated comm;
327      ::MPI_Group local_group;
328      ::MPI_Group new_group;
329      ::MPI_Comm new_comm;
330      ::MPI_Comm intercomm;
331
332      #ifdef _serialized
333      #pragma omp critical (_mpi_call)
334      #endif // _serialized
335      ::MPI_Comm_group(local_mpi_comm, &local_group);
336
337      #ifdef _serialized
338      #pragma omp critical (_mpi_call)
339      #endif // _serialized
340      ::MPI_Group_incl(local_group, size_info[2], new_rank_info[1].data(), &new_group);
341
342      #ifdef _serialized
343      #pragma omp critical (_mpi_call)
344      #endif // _serialized
345      ::MPI_Comm_create(local_mpi_comm, new_group, &new_comm);
346
347
348
349      if(is_local_leader)
350      {
351        #ifdef _serialized
352        #pragma omp critical (_mpi_call)
353        #endif // _serialized
354        ::MPI_Comm_rank(new_comm, &leader_info[2]);
355      }
356
357      #ifdef _serialized
358      #pragma omp critical (_mpi_call)
359      #endif // _serialized
360      ::MPI_Bcast(&leader_info[2], 1, MPI_INT_STD, local_comm.rank_map->at(local_leader).second, local_mpi_comm);
361
362      if(new_comm != MPI_COMM_NULL_STD)
363      {
364
365        #ifdef _serialized
366        #pragma omp critical (_mpi_call)
367        #endif // _serialized
368        ::MPI_Barrier(new_comm);
369
370        #ifdef _serialized
371        #pragma omp critical (_mpi_call)
372        #endif // _serialized
373        ::MPI_Intercomm_create(new_comm, leader_info[2], static_cast< ::MPI_Comm>(peer_comm.mpi_comm), rank_in_peer_mpi[1], tag, &intercomm);
374
375        int id;
376        #ifdef _serialized
377        #pragma omp critical (_mpi_call)
378        #endif // _serialized
379        ::MPI_Comm_rank(new_comm, &id);
380        int my_num_ep = new_ep_info[0][id];
381
382        MPI_Comm *ep_intercomm;
383        MPI_Info info;
384        MPI_Comm_create_endpoints(new_comm, my_num_ep, info, ep_intercomm);
385
386
387        for(int i= 0; i<my_num_ep; i++)
388        {
389          ep_intercomm[i].is_intercomm = true;
390
391          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
392          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = intercomm;
393          ep_intercomm[i].ep_comm_ptr->comm_label = leader_info[0];
394        }
395
396
397        #pragma omp critical (write_to_tag_list)
398        tag_list.push_back(make_pair( make_pair(tag, min(leader_info[0], leader_info[1])) , ep_intercomm));
399      }
400
401
402    }
403
404
405    MPI_Barrier_local(local_comm);
406
407    vector<int> bcast_buf(8);
408    if(is_local_leader)
409    {
410      std::copy(size_info, size_info+4, bcast_buf.begin());
411      std::copy(leader_info, leader_info+4, bcast_buf.begin()+4);
412    }
413
414    MPI_Bcast(bcast_buf.data(), 8, MPI_INT_STD, local_leader, local_comm);
415
416    if(!is_local_leader)
417    {
418      std::copy(bcast_buf.begin(), bcast_buf.begin()+4, size_info);
419      std::copy(bcast_buf.begin()+4, bcast_buf.begin()+8, leader_info);
420    }
421
422    if(!is_local_leader)
423    {
424      new_rank_info[1].resize(size_info[2]);
425      ep_info[1].resize(size_info[1]);
426      offset.resize(size_info[0]);
427    }
428
429    bcast_buf.resize(size_info[2]+size_info[1]+size_info[0]+1);
430
431    if(is_local_leader)
432    {
433      bcast_buf[0] = remote_ep_size;
434      std::copy(new_rank_info[1].data(), new_rank_info[1].data()+size_info[2], bcast_buf.begin()+1);
435      std::copy(ep_info[1].data(), ep_info[1].data()+size_info[1], bcast_buf.begin()+size_info[2]+1);
436      std::copy(offset.data(), offset.data()+size_info[0], bcast_buf.begin()+size_info[2]+size_info[1]+1);
437    }
438
439    MPI_Bcast(bcast_buf.data(), size_info[2]+size_info[1]+size_info[0]+1, MPI_INT_STD, local_leader, local_comm);
440
441    if(!is_local_leader)
442    {
443      remote_ep_size = bcast_buf[0];
444      std::copy(bcast_buf.data()+1, bcast_buf.data()+1+size_info[2], new_rank_info[1].begin());
445      std::copy(bcast_buf.data()+1+size_info[2], bcast_buf.data()+1+size_info[2]+size_info[1], ep_info[1].begin());
446      std::copy(bcast_buf.data()+1+size_info[2]+size_info[1], bcast_buf.data()+1+size_info[2]+size_info[1]+size_info[0], offset.begin());
447    }
448
449    int my_position = offset[rank_in_local_parent]+ep_rank_loc;
450
451
452    MPI_Barrier_local(local_comm);
453    #pragma omp flush
454
455
456    #pragma omp critical (read_from_tag_list)
457    {
458      bool found = false;
459      while(!found)
460      {
461        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
462        {
463          if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
464          {
465            *newintercomm =  iter->second[my_position];
466            found = true;
467            break;
468          }
469        }
470      }
471    }
472
473    MPI_Barrier_local(local_comm);
474
475    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
476    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
477
478    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
479    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
480    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
481    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
482    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
483    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
484
485    MPI_Bcast(&remote_ep_size, 1, MPI_INT_STD, local_leader, local_comm);
486
487    int my_rank_map_elem[2];
488
489    my_rank_map_elem[0] = intercomm_ep_rank;
490    my_rank_map_elem[1] = (*newintercomm).ep_comm_ptr->comm_label;
491
492    vector<pair<int, int> > local_rank_map_array;
493    vector<pair<int, int> > remote_rank_map_array;
494
495
496    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map = new RANK_MAP;
497    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_ep_size);
498
499    MPI_Allgather(my_rank_map_elem, 2, MPI_INT_STD, (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT_STD, local_comm);
500
501    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
502    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_ep_size);
503
504    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
505    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
506    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
507
508    int local_intercomm_size = intercomm_ep_size;
509    int remote_intercomm_size;
510
511    int new_bcast_root_0 = 0;
512    int new_bcast_root = 0;
513
514
515    if(is_local_leader)
516    {
517      MPI_Status status;
518      MPI_Send((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_ep_size, MPI_INT_STD, remote_leader, tag, peer_comm);
519      MPI_Recv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT_STD, remote_leader, tag, peer_comm, &status);
520
521      MPI_Send(&local_intercomm_size, 1, MPI_INT_STD, remote_leader, tag, peer_comm);
522      MPI_Recv(&remote_intercomm_size, 1, MPI_INT_STD, remote_leader, tag, peer_comm, &status);
523
524      new_bcast_root_0 = intercomm_ep_rank;
525    }
526
527    MPI_Allreduce(&new_bcast_root_0, &new_bcast_root, 1, MPI_INT_STD, MPI_SUM_STD, *newintercomm);
528
529
530    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, MPI_INT_STD, local_leader, local_comm);
531    MPI_Bcast(&remote_intercomm_size, 1, MPI_INT_STD, new_bcast_root, *newintercomm);
532
533
534    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map = new RANK_MAP;
535    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->resize(remote_intercomm_size);
536
537
538
539
540    if(is_local_leader)
541    {
542      MPI_Status status;
543      MPI_Send((*newintercomm).rank_map->data(), 2*local_intercomm_size, MPI_INT_STD, remote_leader, tag, peer_comm);
544      MPI_Recv((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT_STD, remote_leader, tag, peer_comm, &status);
545    }
546
547    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, MPI_INT_STD, new_bcast_root, *newintercomm);
548
549    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
550    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
551
552/*
553    for(int i=0; i<local_ep_size; i++)
554    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, local_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
555          (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).second);
556
557    for(int i=0; i<remote_ep_size; i++)
558    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, remote_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
559          (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).second);
560
561    for(int i=0; i<remote_intercomm_size; i++)
562    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, intercomm_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
563          (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second);
564*/
565
566//    for(int i=0; i<(*newintercomm).rank_map->size(); i++)
567//    if(local_comm.ep_comm_ptr->comm_label != 99) printf("ep_rank = %d, rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
568//          (*newintercomm).rank_map->at(i).first, (*newintercomm).rank_map->at(i).second);
569
570//    MPI_Comm *test_comm = newintercomm->ep_comm_ptr->intercomm->local_comm;
571//    int test_rank;
572//    MPI_Comm_rank(*test_comm, &test_rank);
573//    printf("=================test_rank = %d\n", test_rank);
574
575    return MPI_SUCCESS;
576
577  }
578
579
580
581
582  int MPI_Intercomm_create_unique_leader(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
583  {
584    //! mpi_size of local comm = 1
585    //! same world rank of leaders
586
587    int ep_rank, ep_rank_loc, mpi_rank;
588    int ep_size, num_ep, mpi_size;
589
590    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
591    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
592    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
593    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
594    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
595    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
596
597
598
599    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
600                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
601
602    int rank_in_world;
603
604    int rank_in_peer_mpi[2];
605
606    #ifdef _serialized
607    #pragma omp critical (_mpi_call)
608    #endif // _serialized
609    ::MPI_Comm_rank(MPI_COMM_WORLD_STD, &rank_in_world);
610
611
612    int local_num_ep = num_ep;
613    int remote_num_ep;
614    int total_num_ep;
615
616    int leader_rank_in_peer[2];
617
618    int my_position;
619    int tag_label[2];
620
621    vector<int> send_buf(4);
622    vector<int> recv_buf(4);
623
624
625    if(ep_rank == local_leader)
626    {
627      MPI_Status status;
628
629
630
631      MPI_Comm_rank(peer_comm, &leader_rank_in_peer[0]);
632
633      send_buf[0] = local_num_ep;
634      send_buf[1] = leader_rank_in_peer[0];
635
636      MPI_Request req_s, req_r;
637
638      MPI_Isend(send_buf.data(), 2, MPI_INT_STD, remote_leader, tag, peer_comm, &req_s);
639      MPI_Irecv(recv_buf.data(), 2, MPI_INT_STD, remote_leader, tag, peer_comm, &req_r);
640
641
642      MPI_Wait(&req_s, &status);
643      MPI_Wait(&req_r, &status);
644
645      recv_buf[2] = leader_rank_in_peer[0];
646
647    }
648
649    MPI_Bcast(recv_buf.data(), 3, MPI_INT_STD, local_leader, local_comm);
650
651    remote_num_ep = recv_buf[0];
652    leader_rank_in_peer[1] = recv_buf[1];
653    leader_rank_in_peer[0] = recv_buf[2];
654
655    total_num_ep = local_num_ep + remote_num_ep;
656
657
658    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
659    {
660      my_position = ep_rank_loc;
661      //! LEADER create EP
662      if(ep_rank == local_leader)
663      {
664        ::MPI_Comm mpi_dup;
665
666        #ifdef _serialized
667        #pragma omp critical (_mpi_call)
668        #endif // _serialized
669        ::MPI_Comm_dup(static_cast< ::MPI_Comm>(local_comm.mpi_comm), &mpi_dup);
670
671        MPI_Comm *ep_intercomm;
672        MPI_Info info;
673        MPI_Comm_create_endpoints(mpi_dup, total_num_ep, info, ep_intercomm);
674
675
676        for(int i=0; i<total_num_ep; i++)
677        {
678          ep_intercomm[i].is_intercomm = true;
679          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
680          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = 0;
681
682          ep_intercomm[i].ep_comm_ptr->comm_label = leader_rank_in_peer[0];
683        }
684
685        tag_label[0] = TAG++;
686        tag_label[1] = rank_in_world;
687
688        #pragma omp critical (write_to_tag_list)
689        tag_list.push_back(make_pair( make_pair(tag_label[0], tag_label[1]) , ep_intercomm));
690
691        MPI_Request req_s;
692        MPI_Status sta_s;
693        MPI_Isend(tag_label, 2, MPI_INT_STD, remote_leader, tag, peer_comm, &req_s);
694
695        MPI_Wait(&req_s, &sta_s);
696
697      }
698    }
699    else
700    {
701      //! Wait for EP creation
702      my_position = remote_num_ep + ep_rank_loc;
703      if(ep_rank == local_leader)
704      {
705        MPI_Status status;
706        MPI_Request req_r;
707        MPI_Irecv(tag_label, 2, MPI_INT_STD, remote_leader, tag, peer_comm, &req_r);
708        MPI_Wait(&req_r, &status);
709      }
710    }
711
712    MPI_Bcast(tag_label, 2, MPI_INT_STD, local_leader, local_comm);
713
714
715
716
717    #pragma omp critical (read_from_tag_list)
718    {
719      bool found = false;
720      while(!found)
721      {
722        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
723        {
724          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
725          {
726            *newintercomm =  iter->second[my_position];
727            found = true;
728            break;
729          }
730        }
731      }
732    }
733
734
735
736    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
737    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
738
739    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
740    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
741    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
742    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
743    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
744    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
745
746
747
748    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map  = new RANK_MAP;
749    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
750    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_num_ep);
751    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_num_ep);
752
753    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
754    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
755    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
756
757
758
759    int local_rank_map_ele[2];
760    local_rank_map_ele[0] = intercomm_ep_rank;
761    local_rank_map_ele[1] = (*newintercomm).ep_comm_ptr->comm_label;
762
763    MPI_Allgather(local_rank_map_ele, 2, MPI_INT_STD, (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, MPI_INT_STD, local_comm);
764
765    if(ep_rank == local_leader)
766    {
767      MPI_Status status;
768      MPI_Request req_s, req_r;
769
770      MPI_Isend((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_num_ep, MPI_INT_STD, remote_leader, tag, peer_comm, &req_s);
771      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT_STD, remote_leader, tag, peer_comm, &req_r);
772
773
774      MPI_Wait(&req_s, &status);
775      MPI_Wait(&req_r, &status);
776
777    }
778
779    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, MPI_INT_STD, local_leader, local_comm);
780    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
781    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
782
783
784    return MPI_SUCCESS;
785  }
786
787
788}
Note: See TracBrowser for help on using the repository browser.