source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp @ 1147

Last change on this file since 1147 was 1147, checked in by yushan, 7 years ago

bug fixed in MPI_Gather(v)

File size: 18.3 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
49    }
50    else
51    {
52      printf("MPI_Gatherv Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      assert(count == recvcounts[0]);
69      copy(send_buf, send_buf+count, recv_buf + displs[0]);
70    }
71
72    for(int j=0; j<count; j+=BUFFER_SIZE)
73    {
74      for(int k=1; k<num_ep; k++)
75      {
76        if(my_rank == k)
77        {
78          #pragma omp critical (write_to_buffer)
79          {
80            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
81            #pragma omp flush
82          }
83        }
84
85        MPI_Barrier_local(comm);
86
87        if(my_rank == 0)
88        {
89          #pragma omp flush
90          #pragma omp critical (read_from_buffer)
91          {
92            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
93          }
94        }
95
96        MPI_Barrier_local(comm);
97      }
98    }
99  }
100
101  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
102  {
103    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
104    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
105
106    float *buffer = comm.my_buffer->buf_float;
107    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
108    float *recv_buf = static_cast<float*>(recvbuf);
109
110    if(my_rank == 0)
111    {
112      assert(count == recvcounts[0]);
113      copy(send_buf, send_buf+count, recv_buf + displs[0]);
114    }
115
116    for(int j=0; j<count; j+=BUFFER_SIZE)
117    {
118      for(int k=1; k<num_ep; k++)
119      {
120        if(my_rank == k)
121        {
122          #pragma omp critical (write_to_buffer)
123          {
124            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
125            #pragma omp flush
126          }
127        }
128
129        MPI_Barrier_local(comm);
130
131        if(my_rank == 0)
132        {
133          #pragma omp flush
134          #pragma omp critical (read_from_buffer)
135          {
136            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
137          }
138        }
139
140        MPI_Barrier_local(comm);
141      }
142    }
143  }
144
145  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
146  {
147    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
148    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
149
150    double *buffer = comm.my_buffer->buf_double;
151    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
152    double *recv_buf = static_cast<double*>(recvbuf);
153
154    if(my_rank == 0)
155    {
156      assert(count == recvcounts[0]);
157      copy(send_buf, send_buf+count, recv_buf + displs[0]);
158    }
159
160    for(int j=0; j<count; j+=BUFFER_SIZE)
161    {
162      for(int k=1; k<num_ep; k++)
163      {
164        if(my_rank == k)
165        {
166          #pragma omp critical (write_to_buffer)
167          {
168            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
169            #pragma omp flush
170          }
171        }
172
173        MPI_Barrier_local(comm);
174
175        if(my_rank == 0)
176        {
177          #pragma omp flush
178          #pragma omp critical (read_from_buffer)
179          {
180            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
181          }
182        }
183
184        MPI_Barrier_local(comm);
185      }
186    }
187  }
188
189  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
190  {
191    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
192    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
193
194    long *buffer = comm.my_buffer->buf_long;
195    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
196    long *recv_buf = static_cast<long*>(recvbuf);
197
198    if(my_rank == 0)
199    {
200      assert(count == recvcounts[0]);
201      copy(send_buf, send_buf+count, recv_buf + displs[0]);
202    }
203
204    for(int j=0; j<count; j+=BUFFER_SIZE)
205    {
206      for(int k=1; k<num_ep; k++)
207      {
208        if(my_rank == k)
209        {
210          #pragma omp critical (write_to_buffer)
211          {
212            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
213            #pragma omp flush
214          }
215        }
216
217        MPI_Barrier_local(comm);
218
219        if(my_rank == 0)
220        {
221          #pragma omp flush
222          #pragma omp critical (read_from_buffer)
223          {
224            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
225          }
226        }
227
228        MPI_Barrier_local(comm);
229      }
230    }
231  }
232
233  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
234  {
235    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
236    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
237
238    unsigned long *buffer = comm.my_buffer->buf_ulong;
239    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
240    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
241
242    if(my_rank == 0)
243    {
244      assert(count == recvcounts[0]);
245      copy(send_buf, send_buf+count, recv_buf + displs[0]);
246    }
247
248    for(int j=0; j<count; j+=BUFFER_SIZE)
249    {
250      for(int k=1; k<num_ep; k++)
251      {
252        if(my_rank == k)
253        {
254          #pragma omp critical (write_to_buffer)
255          {
256            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
257            #pragma omp flush
258          }
259        }
260
261        MPI_Barrier_local(comm);
262
263        if(my_rank == 0)
264        {
265          #pragma omp flush
266          #pragma omp critical (read_from_buffer)
267          {
268            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
269          }
270        }
271
272        MPI_Barrier_local(comm);
273      }
274    }
275  }
276
277  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
278  {
279    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
280    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
281
282    char *buffer = comm.my_buffer->buf_char;
283    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
284    char *recv_buf = static_cast<char*>(recvbuf);
285
286    if(my_rank == 0)
287    {
288      assert(count == recvcounts[0]);
289      copy(send_buf, send_buf+count, recv_buf + displs[0]);
290    }
291
292    for(int j=0; j<count; j+=BUFFER_SIZE)
293    {
294      for(int k=1; k<num_ep; k++)
295      {
296        if(my_rank == k)
297        {
298          #pragma omp critical (write_to_buffer)
299          {
300            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
301            #pragma omp flush
302          }
303        }
304
305        MPI_Barrier_local(comm);
306
307        if(my_rank == 0)
308        {
309          #pragma omp flush
310          #pragma omp critical (read_from_buffer)
311          {
312            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
313          }
314        }
315
316        MPI_Barrier_local(comm);
317      }
318    }
319  }
320
321
322  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
323                  MPI_Datatype recvtype, int root, MPI_Comm comm)
324  {
325 
326    if(!comm.is_ep && comm.mpi_comm)
327    {
328      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
329                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
330      return 0;
331    }
332
333    if(!comm.mpi_comm) return 0;
334
335    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
336
337    MPI_Datatype datatype = sendtype;
338    int count = sendcount;
339
340    int ep_rank, ep_rank_loc, mpi_rank;
341    int ep_size, num_ep, mpi_size;
342
343    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
344    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
345    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
346    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
347    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
348    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
349   
350    if(ep_size == mpi_size) 
351      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
352                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
353
354    int recv_plus_displs[ep_size];
355    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
356   
357    #pragma omp single nowait
358    {
359      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]);
360      for(int i=1; i<num_ep-1; i++)
361      {
362        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]);
363        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]);
364      }
365      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]);
366    }
367
368    if(ep_rank != root)
369    {
370      recvcounts = new int[ep_size];
371      displs = new int[ep_size];
372    }
373   
374    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
375    MPI_Bcast(const_cast< int* >(displs), ep_size, MPI_INT, root, comm);
376
377
378    int root_mpi_rank = comm.rank_map->at(root).second;
379    int root_ep_loc = comm.rank_map->at(root).first;
380
381
382    ::MPI_Aint datasize, lb;
383
384    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
385
386    void *local_gather_recvbuf;
387    int buffer_size;
388    void *master_recvbuf;
389
390    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
391
392    if(ep_rank_loc==0)
393    {
394      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
395
396      local_gather_recvbuf = new void*[datasize*buffer_size];
397    }
398
399    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
400
401    //MPI_Gather
402    if(ep_rank_loc == 0)
403    {
404      int *mpi_recvcnt= new int[mpi_size];
405      int *mpi_displs= new int[mpi_size];
406
407      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
408      int buff_end = buffer_size;
409
410      int mpi_sendcnt = buff_end - buff_start;
411
412
413      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
414      ::MPI_Gather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
415
416      if(root_ep_loc == 0)
417      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
418                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
419      }
420      else  // gatherv to master_recvbuf
421      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt,
422                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
423      }
424
425      delete[] mpi_recvcnt;
426      delete[] mpi_displs;
427    }
428
429    int global_min_displs = *std::min_element(displs, displs+ep_size);
430    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
431
432
433    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
434    {
435      innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
436      if(ep_rank_loc == 0) delete[] master_recvbuf;
437    }
438
439
440
441    if(ep_rank_loc==0)
442    {
443      if(datatype == MPI_INT)
444      {
445        delete[] static_cast<int*>(local_gather_recvbuf);
446      }
447      else if(datatype == MPI_FLOAT)
448      {
449        delete[] static_cast<float*>(local_gather_recvbuf);
450      }
451      else if(datatype == MPI_DOUBLE)
452      {
453        delete[] static_cast<double*>(local_gather_recvbuf);
454      }
455      else if(datatype == MPI_LONG)
456      {
457        delete[] static_cast<long*>(local_gather_recvbuf);
458      }
459      else if(datatype == MPI_UNSIGNED_LONG)
460      {
461        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
462      }
463      else // if(datatype == MPI_CHAR)
464      {
465        delete[] static_cast<char*>(local_gather_recvbuf);
466      }
467    }
468    else
469    {
470      delete[] recvcounts;
471      delete[] displs;
472    }
473    return 0;
474  }
475
476
477
478  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
479                  MPI_Datatype recvtype, MPI_Comm comm)
480  {
481
482    if(!comm.is_ep && comm.mpi_comm)
483    {
484      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
485                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
486      return 0;
487    }
488
489    if(!comm.mpi_comm) return 0;
490
491    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
492
493
494    MPI_Datatype datatype = sendtype;
495    int count = sendcount;
496
497    int ep_rank, ep_rank_loc, mpi_rank;
498    int ep_size, num_ep, mpi_size;
499
500    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
501    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
502    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
503    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
504    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
505    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
506   
507    if(ep_size == mpi_size) 
508      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
509                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
510   
511
512    int recv_plus_displs[ep_size];
513    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
514
515    #pragma omp single nowait
516    {
517      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]);
518      for(int i=1; i<num_ep-1; i++)
519      {
520        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]);
521        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]);
522      }
523      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]);
524    }
525
526
527    ::MPI_Aint datasize, lb;
528
529    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
530
531    void *local_gather_recvbuf;
532    int buffer_size;
533
534    if(ep_rank_loc==0)
535    {
536      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
537
538      local_gather_recvbuf = new void*[datasize*buffer_size];
539    }
540
541    // local gather to master
542    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
543
544    //MPI_Gather
545    if(ep_rank_loc == 0)
546    {
547      int *mpi_recvcnt= new int[mpi_size];
548      int *mpi_displs= new int[mpi_size];
549
550      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
551      int buff_end = buffer_size;
552
553      int mpi_sendcnt = buff_end - buff_start;
554
555
556      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
557      ::MPI_Allgather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
558
559
560      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
561                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
562
563      delete[] mpi_recvcnt;
564      delete[] mpi_displs;
565    }
566
567    int global_min_displs = *std::min_element(displs, displs+ep_size);
568    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
569
570    MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
571
572    if(ep_rank_loc==0)
573    {
574      if(datatype == MPI_INT)
575      {
576        delete[] static_cast<int*>(local_gather_recvbuf);
577      }
578      else if(datatype == MPI_FLOAT)
579      {
580        delete[] static_cast<float*>(local_gather_recvbuf);
581      }
582      else if(datatype == MPI_DOUBLE)
583      {
584        delete[] static_cast<double*>(local_gather_recvbuf);
585      }
586      else if(datatype == MPI_LONG)
587      {
588        delete[] static_cast<long*>(local_gather_recvbuf);
589      }
590      else if(datatype == MPI_UNSIGNED_LONG)
591      {
592        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
593      }
594      else // if(datatype == MPI_CHAR)
595      {
596        delete[] static_cast<char*>(local_gather_recvbuf);
597      }
598    }
599  }
600
601
602}
Note: See TracBrowser for help on using the repository browser.