source: XIOS/dev/branch_yushan/extern/src_ep_dev/ep_gather.cpp @ 1037

Last change on this file since 1037 was 1037, checked in by yushan, 7 years ago

initialize the branch

File size: 14.7 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gather, MPI_Allgather
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gather_local_int(sendbuf, count, recvbuf, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gather_local_float(sendbuf, count, recvbuf, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gather_local_double(sendbuf, count, recvbuf, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gather_local_long(sendbuf, count, recvbuf, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gather_local_char(sendbuf, count, recvbuf, comm);
49    }
50    else
51    {
52      printf("MPI_Gather Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      copy(send_buf, send_buf+count, recv_buf);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      copy(send_buf, send_buf+count, recv_buf);
112    }
113
114    for(int j=0; j<count; j+=BUFFER_SIZE)
115    {
116      for(int k=1; k<num_ep; k++)
117      {
118        if(my_rank == k)
119        {
120          #pragma omp critical (write_to_buffer)
121          {
122            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
123            #pragma omp flush
124          }
125        }
126
127        MPI_Barrier_local(comm);
128
129        if(my_rank == 0)
130        {
131          #pragma omp flush
132          #pragma omp critical (read_from_buffer)
133          {
134            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
135          }
136        }
137
138        MPI_Barrier_local(comm);
139      }
140    }
141  }
142
143  int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
144  {
145    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
146    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
147
148    double *buffer = comm.my_buffer->buf_double;
149    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
150    double *recv_buf = static_cast<double*>(recvbuf);
151
152    if(my_rank == 0)
153    {
154      copy(send_buf, send_buf+count, recv_buf);
155    }
156
157    for(int j=0; j<count; j+=BUFFER_SIZE)
158    {
159      for(int k=1; k<num_ep; k++)
160      {
161        if(my_rank == k)
162        {
163          #pragma omp critical (write_to_buffer)
164          {
165            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
166            #pragma omp flush
167          }
168        }
169
170        MPI_Barrier_local(comm);
171
172        if(my_rank == 0)
173        {
174          #pragma omp flush
175          #pragma omp critical (read_from_buffer)
176          {
177            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
178          }
179        }
180
181        MPI_Barrier_local(comm);
182      }
183    }
184  }
185
186  int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
187  {
188    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
189    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
190
191    long *buffer = comm.my_buffer->buf_long;
192    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
193    long *recv_buf = static_cast<long*>(recvbuf);
194
195    if(my_rank == 0)
196    {
197      copy(send_buf, send_buf+count, recv_buf);
198    }
199
200    for(int j=0; j<count; j+=BUFFER_SIZE)
201    {
202      for(int k=1; k<num_ep; k++)
203      {
204        if(my_rank == k)
205        {
206          #pragma omp critical (write_to_buffer)
207          {
208            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
209            #pragma omp flush
210          }
211        }
212
213        MPI_Barrier_local(comm);
214
215        if(my_rank == 0)
216        {
217          #pragma omp flush
218          #pragma omp critical (read_from_buffer)
219          {
220            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
221          }
222        }
223
224        MPI_Barrier_local(comm);
225      }
226    }
227  }
228
229  int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
230  {
231    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
232    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
233
234    unsigned long *buffer = comm.my_buffer->buf_ulong;
235    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
236    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
237
238    if(my_rank == 0)
239    {
240      copy(send_buf, send_buf+count, recv_buf);
241    }
242
243    for(int j=0; j<count; j+=BUFFER_SIZE)
244    {
245      for(int k=1; k<num_ep; k++)
246      {
247        if(my_rank == k)
248        {
249          #pragma omp critical (write_to_buffer)
250          {
251            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
252            #pragma omp flush
253          }
254        }
255
256        MPI_Barrier_local(comm);
257
258        if(my_rank == 0)
259        {
260          #pragma omp flush
261          #pragma omp critical (read_from_buffer)
262          {
263            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
264          }
265        }
266
267        MPI_Barrier_local(comm);
268      }
269    }
270  }
271
272
273  int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
274  {
275    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
276    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
277
278    char *buffer = comm.my_buffer->buf_char;
279    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
280    char *recv_buf = static_cast<char*>(recvbuf);
281
282    if(my_rank == 0)
283    {
284      copy(send_buf, send_buf+count, recv_buf);
285    }
286
287    for(int j=0; j<count; j+=BUFFER_SIZE)
288    {
289      for(int k=1; k<num_ep; k++)
290      {
291        if(my_rank == k)
292        {
293          #pragma omp critical (write_to_buffer)
294          {
295            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
296            #pragma omp flush
297          }
298        }
299
300        MPI_Barrier_local(comm);
301
302        if(my_rank == 0)
303        {
304          #pragma omp flush
305          #pragma omp critical (read_from_buffer)
306          {
307            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);
308          }
309        }
310
311        MPI_Barrier_local(comm);
312      }
313    }
314  }
315
316
317
318  int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
319  {
320    if(!comm.is_ep && comm.mpi_comm)
321    {
322      #ifdef _serialized
323      #pragma omp critical (_mpi_call)
324      #endif // _serialized
325      ::MPI_Gather(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
326                   root, static_cast< ::MPI_Comm>(comm.mpi_comm));
327      return 0;
328    }
329
330    if(!comm.mpi_comm) return 0;
331
332    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
333
334    MPI_Datatype datatype = sendtype;
335    int count = sendcount;
336
337    int ep_rank, ep_rank_loc, mpi_rank;
338    int ep_size, num_ep, mpi_size;
339
340    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
341    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
342    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
343    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
344    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
345    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
346
347
348    int root_mpi_rank = comm.rank_map->at(root).second;
349    int root_ep_loc = comm.rank_map->at(root).first;
350
351
352    ::MPI_Aint datasize, lb;
353    #ifdef _serialized
354    #pragma omp critical (_mpi_call)
355    #endif // _serialized
356    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
357
358    void *local_gather_recvbuf;
359
360    if(ep_rank_loc==0)
361    {
362      local_gather_recvbuf = new void*[datasize*num_ep*count];
363    }
364
365    // local gather to master
366    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
367
368    //MPI_Gather
369
370    if(ep_rank_loc == 0)
371    {
372      int *gatherv_recvcnt;
373      int *gatherv_displs;
374      int gatherv_cnt = count*num_ep;
375
376      gatherv_recvcnt = new int[mpi_size];
377      gatherv_displs = new int[mpi_size];
378
379      #ifdef _serialized
380      #pragma omp critical (_mpi_call)
381      #endif // _serialized
382      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
383
384      gatherv_displs[0] = 0;
385      for(int i=1; i<mpi_size; i++)
386      {
387        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
388      }
389
390      #ifdef _serialized
391      #pragma omp critical (_mpi_call)
392      #endif // _serialized
393      ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
394                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
395
396      delete[] gatherv_recvcnt;
397      delete[] gatherv_displs;
398    }
399
400
401    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
402    {
403      innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);
404    }
405
406
407
408    if(ep_rank_loc==0)
409    {
410      if(datatype == MPI_INT)
411      {
412        delete[] static_cast<int*>(local_gather_recvbuf);
413      }
414      else if(datatype == MPI_FLOAT)
415      {
416        delete[] static_cast<float*>(local_gather_recvbuf);
417      }
418      else if(datatype == MPI_DOUBLE)
419      {
420        delete[] static_cast<double*>(local_gather_recvbuf);
421      }
422      else if(datatype == MPI_CHAR)
423      {
424        delete[] static_cast<char*>(local_gather_recvbuf);
425      }
426      else if(datatype == MPI_LONG)
427      {
428        delete[] static_cast<long*>(local_gather_recvbuf);
429      }
430      else// if(datatype == MPI_UNSIGNED_LONG)
431      {
432        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
433      }
434    }
435
436
437  }
438
439
440  int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
441  {
442    if(!comm.is_ep && comm.mpi_comm)
443    {
444      #ifdef _serialized
445      #pragma omp critical (_mpi_call)
446      #endif // _serialized
447      ::MPI_Allgather(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
448                      static_cast< ::MPI_Comm>(comm.mpi_comm));
449      return 0;
450    }
451
452    if(!comm.mpi_comm) return 0;
453
454    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
455
456    MPI_Datatype datatype = sendtype;
457    int count = sendcount;
458
459    int ep_rank, ep_rank_loc, mpi_rank;
460    int ep_size, num_ep, mpi_size;
461
462    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
463    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
464    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
465    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
466    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
467    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
468
469
470    ::MPI_Aint datasize, lb;
471    #ifdef _serialized
472    #pragma omp critical (_mpi_call)
473    #endif // _serialized
474    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
475
476    void *local_gather_recvbuf;
477
478    if(ep_rank_loc==0)
479    {
480      local_gather_recvbuf = new void*[datasize*num_ep*count];
481    }
482
483    // local gather to master
484    MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);
485
486    //MPI_Gather
487
488    if(ep_rank_loc == 0)
489    {
490      int *gatherv_recvcnt;
491      int *gatherv_displs;
492      int gatherv_cnt = count*num_ep;
493
494      gatherv_recvcnt = new int[mpi_size];
495      gatherv_displs = new int[mpi_size];
496
497      #ifdef _serialized
498      #pragma omp critical (_mpi_call)
499      #endif // _serialized
500      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
501
502      gatherv_displs[0] = 0;
503      for(int i=1; i<mpi_size; i++)
504      {
505        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
506      }
507
508      #ifdef _serialized
509      #pragma omp critical (_mpi_call)
510      #endif // _serialized
511      ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
512                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
513
514      delete[] gatherv_recvcnt;
515      delete[] gatherv_displs;
516    }
517
518    MPI_Bcast_local(recvbuf, count*ep_size, datatype, comm);
519
520
521    if(ep_rank_loc==0)
522    {
523      if(datatype == MPI_INT)
524      {
525        delete[] static_cast<int*>(local_gather_recvbuf);
526      }
527      else if(datatype == MPI_FLOAT)
528      {
529        delete[] static_cast<float*>(local_gather_recvbuf);
530      }
531      else if(datatype == MPI_DOUBLE)
532      {
533        delete[] static_cast<double*>(local_gather_recvbuf);
534      }
535      else if(datatype == MPI_CHAR)
536      {
537        delete[] static_cast<char*>(local_gather_recvbuf);
538      }
539      else if(datatype == MPI_LONG)
540      {
541        delete[] static_cast<long*>(local_gather_recvbuf);
542      }
543      else// if(datatype == MPI_UNSIGNED_LONG)
544      {
545        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
546      }
547    }
548  }
549
550
551}
Note: See TracBrowser for help on using the repository browser.