source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gather.cpp @ 1134

Last change on this file since 1134 was 1134, checked in by yushan, 7 years ago

branch merged with trunk r1130

File size: 14.1 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gather, MPI_Allgather
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gather_local_int(sendbuf, count, recvbuf, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gather_local_float(sendbuf, count, recvbuf, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gather_local_double(sendbuf, count, recvbuf, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gather_local_long(sendbuf, count, recvbuf, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gather_local_char(sendbuf, count, recvbuf, comm);
49    }
50    else
51    {
52      printf("MPI_Gather Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      copy(send_buf, send_buf+count, recv_buf);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      copy(send_buf, send_buf+count, recv_buf);
112    }
113
114    for(int j=0; j<count; j+=BUFFER_SIZE)
115    {
116      for(int k=1; k<num_ep; k++)
117      {
118        if(my_rank == k)
119        {
120          #pragma omp critical (write_to_buffer)
121          {
122            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
123            #pragma omp flush
124          }
125        }
126
127        MPI_Barrier_local(comm);
128
129        if(my_rank == 0)
130        {
131          #pragma omp flush
132          #pragma omp critical (read_from_buffer)
133          {
134            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
135          }
136        }
137
138        MPI_Barrier_local(comm);
139      }
140    }
141  }
142
143  int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
144  {
145    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
146    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
147
148    double *buffer = comm.my_buffer->buf_double;
149    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
150    double *recv_buf = static_cast<double*>(recvbuf);
151
152    if(my_rank == 0)
153    {
154      copy(send_buf, send_buf+count, recv_buf);
155    }
156
157    for(int j=0; j<count; j+=BUFFER_SIZE)
158    {
159      for(int k=1; k<num_ep; k++)
160      {
161        if(my_rank == k)
162        {
163          #pragma omp critical (write_to_buffer)
164          {
165            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
166            #pragma omp flush
167          }
168        }
169
170        MPI_Barrier_local(comm);
171
172        if(my_rank == 0)
173        {
174          #pragma omp flush
175          #pragma omp critical (read_from_buffer)
176          {
177            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
178          }
179        }
180
181        MPI_Barrier_local(comm);
182      }
183    }
184  }
185
186  int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
187  {
188    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
189    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
190
191    long *buffer = comm.my_buffer->buf_long;
192    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
193    long *recv_buf = static_cast<long*>(recvbuf);
194
195    if(my_rank == 0)
196    {
197      copy(send_buf, send_buf+count, recv_buf);
198    }
199
200    for(int j=0; j<count; j+=BUFFER_SIZE)
201    {
202      for(int k=1; k<num_ep; k++)
203      {
204        if(my_rank == k)
205        {
206          #pragma omp critical (write_to_buffer)
207          {
208            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
209            #pragma omp flush
210          }
211        }
212
213        MPI_Barrier_local(comm);
214
215        if(my_rank == 0)
216        {
217          #pragma omp flush
218          #pragma omp critical (read_from_buffer)
219          {
220            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
221          }
222        }
223
224        MPI_Barrier_local(comm);
225      }
226    }
227  }
228
229  int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
230  {
231    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
232    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
233
234    unsigned long *buffer = comm.my_buffer->buf_ulong;
235    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
236    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
237
238    if(my_rank == 0)
239    {
240      copy(send_buf, send_buf+count, recv_buf);
241    }
242
243    for(int j=0; j<count; j+=BUFFER_SIZE)
244    {
245      for(int k=1; k<num_ep; k++)
246      {
247        if(my_rank == k)
248        {
249          #pragma omp critical (write_to_buffer)
250          {
251            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
252            #pragma omp flush
253          }
254        }
255
256        MPI_Barrier_local(comm);
257
258        if(my_rank == 0)
259        {
260          #pragma omp flush
261          #pragma omp critical (read_from_buffer)
262          {
263            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
264          }
265        }
266
267        MPI_Barrier_local(comm);
268      }
269    }
270  }
271
272
273  int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
274  {
275    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
276    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
277
278    char *buffer = comm.my_buffer->buf_char;
279    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
280    char *recv_buf = static_cast<char*>(recvbuf);
281
282    if(my_rank == 0)
283    {
284      copy(send_buf, send_buf+count, recv_buf);
285    }
286
287    for(int j=0; j<count; j+=BUFFER_SIZE)
288    {
289      for(int k=1; k<num_ep; k++)
290      {
291        if(my_rank == k)
292        {
293          #pragma omp critical (write_to_buffer)
294          {
295            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
296            #pragma omp flush
297          }
298        }
299
300        MPI_Barrier_local(comm);
301
302        if(my_rank == 0)
303        {
304          #pragma omp flush
305          #pragma omp critical (read_from_buffer)
306          {
307            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
308          }
309        }
310
311        MPI_Barrier_local(comm);
312      }
313    }
314  }
315
316
317
318  int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
319  {
320    if(!comm.is_ep && comm.mpi_comm)
321    {
322      ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
323                   root, static_cast< ::MPI_Comm>(comm.mpi_comm));
324      return 0;
325    }
326
327    if(!comm.mpi_comm) return 0;
328
329    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
330
331    MPI_Datatype datatype = sendtype;
332    int count = sendcount;
333
334    int ep_rank, ep_rank_loc, mpi_rank;
335    int ep_size, num_ep, mpi_size;
336
337    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
338    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
339    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
340    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
341    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
342    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
343
344
345    int root_mpi_rank = comm.rank_map->at(root).second;
346    int root_ep_loc = comm.rank_map->at(root).first;
347
348
349    ::MPI_Aint datasize, lb;
350
351    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
352
353    void *local_gather_recvbuf;
354
355    if(ep_rank_loc==0)
356    {
357      local_gather_recvbuf = new void*[datasize*num_ep*count];
358    }
359
360    // local gather to master
361    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
362
363    //MPI_Gather
364
365    if(ep_rank_loc == 0)
366    {
367      int *gatherv_recvcnt;
368      int *gatherv_displs;
369      int gatherv_cnt = count*num_ep;
370
371      gatherv_recvcnt = new int[mpi_size];
372      gatherv_displs = new int[mpi_size];
373
374
375      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
376
377      gatherv_displs[0] = 0;
378      for(int i=1; i<mpi_size; i++)
379      {
380        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
381      }
382
383
384      ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
385                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
386
387      delete[] gatherv_recvcnt;
388      delete[] gatherv_displs;
389    }
390
391
392    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
393    {
394      innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);
395    }
396
397
398
399    if(ep_rank_loc==0)
400    {
401      if(datatype == MPI_INT)
402      {
403        delete[] static_cast<int*>(local_gather_recvbuf);
404      }
405      else if(datatype == MPI_FLOAT)
406      {
407        delete[] static_cast<float*>(local_gather_recvbuf);
408      }
409      else if(datatype == MPI_DOUBLE)
410      {
411        delete[] static_cast<double*>(local_gather_recvbuf);
412      }
413      else if(datatype == MPI_CHAR)
414      {
415        delete[] static_cast<char*>(local_gather_recvbuf);
416      }
417      else if(datatype == MPI_LONG)
418      {
419        delete[] static_cast<long*>(local_gather_recvbuf);
420      }
421      else// if(datatype == MPI_UNSIGNED_LONG)
422      {
423        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
424      }
425    }
426
427
428  }
429
430
431  int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
432  {
433    if(!comm.is_ep && comm.mpi_comm)
434    {
435      ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
436                      static_cast< ::MPI_Comm>(comm.mpi_comm));
437      return 0;
438    }
439
440    if(!comm.mpi_comm) return 0;
441
442    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
443
444    MPI_Datatype datatype = sendtype;
445    int count = sendcount;
446
447    int ep_rank, ep_rank_loc, mpi_rank;
448    int ep_size, num_ep, mpi_size;
449
450    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
451    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
452    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
453    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
454    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
455    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
456
457
458    ::MPI_Aint datasize, lb;
459
460    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
461
462    void *local_gather_recvbuf;
463
464    if(ep_rank_loc==0)
465    {
466      local_gather_recvbuf = new void*[datasize*num_ep*count];
467    }
468
469    // local gather to master
470    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
471
472    //MPI_Gather
473
474    if(ep_rank_loc == 0)
475    {
476      int *gatherv_recvcnt;
477      int *gatherv_displs;
478      int gatherv_cnt = count*num_ep;
479
480      gatherv_recvcnt = new int[mpi_size];
481      gatherv_displs = new int[mpi_size];
482
483      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
484
485      gatherv_displs[0] = 0;
486      for(int i=1; i<mpi_size; i++)
487      {
488        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
489      }
490
491      ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
492                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
493
494      delete[] gatherv_recvcnt;
495      delete[] gatherv_displs;
496    }
497
498    MPI_Bcast_local(recvbuf, count*ep_size, datatype, comm);
499
500
501    if(ep_rank_loc==0)
502    {
503      if(datatype == MPI_INT)
504      {
505        delete[] static_cast<int*>(local_gather_recvbuf);
506      }
507      else if(datatype == MPI_FLOAT)
508      {
509        delete[] static_cast<float*>(local_gather_recvbuf);
510      }
511      else if(datatype == MPI_DOUBLE)
512      {
513        delete[] static_cast<double*>(local_gather_recvbuf);
514      }
515      else if(datatype == MPI_CHAR)
516      {
517        delete[] static_cast<char*>(local_gather_recvbuf);
518      }
519      else if(datatype == MPI_LONG)
520      {
521        delete[] static_cast<long*>(local_gather_recvbuf);
522      }
523      else// if(datatype == MPI_UNSIGNED_LONG)
524      {
525        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
526      }
527    }
528  }
529
530
531}
Note: See TracBrowser for help on using the repository browser.