source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_intercomm_kernel.cpp @ 1287

Last change on this file since 1287 was 1287, checked in by yushan, 4 years ago

EP updated

File size: 26.8 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4
5using namespace std;
6
7
8namespace ep_lib
9{
10  int MPI_Intercomm_create_kernel(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
11  {
12    int ep_rank, ep_rank_loc, mpi_rank;
13    int ep_size, num_ep, mpi_size;
14
15    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
16    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
17    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
18    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
19    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
20    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
21
22    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
23                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
24
25    int rank_in_world;
26    int rank_in_local_parent;
27
28    int rank_in_peer_mpi[2];
29
30    int local_ep_size = ep_size;
31    int remote_ep_size;
32
33
34    ::MPI_Comm local_mpi_comm = static_cast< ::MPI_Comm>(local_comm.mpi_comm);
35
36   
37    ::MPI_Comm_rank(static_cast< ::MPI_Comm>(MPI_COMM_WORLD.mpi_comm), &rank_in_world);
38    ::MPI_Comm_rank(static_cast< ::MPI_Comm>(local_comm.mpi_comm), &rank_in_local_parent);
39   
40
41    bool is_proc_master = false;
42    bool is_local_leader = false;
43    bool is_final_master = false;
44
45
46    if(ep_rank == local_leader) { is_proc_master = true; is_local_leader = true; is_final_master = true;}
47    if(ep_rank_loc == 0 && mpi_rank != local_comm.rank_map->at(local_leader).second) is_proc_master = true;
48
49
50    int size_info[4]; //! used for choose size of rank_info 0-> mpi_size of local_comm, 1-> mpi_size of remote_comm
51
52    int leader_info[4]; //! 0->world rank of local_leader, 1->world rank of remote leader
53
54
55    std::vector<int> ep_info[2]; //! 0-> num_ep in local_comm, 1->num_ep in remote_comm
56
57    std::vector<int> new_rank_info[4];
58    std::vector<int> new_ep_info[2];
59
60    std::vector<int> offset;
61
62    if(is_proc_master)
63    {
64
65      size_info[0] = mpi_size;
66
67      rank_info[0].resize(size_info[0]);
68      rank_info[1].resize(size_info[0]);
69
70
71
72      ep_info[0].resize(size_info[0]);
73
74      vector<int> send_buf(6);
75      vector<int> recv_buf(3*size_info[0]);
76
77      send_buf[0] = rank_in_world;
78      send_buf[1] = rank_in_local_parent;
79      send_buf[2] = num_ep;
80
81      ::MPI_Allgather(send_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), recv_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), local_mpi_comm);
82
83      for(int i=0; i<size_info[0]; i++)
84      {
85        rank_info[0][i] = recv_buf[3*i];
86        rank_info[1][i] = recv_buf[3*i+1];
87        ep_info[0][i]   = recv_buf[3*i+2];
88      }
89
90      if(is_local_leader)
91      {
92        leader_info[0] = rank_in_world;
93        leader_info[1] = remote_leader;
94
95        ::MPI_Comm_rank(static_cast< ::MPI_Comm>(peer_comm.mpi_comm), &rank_in_peer_mpi[0]);
96
97       
98
99        send_buf[0] = size_info[0];
100        send_buf[1] = local_ep_size;
101        send_buf[2] = rank_in_peer_mpi[0];
102
103        MPI_Request req_send, req_recv;
104        MPI_Status sta_send, sta_recv;
105       
106        MPI_Isend(send_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_send);
107        MPI_Irecv(recv_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_recv);
108
109
110        MPI_Wait(&req_send, &sta_send);
111        MPI_Wait(&req_recv, &sta_recv);
112       
113        size_info[1] = recv_buf[0];
114        remote_ep_size = recv_buf[1];
115        rank_in_peer_mpi[1] = recv_buf[2];
116
117      }
118
119
120
121      send_buf[0] = size_info[1];
122      send_buf[1] = leader_info[0];
123      send_buf[2] = leader_info[1];
124      send_buf[3] = rank_in_peer_mpi[0];
125      send_buf[4] = rank_in_peer_mpi[1];
126
127      ::MPI_Bcast(send_buf.data(), 5, static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
128
129      size_info[1] = send_buf[0];
130      leader_info[0] = send_buf[1];
131      leader_info[1] = send_buf[2];
132      rank_in_peer_mpi[0] = send_buf[3];
133      rank_in_peer_mpi[1] = send_buf[4];
134
135
136      rank_info[2].resize(size_info[1]);
137      rank_info[3].resize(size_info[1]);
138
139      ep_info[1].resize(size_info[1]);
140
141      send_buf.resize(3*size_info[0]);
142      recv_buf.resize(3*size_info[1]);
143
144      if(is_local_leader)
145      {
146        MPI_Status status;
147
148        std::copy ( rank_info[0].data(), rank_info[0].data() + size_info[0], send_buf.begin() );
149        std::copy ( rank_info[1].data(), rank_info[1].data() + size_info[0], send_buf.begin() + size_info[0] );
150        std::copy ( ep_info[0].data(),   ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[0] );
151
152        MPI_Send(send_buf.data(), 3*size_info[0], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+1, peer_comm);
153        MPI_Recv(recv_buf.data(), 3*size_info[1], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+1, peer_comm, &status);
154      }
155
156      ::MPI_Bcast(recv_buf.data(), 3*size_info[1], static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
157
158      std::copy ( recv_buf.data(), recv_buf.data() + size_info[1], rank_info[2].begin() );
159      std::copy ( recv_buf.data() + size_info[1], recv_buf.data() + 2*size_info[1], rank_info[3].begin()  );
160      std::copy ( recv_buf.data() + 2*size_info[1], recv_buf.data() + 3*size_info[1], ep_info[1].begin() );
161
162
163      offset.resize(size_info[0]);
164
165      if(leader_info[0]<leader_info[1]) // erase all ranks doubled with remote_comm, except the local leader
166      {
167
168        bool found = false;
169        int ep_local;
170        int ep_remote;
171        for(int i=0; i<size_info[0]; i++)
172        {
173          int target = rank_info[0][i];
174          found = false;
175          for(int j=0; j<size_info[1]; j++)
176          {
177            if(target == rank_info[2][j])
178            {
179              found = true;
180              ep_local = ep_info[0][j];
181              ep_remote = ep_info[1][j];
182              break;
183            }
184          }
185          if(found)
186          {
187
188            if(target == leader_info[0]) // the leader is doubled in remote
189            {
190              new_rank_info[0].push_back(target);
191              new_rank_info[1].push_back(rank_info[1][i]);
192
193              new_ep_info[0].push_back(ep_local + ep_remote);
194              offset[i] = 0;
195            }
196            else
197            {
198              offset[i] = ep_local;
199            }
200          }
201          else
202          {
203            new_rank_info[0].push_back(target);
204            new_rank_info[1].push_back(rank_info[1][i]);
205
206            new_ep_info[0].push_back(ep_info[0][i]);
207
208            offset[i] = 0;
209          }
210
211        }
212      }
213
214      else // erase rank doubled with remote leader
215      {
216
217        bool found = false;
218        int ep_local;
219        int ep_remote;
220        for(int i=0; i<size_info[0]; i++)
221        {
222          int target = rank_info[0][i];
223          found = false;
224          for(int j=0; j<size_info[1]; j++)
225          {
226
227            if(target == rank_info[2][j])
228            {
229              found = true;
230              ep_local = ep_info[0][j];
231              ep_remote = ep_info[1][j];
232              break;
233            }
234          }
235          if(found)
236          {
237            if(target != leader_info[1])
238            {
239              new_rank_info[0].push_back(target);
240              new_rank_info[1].push_back(rank_info[1][i]);
241
242              new_ep_info[0].push_back(ep_local + ep_remote);
243              offset[i] = 0;
244            }
245            else // found remote leader
246            {
247              offset[i] = ep_remote;
248            }
249          }
250          else
251          {
252            new_rank_info[0].push_back(target);
253            new_rank_info[1].push_back(rank_info[1][i]);
254
255            new_ep_info[0].push_back(ep_info[0][i]);
256            offset[i] = 0;
257          }
258        }
259      }
260
261      if(offset[mpi_rank] == 0)
262      {
263        is_final_master = true;
264      }
265
266
267      //! size_info[4]: 2->size of new_ep_info for local, 3->size of new_ep_info for remote
268
269      if(is_local_leader)
270      {
271        size_info[2] = new_ep_info[0].size();
272        MPI_Status status;
273        MPI_Send(&size_info[2], 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+2, peer_comm);
274        MPI_Recv(&size_info[3], 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+2, peer_comm, &status);
275      }
276
277      ::MPI_Bcast(&size_info[2], 2, static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
278
279      new_rank_info[2].resize(size_info[3]);
280      new_rank_info[3].resize(size_info[3]);
281      new_ep_info[1].resize(size_info[3]);
282
283      send_buf.resize(size_info[2]);
284      recv_buf.resize(size_info[3]);
285
286      if(is_local_leader)
287      {
288        MPI_Status status;
289
290        std::copy ( new_rank_info[0].data(), new_rank_info[0].data() + size_info[2], send_buf.begin() );
291        std::copy ( new_rank_info[1].data(), new_rank_info[1].data() + size_info[2], send_buf.begin() + size_info[2] );
292        std::copy ( new_ep_info[0].data(),   new_ep_info[0].data()   + size_info[0], send_buf.begin() + 2*size_info[2] );
293
294        MPI_Send(send_buf.data(), 3*size_info[2], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+3, peer_comm);
295        MPI_Recv(recv_buf.data(), 3*size_info[3], static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+3, peer_comm, &status);
296      }
297
298      ::MPI_Bcast(recv_buf.data(),   3*size_info[3], static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
299
300      std::copy ( recv_buf.data(), recv_buf.data() + size_info[3], new_rank_info[2].begin() );
301      std::copy ( recv_buf.data() + size_info[3], recv_buf.data() + 2*size_info[3], new_rank_info[3].begin()  );
302      std::copy ( recv_buf.data() + 2*size_info[3], recv_buf.data() + 3*size_info[3], new_ep_info[1].begin() );
303
304    }
305
306   
307
308    if(is_proc_master)
309    {
310      //! leader_info[4]: 2-> rank of local leader in new_group generated comm;
311                      // 3-> rank of remote leader in new_group generated comm;
312      ::MPI_Group local_group;
313      ::MPI_Group new_group;
314      ::MPI_Comm new_comm;
315      ::MPI_Comm intercomm;
316
317      ::MPI_Comm_group(local_mpi_comm, &local_group);
318
319      ::MPI_Group_incl(local_group, size_info[2], new_rank_info[1].data(), &new_group);
320
321      ::MPI_Comm_create(local_mpi_comm, new_group, &new_comm);
322
323
324
325      if(is_local_leader)
326      {
327        ::MPI_Comm_rank(new_comm, &leader_info[2]);
328      }
329
330      ::MPI_Bcast(&leader_info[2], 1, static_cast< ::MPI_Datatype> (MPI_INT), local_comm.rank_map->at(local_leader).second, local_mpi_comm);
331
332      if(new_comm != static_cast< ::MPI_Comm>(MPI_COMM_NULL.mpi_comm))
333      {
334
335        ::MPI_Barrier(new_comm);
336
337        ::MPI_Intercomm_create(new_comm, leader_info[2], static_cast< ::MPI_Comm>(peer_comm.mpi_comm), rank_in_peer_mpi[1], tag, &intercomm);
338
339        int id;
340
341        ::MPI_Comm_rank(new_comm, &id);
342        int my_num_ep = new_ep_info[0][id];
343
344        MPI_Comm *ep_intercomm;
345        MPI_Info info;
346        MPI_Comm_create_endpoints(new_comm, my_num_ep, info, ep_intercomm);
347
348
349        for(int i= 0; i<my_num_ep; i++)
350        {
351          ep_intercomm[i].is_intercomm = true;
352
353          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
354          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = intercomm;
355          ep_intercomm[i].ep_comm_ptr->comm_label = leader_info[0];
356        }
357
358
359        #pragma omp critical (write_to_tag_list)
360        tag_list.push_back(make_pair( make_pair(tag, min(leader_info[0], leader_info[1])) , ep_intercomm));
361        //printf("tag_list size = %lu\n", tag_list.size());
362      }
363    }
364
365    vector<int> bcast_buf(8);
366    if(is_local_leader)
367    {
368      std::copy(size_info, size_info+4, bcast_buf.begin());
369      std::copy(leader_info, leader_info+4, bcast_buf.begin()+4);
370    }
371
372    MPI_Bcast(bcast_buf.data(), 8, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
373
374    if(!is_local_leader)
375    {
376      std::copy(bcast_buf.begin(), bcast_buf.begin()+4, size_info);
377      std::copy(bcast_buf.begin()+4, bcast_buf.begin()+8, leader_info);
378    }
379
380    if(!is_local_leader)
381    {
382      new_rank_info[1].resize(size_info[2]);
383      ep_info[1].resize(size_info[1]);
384      offset.resize(size_info[0]);
385    }
386
387    bcast_buf.resize(size_info[2]+size_info[1]+size_info[0]+1);
388
389    if(is_local_leader)
390    {
391      bcast_buf[0] = remote_ep_size;
392      std::copy(new_rank_info[1].data(), new_rank_info[1].data()+size_info[2], bcast_buf.begin()+1);
393      std::copy(ep_info[1].data(), ep_info[1].data()+size_info[1], bcast_buf.begin()+size_info[2]+1);
394      std::copy(offset.data(), offset.data()+size_info[0], bcast_buf.begin()+size_info[2]+size_info[1]+1);
395    }
396
397    MPI_Bcast(bcast_buf.data(), size_info[2]+size_info[1]+size_info[0]+1, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
398
399    if(!is_local_leader)
400    {
401      remote_ep_size = bcast_buf[0];
402      std::copy(bcast_buf.data()+1, bcast_buf.data()+1+size_info[2], new_rank_info[1].begin());
403      std::copy(bcast_buf.data()+1+size_info[2], bcast_buf.data()+1+size_info[2]+size_info[1], ep_info[1].begin());
404      std::copy(bcast_buf.data()+1+size_info[2]+size_info[1], bcast_buf.data()+1+size_info[2]+size_info[1]+size_info[0], offset.begin());
405    }
406
407    int my_position = offset[rank_in_local_parent]+ep_rank_loc;
408   
409    MPI_Barrier_local(local_comm);
410    #pragma omp flush
411
412
413    #pragma omp critical (read_from_tag_list)
414    {
415      bool found = false;
416      while(!found)
417      {
418        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
419        {
420          if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
421          {
422            *newintercomm = iter->second[my_position];
423            found = true;
424            break;
425          }
426        }
427      }
428    }
429
430    MPI_Barrier(local_comm);
431
432    if(is_local_leader)
433    {
434      int local_flag = true;
435      int remote_flag = false;
436      MPI_Status mpi_status;
437     
438      MPI_Send(&local_flag, 1, MPI_INT, remote_leader, tag, peer_comm);
439
440      MPI_Recv(&remote_flag, 1, MPI_INT, remote_leader, tag, peer_comm, &mpi_status);
441    }
442
443
444    MPI_Barrier(local_comm);
445
446    if(is_proc_master)
447    {
448      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
449      {
450        if((*iter).first == make_pair(tag, min(leader_info[0], leader_info[1])))
451        {
452          tag_list.erase(iter);
453          break;
454        }
455      }
456    }
457
458    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
459    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
460
461    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
462    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
463    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
464    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
465    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
466    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
467
468    MPI_Bcast(&remote_ep_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
469
470    int my_rank_map_elem[2];
471
472    my_rank_map_elem[0] = intercomm_ep_rank;
473    my_rank_map_elem[1] = (*newintercomm).ep_comm_ptr->comm_label;
474
475    vector<pair<int, int> > local_rank_map_array;
476    vector<pair<int, int> > remote_rank_map_array;
477
478
479    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map = new RANK_MAP;
480    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_ep_size);
481
482    MPI_Allgather(my_rank_map_elem, 2, static_cast< ::MPI_Datatype> (MPI_INT), 
483      (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), local_comm);
484
485    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
486    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_ep_size);
487
488    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
489    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
490    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
491
492    int local_intercomm_size = intercomm_ep_size;
493    int remote_intercomm_size;
494
495    int new_bcast_root_0 = 0;
496    int new_bcast_root = 0;
497
498
499    if(is_local_leader)
500    {
501      MPI_Status status;
502      MPI_Send((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_ep_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+4, peer_comm);
503      MPI_Recv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+4, peer_comm, &status);
504
505      MPI_Send(&local_intercomm_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+5, peer_comm);
506      MPI_Recv(&remote_intercomm_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+5, peer_comm, &status);
507
508      new_bcast_root_0 = intercomm_ep_rank;
509    }
510
511    MPI_Allreduce(&new_bcast_root_0, &new_bcast_root, 1, static_cast< ::MPI_Datatype> (MPI_INT), static_cast< ::MPI_Op>(MPI_SUM), *newintercomm);
512
513
514    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_ep_size, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
515    MPI_Bcast(&remote_intercomm_size, 1, static_cast< ::MPI_Datatype> (MPI_INT), new_bcast_root, *newintercomm);
516
517
518    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map = new RANK_MAP;
519    (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->resize(remote_intercomm_size);
520
521
522
523
524    if(is_local_leader)
525    {
526      MPI_Status status;
527      MPI_Send((*newintercomm).rank_map->data(), 2*local_intercomm_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+6, peer_comm);
528      MPI_Recv((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag+6, peer_comm, &status);
529    }
530
531    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->data(), 2*remote_intercomm_size, static_cast< ::MPI_Datatype> (MPI_INT), new_bcast_root, *newintercomm);
532
533    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
534    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
535
536/*
537    for(int i=0; i<local_ep_size; i++)
538    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, local_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
539          (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->at(i).second);
540
541    for(int i=0; i<remote_ep_size; i++)
542    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, remote_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
543          (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->at(i).second);
544
545    for(int i=0; i<remote_intercomm_size; i++)
546    if(local_comm.ep_comm_ptr->comm_label == 0) printf("ep_rank (from EP) = %d, intercomm_rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
547          (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first, (*newintercomm).ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second);
548*/
549
550//    for(int i=0; i<(*newintercomm).rank_map->size(); i++)
551//    if(local_comm.ep_comm_ptr->comm_label != 99) printf("ep_rank = %d, rank_map[%d] = (%d,%d)\n", intercomm_ep_rank, i,
552//          (*newintercomm).rank_map->at(i).first, (*newintercomm).rank_map->at(i).second);
553
554//    MPI_Comm *test_comm = newintercomm->ep_comm_ptr->intercomm->local_comm;
555//    int test_rank;
556//    MPI_Comm_rank(*test_comm, &test_rank);
557//    printf("=================test_rank = %d\n", test_rank);
558   
559   
560
561    return MPI_SUCCESS;
562
563  }
564
565
566
567
568  int MPI_Intercomm_create_unique_leader(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm)
569  {
570    //! mpi_size of local comm = 1
571    //! same world rank of leaders
572
573    int ep_rank, ep_rank_loc, mpi_rank;
574    int ep_size, num_ep, mpi_size;
575
576    ep_rank = local_comm.ep_comm_ptr->size_rank_info[0].first;
577    ep_rank_loc = local_comm.ep_comm_ptr->size_rank_info[1].first;
578    mpi_rank = local_comm.ep_comm_ptr->size_rank_info[2].first;
579    ep_size = local_comm.ep_comm_ptr->size_rank_info[0].second;
580    num_ep = local_comm.ep_comm_ptr->size_rank_info[1].second;
581    mpi_size = local_comm.ep_comm_ptr->size_rank_info[2].second;
582
583
584
585    std::vector<int> rank_info[4];  //! 0->rank_in_world of local_comm,  1->rank_in_local_parent of local_comm
586                                    //! 2->rank_in_world of remote_comm, 3->rank_in_local_parent of remote_comm
587
588    int rank_in_world;
589
590    int rank_in_peer_mpi[2];
591
592    ::MPI_Comm_rank(static_cast< ::MPI_Comm >(MPI_COMM_WORLD.mpi_comm), &rank_in_world);
593
594
595    int local_num_ep = num_ep;
596    int remote_num_ep;
597    int total_num_ep;
598
599    int leader_rank_in_peer[2];
600
601    int my_position;
602    int tag_label[2];
603
604    vector<int> send_buf(4);
605    vector<int> recv_buf(4);
606
607
608    if(ep_rank == local_leader)
609    {
610      MPI_Status status;
611
612
613
614      MPI_Comm_rank(peer_comm, &leader_rank_in_peer[0]);
615
616      send_buf[0] = local_num_ep;
617      send_buf[1] = leader_rank_in_peer[0];
618
619      MPI_Request req_s, req_r;
620
621      MPI_Isend(send_buf.data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_s);
622      MPI_Irecv(recv_buf.data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_r);
623
624
625      MPI_Wait(&req_s, &status);
626      MPI_Wait(&req_r, &status);
627
628      recv_buf[2] = leader_rank_in_peer[0];
629
630    }
631
632    MPI_Bcast(recv_buf.data(), 3, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
633
634    remote_num_ep = recv_buf[0];
635    leader_rank_in_peer[1] = recv_buf[1];
636    leader_rank_in_peer[0] = recv_buf[2];
637
638    total_num_ep = local_num_ep + remote_num_ep;
639
640
641    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
642    {
643      my_position = ep_rank_loc;
644      //! LEADER create EP
645      if(ep_rank == local_leader)
646      {
647        ::MPI_Comm mpi_dup;
648       
649        ::MPI_Comm_dup(static_cast< ::MPI_Comm>(local_comm.mpi_comm), &mpi_dup);
650
651        MPI_Comm *ep_intercomm;
652        MPI_Info info;
653        MPI_Comm_create_endpoints(mpi_dup, total_num_ep, info, ep_intercomm);
654
655
656        for(int i=0; i<total_num_ep; i++)
657        {
658          ep_intercomm[i].is_intercomm = true;
659          ep_intercomm[i].ep_comm_ptr->intercomm = new ep_lib::ep_intercomm;
660          ep_intercomm[i].ep_comm_ptr->intercomm->mpi_inter_comm = 0;
661
662          ep_intercomm[i].ep_comm_ptr->comm_label = leader_rank_in_peer[0];
663        }
664
665        tag_label[0] = TAG++;
666        tag_label[1] = rank_in_world;
667
668        #pragma omp critical (write_to_tag_list)
669        tag_list.push_back(make_pair( make_pair(tag_label[0], tag_label[1]) , ep_intercomm));
670
671        MPI_Request req_s;
672        MPI_Status sta_s;
673        MPI_Isend(tag_label, 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_s);
674
675        MPI_Wait(&req_s, &sta_s);
676
677      }
678    }
679    else
680    {
681      //! Wait for EP creation
682      my_position = remote_num_ep + ep_rank_loc;
683      if(ep_rank == local_leader)
684      {
685        MPI_Status status;
686        MPI_Request req_r;
687        MPI_Irecv(tag_label, 2, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_r);
688        MPI_Wait(&req_r, &status);
689      }
690    }
691
692    MPI_Bcast(tag_label, 2, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
693
694
695
696
697    #pragma omp critical (read_from_tag_list)
698    {
699      bool found = false;
700      while(!found)
701      {
702        for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
703        {
704          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
705          {
706            *newintercomm =  iter->second[my_position];
707            found = true;
708            // tag_list.erase(iter);
709            break;
710          }
711        }
712      }
713    }
714
715    MPI_Barrier_local(local_comm);
716
717    if(leader_rank_in_peer[0] < leader_rank_in_peer[1])
718    {
719      for(std::list<std::pair < std::pair<int,int>, MPI_Comm* > >::iterator iter = tag_list.begin(); iter!=tag_list.end(); iter++)
720        {
721          if((*iter).first == make_pair(tag_label[0], tag_label[1]))
722          {
723            tag_list.erase(iter);
724            break;
725          }
726        }
727    }
728
729
730
731    int intercomm_ep_rank, intercomm_ep_rank_loc, intercomm_mpi_rank;
732    int intercomm_ep_size, intercomm_num_ep, intercomm_mpi_size;
733
734    intercomm_ep_rank = newintercomm->ep_comm_ptr->size_rank_info[0].first;
735    intercomm_ep_rank_loc = newintercomm->ep_comm_ptr->size_rank_info[1].first;
736    intercomm_mpi_rank = newintercomm->ep_comm_ptr->size_rank_info[2].first;
737    intercomm_ep_size = newintercomm->ep_comm_ptr->size_rank_info[0].second;
738    intercomm_num_ep = newintercomm->ep_comm_ptr->size_rank_info[1].second;
739    intercomm_mpi_size = newintercomm->ep_comm_ptr->size_rank_info[2].second;
740
741
742
743    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map  = new RANK_MAP;
744    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map = new RANK_MAP;
745    (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->resize(local_num_ep);
746    (*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->resize(remote_num_ep);
747
748    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[0] = local_comm.ep_comm_ptr->size_rank_info[0];
749    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[1] = local_comm.ep_comm_ptr->size_rank_info[1];
750    (*newintercomm).ep_comm_ptr->intercomm->size_rank_info[2] = local_comm.ep_comm_ptr->size_rank_info[2];
751
752
753
754    int local_rank_map_ele[2];
755    local_rank_map_ele[0] = intercomm_ep_rank;
756    local_rank_map_ele[1] = (*newintercomm).ep_comm_ptr->comm_label;
757
758    MPI_Allgather(local_rank_map_ele, 2, static_cast< ::MPI_Datatype> (MPI_INT), 
759      (*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2, static_cast< ::MPI_Datatype> (MPI_INT), local_comm);
760
761    if(ep_rank == local_leader)
762    {
763      MPI_Status status;
764      MPI_Request req_s, req_r;
765
766      MPI_Isend((*newintercomm).ep_comm_ptr->intercomm->local_rank_map->data(), 2*local_num_ep, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_s);
767      MPI_Irecv((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, static_cast< ::MPI_Datatype> (MPI_INT), remote_leader, tag, peer_comm, &req_r);
768
769
770      MPI_Wait(&req_s, &status);
771      MPI_Wait(&req_r, &status);
772
773    }
774
775    MPI_Bcast((*newintercomm).ep_comm_ptr->intercomm->remote_rank_map->data(), 2*remote_num_ep, static_cast< ::MPI_Datatype> (MPI_INT), local_leader, local_comm);
776    (*newintercomm).ep_comm_ptr->intercomm->local_comm = &(local_comm.ep_comm_ptr->comm_list[ep_rank_loc]);
777    (*newintercomm).ep_comm_ptr->intercomm->intercomm_tag = tag;
778
779
780    return MPI_SUCCESS;
781  }
782
783
784}
Note: See TracBrowser for help on using the repository browser.