source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp @ 1151

Last change on this file since 1151 was 1151, checked in by yushan, 5 years ago

bug fixed in MPI_Gather(v)

File size: 18.4 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
49    }
50    else
51    {
52      printf("MPI_Gatherv Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      assert(count == recvcounts[0]);
69      copy(send_buf, send_buf+count, recv_buf + displs[0]);
70    }
71
72    for(int j=0; j<count; j+=BUFFER_SIZE)
73    {
74      for(int k=1; k<num_ep; k++)
75      {
76        if(my_rank == k)
77        {
78          #pragma omp critical (write_to_buffer)
79          {
80            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
81            #pragma omp flush
82          }
83        }
84
85        MPI_Barrier_local(comm);
86
87        if(my_rank == 0)
88        {
89          #pragma omp flush
90          #pragma omp critical (read_from_buffer)
91          {
92            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
93          }
94        }
95
96        MPI_Barrier_local(comm);
97      }
98    }
99  }
100
101  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
102  {
103    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
104    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
105
106    float *buffer = comm.my_buffer->buf_float;
107    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
108    float *recv_buf = static_cast<float*>(recvbuf);
109
110    if(my_rank == 0)
111    {
112      assert(count == recvcounts[0]);
113      copy(send_buf, send_buf+count, recv_buf + displs[0]);
114    }
115
116    for(int j=0; j<count; j+=BUFFER_SIZE)
117    {
118      for(int k=1; k<num_ep; k++)
119      {
120        if(my_rank == k)
121        {
122          #pragma omp critical (write_to_buffer)
123          {
124            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
125            #pragma omp flush
126          }
127        }
128
129        MPI_Barrier_local(comm);
130
131        if(my_rank == 0)
132        {
133          #pragma omp flush
134          #pragma omp critical (read_from_buffer)
135          {
136            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
137          }
138        }
139
140        MPI_Barrier_local(comm);
141      }
142    }
143  }
144
145  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
146  {
147    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
148    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
149
150    double *buffer = comm.my_buffer->buf_double;
151    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
152    double *recv_buf = static_cast<double*>(recvbuf);
153
154    if(my_rank == 0)
155    {
156      assert(count == recvcounts[0]);
157      copy(send_buf, send_buf+count, recv_buf + displs[0]);
158    }
159
160    for(int j=0; j<count; j+=BUFFER_SIZE)
161    {
162      for(int k=1; k<num_ep; k++)
163      {
164        if(my_rank == k)
165        {
166          #pragma omp critical (write_to_buffer)
167          {
168            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
169            #pragma omp flush
170          }
171        }
172
173        MPI_Barrier_local(comm);
174
175        if(my_rank == 0)
176        {
177          #pragma omp flush
178          #pragma omp critical (read_from_buffer)
179          {
180            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
181          }
182        }
183
184        MPI_Barrier_local(comm);
185      }
186    }
187  }
188
189  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
190  {
191    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
192    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
193
194    long *buffer = comm.my_buffer->buf_long;
195    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
196    long *recv_buf = static_cast<long*>(recvbuf);
197
198    if(my_rank == 0)
199    {
200      assert(count == recvcounts[0]);
201      copy(send_buf, send_buf+count, recv_buf + displs[0]);
202    }
203
204    for(int j=0; j<count; j+=BUFFER_SIZE)
205    {
206      for(int k=1; k<num_ep; k++)
207      {
208        if(my_rank == k)
209        {
210          #pragma omp critical (write_to_buffer)
211          {
212            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
213            #pragma omp flush
214          }
215        }
216
217        MPI_Barrier_local(comm);
218
219        if(my_rank == 0)
220        {
221          #pragma omp flush
222          #pragma omp critical (read_from_buffer)
223          {
224            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
225          }
226        }
227
228        MPI_Barrier_local(comm);
229      }
230    }
231  }
232
233  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
234  {
235    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
236    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
237
238    unsigned long *buffer = comm.my_buffer->buf_ulong;
239    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
240    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
241
242    if(my_rank == 0)
243    {
244      assert(count == recvcounts[0]);
245      copy(send_buf, send_buf+count, recv_buf + displs[0]);
246    }
247
248    for(int j=0; j<count; j+=BUFFER_SIZE)
249    {
250      for(int k=1; k<num_ep; k++)
251      {
252        if(my_rank == k)
253        {
254          #pragma omp critical (write_to_buffer)
255          {
256            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
257            #pragma omp flush
258          }
259        }
260
261        MPI_Barrier_local(comm);
262
263        if(my_rank == 0)
264        {
265          #pragma omp flush
266          #pragma omp critical (read_from_buffer)
267          {
268            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
269          }
270        }
271
272        MPI_Barrier_local(comm);
273      }
274    }
275  }
276
277  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
278  {
279    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
280    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
281
282    char *buffer = comm.my_buffer->buf_char;
283    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
284    char *recv_buf = static_cast<char*>(recvbuf);
285
286    if(my_rank == 0)
287    {
288      assert(count == recvcounts[0]);
289      copy(send_buf, send_buf+count, recv_buf + displs[0]);
290    }
291
292    for(int j=0; j<count; j+=BUFFER_SIZE)
293    {
294      for(int k=1; k<num_ep; k++)
295      {
296        if(my_rank == k)
297        {
298          #pragma omp critical (write_to_buffer)
299          {
300            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
301            #pragma omp flush
302          }
303        }
304
305        MPI_Barrier_local(comm);
306
307        if(my_rank == 0)
308        {
309          #pragma omp flush
310          #pragma omp critical (read_from_buffer)
311          {
312            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
313          }
314        }
315
316        MPI_Barrier_local(comm);
317      }
318    }
319  }
320
321
322  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
323                  MPI_Datatype recvtype, int root, MPI_Comm comm)
324  {
325 
326    if(!comm.is_ep && comm.mpi_comm)
327    {
328      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
329                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
330      return 0;
331    }
332
333    if(!comm.mpi_comm) return 0;
334
335    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
336
337    MPI_Datatype datatype = sendtype;
338    int count = sendcount;
339
340    int ep_rank, ep_rank_loc, mpi_rank;
341    int ep_size, num_ep, mpi_size;
342
343    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
344    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
345    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
346    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
347    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
348    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
349   
350   
351   
352    if(ep_size == mpi_size) 
353      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
354                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
355
356    if(ep_rank != root)
357    {
358      recvcounts = new int[ep_size];
359      displs = new int[ep_size];
360    }
361   
362    MPI_Bcast(const_cast< int* >(displs),     ep_size, MPI_INT, root, comm);
363    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
364                             
365
366    int recv_plus_displs[ep_size];
367    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
368   
369    #pragma omp single nowait
370    {
371      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]);
372      for(int i=1; i<num_ep-1; i++)
373      {
374        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]);
375        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]);
376      }
377      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]);
378    }
379
380
381    int root_mpi_rank = comm.rank_map->at(root).second;
382    int root_ep_loc = comm.rank_map->at(root).first;
383
384
385    ::MPI_Aint datasize, lb;
386
387    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
388
389    void *local_gather_recvbuf;
390    int buffer_size;
391    void *master_recvbuf;
392
393    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
394
395    if(ep_rank_loc==0)
396    {
397      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
398
399      local_gather_recvbuf = new void*[datasize*buffer_size];
400    }
401
402    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
403
404    //MPI_Gather
405    if(ep_rank_loc == 0)
406    {
407      int *mpi_recvcnt= new int[mpi_size];
408      int *mpi_displs= new int[mpi_size];
409
410      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
411      int buff_end = buffer_size;
412
413      int mpi_sendcnt = buff_end - buff_start;
414
415
416      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
417      ::MPI_Gather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
418
419      if(root_ep_loc == 0)
420      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
421                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
422      }
423      else  // gatherv to master_recvbuf
424      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt,
425                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
426      }
427
428      delete[] mpi_recvcnt;
429      delete[] mpi_displs;
430    }
431
432    int global_min_displs = *std::min_element(displs, displs+ep_size);
433    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
434
435
436    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
437    {
438      innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
439      if(ep_rank_loc == 0) delete[] master_recvbuf;
440    }
441
442
443
444    if(ep_rank_loc==0)
445    {
446      if(datatype == MPI_INT)
447      {
448        delete[] static_cast<int*>(local_gather_recvbuf);
449      }
450      else if(datatype == MPI_FLOAT)
451      {
452        delete[] static_cast<float*>(local_gather_recvbuf);
453      }
454      else if(datatype == MPI_DOUBLE)
455      {
456        delete[] static_cast<double*>(local_gather_recvbuf);
457      }
458      else if(datatype == MPI_LONG)
459      {
460        delete[] static_cast<long*>(local_gather_recvbuf);
461      }
462      else if(datatype == MPI_UNSIGNED_LONG)
463      {
464        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
465      }
466      else // if(datatype == MPI_CHAR)
467      {
468        delete[] static_cast<char*>(local_gather_recvbuf);
469      }
470    }
471    else
472    {
473      delete[] recvcounts;
474      delete[] displs;
475    }
476    return 0;
477  }
478
479
480
481  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
482                  MPI_Datatype recvtype, MPI_Comm comm)
483  {
484
485    if(!comm.is_ep && comm.mpi_comm)
486    {
487      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
488                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
489      return 0;
490    }
491
492    if(!comm.mpi_comm) return 0;
493
494    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
495
496
497    MPI_Datatype datatype = sendtype;
498    int count = sendcount;
499
500    int ep_rank, ep_rank_loc, mpi_rank;
501    int ep_size, num_ep, mpi_size;
502
503    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
504    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
505    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
506    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
507    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
508    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
509   
510    if(ep_size == mpi_size) 
511      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
512                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
513   
514
515    int recv_plus_displs[ep_size];
516    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
517
518    #pragma omp single nowait
519    {
520      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]);
521      for(int i=1; i<num_ep-1; i++)
522      {
523        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]);
524        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]);
525      }
526      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]);
527    }
528
529
530    ::MPI_Aint datasize, lb;
531
532    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
533
534    void *local_gather_recvbuf;
535    int buffer_size;
536
537    if(ep_rank_loc==0)
538    {
539      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
540
541      local_gather_recvbuf = new void*[datasize*buffer_size];
542    }
543
544    // local gather to master
545    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
546
547    //MPI_Gather
548    if(ep_rank_loc == 0)
549    {
550      int *mpi_recvcnt= new int[mpi_size];
551      int *mpi_displs= new int[mpi_size];
552
553      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
554      int buff_end = buffer_size;
555
556      int mpi_sendcnt = buff_end - buff_start;
557
558
559      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
560      ::MPI_Allgather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
561
562
563      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
564                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
565
566      delete[] mpi_recvcnt;
567      delete[] mpi_displs;
568    }
569
570    int global_min_displs = *std::min_element(displs, displs+ep_size);
571    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
572
573    MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
574
575    if(ep_rank_loc==0)
576    {
577      if(datatype == MPI_INT)
578      {
579        delete[] static_cast<int*>(local_gather_recvbuf);
580      }
581      else if(datatype == MPI_FLOAT)
582      {
583        delete[] static_cast<float*>(local_gather_recvbuf);
584      }
585      else if(datatype == MPI_DOUBLE)
586      {
587        delete[] static_cast<double*>(local_gather_recvbuf);
588      }
589      else if(datatype == MPI_LONG)
590      {
591        delete[] static_cast<long*>(local_gather_recvbuf);
592      }
593      else if(datatype == MPI_UNSIGNED_LONG)
594      {
595        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
596      }
597      else // if(datatype == MPI_CHAR)
598      {
599        delete[] static_cast<char*>(local_gather_recvbuf);
600      }
601    }
602  }
603
604
605}
Note: See TracBrowser for help on using the repository browser.