source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gather.cpp @ 1289

Last change on this file since 1289 was 1289, checked in by yushan, 7 years ago

EP update part 2

File size: 21.2 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gather, MPI_Allgather
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, int local_root, MPI_Comm comm)
18  {
19    assert(valid_type(datatype));
20
21    ::MPI_Aint datasize, lb;
22    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
23
24    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
25    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
26
27    #pragma omp critical (_gather)
28    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf);
29
30    MPI_Barrier_local(comm);
31
32    if(ep_rank_loc == local_root)
33    {
34      for(int i=0; i<num_ep; i++)
35        memcpy(recvbuf + datasize * i * count, comm.my_buffer->void_buffer[i], datasize * count);
36
37      //printf("local_recvbuf = %d %d \n", static_cast<int*>(recvbuf)[0], static_cast<int*>(recvbuf)[1] );
38    }
39
40    MPI_Barrier_local(comm);
41  }
42
43  int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
44  {
45    if(!comm.is_ep)
46    {
47      return ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype),
48                   root, to_mpi_comm(comm.mpi_comm));
49    }
50
51    assert(sendcount == recvcount && sendtype == recvtype);
52
53    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
54    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
55    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
56    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
57    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
58    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
59
60    int root_mpi_rank = comm.rank_map->at(root).second;
61    int root_ep_loc = comm.rank_map->at(root).first;
62
63    ::MPI_Aint datasize, lb;
64    ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize);
65
66    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root;
67    bool is_root = ep_rank == root;
68
69    void* local_recvbuf;
70
71    if(is_master)
72    {
73      local_recvbuf = new void*[datasize * num_ep * sendcount];
74    }
75
76    void* tmp_recvbuf;
77    if(is_root) tmp_recvbuf = new void*[datasize * recvcount * ep_size];
78
79
80    if(mpi_rank == root_mpi_rank) MPI_Gather_local(sendbuf, sendcount, sendtype, local_recvbuf, root_ep_loc, comm);
81    else                          MPI_Gather_local(sendbuf, sendcount, sendtype, local_recvbuf, 0, comm);
82
83    std::vector<int> recvcounts(mpi_size, 0);
84    std::vector<int> displs(mpi_size, 0);
85
86
87    if(is_master)
88    {
89      for(int i=0; i<ep_size; i++)
90      {
91        recvcounts[comm.rank_map->at(i).second]+=sendcount;
92      }
93
94      for(int i=1; i<mpi_size; i++)
95        displs[i] = displs[i-1] + recvcounts[i-1];
96
97      ::MPI_Gatherv(local_recvbuf, sendcount*num_ep, sendtype, tmp_recvbuf, recvcounts.data(), displs.data(), recvtype, root_mpi_rank, to_mpi_comm(comm.mpi_comm));
98    }   
99
100
101    // reorder data
102    if(is_root)
103    {
104      // printf("tmp_recvbuf = %d %d %d %d %d %d %d %d\n", static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1],
105      //                                                   static_cast<int*>(tmp_recvbuf)[2], static_cast<int*>(tmp_recvbuf)[3],
106      //                                                   static_cast<int*>(tmp_recvbuf)[4], static_cast<int*>(tmp_recvbuf)[5],
107      //                                                   static_cast<int*>(tmp_recvbuf)[6], static_cast<int*>(tmp_recvbuf)[7] );
108
109      int offset;
110      for(int i=0; i<ep_size; i++)
111      {
112        offset = displs[comm.rank_map->at(i).second] + comm.rank_map->at(i).first * sendcount; 
113        memcpy(recvbuf + i*sendcount*datasize, tmp_recvbuf+offset*datasize, sendcount*datasize);
114
115
116      }
117
118    }
119
120
121    if(is_master)
122    {
123      delete[] local_recvbuf;
124    }
125    if(is_root) delete[] tmp_recvbuf;
126   
127  }
128
129  // int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
130  // {
131
132  //   if(!comm.is_ep && comm.mpi_comm)
133  //   {
134  //     ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
135  //                     static_cast< ::MPI_Comm>(comm.mpi_comm));
136  //     return 0;
137  //   }
138
139  //   if(!comm.mpi_comm) return 0;
140
141  //   assert(sendcount == recvcount);
142
143  //   assert(valid_type(sendtype) && valid_type(recvtype));
144
145  //   MPI_Datatype datatype = sendtype;
146  //   int count = sendcount;
147
148  //   ::MPI_Aint datasize, lb;
149
150  //   ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
151
152
153  //   int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
154  //   int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
155  //   int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
156  //   int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
157  //   int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
158  //   int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
159
160  //   bool is_master = ep_rank_loc==0;
161
162  //   void* local_recvbuf;
163  //   void* tmp_recvbuf;
164
165
166  //   if(is_master)
167  //   {
168  //     local_recvbuf = new void*[datasize * num_ep * count];
169  //     tmp_recvbuf = new void*[datasize * count * ep_size];
170  //   }
171
172  //   MPI_Gather_local(sendbuf, count, datatype, local_recvbuf, 0, comm);
173
174
175  //   int* mpi_recvcounts;
176  //   int *mpi_displs;
177   
178  //   if(is_master)
179  //   {
180     
181  //     mpi_recvcounts = new int[mpi_size];
182  //     mpi_displs = new int[mpi_size];
183
184  //     int local_sendcount = num_ep * count;
185
186  //     ::MPI_Allgather(&local_sendcount, 1, to_mpi_type(MPI_INT), mpi_recvcounts, 1, to_mpi_type(MPI_INT), to_mpi_comm(comm.mpi_comm));
187
188  //     mpi_displs[0] = 0;
189  //     for(int i=1; i<mpi_size; i++)
190  //     {
191  //       mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1];
192  //     }
193
194   
195  //     ::MPI_Allgatherv(local_recvbuf, num_ep * count, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts, mpi_displs, to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm));
196
197
198  //     // reorder
199  //     int offset;
200  //     for(int i=0; i<ep_size; i++)
201  //     {
202  //       offset = mpi_displs[comm.rank_map->at(i).second] + comm.rank_map->at(i).first * sendcount;
203  //       memcpy(recvbuf + i*sendcount*datasize, tmp_recvbuf+offset*datasize, sendcount*datasize);
204  //     }
205
206  //     delete[] mpi_recvcounts;
207  //     delete[] mpi_displs;
208  //   }
209
210  //   MPI_Bcast_local(recvbuf, count*ep_size, datatype, 0, comm);
211
212  //   MPI_Barrier(comm);
213
214
215  //   if(is_master)
216  //   {
217  //     delete[] local_recvbuf;
218  //     delete[] tmp_recvbuf;
219
220  //   }
221
222  // }
223
224  int MPI_Gather_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
225  {
226    if(datatype == MPI_INT)
227    {
228      Debug("datatype is INT\n");
229      return MPI_Gather_local_int(sendbuf, count, recvbuf, comm);
230    }
231    else if(datatype == MPI_FLOAT)
232    {
233      Debug("datatype is FLOAT\n");
234      return MPI_Gather_local_float(sendbuf, count, recvbuf, comm);
235    }
236    else if(datatype == MPI_DOUBLE)
237    {
238      Debug("datatype is DOUBLE\n");
239      return MPI_Gather_local_double(sendbuf, count, recvbuf, comm);
240    }
241    else if(datatype == MPI_LONG)
242    {
243      Debug("datatype is LONG\n");
244      return MPI_Gather_local_long(sendbuf, count, recvbuf, comm);
245    }
246    else if(datatype == MPI_UNSIGNED_LONG)
247    {
248      Debug("datatype is uLONG\n");
249      return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm);
250    }
251    else if(datatype == MPI_CHAR)
252    {
253      Debug("datatype is CHAR\n");
254      return MPI_Gather_local_char(sendbuf, count, recvbuf, comm);
255    }
256    else
257    {
258      printf("MPI_Gather Datatype not supported!\n");
259      exit(0);
260    }
261  }
262
263  int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
264  {
265    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
266    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
267
268    int *buffer = comm.my_buffer->buf_int;
269    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
270    int *recv_buf = static_cast<int*>(recvbuf);
271
272    if(my_rank == 0)
273    {
274      copy(send_buf, send_buf+count, recv_buf);
275    }
276
277    for(int j=0; j<count; j+=BUFFER_SIZE)
278    {
279      for(int k=1; k<num_ep; k++)
280      {
281        if(my_rank == k)
282        {
283          #pragma omp critical (write_to_buffer)
284          {
285            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
286            #pragma omp flush
287          }
288        }
289
290        MPI_Barrier_local(comm);
291
292        if(my_rank == 0)
293        {
294          #pragma omp flush
295          #pragma omp critical (read_from_buffer)
296          {
297            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
298          }
299        }
300
301        MPI_Barrier_local(comm);
302      }
303    }
304  }
305
306  int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
307  {
308    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
309    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
310
311    float *buffer = comm.my_buffer->buf_float;
312    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
313    float *recv_buf = static_cast<float*>(recvbuf);
314
315    if(my_rank == 0)
316    {
317      copy(send_buf, send_buf+count, recv_buf);
318    }
319
320    for(int j=0; j<count; j+=BUFFER_SIZE)
321    {
322      for(int k=1; k<num_ep; k++)
323      {
324        if(my_rank == k)
325        {
326          #pragma omp critical (write_to_buffer)
327          {
328            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
329            #pragma omp flush
330          }
331        }
332
333        MPI_Barrier_local(comm);
334
335        if(my_rank == 0)
336        {
337          #pragma omp flush
338          #pragma omp critical (read_from_buffer)
339          {
340            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
341          }
342        }
343
344        MPI_Barrier_local(comm);
345      }
346    }
347  }
348
349  int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
350  {
351    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
352    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
353
354    double *buffer = comm.my_buffer->buf_double;
355    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
356    double *recv_buf = static_cast<double*>(recvbuf);
357
358    if(my_rank == 0)
359    {
360      copy(send_buf, send_buf+count, recv_buf);
361    }
362
363    for(int j=0; j<count; j+=BUFFER_SIZE)
364    {
365      for(int k=1; k<num_ep; k++)
366      {
367        if(my_rank == k)
368        {
369          #pragma omp critical (write_to_buffer)
370          {
371            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
372            #pragma omp flush
373          }
374        }
375
376        MPI_Barrier_local(comm);
377
378        if(my_rank == 0)
379        {
380          #pragma omp flush
381          #pragma omp critical (read_from_buffer)
382          {
383            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
384          }
385        }
386
387        MPI_Barrier_local(comm);
388      }
389    }
390  }
391
392  int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
393  {
394    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
395    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
396
397    long *buffer = comm.my_buffer->buf_long;
398    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
399    long *recv_buf = static_cast<long*>(recvbuf);
400
401    if(my_rank == 0)
402    {
403      copy(send_buf, send_buf+count, recv_buf);
404    }
405
406    for(int j=0; j<count; j+=BUFFER_SIZE)
407    {
408      for(int k=1; k<num_ep; k++)
409      {
410        if(my_rank == k)
411        {
412          #pragma omp critical (write_to_buffer)
413          {
414            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
415            #pragma omp flush
416          }
417        }
418
419        MPI_Barrier_local(comm);
420
421        if(my_rank == 0)
422        {
423          #pragma omp flush
424          #pragma omp critical (read_from_buffer)
425          {
426            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
427          }
428        }
429
430        MPI_Barrier_local(comm);
431      }
432    }
433  }
434
435  int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
436  {
437    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
438    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
439
440    unsigned long *buffer = comm.my_buffer->buf_ulong;
441    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
442    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
443
444    if(my_rank == 0)
445    {
446      copy(send_buf, send_buf+count, recv_buf);
447    }
448
449    for(int j=0; j<count; j+=BUFFER_SIZE)
450    {
451      for(int k=1; k<num_ep; k++)
452      {
453        if(my_rank == k)
454        {
455          #pragma omp critical (write_to_buffer)
456          {
457            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
458            #pragma omp flush
459          }
460        }
461
462        MPI_Barrier_local(comm);
463
464        if(my_rank == 0)
465        {
466          #pragma omp flush
467          #pragma omp critical (read_from_buffer)
468          {
469            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
470          }
471        }
472
473        MPI_Barrier_local(comm);
474      }
475    }
476  }
477
478
479  int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
480  {
481    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
482    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
483
484    char *buffer = comm.my_buffer->buf_char;
485    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
486    char *recv_buf = static_cast<char*>(recvbuf);
487
488    if(my_rank == 0)
489    {
490      copy(send_buf, send_buf+count, recv_buf);
491    }
492
493    for(int j=0; j<count; j+=BUFFER_SIZE)
494    {
495      for(int k=1; k<num_ep; k++)
496      {
497        if(my_rank == k)
498        {
499          #pragma omp critical (write_to_buffer)
500          {
501            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
502            #pragma omp flush
503          }
504        }
505
506        MPI_Barrier_local(comm);
507
508        if(my_rank == 0)
509        {
510          #pragma omp flush
511          #pragma omp critical (read_from_buffer)
512          {
513            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
514          }
515        }
516
517        MPI_Barrier_local(comm);
518      }
519    }
520  }
521
522
523
524  int MPI_Gather2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
525  {
526    if(!comm.is_ep && comm.mpi_comm)
527    {
528      ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
529                   root, static_cast< ::MPI_Comm>(comm.mpi_comm));
530      return 0;
531    }
532
533    if(!comm.mpi_comm) return 0;
534   
535    MPI_Bcast(&recvcount, 1, MPI_INT, root, comm);
536
537    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
538
539    MPI_Datatype datatype = sendtype;
540    int count = sendcount;
541
542    int ep_rank, ep_rank_loc, mpi_rank;
543    int ep_size, num_ep, mpi_size;
544
545    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
546    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
547    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
548    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
549    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
550    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
551
552
553    int root_mpi_rank = comm.rank_map->at(root).second;
554    int root_ep_loc = comm.rank_map->at(root).first;
555
556
557    ::MPI_Aint datasize, lb;
558
559    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
560
561    void *local_gather_recvbuf;
562    void *master_recvbuf;
563    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 
564    {
565      master_recvbuf = new void*[datasize*ep_size*count];
566    }
567
568    if(ep_rank_loc==0)
569    {
570      local_gather_recvbuf = new void*[datasize*num_ep*count];
571    }
572
573    // local gather to master
574    MPI_Gather_local2(sendbuf, count, datatype, local_gather_recvbuf, comm);
575
576    //MPI_Gather
577
578    if(ep_rank_loc == 0)
579    {
580      int *gatherv_recvcnt;
581      int *gatherv_displs;
582      int gatherv_cnt = count*num_ep;
583
584      gatherv_recvcnt = new int[mpi_size];
585      gatherv_displs = new int[mpi_size];
586
587
588      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT, gatherv_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm));
589
590      gatherv_displs[0] = 0;
591      for(int i=1; i<mpi_size; i++)
592      {
593        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
594      }
595
596      if(root_ep_loc != 0) // gather to root_master
597      {
598        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, gatherv_recvcnt,
599                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
600      }
601      else
602      {
603        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
604                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
605      }
606
607      delete[] gatherv_recvcnt;
608      delete[] gatherv_displs;
609    }
610
611
612    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
613    {
614      innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);
615    }
616
617
618
619    if(ep_rank_loc==0)
620    {
621      if(datatype == MPI_INT)
622      {
623        delete[] static_cast<int*>(local_gather_recvbuf);
624      }
625      else if(datatype == MPI_FLOAT)
626      {
627        delete[] static_cast<float*>(local_gather_recvbuf);
628      }
629      else if(datatype == MPI_DOUBLE)
630      {
631        delete[] static_cast<double*>(local_gather_recvbuf);
632      }
633      else if(datatype == MPI_CHAR)
634      {
635        delete[] static_cast<char*>(local_gather_recvbuf);
636      }
637      else if(datatype == MPI_LONG)
638      {
639        delete[] static_cast<long*>(local_gather_recvbuf);
640      }
641      else// if(datatype == MPI_UNSIGNED_LONG)
642      {
643        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
644      }
645     
646      if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) delete[] master_recvbuf;
647    }
648  }
649
650
651  int MPI_Allgather2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
652  {
653    if(!comm.is_ep && comm.mpi_comm)
654    {
655      ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
656                      static_cast< ::MPI_Comm>(comm.mpi_comm));
657      return 0;
658    }
659
660    if(!comm.mpi_comm) return 0;
661
662    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
663
664    MPI_Datatype datatype = sendtype;
665    int count = sendcount;
666
667    int ep_rank, ep_rank_loc, mpi_rank;
668    int ep_size, num_ep, mpi_size;
669
670    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
671    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
672    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
673    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
674    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
675    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
676
677
678    ::MPI_Aint datasize, lb;
679
680    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
681
682    void *local_gather_recvbuf;
683
684    if(ep_rank_loc==0)
685    {
686      local_gather_recvbuf = new void*[datasize*num_ep*count];
687    }
688
689    // local gather to master
690    MPI_Gather_local2(sendbuf, count, datatype, local_gather_recvbuf, comm);
691
692    //MPI_Gather
693
694    if(ep_rank_loc == 0)
695    {
696      int *gatherv_recvcnt;
697      int *gatherv_displs;
698      int gatherv_cnt = count*num_ep;
699
700      gatherv_recvcnt = new int[mpi_size];
701      gatherv_displs = new int[mpi_size];
702
703      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT, gatherv_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm));
704
705      gatherv_displs[0] = 0;
706      for(int i=1; i<mpi_size; i++)
707      {
708        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
709      }
710
711      ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
712                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
713
714      delete[] gatherv_recvcnt;
715      delete[] gatherv_displs;
716    }
717
718    MPI_Bcast_local2(recvbuf, count*ep_size, datatype, comm);
719
720
721    if(ep_rank_loc==0)
722    {
723      if(datatype == MPI_INT)
724      {
725        delete[] static_cast<int*>(local_gather_recvbuf);
726      }
727      else if(datatype == MPI_FLOAT)
728      {
729        delete[] static_cast<float*>(local_gather_recvbuf);
730      }
731      else if(datatype == MPI_DOUBLE)
732      {
733        delete[] static_cast<double*>(local_gather_recvbuf);
734      }
735      else if(datatype == MPI_CHAR)
736      {
737        delete[] static_cast<char*>(local_gather_recvbuf);
738      }
739      else if(datatype == MPI_LONG)
740      {
741        delete[] static_cast<long*>(local_gather_recvbuf);
742      }
743      else// if(datatype == MPI_UNSIGNED_LONG)
744      {
745        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
746      }
747    }
748  }
749
750
751}
Note: See TracBrowser for help on using the repository browser.