source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp @ 1138

Last change on this file since 1138 was 1138, checked in by yushan, 4 years ago

test_remap back to work. No thread for now

File size: 16.4 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
48    }
49    else
50    {
51      printf("MPI_Gatherv Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61    int *buffer = comm.my_buffer->buf_int;
62    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
63    int *recv_buf = static_cast<int*>(recvbuf);
64
65    if(my_rank == 0)
66    {
67      assert(count == recvcounts[0]);
68      copy(send_buf, send_buf+count, recv_buf + displs[0]);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      assert(count == recvcounts[0]);
112      copy(send_buf, send_buf+count, recv_buf + displs[0]);
113    }
114
115    for(int j=0; j<count; j+=BUFFER_SIZE)
116    {
117      for(int k=1; k<num_ep; k++)
118      {
119        if(my_rank == k)
120        {
121          #pragma omp critical (write_to_buffer)
122          {
123            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
124            #pragma omp flush
125          }
126        }
127
128        MPI_Barrier_local(comm);
129
130        if(my_rank == 0)
131        {
132          #pragma omp flush
133          #pragma omp critical (read_from_buffer)
134          {
135            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
136          }
137        }
138
139        MPI_Barrier_local(comm);
140      }
141    }
142  }
143
144  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
145  {
146    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
147    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
148
149    double *buffer = comm.my_buffer->buf_double;
150    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
151    double *recv_buf = static_cast<double*>(recvbuf);
152
153    if(my_rank == 0)
154    {
155      assert(count == recvcounts[0]);
156      copy(send_buf, send_buf+count, recv_buf + displs[0]);
157    }
158
159    for(int j=0; j<count; j+=BUFFER_SIZE)
160    {
161      for(int k=1; k<num_ep; k++)
162      {
163        if(my_rank == k)
164        {
165          #pragma omp critical (write_to_buffer)
166          {
167            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
168            #pragma omp flush
169          }
170        }
171
172        MPI_Barrier_local(comm);
173
174        if(my_rank == 0)
175        {
176          #pragma omp flush
177          #pragma omp critical (read_from_buffer)
178          {
179            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
180          }
181        }
182
183        MPI_Barrier_local(comm);
184      }
185    }
186  }
187
188  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
189  {
190    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
191    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
192
193    long *buffer = comm.my_buffer->buf_long;
194    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
195    long *recv_buf = static_cast<long*>(recvbuf);
196
197    if(my_rank == 0)
198    {
199      assert(count == recvcounts[0]);
200      copy(send_buf, send_buf+count, recv_buf + displs[0]);
201    }
202
203    for(int j=0; j<count; j+=BUFFER_SIZE)
204    {
205      for(int k=1; k<num_ep; k++)
206      {
207        if(my_rank == k)
208        {
209          #pragma omp critical (write_to_buffer)
210          {
211            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
212            #pragma omp flush
213          }
214        }
215
216        MPI_Barrier_local(comm);
217
218        if(my_rank == 0)
219        {
220          #pragma omp flush
221          #pragma omp critical (read_from_buffer)
222          {
223            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
224          }
225        }
226
227        MPI_Barrier_local(comm);
228      }
229    }
230  }
231
232  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
233  {
234    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
235    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
236
237    unsigned long *buffer = comm.my_buffer->buf_ulong;
238    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
239    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
240
241    if(my_rank == 0)
242    {
243      assert(count == recvcounts[0]);
244      copy(send_buf, send_buf+count, recv_buf + displs[0]);
245    }
246
247    for(int j=0; j<count; j+=BUFFER_SIZE)
248    {
249      for(int k=1; k<num_ep; k++)
250      {
251        if(my_rank == k)
252        {
253          #pragma omp critical (write_to_buffer)
254          {
255            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
256            #pragma omp flush
257          }
258        }
259
260        MPI_Barrier_local(comm);
261
262        if(my_rank == 0)
263        {
264          #pragma omp flush
265          #pragma omp critical (read_from_buffer)
266          {
267            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
268          }
269        }
270
271        MPI_Barrier_local(comm);
272      }
273    }
274  }
275
276  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
277  {
278    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
279    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
280
281    char *buffer = comm.my_buffer->buf_char;
282    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
283    char *recv_buf = static_cast<char*>(recvbuf);
284
285    if(my_rank == 0)
286    {
287      assert(count == recvcounts[0]);
288      copy(send_buf, send_buf+count, recv_buf + displs[0]);
289    }
290
291    for(int j=0; j<count; j+=BUFFER_SIZE)
292    {
293      for(int k=1; k<num_ep; k++)
294      {
295        if(my_rank == k)
296        {
297          #pragma omp critical (write_to_buffer)
298          {
299            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
300            #pragma omp flush
301          }
302        }
303
304        MPI_Barrier_local(comm);
305
306        if(my_rank == 0)
307        {
308          #pragma omp flush
309          #pragma omp critical (read_from_buffer)
310          {
311            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
312          }
313        }
314
315        MPI_Barrier_local(comm);
316      }
317    }
318  }
319
320
321  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
322                  MPI_Datatype recvtype, int root, MPI_Comm comm)
323  {
324 
325    if(!comm.is_ep && comm.mpi_comm)
326    {
327      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
328                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
329      return 0;
330    }
331
332    if(!comm.mpi_comm) return 0;
333
334    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
335
336    MPI_Datatype datatype = sendtype;
337    int count = sendcount;
338
339    int ep_rank, ep_rank_loc, mpi_rank;
340    int ep_size, num_ep, mpi_size;
341
342    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
343    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
344    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
345    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
346    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
347    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
348   
349    if(ep_rank != root)
350    {
351      recvcounts = new int[ep_size];
352      displs = new int[ep_size];
353    }
354   
355    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
356    MPI_Bcast(const_cast< int* >(displs), ep_size, MPI_INT, root, comm);
357
358
359    int root_mpi_rank = comm.rank_map->at(root).second;
360    int root_ep_loc = comm.rank_map->at(root).first;
361
362
363    ::MPI_Aint datasize, lb;
364
365    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
366
367    void *local_gather_recvbuf;
368
369    if(ep_rank_loc==0)
370    {
371      int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
372      local_gather_recvbuf = new void*[datasize*buffer_size];
373    }
374
375    // local gather to master
376    int local_displs[num_ep];
377    local_displs[0] = 0;
378    for(int i=1; i<num_ep; i++)
379    {
380      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
381    }
382    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm);
383
384    //MPI_Gather
385    if(ep_rank_loc == 0)
386    {
387
388      int gatherv_recvcnt[mpi_size];
389      int gatherv_displs[mpi_size];
390      int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
391
392      //gatherv_recvcnt = new int[mpi_size];
393      //gatherv_displs = new int[mpi_size];
394
395
396      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
397
398      gatherv_displs[0] = 0;
399      for(int i=1; i<mpi_size; i++)
400      {
401        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
402      }
403
404
405      ::MPI_Gatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
406                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
407
408      //delete[] gatherv_recvcnt;
409      //delete[] gatherv_displs;
410    }
411
412
413    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
414    {
415      innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm);
416    }
417
418
419
420    if(ep_rank_loc==0)
421    {
422      if(datatype == MPI_INT)
423      {
424        delete[] static_cast<int*>(local_gather_recvbuf);
425      }
426      else if(datatype == MPI_FLOAT)
427      {
428        delete[] static_cast<float*>(local_gather_recvbuf);
429      }
430      else if(datatype == MPI_DOUBLE)
431      {
432        delete[] static_cast<double*>(local_gather_recvbuf);
433      }
434      else if(datatype == MPI_LONG)
435      {
436        delete[] static_cast<long*>(local_gather_recvbuf);
437      }
438      else if(datatype == MPI_UNSIGNED_LONG)
439      {
440        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
441      }
442      else // if(datatype == MPI_CHAR)
443      {
444        delete[] static_cast<char*>(local_gather_recvbuf);
445      }
446    }
447    else
448    {
449      delete[] recvcounts;
450      delete[] displs;
451    }
452    return 0;
453  }
454
455
456
457  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
458                  MPI_Datatype recvtype, MPI_Comm comm)
459  {
460
461    if(!comm.is_ep && comm.mpi_comm)
462    {
463      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
464                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
465      return 0;
466    }
467
468    if(!comm.mpi_comm) return 0;
469
470    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
471
472
473    MPI_Datatype datatype = sendtype;
474    int count = sendcount;
475
476    int ep_rank, ep_rank_loc, mpi_rank;
477    int ep_size, num_ep, mpi_size;
478
479    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
480    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
481    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
482    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
483    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
484    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
485   
486    if(ep_size == mpi_size) 
487      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
488                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
489   
490
491    assert(accumulate(recvcounts, recvcounts+ep_size-1, 0) >= displs[ep_size-1]); // Only for continuous gather.
492
493
494    ::MPI_Aint datasize, lb;
495
496    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
497
498    void *local_gather_recvbuf;
499
500    if(ep_rank_loc==0)
501    {
502      int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
503      local_gather_recvbuf = new void*[datasize*buffer_size];
504    }
505
506    // local gather to master
507    int local_displs[num_ep];
508    local_displs[0] = 0;
509    for(int i=1; i<num_ep; i++)
510    {
511      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
512    }
513    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm);
514
515    //MPI_Gather
516    if(ep_rank_loc == 0)
517    {
518      int *gatherv_recvcnt;
519      int *gatherv_displs;
520      int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
521
522      gatherv_recvcnt = new int[mpi_size];
523      gatherv_displs = new int[mpi_size];
524
525      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
526      gatherv_displs[0] = displs[0];
527      for(int i=1; i<mpi_size; i++)
528      {
529        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
530      }
531
532      ::MPI_Allgatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
533                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
534
535      delete[] gatherv_recvcnt;
536      delete[] gatherv_displs;
537    }
538
539    MPI_Bcast_local(recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm);
540
541    if(ep_rank_loc==0)
542    {
543      if(datatype == MPI_INT)
544      {
545        delete[] static_cast<int*>(local_gather_recvbuf);
546      }
547      else if(datatype == MPI_FLOAT)
548      {
549        delete[] static_cast<float*>(local_gather_recvbuf);
550      }
551      else if(datatype == MPI_DOUBLE)
552      {
553        delete[] static_cast<double*>(local_gather_recvbuf);
554      }
555      else if(datatype == MPI_LONG)
556      {
557        delete[] static_cast<long*>(local_gather_recvbuf);
558      }
559      else if(datatype == MPI_UNSIGNED_LONG)
560      {
561        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
562      }
563      else // if(datatype == MPI_CHAR)
564      {
565        delete[] static_cast<char*>(local_gather_recvbuf);
566      }
567    }
568  }
569
570
571}
Note: See TracBrowser for help on using the repository browser.