source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gather.cpp @ 1164

Last change on this file since 1164 was 1164, checked in by yushan, 7 years ago

Bug fixed in MPI_(All)Gatherv with displs

File size: 14.7 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gather, MPI_Allgather
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gather_local_int(sendbuf, count, recvbuf, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gather_local_float(sendbuf, count, recvbuf, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gather_local_double(sendbuf, count, recvbuf, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gather_local_long(sendbuf, count, recvbuf, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gather_local_char(sendbuf, count, recvbuf, comm);
49    }
50    else
51    {
52      printf("MPI_Gather Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      copy(send_buf, send_buf+count, recv_buf);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      copy(send_buf, send_buf+count, recv_buf);
112    }
113
114    for(int j=0; j<count; j+=BUFFER_SIZE)
115    {
116      for(int k=1; k<num_ep; k++)
117      {
118        if(my_rank == k)
119        {
120          #pragma omp critical (write_to_buffer)
121          {
122            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
123            #pragma omp flush
124          }
125        }
126
127        MPI_Barrier_local(comm);
128
129        if(my_rank == 0)
130        {
131          #pragma omp flush
132          #pragma omp critical (read_from_buffer)
133          {
134            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
135          }
136        }
137
138        MPI_Barrier_local(comm);
139      }
140    }
141  }
142
143  int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
144  {
145    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
146    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
147
148    double *buffer = comm.my_buffer->buf_double;
149    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
150    double *recv_buf = static_cast<double*>(recvbuf);
151
152    if(my_rank == 0)
153    {
154      copy(send_buf, send_buf+count, recv_buf);
155    }
156
157    for(int j=0; j<count; j+=BUFFER_SIZE)
158    {
159      for(int k=1; k<num_ep; k++)
160      {
161        if(my_rank == k)
162        {
163          #pragma omp critical (write_to_buffer)
164          {
165            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
166            #pragma omp flush
167          }
168        }
169
170        MPI_Barrier_local(comm);
171
172        if(my_rank == 0)
173        {
174          #pragma omp flush
175          #pragma omp critical (read_from_buffer)
176          {
177            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
178          }
179        }
180
181        MPI_Barrier_local(comm);
182      }
183    }
184  }
185
186  int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
187  {
188    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
189    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
190
191    long *buffer = comm.my_buffer->buf_long;
192    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
193    long *recv_buf = static_cast<long*>(recvbuf);
194
195    if(my_rank == 0)
196    {
197      copy(send_buf, send_buf+count, recv_buf);
198    }
199
200    for(int j=0; j<count; j+=BUFFER_SIZE)
201    {
202      for(int k=1; k<num_ep; k++)
203      {
204        if(my_rank == k)
205        {
206          #pragma omp critical (write_to_buffer)
207          {
208            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
209            #pragma omp flush
210          }
211        }
212
213        MPI_Barrier_local(comm);
214
215        if(my_rank == 0)
216        {
217          #pragma omp flush
218          #pragma omp critical (read_from_buffer)
219          {
220            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
221          }
222        }
223
224        MPI_Barrier_local(comm);
225      }
226    }
227  }
228
229  int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
230  {
231    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
232    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
233
234    unsigned long *buffer = comm.my_buffer->buf_ulong;
235    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
236    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
237
238    if(my_rank == 0)
239    {
240      copy(send_buf, send_buf+count, recv_buf);
241    }
242
243    for(int j=0; j<count; j+=BUFFER_SIZE)
244    {
245      for(int k=1; k<num_ep; k++)
246      {
247        if(my_rank == k)
248        {
249          #pragma omp critical (write_to_buffer)
250          {
251            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
252            #pragma omp flush
253          }
254        }
255
256        MPI_Barrier_local(comm);
257
258        if(my_rank == 0)
259        {
260          #pragma omp flush
261          #pragma omp critical (read_from_buffer)
262          {
263            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
264          }
265        }
266
267        MPI_Barrier_local(comm);
268      }
269    }
270  }
271
272
273  int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
274  {
275    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
276    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
277
278    char *buffer = comm.my_buffer->buf_char;
279    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
280    char *recv_buf = static_cast<char*>(recvbuf);
281
282    if(my_rank == 0)
283    {
284      copy(send_buf, send_buf+count, recv_buf);
285    }
286
287    for(int j=0; j<count; j+=BUFFER_SIZE)
288    {
289      for(int k=1; k<num_ep; k++)
290      {
291        if(my_rank == k)
292        {
293          #pragma omp critical (write_to_buffer)
294          {
295            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
296            #pragma omp flush
297          }
298        }
299
300        MPI_Barrier_local(comm);
301
302        if(my_rank == 0)
303        {
304          #pragma omp flush
305          #pragma omp critical (read_from_buffer)
306          {
307            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
308          }
309        }
310
311        MPI_Barrier_local(comm);
312      }
313    }
314  }
315
316
317
318  int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
319  {
320    if(!comm.is_ep && comm.mpi_comm)
321    {
322      ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
323                   root, static_cast< ::MPI_Comm>(comm.mpi_comm));
324      return 0;
325    }
326
327    if(!comm.mpi_comm) return 0;
328   
329    MPI_Bcast(&recvcount, 1, MPI_INT, root, comm);
330
331    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
332
333    MPI_Datatype datatype = sendtype;
334    int count = sendcount;
335
336    int ep_rank, ep_rank_loc, mpi_rank;
337    int ep_size, num_ep, mpi_size;
338
339    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
340    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
341    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
342    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
343    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
344    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
345
346
347    int root_mpi_rank = comm.rank_map->at(root).second;
348    int root_ep_loc = comm.rank_map->at(root).first;
349
350
351    ::MPI_Aint datasize, lb;
352
353    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
354
355    void *local_gather_recvbuf;
356    void *master_recvbuf;
357    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 
358    {
359      master_recvbuf = new void*[datasize*ep_size*count];
360    }
361
362    if(ep_rank_loc==0)
363    {
364      local_gather_recvbuf = new void*[datasize*num_ep*count];
365    }
366
367    // local gather to master
368    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
369
370    //MPI_Gather
371
372    if(ep_rank_loc == 0)
373    {
374      int *gatherv_recvcnt;
375      int *gatherv_displs;
376      int gatherv_cnt = count*num_ep;
377
378      gatherv_recvcnt = new int[mpi_size];
379      gatherv_displs = new int[mpi_size];
380
381
382      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
383
384      gatherv_displs[0] = 0;
385      for(int i=1; i<mpi_size; i++)
386      {
387        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
388      }
389
390      if(root_ep_loc != 0) // gather to root_master
391      {
392        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, gatherv_recvcnt,
393                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
394      }
395      else
396      {
397        ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
398                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
399      }
400
401      delete[] gatherv_recvcnt;
402      delete[] gatherv_displs;
403    }
404
405
406    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
407    {
408      innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);
409    }
410
411
412
413    if(ep_rank_loc==0)
414    {
415      if(datatype == MPI_INT)
416      {
417        delete[] static_cast<int*>(local_gather_recvbuf);
418      }
419      else if(datatype == MPI_FLOAT)
420      {
421        delete[] static_cast<float*>(local_gather_recvbuf);
422      }
423      else if(datatype == MPI_DOUBLE)
424      {
425        delete[] static_cast<double*>(local_gather_recvbuf);
426      }
427      else if(datatype == MPI_CHAR)
428      {
429        delete[] static_cast<char*>(local_gather_recvbuf);
430      }
431      else if(datatype == MPI_LONG)
432      {
433        delete[] static_cast<long*>(local_gather_recvbuf);
434      }
435      else// if(datatype == MPI_UNSIGNED_LONG)
436      {
437        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
438      }
439     
440      if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) delete[] master_recvbuf;
441    }
442  }
443
444
445  int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
446  {
447    if(!comm.is_ep && comm.mpi_comm)
448    {
449      ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
450                      static_cast< ::MPI_Comm>(comm.mpi_comm));
451      return 0;
452    }
453
454    if(!comm.mpi_comm) return 0;
455
456    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
457
458    MPI_Datatype datatype = sendtype;
459    int count = sendcount;
460
461    int ep_rank, ep_rank_loc, mpi_rank;
462    int ep_size, num_ep, mpi_size;
463
464    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
465    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
466    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
467    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
468    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
469    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
470
471
472    ::MPI_Aint datasize, lb;
473
474    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
475
476    void *local_gather_recvbuf;
477
478    if(ep_rank_loc==0)
479    {
480      local_gather_recvbuf = new void*[datasize*num_ep*count];
481    }
482
483    // local gather to master
484    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
485
486    //MPI_Gather
487
488    if(ep_rank_loc == 0)
489    {
490      int *gatherv_recvcnt;
491      int *gatherv_displs;
492      int gatherv_cnt = count*num_ep;
493
494      gatherv_recvcnt = new int[mpi_size];
495      gatherv_displs = new int[mpi_size];
496
497      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
498
499      gatherv_displs[0] = 0;
500      for(int i=1; i<mpi_size; i++)
501      {
502        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
503      }
504
505      ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
506                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
507
508      delete[] gatherv_recvcnt;
509      delete[] gatherv_displs;
510    }
511
512    MPI_Bcast_local(recvbuf, count*ep_size, datatype, comm);
513
514
515    if(ep_rank_loc==0)
516    {
517      if(datatype == MPI_INT)
518      {
519        delete[] static_cast<int*>(local_gather_recvbuf);
520      }
521      else if(datatype == MPI_FLOAT)
522      {
523        delete[] static_cast<float*>(local_gather_recvbuf);
524      }
525      else if(datatype == MPI_DOUBLE)
526      {
527        delete[] static_cast<double*>(local_gather_recvbuf);
528      }
529      else if(datatype == MPI_CHAR)
530      {
531        delete[] static_cast<char*>(local_gather_recvbuf);
532      }
533      else if(datatype == MPI_LONG)
534      {
535        delete[] static_cast<long*>(local_gather_recvbuf);
536      }
537      else// if(datatype == MPI_UNSIGNED_LONG)
538      {
539        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
540      }
541    }
542  }
543
544
545}
Note: See TracBrowser for help on using the repository browser.