source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gather.cpp @ 1151

Last change on this file since 1151 was 1151, checked in by yushan, 7 years ago

bug fixed in MPI_Gather(v)

File size: 14.7 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gather, MPI_Allgather
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gather_local_int(sendbuf, count, recvbuf, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gather_local_float(sendbuf, count, recvbuf, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gather_local_double(sendbuf, count, recvbuf, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gather_local_long(sendbuf, count, recvbuf, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gather_local_char(sendbuf, count, recvbuf, comm);
49    }
50    else
51    {
52      printf("MPI_Gather Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      copy(send_buf, send_buf+count, recv_buf);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      copy(send_buf, send_buf+count, recv_buf);
112    }
113
114    for(int j=0; j<count; j+=BUFFER_SIZE)
115    {
116      for(int k=1; k<num_ep; k++)
117      {
118        if(my_rank == k)
119        {
120          #pragma omp critical (write_to_buffer)
121          {
122            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
123            #pragma omp flush
124          }
125        }
126
127        MPI_Barrier_local(comm);
128
129        if(my_rank == 0)
130        {
131          #pragma omp flush
132          #pragma omp critical (read_from_buffer)
133          {
134            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
135          }
136        }
137
138        MPI_Barrier_local(comm);
139      }
140    }
141  }
142
143  int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
144  {
145    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
146    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
147
148    double *buffer = comm.my_buffer->buf_double;
149    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
150    double *recv_buf = static_cast<double*>(recvbuf);
151
152    if(my_rank == 0)
153    {
154      copy(send_buf, send_buf+count, recv_buf);
155    }
156
157    for(int j=0; j<count; j+=BUFFER_SIZE)
158    {
159      for(int k=1; k<num_ep; k++)
160      {
161        if(my_rank == k)
162        {
163          #pragma omp critical (write_to_buffer)
164          {
165            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
166            #pragma omp flush
167          }
168        }
169
170        MPI_Barrier_local(comm);
171
172        if(my_rank == 0)
173        {
174          #pragma omp flush
175          #pragma omp critical (read_from_buffer)
176          {
177            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
178          }
179        }
180
181        MPI_Barrier_local(comm);
182      }
183    }
184  }
185
186  int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
187  {
188    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
189    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
190
191    long *buffer = comm.my_buffer->buf_long;
192    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
193    long *recv_buf = static_cast<long*>(recvbuf);
194
195    if(my_rank == 0)
196    {
197      copy(send_buf, send_buf+count, recv_buf);
198    }
199
200    for(int j=0; j<count; j+=BUFFER_SIZE)
201    {
202      for(int k=1; k<num_ep; k++)
203      {
204        if(my_rank == k)
205        {
206          #pragma omp critical (write_to_buffer)
207          {
208            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
209            #pragma omp flush
210          }
211        }
212
213        MPI_Barrier_local(comm);
214
215        if(my_rank == 0)
216        {
217          #pragma omp flush
218          #pragma omp critical (read_from_buffer)
219          {
220            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
221          }
222        }
223
224        MPI_Barrier_local(comm);
225      }
226    }
227  }
228
229  int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
230  {
231    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
232    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
233
234    unsigned long *buffer = comm.my_buffer->buf_ulong;
235    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
236    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
237
238    if(my_rank == 0)
239    {
240      copy(send_buf, send_buf+count, recv_buf);
241    }
242
243    for(int j=0; j<count; j+=BUFFER_SIZE)
244    {
245      for(int k=1; k<num_ep; k++)
246      {
247        if(my_rank == k)
248        {
249          #pragma omp critical (write_to_buffer)
250          {
251            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
252            #pragma omp flush
253          }
254        }
255
256        MPI_Barrier_local(comm);
257
258        if(my_rank == 0)
259        {
260          #pragma omp flush
261          #pragma omp critical (read_from_buffer)
262          {
263            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
264          }
265        }
266
267        MPI_Barrier_local(comm);
268      }
269    }
270  }
271
272
273  int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
274  {
275    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
276    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
277
278    char *buffer = comm.my_buffer->buf_char;
279    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
280    char *recv_buf = static_cast<char*>(recvbuf);
281
282    if(my_rank == 0)
283    {
284      copy(send_buf, send_buf+count, recv_buf);
285    }
286
287    for(int j=0; j<count; j+=BUFFER_SIZE)
288    {
289      for(int k=1; k<num_ep; k++)
290      {
291        if(my_rank == k)
292        {
293          #pragma omp critical (write_to_buffer)
294          {
295            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
296            #pragma omp flush
297          }
298        }
299
300        MPI_Barrier_local(comm);
301
302        if(my_rank == 0)
303        {
304          #pragma omp flush
305          #pragma omp critical (read_from_buffer)
306          {
307            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
308          }
309        }
310
311        MPI_Barrier_local(comm);
312      }
313    }
314  }
315
316
317
318  int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
319  {
320    if(!comm.is_ep && comm.mpi_comm)
321    {
322      ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
323                   root, static_cast< ::MPI_Comm>(comm.mpi_comm));
324      return 0;
325    }
326
327    if(!comm.mpi_comm) return 0;
328   
329    MPI_Bcast(&recvcount, 1, MPI_INT, root, comm);
330
331    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
332
333    MPI_Datatype datatype = sendtype;
334    int count = sendcount;
335
336    int ep_rank, ep_rank_loc, mpi_rank;
337    int ep_size, num_ep, mpi_size;
338
339    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
340    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
341    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
342    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
343    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
344    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
345
346
347    int root_mpi_rank = comm.rank_map->at(root).second;
348    int root_ep_loc = comm.rank_map->at(root).first;
349
350
351    ::MPI_Aint datasize, lb;
352
353    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
354
355    void *local_gather_recvbuf;
356    void *master_recvbuf;
357    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
358
359    if(ep_rank_loc==0)
360    {
361      local_gather_recvbuf = new void*[datasize*num_ep*count];
362    }
363
364    // local gather to master
365    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
366
367    //MPI_Gather
368
369    if(ep_rank_loc == 0)
370    {
371      int *gatherv_recvcnt;
372      int *gatherv_displs;
373      int gatherv_cnt = count*num_ep;
374
375      gatherv_recvcnt = new int[mpi_size];
376      gatherv_displs = new int[mpi_size];
377
378
379      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
380
381      gatherv_displs[0] = 0;
382      for(int i=1; i<mpi_size; i++)
383      {
384        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
385      }
386
387      if(root_ep_loc != 0) // gather to root_master
388      {
389        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, gatherv_recvcnt,
390                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
391      }
392      else
393      {
394        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
395                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
396      }
397
398      delete[] gatherv_recvcnt;
399      delete[] gatherv_displs;
400    }
401
402
403    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
404    {
405      innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);
406      if(ep_rank_loc == 0 ) delete[] master_recvbuf;
407    }
408
409
410
411    if(ep_rank_loc==0)
412    {
413
414      if(datatype == MPI_INT)
415      {
416        delete[] static_cast<int*>(local_gather_recvbuf);
417      }
418      else if(datatype == MPI_FLOAT)
419      {
420        delete[] static_cast<float*>(local_gather_recvbuf);
421      }
422      else if(datatype == MPI_DOUBLE)
423      {
424        delete[] static_cast<double*>(local_gather_recvbuf);
425      }
426      else if(datatype == MPI_CHAR)
427      {
428        delete[] static_cast<char*>(local_gather_recvbuf);
429      }
430      else if(datatype == MPI_LONG)
431      {
432        delete[] static_cast<long*>(local_gather_recvbuf);
433      }
434      else// if(datatype == MPI_UNSIGNED_LONG)
435      {
436        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
437      }
438    }
439
440
441  }
442
443
444  int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
445  {
446    if(!comm.is_ep && comm.mpi_comm)
447    {
448      ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
449                      static_cast< ::MPI_Comm>(comm.mpi_comm));
450      return 0;
451    }
452
453    if(!comm.mpi_comm) return 0;
454
455    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
456
457    MPI_Datatype datatype = sendtype;
458    int count = sendcount;
459
460    int ep_rank, ep_rank_loc, mpi_rank;
461    int ep_size, num_ep, mpi_size;
462
463    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
464    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
465    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
466    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
467    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
468    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
469
470
471    ::MPI_Aint datasize, lb;
472
473    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
474
475    void *local_gather_recvbuf;
476
477    if(ep_rank_loc==0)
478    {
479      local_gather_recvbuf = new void*[datasize*num_ep*count];
480    }
481
482    // local gather to master
483    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
484
485    //MPI_Gather
486
487    if(ep_rank_loc == 0)
488    {
489      int *gatherv_recvcnt;
490      int *gatherv_displs;
491      int gatherv_cnt = count*num_ep;
492
493      gatherv_recvcnt = new int[mpi_size];
494      gatherv_displs = new int[mpi_size];
495
496      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
497
498      gatherv_displs[0] = 0;
499      for(int i=1; i<mpi_size; i++)
500      {
501        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
502      }
503
504      ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
505                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
506
507      delete[] gatherv_recvcnt;
508      delete[] gatherv_displs;
509    }
510
511    MPI_Bcast_local(recvbuf, count*ep_size, datatype, comm);
512
513
514    if(ep_rank_loc==0)
515    {
516      if(datatype == MPI_INT)
517      {
518        delete[] static_cast<int*>(local_gather_recvbuf);
519      }
520      else if(datatype == MPI_FLOAT)
521      {
522        delete[] static_cast<float*>(local_gather_recvbuf);
523      }
524      else if(datatype == MPI_DOUBLE)
525      {
526        delete[] static_cast<double*>(local_gather_recvbuf);
527      }
528      else if(datatype == MPI_CHAR)
529      {
530        delete[] static_cast<char*>(local_gather_recvbuf);
531      }
532      else if(datatype == MPI_LONG)
533      {
534        delete[] static_cast<long*>(local_gather_recvbuf);
535      }
536      else// if(datatype == MPI_UNSIGNED_LONG)
537      {
538        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
539      }
540    }
541  }
542
543
544}
Note: See TracBrowser for help on using the repository browser.