source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp @ 1145

Last change on this file since 1145 was 1145, checked in by yushan, 7 years ago

bug fixed in MPI_(All)Gatherv

File size: 17.8 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
49    }
50    else
51    {
52      printf("MPI_Gatherv Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      assert(count == recvcounts[0]);
69      copy(send_buf, send_buf+count, recv_buf + displs[0]);
70    }
71
72    for(int j=0; j<count; j+=BUFFER_SIZE)
73    {
74      for(int k=1; k<num_ep; k++)
75      {
76        if(my_rank == k)
77        {
78          #pragma omp critical (write_to_buffer)
79          {
80            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
81            #pragma omp flush
82          }
83        }
84
85        MPI_Barrier_local(comm);
86
87        if(my_rank == 0)
88        {
89          #pragma omp flush
90          #pragma omp critical (read_from_buffer)
91          {
92            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
93          }
94        }
95
96        MPI_Barrier_local(comm);
97      }
98    }
99  }
100
101  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
102  {
103    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
104    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
105
106    float *buffer = comm.my_buffer->buf_float;
107    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
108    float *recv_buf = static_cast<float*>(recvbuf);
109
110    if(my_rank == 0)
111    {
112      assert(count == recvcounts[0]);
113      copy(send_buf, send_buf+count, recv_buf + displs[0]);
114    }
115
116    for(int j=0; j<count; j+=BUFFER_SIZE)
117    {
118      for(int k=1; k<num_ep; k++)
119      {
120        if(my_rank == k)
121        {
122          #pragma omp critical (write_to_buffer)
123          {
124            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
125            #pragma omp flush
126          }
127        }
128
129        MPI_Barrier_local(comm);
130
131        if(my_rank == 0)
132        {
133          #pragma omp flush
134          #pragma omp critical (read_from_buffer)
135          {
136            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
137          }
138        }
139
140        MPI_Barrier_local(comm);
141      }
142    }
143  }
144
145  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
146  {
147    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
148    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
149
150    double *buffer = comm.my_buffer->buf_double;
151    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
152    double *recv_buf = static_cast<double*>(recvbuf);
153
154    if(my_rank == 0)
155    {
156      assert(count == recvcounts[0]);
157      copy(send_buf, send_buf+count, recv_buf + displs[0]);
158    }
159
160    for(int j=0; j<count; j+=BUFFER_SIZE)
161    {
162      for(int k=1; k<num_ep; k++)
163      {
164        if(my_rank == k)
165        {
166          #pragma omp critical (write_to_buffer)
167          {
168            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
169            #pragma omp flush
170          }
171        }
172
173        MPI_Barrier_local(comm);
174
175        if(my_rank == 0)
176        {
177          #pragma omp flush
178          #pragma omp critical (read_from_buffer)
179          {
180            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
181          }
182        }
183
184        MPI_Barrier_local(comm);
185      }
186    }
187  }
188
189  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
190  {
191    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
192    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
193
194    long *buffer = comm.my_buffer->buf_long;
195    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
196    long *recv_buf = static_cast<long*>(recvbuf);
197
198    if(my_rank == 0)
199    {
200      assert(count == recvcounts[0]);
201      copy(send_buf, send_buf+count, recv_buf + displs[0]);
202    }
203
204    for(int j=0; j<count; j+=BUFFER_SIZE)
205    {
206      for(int k=1; k<num_ep; k++)
207      {
208        if(my_rank == k)
209        {
210          #pragma omp critical (write_to_buffer)
211          {
212            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
213            #pragma omp flush
214          }
215        }
216
217        MPI_Barrier_local(comm);
218
219        if(my_rank == 0)
220        {
221          #pragma omp flush
222          #pragma omp critical (read_from_buffer)
223          {
224            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
225          }
226        }
227
228        MPI_Barrier_local(comm);
229      }
230    }
231  }
232
233  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
234  {
235    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
236    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
237
238    unsigned long *buffer = comm.my_buffer->buf_ulong;
239    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
240    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
241
242    if(my_rank == 0)
243    {
244      assert(count == recvcounts[0]);
245      copy(send_buf, send_buf+count, recv_buf + displs[0]);
246    }
247
248    for(int j=0; j<count; j+=BUFFER_SIZE)
249    {
250      for(int k=1; k<num_ep; k++)
251      {
252        if(my_rank == k)
253        {
254          #pragma omp critical (write_to_buffer)
255          {
256            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
257            #pragma omp flush
258          }
259        }
260
261        MPI_Barrier_local(comm);
262
263        if(my_rank == 0)
264        {
265          #pragma omp flush
266          #pragma omp critical (read_from_buffer)
267          {
268            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
269          }
270        }
271
272        MPI_Barrier_local(comm);
273      }
274    }
275  }
276
277  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
278  {
279    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
280    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
281
282    char *buffer = comm.my_buffer->buf_char;
283    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
284    char *recv_buf = static_cast<char*>(recvbuf);
285
286    if(my_rank == 0)
287    {
288      assert(count == recvcounts[0]);
289      copy(send_buf, send_buf+count, recv_buf + displs[0]);
290    }
291
292    for(int j=0; j<count; j+=BUFFER_SIZE)
293    {
294      for(int k=1; k<num_ep; k++)
295      {
296        if(my_rank == k)
297        {
298          #pragma omp critical (write_to_buffer)
299          {
300            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
301            #pragma omp flush
302          }
303        }
304
305        MPI_Barrier_local(comm);
306
307        if(my_rank == 0)
308        {
309          #pragma omp flush
310          #pragma omp critical (read_from_buffer)
311          {
312            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
313          }
314        }
315
316        MPI_Barrier_local(comm);
317      }
318    }
319  }
320
321
322  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
323                  MPI_Datatype recvtype, int root, MPI_Comm comm)
324  {
325 
326    if(!comm.is_ep && comm.mpi_comm)
327    {
328      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
329                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
330      return 0;
331    }
332
333    if(!comm.mpi_comm) return 0;
334
335    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
336
337    MPI_Datatype datatype = sendtype;
338    int count = sendcount;
339
340    int ep_rank, ep_rank_loc, mpi_rank;
341    int ep_size, num_ep, mpi_size;
342
343    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
344    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
345    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
346    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
347    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
348    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
349   
350    if(ep_size == mpi_size) 
351      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
352                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
353
354    int recv_plus_displs[ep_size];
355    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
356   
357    #pragma omp single nowait
358    {
359      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]);
360      for(int i=1; i<num_ep-1; i++)
361      {
362        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]);
363        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]);
364      }
365      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]);
366    }
367
368    if(ep_rank != root)
369    {
370      recvcounts = new int[ep_size];
371      displs = new int[ep_size];
372    }
373   
374    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
375    MPI_Bcast(const_cast< int* >(displs), ep_size, MPI_INT, root, comm);
376
377
378    int root_mpi_rank = comm.rank_map->at(root).second;
379    int root_ep_loc = comm.rank_map->at(root).first;
380
381
382    ::MPI_Aint datasize, lb;
383
384    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
385
386    void *local_gather_recvbuf;
387    int buffer_size;
388
389    if(ep_rank_loc==0)
390    {
391      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
392
393      local_gather_recvbuf = new void*[datasize*buffer_size];
394    }
395
396    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
397
398    //MPI_Gather
399    if(ep_rank_loc == 0)
400    {
401      int *mpi_recvcnt= new int[mpi_size];
402      int *mpi_displs= new int[mpi_size];
403
404      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
405      int buff_end = buffer_size;
406
407      int mpi_sendcnt = buff_end - buff_start;
408
409
410      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
411      ::MPI_Gather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
412
413
414      ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
415                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
416
417      delete[] mpi_recvcnt;
418      delete[] mpi_displs;
419    }
420
421    int global_min_displs = *std::min_element(displs, displs+ep_size);
422    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
423
424
425    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
426    {
427      innode_memcpy(0, recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
428    }
429
430
431
432    if(ep_rank_loc==0)
433    {
434      if(datatype == MPI_INT)
435      {
436        delete[] static_cast<int*>(local_gather_recvbuf);
437      }
438      else if(datatype == MPI_FLOAT)
439      {
440        delete[] static_cast<float*>(local_gather_recvbuf);
441      }
442      else if(datatype == MPI_DOUBLE)
443      {
444        delete[] static_cast<double*>(local_gather_recvbuf);
445      }
446      else if(datatype == MPI_LONG)
447      {
448        delete[] static_cast<long*>(local_gather_recvbuf);
449      }
450      else if(datatype == MPI_UNSIGNED_LONG)
451      {
452        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
453      }
454      else // if(datatype == MPI_CHAR)
455      {
456        delete[] static_cast<char*>(local_gather_recvbuf);
457      }
458    }
459    else
460    {
461      delete[] recvcounts;
462      delete[] displs;
463    }
464    return 0;
465  }
466
467
468
469  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
470                  MPI_Datatype recvtype, MPI_Comm comm)
471  {
472
473    if(!comm.is_ep && comm.mpi_comm)
474    {
475      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
476                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
477      return 0;
478    }
479
480    if(!comm.mpi_comm) return 0;
481
482    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
483
484
485    MPI_Datatype datatype = sendtype;
486    int count = sendcount;
487
488    int ep_rank, ep_rank_loc, mpi_rank;
489    int ep_size, num_ep, mpi_size;
490
491    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
492    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
493    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
494    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
495    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
496    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
497   
498    if(ep_size == mpi_size) 
499      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
500                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
501   
502
503    int recv_plus_displs[ep_size];
504    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
505
506    #pragma omp single nowait
507    {
508      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]);
509      for(int i=1; i<num_ep-1; i++)
510      {
511        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]);
512        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]);
513      }
514      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]);
515    }
516
517
518    ::MPI_Aint datasize, lb;
519
520    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
521
522    void *local_gather_recvbuf;
523    int buffer_size;
524
525    if(ep_rank_loc==0)
526    {
527      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
528
529      local_gather_recvbuf = new void*[datasize*buffer_size];
530    }
531
532    // local gather to master
533    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
534
535    //MPI_Gather
536    if(ep_rank_loc == 0)
537    {
538      int *mpi_recvcnt= new int[mpi_size];
539      int *mpi_displs= new int[mpi_size];
540
541      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
542      int buff_end = buffer_size;
543
544      int mpi_sendcnt = buff_end - buff_start;
545
546
547      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
548      ::MPI_Allgather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
549
550
551      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
552                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
553
554      delete[] mpi_recvcnt;
555      delete[] mpi_displs;
556    }
557
558    int global_min_displs = *std::min_element(displs, displs+ep_size);
559    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
560
561    MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
562
563    if(ep_rank_loc==0)
564    {
565      if(datatype == MPI_INT)
566      {
567        delete[] static_cast<int*>(local_gather_recvbuf);
568      }
569      else if(datatype == MPI_FLOAT)
570      {
571        delete[] static_cast<float*>(local_gather_recvbuf);
572      }
573      else if(datatype == MPI_DOUBLE)
574      {
575        delete[] static_cast<double*>(local_gather_recvbuf);
576      }
577      else if(datatype == MPI_LONG)
578      {
579        delete[] static_cast<long*>(local_gather_recvbuf);
580      }
581      else if(datatype == MPI_UNSIGNED_LONG)
582      {
583        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
584      }
585      else // if(datatype == MPI_CHAR)
586      {
587        delete[] static_cast<char*>(local_gather_recvbuf);
588      }
589    }
590  }
591
592
593}
Note: See TracBrowser for help on using the repository browser.