source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp @ 1289

Last change on this file since 1289 was 1289, checked in by yushan, 7 years ago

EP update part 2

File size: 35.7 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17   int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm)
18  {
19    assert(valid_type(datatype));
20
21    ::MPI_Aint datasize, lb;
22    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
23
24    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
25    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
26
27    //if(ep_rank_loc == local_root) printf("local_gatherv : recvcounts = %d %d\n\n", recvcounts[0], recvcounts[1]);
28    //if(ep_rank_loc == local_root) printf("local_gatherv : displs = %d %d\n\n", displs[0], displs[1]);
29
30    #pragma omp critical (_gatherv)
31    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf);
32
33    MPI_Barrier_local(comm);
34
35    if(ep_rank_loc == local_root)
36    {
37      for(int i=0; i<num_ep; i++)
38        memcpy(recvbuf + datasize*displs[i], comm.my_buffer->void_buffer[i], datasize*recvcounts[i]);
39
40    }
41
42    MPI_Barrier_local(comm);
43  }
44
45  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int input_recvcounts[], const int input_displs[],
46                  MPI_Datatype recvtype, int root, MPI_Comm comm)
47  {
48 
49    if(!comm.is_ep)
50    {
51      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(input_recvcounts), const_cast<int*>(input_displs),
52                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
53      return 0;
54    }
55
56
57    assert(sendtype == recvtype);
58
59   
60    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
61    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
62    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
63    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
64    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
65    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
66
67    int root_mpi_rank = comm.rank_map->at(root).second;
68    int root_ep_loc = comm.rank_map->at(root).first;
69
70    ::MPI_Aint datasize, lb;
71    ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize);
72
73    int *recvcounts;
74    int* displs;
75
76    recvcounts = new int[ep_size];
77    displs = new int[ep_size];
78
79
80    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root;
81    bool is_root = ep_rank == root;
82
83    void* local_recvbuf;
84    std::vector<int>local_recvcounts(num_ep, 0);
85    std::vector<int>local_displs(num_ep, 0);
86
87
88    if(is_root)
89    { 
90      copy(input_recvcounts, input_recvcounts+ep_size, recvcounts);
91      copy(input_displs, input_displs+ep_size, displs);
92    }
93
94    MPI_Bcast(recvcounts, ep_size, MPI_INT, root, comm);
95    MPI_Bcast(displs, ep_size, MPI_INT, root, comm);
96
97    if(mpi_rank == root_mpi_rank) MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), root_ep_loc, comm);
98    else                          MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), 0, comm);
99
100
101
102    if(is_master)
103    {
104      int local_recvbuf_size = std::accumulate(local_recvcounts.begin(), local_recvcounts.end(), 0);
105     
106      for(int i=1; i<num_ep; i++)
107        local_displs[i] = local_displs[i-1] + local_recvcounts[i-1];
108
109      local_recvbuf = new void*[datasize * local_recvbuf_size];
110    }
111
112    if(mpi_rank == root_mpi_rank) MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_recvbuf, local_recvcounts.data(), local_displs.data(), root_ep_loc, comm);
113    else                          MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_recvbuf, local_recvcounts.data(), local_displs.data(), 0, comm);
114
115    //if(is_master) printf("local_recvbuf = %d %d %d %d\n", static_cast<int*>(local_recvbuf)[0], static_cast<int*>(local_recvbuf)[1], static_cast<int*>(local_recvbuf)[2], static_cast<int*>(local_recvbuf)[3]);
116
117    void* tmp_recvbuf;
118    int tmp_recvbuf_size = std::accumulate(recvcounts, recvcounts+ep_size, 0);
119
120    if(is_root) tmp_recvbuf = new void*[datasize * tmp_recvbuf_size];
121
122
123    std::vector<int> mpi_recvcounts(mpi_size, 0);
124    std::vector<int> mpi_displs(mpi_size, 0);
125
126
127    if(is_master)
128    {
129      for(int i=0; i<ep_size; i++)
130      {
131        mpi_recvcounts[comm.rank_map->at(i).second]+=recvcounts[i];
132      }
133
134
135
136      for(int i=1; i<mpi_size; i++)
137        mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1];
138
139
140
141      ::MPI_Gatherv(local_recvbuf, sendcount*num_ep, sendtype, tmp_recvbuf, mpi_recvcounts.data(), mpi_displs.data(), recvtype, root_mpi_rank, to_mpi_comm(comm.mpi_comm));
142    }   
143
144
145    // reorder data
146    if(is_root)
147    {
148      // printf("tmp_recvbuf =\n");
149      // for(int i=0; i<ep_size*sendcount; i++) printf("%d\t", static_cast<int*>(tmp_recvbuf)[i]);
150      // printf("\n");
151
152      int offset;
153      for(int i=0; i<ep_size; i++)
154      {
155        int extra = 0;
156        for(int j=0, k=0; j<ep_size, k<comm.rank_map->at(i).first; j++)
157          if(comm.rank_map->at(i).second == comm.rank_map->at(j).second)
158          {
159            extra += recvcounts[j];
160            k++;
161          } 
162
163        offset = mpi_displs[comm.rank_map->at(i).second] +  extra;
164
165        memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize);
166
167        //printf("recvbuf[%d] = tmp_recvbuf[%d] \n", i, offset);
168       
169      }
170
171      // printf("recvbuf =\n");
172      // for(int i=0; i<ep_size*sendcount; i++) printf("%d\t", static_cast<int*>(recvbuf)[i]);
173      // printf("\n");
174
175    }
176
177    delete[] recvcounts;
178    delete[] displs;
179
180    if(is_master)
181    {
182      delete[] local_recvbuf;
183    }
184    if(is_root) delete[] tmp_recvbuf;
185  }
186
187  // int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm)
188  // {
189
190  //   if(!comm.is_ep && comm.mpi_comm)
191  //   {
192  //     ::MPI_Allgatherv(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcounts, displs, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm));
193  //     return 0;
194  //   }
195
196  //   if(!comm.mpi_comm) return 0;
197
198
199
200
201  //   assert(valid_type(sendtype) && valid_type(recvtype));
202
203  //   MPI_Datatype datatype = sendtype;
204  //   int count = sendcount;
205
206  //   ::MPI_Aint datasize, lb;
207
208  //   ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
209
210
211  //   int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
212  //   int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
213  //   int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
214  //   int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
215  //   int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
216  //   int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
217
218
219  //   assert(sendcount == recvcounts[ep_rank]);
220
221  //   bool is_master = ep_rank_loc==0;
222
223  //   void* local_recvbuf;
224  //   void* tmp_recvbuf;
225
226  //   int recvbuf_size = 0;
227  //   for(int i=0; i<ep_size; i++)
228  //     recvbuf_size = max(recvbuf_size, displs[i]+recvcounts[i]);
229
230
231  //   vector<int>local_recvcounts(num_ep, 0);
232  //   vector<int>local_displs(num_ep, 0);
233
234  //   MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), 0, comm);
235  //   for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_recvcounts[i-1];
236
237
238  //   if(is_master)
239  //   {
240  //     local_recvbuf = new void*[datasize * std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0)];
241  //     tmp_recvbuf = new void*[datasize * std::accumulate(recvcounts, recvcounts+ep_size, 0)];
242  //   }
243
244  //   MPI_Gatherv_local(sendbuf, count, datatype, local_recvbuf, local_recvcounts.data(), local_displs.data(), 0, comm);
245
246
247  //   if(is_master)
248  //   {
249  //     std::vector<int>mpi_recvcounts(mpi_size, 0);
250  //     std::vector<int>mpi_displs(mpi_size, 0);
251
252  //     int local_sendcount = std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0);
253  //     MPI_Allgather(&local_sendcount, 1, MPI_INT, mpi_recvcounts.data(), 1, MPI_INT, to_mpi_comm(comm.mpi_comm));
254
255  //     for(int i=1; i<mpi_size; i++)
256  //       mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1];
257
258
259  //     ::MPI_Allgatherv(local_recvbuf, local_sendcount, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts.data(), mpi_displs.data(), to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm));
260
261
262
263  //     // reorder
264  //     int offset;
265  //     for(int i=0; i<ep_size; i++)
266  //     {
267  //       int extra = 0;
268  //       for(int j=0, k=0; j<ep_size, k<comm.rank_map->at(i).first; j++)
269  //         if(comm.rank_map->at(i).second == comm.rank_map->at(j).second)
270  //         {
271  //           extra += recvcounts[j];
272  //           k++;
273  //         } 
274
275  //       offset = mpi_displs[comm.rank_map->at(i).second] +  extra;
276
277  //       memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize);
278       
279  //     }
280
281  //   }
282
283  //   MPI_Bcast_local(recvbuf, recvbuf_size, datatype, 0, comm);
284
285  //   if(is_master)
286  //   {
287  //     delete[] local_recvbuf;
288  //     delete[] tmp_recvbuf;
289  //   }
290
291  // }
292
293
294  int MPI_Gatherv_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
295  {
296    if(datatype == MPI_INT)
297    {
298      Debug("datatype is INT\n");
299      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
300    }
301    else if(datatype == MPI_FLOAT)
302    {
303      Debug("datatype is FLOAT\n");
304      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
305    }
306    else if(datatype == MPI_DOUBLE)
307    {
308      Debug("datatype is DOUBLE\n");
309      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
310    }
311    else if(datatype == MPI_LONG)
312    {
313      Debug("datatype is LONG\n");
314      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
315    }
316    else if(datatype == MPI_UNSIGNED_LONG)
317    {
318      Debug("datatype is uLONG\n");
319      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
320    }
321    else if(datatype == MPI_CHAR)
322    {
323      Debug("datatype is CHAR\n");
324      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
325    }
326    else
327    {
328      printf("MPI_Gatherv Datatype not supported!\n");
329      exit(0);
330    }
331  }
332
333  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
334  {
335    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
336    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
337
338    int *buffer = comm.my_buffer->buf_int;
339    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
340    int *recv_buf = static_cast<int*>(recvbuf);
341
342    if(my_rank == 0)
343    {
344      assert(count == recvcounts[0]);
345      copy(send_buf, send_buf+count, recv_buf + displs[0]);
346    }
347
348    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
349    {
350      for(int k=1; k<num_ep; k++)
351      {
352        if(my_rank == k)
353        {
354          #pragma omp critical (write_to_buffer)
355          {
356            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
357            #pragma omp flush
358          }
359        }
360
361        MPI_Barrier_local(comm);
362
363        if(my_rank == 0)
364        {
365          #pragma omp flush
366          #pragma omp critical (read_from_buffer)
367          {
368            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
369          }
370        }
371
372        MPI_Barrier_local(comm);
373      }
374    }
375  }
376
377  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
378  {
379    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
380    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
381
382    float *buffer = comm.my_buffer->buf_float;
383    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
384    float *recv_buf = static_cast<float*>(recvbuf);
385
386    if(my_rank == 0)
387    {
388      assert(count == recvcounts[0]);
389      copy(send_buf, send_buf+count, recv_buf + displs[0]);
390    }
391
392    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
393    {
394      for(int k=1; k<num_ep; k++)
395      {
396        if(my_rank == k)
397        {
398          #pragma omp critical (write_to_buffer)
399          {
400            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
401            #pragma omp flush
402          }
403        }
404
405        MPI_Barrier_local(comm);
406
407        if(my_rank == 0)
408        {
409          #pragma omp flush
410          #pragma omp critical (read_from_buffer)
411          {
412            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
413          }
414        }
415
416        MPI_Barrier_local(comm);
417      }
418    }
419  }
420
421  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
422  {
423    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
424    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
425
426    double *buffer = comm.my_buffer->buf_double;
427    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
428    double *recv_buf = static_cast<double*>(recvbuf);
429
430    if(my_rank == 0)
431    {
432      assert(count == recvcounts[0]);
433      copy(send_buf, send_buf+count, recv_buf + displs[0]);
434    }
435
436    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
437    {
438      for(int k=1; k<num_ep; k++)
439      {
440        if(my_rank == k)
441        {
442          #pragma omp critical (write_to_buffer)
443          {
444            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
445            #pragma omp flush
446          }
447        }
448
449        MPI_Barrier_local(comm);
450
451        if(my_rank == 0)
452        {
453          #pragma omp flush
454          #pragma omp critical (read_from_buffer)
455          {
456            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
457          }
458        }
459
460        MPI_Barrier_local(comm);
461      }
462    }
463  }
464
465  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
466  {
467    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
468    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
469
470    long *buffer = comm.my_buffer->buf_long;
471    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
472    long *recv_buf = static_cast<long*>(recvbuf);
473
474    if(my_rank == 0)
475    {
476      assert(count == recvcounts[0]);
477      copy(send_buf, send_buf+count, recv_buf + displs[0]);
478    }
479
480    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
481    {
482      for(int k=1; k<num_ep; k++)
483      {
484        if(my_rank == k)
485        {
486          #pragma omp critical (write_to_buffer)
487          {
488            if(count!=0)copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
489            #pragma omp flush
490          }
491        }
492
493        MPI_Barrier_local(comm);
494
495        if(my_rank == 0)
496        {
497          #pragma omp flush
498          #pragma omp critical (read_from_buffer)
499          {
500            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
501          }
502        }
503
504        MPI_Barrier_local(comm);
505      }
506    }
507  }
508
509  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
510  {
511    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
512    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
513
514    unsigned long *buffer = comm.my_buffer->buf_ulong;
515    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
516    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
517
518    if(my_rank == 0)
519    {
520      assert(count == recvcounts[0]);
521      copy(send_buf, send_buf+count, recv_buf + displs[0]);
522    }
523
524    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
525    {
526      for(int k=1; k<num_ep; k++)
527      {
528        if(my_rank == k)
529        {
530          #pragma omp critical (write_to_buffer)
531          {
532            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
533            #pragma omp flush
534          }
535        }
536
537        MPI_Barrier_local(comm);
538
539        if(my_rank == 0)
540        {
541          #pragma omp flush
542          #pragma omp critical (read_from_buffer)
543          {
544            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
545          }
546        }
547
548        MPI_Barrier_local(comm);
549      }
550    }
551  }
552
553  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
554  {
555    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
556    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
557
558    char *buffer = comm.my_buffer->buf_char;
559    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
560    char *recv_buf = static_cast<char*>(recvbuf);
561
562    if(my_rank == 0)
563    {
564      assert(count == recvcounts[0]);
565      copy(send_buf, send_buf+count, recv_buf + displs[0]);
566    }
567
568    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
569    {
570      for(int k=1; k<num_ep; k++)
571      {
572        if(my_rank == k)
573        {
574          #pragma omp critical (write_to_buffer)
575          {
576            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
577            #pragma omp flush
578          }
579        }
580
581        MPI_Barrier_local(comm);
582
583        if(my_rank == 0)
584        {
585          #pragma omp flush
586          #pragma omp critical (read_from_buffer)
587          {
588            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
589          }
590        }
591
592        MPI_Barrier_local(comm);
593      }
594    }
595  }
596
597
598  int MPI_Gatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
599                  MPI_Datatype recvtype, int root, MPI_Comm comm)
600  {
601 
602    if(!comm.is_ep && comm.mpi_comm)
603    {
604      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
605                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
606      return 0;
607    }
608
609    if(!comm.mpi_comm) return 0;
610
611    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
612
613    MPI_Datatype datatype = sendtype;
614    int count = sendcount;
615
616    int ep_rank, ep_rank_loc, mpi_rank;
617    int ep_size, num_ep, mpi_size;
618
619    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
620    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
621    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
622    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
623    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
624    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
625   
626   
627   
628    if(ep_size == mpi_size) 
629      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
630                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
631
632    if(ep_rank != root)
633    {
634      recvcounts = new int[ep_size];
635      displs = new int[ep_size];
636    }
637   
638    MPI_Bcast(const_cast< int* >(displs),     ep_size, MPI_INT, root, comm);
639    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
640                             
641
642    int recv_plus_displs[ep_size];
643    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
644
645    for(int j=0; j<mpi_size; j++)
646    {
647      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||
648         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 
649      { 
650        Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n");
651        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
652      }
653
654      for(int i=1; i<num_ep-1; i++)
655      {
656        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 
657           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])
658        {
659          Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n");
660          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
661        }
662      }
663    }
664
665
666    int root_mpi_rank = comm.rank_map->at(root).second;
667    int root_ep_loc = comm.rank_map->at(root).first;
668
669
670    ::MPI_Aint datasize, lb;
671
672    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
673
674    void *local_gather_recvbuf;
675    int buffer_size;
676    void *master_recvbuf;
677
678    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 
679    {
680      master_recvbuf = new void*[sizeof(recvbuf)];
681      assert(root_ep_loc == 0);
682    }
683
684    if(ep_rank_loc==0)
685    {
686      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
687
688      local_gather_recvbuf = new void*[datasize*buffer_size];
689    }
690
691    MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
692
693    //MPI_Gather
694    if(ep_rank_loc == 0)
695    {
696      int *mpi_recvcnt= new int[mpi_size];
697      int *mpi_displs= new int[mpi_size];
698
699      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
700      int buff_end = buffer_size;
701
702      int mpi_sendcnt = buff_end - buff_start;
703
704
705      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
706      ::MPI_Gather(&buff_start,  1, MPI_INT, mpi_displs,  1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
707
708      if(root_ep_loc == 0)
709      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
710                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
711      }
712      else  // gatherv to master_recvbuf
713      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt,
714                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
715      }
716
717      delete[] mpi_recvcnt;
718      delete[] mpi_displs;
719    }
720
721    int global_min_displs = *std::min_element(displs, displs+ep_size);
722    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
723
724
725    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
726    {
727      innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
728      if(ep_rank_loc == 0) delete[] master_recvbuf;
729    }
730
731
732
733    if(ep_rank_loc==0)
734    {
735      if(datatype == MPI_INT)
736      {
737        delete[] static_cast<int*>(local_gather_recvbuf);
738      }
739      else if(datatype == MPI_FLOAT)
740      {
741        delete[] static_cast<float*>(local_gather_recvbuf);
742      }
743      else if(datatype == MPI_DOUBLE)
744      {
745        delete[] static_cast<double*>(local_gather_recvbuf);
746      }
747      else if(datatype == MPI_LONG)
748      {
749        delete[] static_cast<long*>(local_gather_recvbuf);
750      }
751      else if(datatype == MPI_UNSIGNED_LONG)
752      {
753        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
754      }
755      else // if(datatype == MPI_CHAR)
756      {
757        delete[] static_cast<char*>(local_gather_recvbuf);
758      }
759    }
760    else
761    {
762      delete[] recvcounts;
763      delete[] displs;
764    }
765    return 0;
766  }
767
768
769
770  int MPI_Allgatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
771                  MPI_Datatype recvtype, MPI_Comm comm)
772  {
773
774    if(!comm.is_ep && comm.mpi_comm)
775    {
776      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
777                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
778      return 0;
779    }
780
781    if(!comm.mpi_comm) return 0;
782
783    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
784
785
786    MPI_Datatype datatype = sendtype;
787    int count = sendcount;
788
789    int ep_rank, ep_rank_loc, mpi_rank;
790    int ep_size, num_ep, mpi_size;
791
792    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
793    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
794    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
795    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
796    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
797    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
798   
799    if(ep_size == mpi_size)  // needed by servers
800      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
801                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
802
803    int recv_plus_displs[ep_size];
804    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
805
806
807    for(int j=0; j<mpi_size; j++)
808    {
809      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||
810         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 
811      { 
812        printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size);
813        for(int k=0; k<ep_size; k++)
814          printf("recv_plus_displs[%d] = %d\t displs[%d] = %d\n", k, recv_plus_displs[k], k, displs[k]);
815
816        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
817      }
818
819      for(int i=1; i<num_ep-1; i++)
820      {
821        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 
822           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])
823        {
824          printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size);
825          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
826        }
827      }
828    }
829
830    ::MPI_Aint datasize, lb;
831
832    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
833
834    void *local_gather_recvbuf;
835    int buffer_size;
836
837    if(ep_rank_loc==0)
838    {
839      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
840
841      local_gather_recvbuf = new void*[datasize*buffer_size];
842    }
843
844    // local gather to master
845    MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
846
847    //MPI_Gather
848    if(ep_rank_loc == 0)
849    {
850      int *mpi_recvcnt= new int[mpi_size];
851      int *mpi_displs= new int[mpi_size];
852
853      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
854      int buff_end = buffer_size;
855
856      int mpi_sendcnt = buff_end - buff_start;
857
858
859      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm));
860      ::MPI_Allgather(&buff_start,  1, MPI_INT, mpi_displs,  1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm));
861
862
863      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
864                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
865
866      delete[] mpi_recvcnt;
867      delete[] mpi_displs;
868    }
869
870    int global_min_displs = *std::min_element(displs, displs+ep_size);
871    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
872
873    MPI_Bcast_local2(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
874
875    if(ep_rank_loc==0)
876    {
877      if(datatype == MPI_INT)
878      {
879        delete[] static_cast<int*>(local_gather_recvbuf);
880      }
881      else if(datatype == MPI_FLOAT)
882      {
883        delete[] static_cast<float*>(local_gather_recvbuf);
884      }
885      else if(datatype == MPI_DOUBLE)
886      {
887        delete[] static_cast<double*>(local_gather_recvbuf);
888      }
889      else if(datatype == MPI_LONG)
890      {
891        delete[] static_cast<long*>(local_gather_recvbuf);
892      }
893      else if(datatype == MPI_UNSIGNED_LONG)
894      {
895        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
896      }
897      else // if(datatype == MPI_CHAR)
898      {
899        delete[] static_cast<char*>(local_gather_recvbuf);
900      }
901    }
902  }
903
904  int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
905                          MPI_Datatype recvtype, int root, MPI_Comm comm)
906  {
907    int ep_rank, ep_rank_loc, mpi_rank;
908    int ep_size, num_ep, mpi_size;
909
910    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
911    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
912    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
913    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
914    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
915    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
916
917    int root_mpi_rank = comm.rank_map->at(root).second;
918    int root_ep_loc = comm.rank_map->at(root).first;
919
920    ::MPI_Aint datasize, lb;
921    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);
922
923    void *local_gather_recvbuf;
924    int buffer_size;
925
926    int *local_displs = new int[num_ep];
927    int *local_rvcnts = new int[num_ep];
928    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];
929    local_displs[0] = 0;
930    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];
931
932    if(ep_rank_loc==0)
933    {
934      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];
935      local_gather_recvbuf = new void*[datasize*buffer_size];
936    }
937
938    // local gather to master
939    MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master
940
941    int **mpi_recvcnts = new int*[num_ep];
942    int **mpi_displs   = new int*[num_ep];
943    for(int i=0; i<num_ep; i++) 
944    {
945      mpi_recvcnts[i] = new int[mpi_size];
946      mpi_displs[i]   = new int[mpi_size];
947      for(int j=0; j<mpi_size; j++)
948      {
949        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];
950        mpi_displs[i][j]   = displs[j*num_ep + i];
951      }
952    } 
953
954    void *master_recvbuf;
955    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
956
957    if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop
958      for(int i=0; i<num_ep; i++)
959      {
960        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],
961                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
962      }
963    if(ep_rank_loc == 0 && root_ep_loc != 0)
964      for(int i=0; i<num_ep; i++)
965      {
966        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i],
967                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
968      }
969
970
971    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
972    {
973      for(int i=0; i<ep_size; i++)
974        innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm);
975
976      if(ep_rank_loc == 0) delete[] master_recvbuf;
977    }
978
979   
980    delete[] local_displs;
981    delete[] local_rvcnts;
982    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 
983                                  delete[] mpi_displs[i]; }
984    delete[] mpi_recvcnts;
985    delete[] mpi_displs;
986    if(ep_rank_loc==0)
987    {
988      if(sendtype == MPI_INT)
989      {
990        delete[] static_cast<int*>(local_gather_recvbuf);
991      }
992      else if(sendtype == MPI_FLOAT)
993      {
994        delete[] static_cast<float*>(local_gather_recvbuf);
995      }
996      else if(sendtype == MPI_DOUBLE)
997      {
998        delete[] static_cast<double*>(local_gather_recvbuf);
999      }
1000      else if(sendtype == MPI_LONG)
1001      {
1002        delete[] static_cast<long*>(local_gather_recvbuf);
1003      }
1004      else if(sendtype == MPI_UNSIGNED_LONG)
1005      {
1006        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
1007      }
1008      else // if(sendtype == MPI_CHAR)
1009      {
1010        delete[] static_cast<char*>(local_gather_recvbuf);
1011      }
1012    }
1013  }
1014
1015  int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
1016                             MPI_Datatype recvtype, MPI_Comm comm)
1017  {
1018    int ep_rank, ep_rank_loc, mpi_rank;
1019    int ep_size, num_ep, mpi_size;
1020
1021    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
1022    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
1023    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
1024    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
1025    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
1026    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
1027
1028
1029    ::MPI_Aint datasize, lb;
1030    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);
1031
1032    void *local_gather_recvbuf;
1033    int buffer_size;
1034
1035    int *local_displs = new int[num_ep];
1036    int *local_rvcnts = new int[num_ep];
1037    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];
1038    local_displs[0] = 0;
1039    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];
1040
1041    if(ep_rank_loc==0)
1042    {
1043      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];
1044      local_gather_recvbuf = new void*[datasize*buffer_size];
1045    }
1046
1047    // local gather to master
1048    MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master
1049
1050    int **mpi_recvcnts = new int*[num_ep];
1051    int **mpi_displs   = new int*[num_ep];
1052    for(int i=0; i<num_ep; i++) 
1053    {
1054      mpi_recvcnts[i] = new int[mpi_size];
1055      mpi_displs[i]   = new int[mpi_size];
1056      for(int j=0; j<mpi_size; j++)
1057      {
1058        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];
1059        mpi_displs[i][j]   = displs[j*num_ep + i];
1060      }
1061    } 
1062
1063    if(ep_rank_loc == 0) // master in MPI_Allgatherv loop
1064    for(int i=0; i<num_ep; i++)
1065    {
1066      ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],
1067                  static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
1068    }
1069
1070    for(int i=0; i<ep_size; i++)
1071      MPI_Bcast_local2(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm);
1072
1073   
1074    delete[] local_displs;
1075    delete[] local_rvcnts;
1076    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 
1077                                  delete[] mpi_displs[i]; }
1078    delete[] mpi_recvcnts;
1079    delete[] mpi_displs;
1080    if(ep_rank_loc==0)
1081    {
1082      if(sendtype == MPI_INT)
1083      {
1084        delete[] static_cast<int*>(local_gather_recvbuf);
1085      }
1086      else if(sendtype == MPI_FLOAT)
1087      {
1088        delete[] static_cast<float*>(local_gather_recvbuf);
1089      }
1090      else if(sendtype == MPI_DOUBLE)
1091      {
1092        delete[] static_cast<double*>(local_gather_recvbuf);
1093      }
1094      else if(sendtype == MPI_LONG)
1095      {
1096        delete[] static_cast<long*>(local_gather_recvbuf);
1097      }
1098      else if(sendtype == MPI_UNSIGNED_LONG)
1099      {
1100        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
1101      }
1102      else // if(sendtype == MPI_CHAR)
1103      {
1104        delete[] static_cast<char*>(local_gather_recvbuf);
1105      }
1106    }
1107  }
1108
1109
1110}
Note: See TracBrowser for help on using the repository browser.