source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gather.cpp @ 1147

Last change on this file since 1147 was 1147, checked in by yushan, 7 years ago

bug fixed in MPI_Gather(v)

File size: 14.6 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gather, MPI_Allgather
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gather_local_int(sendbuf, count, recvbuf, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gather_local_float(sendbuf, count, recvbuf, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gather_local_double(sendbuf, count, recvbuf, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gather_local_long(sendbuf, count, recvbuf, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gather_local_char(sendbuf, count, recvbuf, comm);
49    }
50    else
51    {
52      printf("MPI_Gather Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      copy(send_buf, send_buf+count, recv_buf);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      copy(send_buf, send_buf+count, recv_buf);
112    }
113
114    for(int j=0; j<count; j+=BUFFER_SIZE)
115    {
116      for(int k=1; k<num_ep; k++)
117      {
118        if(my_rank == k)
119        {
120          #pragma omp critical (write_to_buffer)
121          {
122            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
123            #pragma omp flush
124          }
125        }
126
127        MPI_Barrier_local(comm);
128
129        if(my_rank == 0)
130        {
131          #pragma omp flush
132          #pragma omp critical (read_from_buffer)
133          {
134            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
135          }
136        }
137
138        MPI_Barrier_local(comm);
139      }
140    }
141  }
142
143  int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
144  {
145    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
146    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
147
148    double *buffer = comm.my_buffer->buf_double;
149    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
150    double *recv_buf = static_cast<double*>(recvbuf);
151
152    if(my_rank == 0)
153    {
154      copy(send_buf, send_buf+count, recv_buf);
155    }
156
157    for(int j=0; j<count; j+=BUFFER_SIZE)
158    {
159      for(int k=1; k<num_ep; k++)
160      {
161        if(my_rank == k)
162        {
163          #pragma omp critical (write_to_buffer)
164          {
165            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
166            #pragma omp flush
167          }
168        }
169
170        MPI_Barrier_local(comm);
171
172        if(my_rank == 0)
173        {
174          #pragma omp flush
175          #pragma omp critical (read_from_buffer)
176          {
177            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
178          }
179        }
180
181        MPI_Barrier_local(comm);
182      }
183    }
184  }
185
186  int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
187  {
188    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
189    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
190
191    long *buffer = comm.my_buffer->buf_long;
192    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
193    long *recv_buf = static_cast<long*>(recvbuf);
194
195    if(my_rank == 0)
196    {
197      copy(send_buf, send_buf+count, recv_buf);
198    }
199
200    for(int j=0; j<count; j+=BUFFER_SIZE)
201    {
202      for(int k=1; k<num_ep; k++)
203      {
204        if(my_rank == k)
205        {
206          #pragma omp critical (write_to_buffer)
207          {
208            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
209            #pragma omp flush
210          }
211        }
212
213        MPI_Barrier_local(comm);
214
215        if(my_rank == 0)
216        {
217          #pragma omp flush
218          #pragma omp critical (read_from_buffer)
219          {
220            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
221          }
222        }
223
224        MPI_Barrier_local(comm);
225      }
226    }
227  }
228
229  int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
230  {
231    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
232    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
233
234    unsigned long *buffer = comm.my_buffer->buf_ulong;
235    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
236    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
237
238    if(my_rank == 0)
239    {
240      copy(send_buf, send_buf+count, recv_buf);
241    }
242
243    for(int j=0; j<count; j+=BUFFER_SIZE)
244    {
245      for(int k=1; k<num_ep; k++)
246      {
247        if(my_rank == k)
248        {
249          #pragma omp critical (write_to_buffer)
250          {
251            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
252            #pragma omp flush
253          }
254        }
255
256        MPI_Barrier_local(comm);
257
258        if(my_rank == 0)
259        {
260          #pragma omp flush
261          #pragma omp critical (read_from_buffer)
262          {
263            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
264          }
265        }
266
267        MPI_Barrier_local(comm);
268      }
269    }
270  }
271
272
273  int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
274  {
275    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
276    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
277
278    char *buffer = comm.my_buffer->buf_char;
279    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
280    char *recv_buf = static_cast<char*>(recvbuf);
281
282    if(my_rank == 0)
283    {
284      copy(send_buf, send_buf+count, recv_buf);
285    }
286
287    for(int j=0; j<count; j+=BUFFER_SIZE)
288    {
289      for(int k=1; k<num_ep; k++)
290      {
291        if(my_rank == k)
292        {
293          #pragma omp critical (write_to_buffer)
294          {
295            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
296            #pragma omp flush
297          }
298        }
299
300        MPI_Barrier_local(comm);
301
302        if(my_rank == 0)
303        {
304          #pragma omp flush
305          #pragma omp critical (read_from_buffer)
306          {
307            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
308          }
309        }
310
311        MPI_Barrier_local(comm);
312      }
313    }
314  }
315
316
317
318  int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
319  {
320    if(!comm.is_ep && comm.mpi_comm)
321    {
322      ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
323                   root, static_cast< ::MPI_Comm>(comm.mpi_comm));
324      return 0;
325    }
326
327    if(!comm.mpi_comm) return 0;
328
329    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
330
331    MPI_Datatype datatype = sendtype;
332    int count = sendcount;
333
334    int ep_rank, ep_rank_loc, mpi_rank;
335    int ep_size, num_ep, mpi_size;
336
337    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
338    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
339    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
340    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
341    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
342    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
343
344
345    int root_mpi_rank = comm.rank_map->at(root).second;
346    int root_ep_loc = comm.rank_map->at(root).first;
347
348
349    ::MPI_Aint datasize, lb;
350
351    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
352
353    void *local_gather_recvbuf;
354    void *master_recvbuf;
355    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
356
357    if(ep_rank_loc==0)
358    {
359      local_gather_recvbuf = new void*[datasize*num_ep*count];
360    }
361
362    // local gather to master
363    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
364
365    //MPI_Gather
366
367    if(ep_rank_loc == 0)
368    {
369      int *gatherv_recvcnt;
370      int *gatherv_displs;
371      int gatherv_cnt = count*num_ep;
372
373      gatherv_recvcnt = new int[mpi_size];
374      gatherv_displs = new int[mpi_size];
375
376
377      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
378
379      gatherv_displs[0] = 0;
380      for(int i=1; i<mpi_size; i++)
381      {
382        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
383      }
384
385      if(root_ep_loc != 0) // gather to root_master
386      {
387        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, gatherv_recvcnt,
388                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
389      }
390      else
391      {
392        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
393                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
394      }
395
396      delete[] gatherv_recvcnt;
397      delete[] gatherv_displs;
398    }
399
400
401    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
402    {
403      innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);
404      if(ep_rank_loc == 0 ) delete[] master_recvbuf;
405    }
406
407
408
409    if(ep_rank_loc==0)
410    {
411
412      if(datatype == MPI_INT)
413      {
414        delete[] static_cast<int*>(local_gather_recvbuf);
415      }
416      else if(datatype == MPI_FLOAT)
417      {
418        delete[] static_cast<float*>(local_gather_recvbuf);
419      }
420      else if(datatype == MPI_DOUBLE)
421      {
422        delete[] static_cast<double*>(local_gather_recvbuf);
423      }
424      else if(datatype == MPI_CHAR)
425      {
426        delete[] static_cast<char*>(local_gather_recvbuf);
427      }
428      else if(datatype == MPI_LONG)
429      {
430        delete[] static_cast<long*>(local_gather_recvbuf);
431      }
432      else// if(datatype == MPI_UNSIGNED_LONG)
433      {
434        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
435      }
436    }
437
438
439  }
440
441
442  int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
443  {
444    if(!comm.is_ep && comm.mpi_comm)
445    {
446      ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
447                      static_cast< ::MPI_Comm>(comm.mpi_comm));
448      return 0;
449    }
450
451    if(!comm.mpi_comm) return 0;
452
453    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
454
455    MPI_Datatype datatype = sendtype;
456    int count = sendcount;
457
458    int ep_rank, ep_rank_loc, mpi_rank;
459    int ep_size, num_ep, mpi_size;
460
461    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
462    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
463    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
464    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
465    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
466    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
467
468
469    ::MPI_Aint datasize, lb;
470
471    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
472
473    void *local_gather_recvbuf;
474
475    if(ep_rank_loc==0)
476    {
477      local_gather_recvbuf = new void*[datasize*num_ep*count];
478    }
479
480    // local gather to master
481    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
482
483    //MPI_Gather
484
485    if(ep_rank_loc == 0)
486    {
487      int *gatherv_recvcnt;
488      int *gatherv_displs;
489      int gatherv_cnt = count*num_ep;
490
491      gatherv_recvcnt = new int[mpi_size];
492      gatherv_displs = new int[mpi_size];
493
494      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
495
496      gatherv_displs[0] = 0;
497      for(int i=1; i<mpi_size; i++)
498      {
499        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
500      }
501
502      ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
503                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
504
505      delete[] gatherv_recvcnt;
506      delete[] gatherv_displs;
507    }
508
509    MPI_Bcast_local(recvbuf, count*ep_size, datatype, comm);
510
511
512    if(ep_rank_loc==0)
513    {
514      if(datatype == MPI_INT)
515      {
516        delete[] static_cast<int*>(local_gather_recvbuf);
517      }
518      else if(datatype == MPI_FLOAT)
519      {
520        delete[] static_cast<float*>(local_gather_recvbuf);
521      }
522      else if(datatype == MPI_DOUBLE)
523      {
524        delete[] static_cast<double*>(local_gather_recvbuf);
525      }
526      else if(datatype == MPI_CHAR)
527      {
528        delete[] static_cast<char*>(local_gather_recvbuf);
529      }
530      else if(datatype == MPI_LONG)
531      {
532        delete[] static_cast<long*>(local_gather_recvbuf);
533      }
534      else// if(datatype == MPI_UNSIGNED_LONG)
535      {
536        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
537      }
538    }
539  }
540
541
542}
Note: See TracBrowser for help on using the repository browser.