source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_reduce.cpp @ 1134

Last change on this file since 1134 was 1134, checked in by yushan, 7 years ago

branch merged with trunk r1130

File size: 18.4 KB
Line 
1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14
15namespace ep_lib {
16
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29
30  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
31  {
32    if(datatype == MPI_INT)
33    {
34      Debug("datatype is INT\n");
35      return MPI_Reduce_local_int(sendbuf, recvbuf, count, op, comm);
36    }
37    else if(datatype == MPI_FLOAT)
38    {
39      Debug("datatype is FLOAT\n");
40      return MPI_Reduce_local_float(sendbuf, recvbuf, count, op, comm);
41    }
42    else if(datatype == MPI_DOUBLE)
43    {
44      Debug("datatype is DOUBLE\n");
45      return MPI_Reduce_local_double(sendbuf, recvbuf, count, op, comm);
46    }
47    else if(datatype == MPI_LONG)
48    {
49      Debug("datatype is DOUBLE\n");
50      return MPI_Reduce_local_long(sendbuf, recvbuf, count, op, comm);
51    }
52    else if(datatype == MPI_UNSIGNED_LONG)
53    {
54      Debug("datatype is DOUBLE\n");
55      return MPI_Reduce_local_ulong(sendbuf, recvbuf, count, op, comm);
56    }
57    else if(datatype == MPI_CHAR)
58    {
59      Debug("datatype is DOUBLE\n");
60      return MPI_Reduce_local_char(sendbuf, recvbuf, count, op, comm);
61    }
62    else
63    {
64      printf("MPI_Reduce Datatype not supported!\n");
65      exit(0);
66    }
67  }
68
69
70  int MPI_Reduce_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
71  {
72    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
73    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
74
75    int *buffer = comm.my_buffer->buf_int;
76    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
77    int *recv_buf = static_cast<int*>(const_cast<void*>(recvbuf));
78
79    for(int j=0; j<count; j+=BUFFER_SIZE)
80    {
81      if( 0 == my_rank )
82      {
83        #pragma omp critical (write_to_buffer)
84        copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer);
85        #pragma omp flush
86      }
87
88      MPI_Barrier_local(comm);
89
90      if(my_rank !=0 )
91      {
92        #pragma omp critical (write_to_buffer)
93        {
94          #pragma omp flush
95          if(op == MPI_SUM)
96          {
97            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>());
98          }
99
100          else if (op == MPI_MAX)
101          {
102            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>);
103          }
104
105          else if (op == MPI_MIN)
106          {
107            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>);
108          }
109
110          else
111          {
112            printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
113            exit(1);
114          }
115          #pragma omp flush
116        }
117      }
118
119      MPI_Barrier_local(comm);
120
121      if(my_rank == 0)
122      {
123        #pragma omp flush
124        copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
125      }
126      MPI_Barrier_local(comm);
127    }
128  }
129
130
131  int MPI_Reduce_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
132  {
133    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
134    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
135
136    float *buffer = comm.my_buffer->buf_float;
137    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
138    float *recv_buf = static_cast<float*>(const_cast<void*>(recvbuf));
139
140    for(int j=0; j<count; j+=BUFFER_SIZE)
141    {
142      if( 0 == my_rank )
143      {
144        #pragma omp critical (write_to_buffer)
145        copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer);
146        #pragma omp flush
147      }
148
149      MPI_Barrier_local(comm);
150
151      if(my_rank !=0 )
152      {
153        #pragma omp critical (write_to_buffer)
154        {
155          #pragma omp flush
156
157          if(op == MPI_SUM)
158          {
159            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>());
160          }
161
162          else if (op == MPI_MAX)
163          {
164            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>);
165          }
166
167          else if (op == MPI_MIN)
168          {
169            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>);
170          }
171
172          else
173          {
174            printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
175            exit(1);
176          }
177          #pragma omp flush
178        }
179      }
180
181      MPI_Barrier_local(comm);
182
183      if(my_rank == 0)
184      {
185        #pragma omp flush
186        copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
187      }
188      MPI_Barrier_local(comm);
189    }
190  }
191
192  int MPI_Reduce_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
193  {
194    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
195    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
196
197    double *buffer = comm.my_buffer->buf_double;
198    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
199    double *recv_buf = static_cast<double*>(const_cast<void*>(recvbuf));
200
201    for(int j=0; j<count; j+=BUFFER_SIZE)
202    {
203      if( 0 == my_rank )
204      {
205        #pragma omp critical (write_to_buffer)
206        copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer);
207        #pragma omp flush
208      }
209
210      MPI_Barrier_local(comm);
211
212      if(my_rank !=0 )
213      {
214        #pragma omp critical (write_to_buffer)
215        {
216          #pragma omp flush
217
218
219          if(op == MPI_SUM)
220          {
221            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>());
222          }
223
224          else if (op == MPI_MAX)
225          {
226            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>);
227          }
228
229
230          else if (op == MPI_MIN)
231          {
232            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>);
233          }
234
235          else
236          {
237            printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
238            exit(1);
239          }
240          #pragma omp flush
241        }
242      }
243
244      MPI_Barrier_local(comm);
245
246      if(my_rank == 0)
247      {
248        #pragma omp flush
249        copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
250      }
251      MPI_Barrier_local(comm);
252    }
253  }
254
255  int MPI_Reduce_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
256  {
257    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
258    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
259
260    long *buffer = comm.my_buffer->buf_long;
261    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
262    long *recv_buf = static_cast<long*>(const_cast<void*>(recvbuf));
263
264    for(int j=0; j<count; j+=BUFFER_SIZE)
265    {
266      if( 0 == my_rank )
267      {
268        #pragma omp critical (write_to_buffer)
269        copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer);
270        #pragma omp flush
271      }
272
273      MPI_Barrier_local(comm);
274
275      if(my_rank !=0 )
276      {
277        #pragma omp critical (write_to_buffer)
278        {
279          #pragma omp flush
280
281
282          if(op == MPI_SUM)
283          {
284            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>());
285          }
286
287          else if (op == MPI_MAX)
288          {
289            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>);
290          }
291
292
293          else if (op == MPI_MIN)
294          {
295            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>);
296          }
297
298          else
299          {
300            printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
301            exit(1);
302          }
303          #pragma omp flush
304        }
305      }
306
307      MPI_Barrier_local(comm);
308
309      if(my_rank == 0)
310      {
311        #pragma omp flush
312        copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
313      }
314      MPI_Barrier_local(comm);
315    }
316  }
317
318  int MPI_Reduce_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
319  {
320    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
321    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
322
323    unsigned long *buffer = comm.my_buffer->buf_ulong;
324    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
325    unsigned long *recv_buf = static_cast<unsigned long*>(const_cast<void*>(recvbuf));
326
327    for(int j=0; j<count; j+=BUFFER_SIZE)
328    {
329      if( 0 == my_rank )
330      {
331        #pragma omp critical (write_to_buffer)
332        copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer);
333        #pragma omp flush
334      }
335
336      MPI_Barrier_local(comm);
337
338      if(my_rank !=0 )
339      {
340        #pragma omp critical (write_to_buffer)
341        {
342          #pragma omp flush
343
344
345          if(op == MPI_SUM)
346          {
347            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>());
348          }
349
350          else if (op == MPI_MAX)
351          {
352            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>);
353          }
354
355
356          else if (op == MPI_MIN)
357          {
358            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>);
359          }
360
361          else
362          {
363            printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
364            exit(1);
365          }
366          #pragma omp flush
367        }
368      }
369
370      MPI_Barrier_local(comm);
371
372      if(my_rank == 0)
373      {
374        #pragma omp flush
375        copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
376      }
377      MPI_Barrier_local(comm);
378    }
379  }
380
381  int MPI_Reduce_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
382  {
383    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
384    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
385
386    char *buffer = comm.my_buffer->buf_char;
387    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
388    char *recv_buf = static_cast<char*>(const_cast<void*>(recvbuf));
389
390    for(int j=0; j<count; j+=BUFFER_SIZE)
391    {
392      if( 0 == my_rank )
393      {
394        #pragma omp critical (write_to_buffer)
395        copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer);
396        #pragma omp flush
397      }
398
399      MPI_Barrier_local(comm);
400
401      if(my_rank !=0 )
402      {
403        #pragma omp critical (write_to_buffer)
404        {
405          #pragma omp flush
406
407
408          if(op == MPI_SUM)
409          {
410            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>());
411          }
412
413          else if (op == MPI_MAX)
414          {
415            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>);
416          }
417
418
419          else if (op == MPI_MIN)
420          {
421            transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>);
422          }
423
424          else
425          {
426            printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
427            exit(1);
428          }
429          #pragma omp flush
430        }
431      }
432
433      MPI_Barrier_local(comm);
434
435      if(my_rank == 0)
436      {
437        #pragma omp flush
438        copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
439      }
440      MPI_Barrier_local(comm);
441    }
442  }
443
444
445  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
446  {
447    if(!comm.is_ep && comm.mpi_comm)
448    {
449      ::MPI_Reduce(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), root,
450                   static_cast< ::MPI_Comm>(comm.mpi_comm));
451      return 0;
452    }
453
454
455    if(!comm.mpi_comm) return 0;
456
457    int root_mpi_rank = comm.rank_map->at(root).second;
458    int root_ep_loc = comm.rank_map->at(root).first;
459
460    int ep_rank, ep_rank_loc, mpi_rank;
461    int ep_size, num_ep, mpi_size;
462
463    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
464    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
465    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
466    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
467    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
468    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
469
470
471    ::MPI_Aint recvsize, lb;
472
473    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &recvsize);
474
475    void *local_recvbuf;
476    if(ep_rank_loc==0)
477    {
478      local_recvbuf = new void*[recvsize*count];
479    }
480
481    MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, comm);
482
483
484    if(ep_rank_loc==0)
485    {
486      ::MPI_Reduce(local_recvbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
487    }
488
489    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
490    {
491      innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, count, datatype, comm);
492    }
493
494    if(ep_rank_loc==0)
495    {
496      if(datatype == MPI_INT) delete[] static_cast<int*>(local_recvbuf);
497      else if(datatype == MPI_FLOAT) delete[] static_cast<float*>(local_recvbuf);
498      else if(datatype == MPI_DOUBLE) delete[] static_cast<double*>(local_recvbuf);
499      else if(datatype == MPI_LONG) delete[] static_cast<long*>(local_recvbuf);
500      else if(datatype == MPI_UNSIGNED_LONG) delete[] static_cast<unsigned long*>(local_recvbuf);
501      else delete[] static_cast<char*>(local_recvbuf);
502    }
503
504    Message_Check(comm);
505
506    return 0;
507  }
508
509
510
511
512  int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
513  {
514    if(!comm.is_ep && comm.mpi_comm)
515    {
516      ::MPI_Allreduce(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op),
517                      static_cast< ::MPI_Comm>(comm.mpi_comm));
518      return 0;
519    }
520
521    if(!comm.mpi_comm) return 0;
522
523
524    int ep_rank, ep_rank_loc, mpi_rank;
525    int ep_size, num_ep, mpi_size;
526
527    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
528    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
529    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
530    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
531    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
532    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
533
534
535    ::MPI_Aint recvsize, lb;
536
537    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &recvsize);
538
539    void *local_recvbuf;
540    if(ep_rank_loc==0)
541    {
542      local_recvbuf = new void*[recvsize*count];
543    }
544
545    MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, comm);
546
547
548    if(ep_rank_loc==0)
549    {
550      ::MPI_Allreduce(local_recvbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm));
551    }
552
553    MPI_Bcast_local(recvbuf, count, datatype, comm);
554
555    if(ep_rank_loc==0)
556    {
557      if(datatype == MPI_INT) delete[] static_cast<int*>(local_recvbuf);
558      else if(datatype == MPI_FLOAT) delete[] static_cast<float*>(local_recvbuf);
559      else if(datatype == MPI_DOUBLE) delete[] static_cast<double*>(local_recvbuf);
560      else if(datatype == MPI_LONG) delete[] static_cast<long*>(local_recvbuf);
561      else if(datatype == MPI_UNSIGNED_LONG) delete[] static_cast<unsigned long*>(local_recvbuf);
562      else delete[] static_cast<char*>(local_recvbuf);
563    }
564
565    Message_Check(comm);
566
567    return 0;
568  }
569
570
571  int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
572  {
573
574    if(!comm.is_ep && comm.mpi_comm)
575    {
576      ::MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op),
577                           static_cast< ::MPI_Comm>(comm.mpi_comm));
578      return 0;
579    }
580
581    if(!comm.mpi_comm) return 0;
582
583    int ep_rank, ep_rank_loc, mpi_rank;
584    int ep_size, num_ep, mpi_size;
585
586    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
587    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
588    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
589    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
590    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
591    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
592
593    void* local_buf;
594    void* local_buf2;
595    int local_buf_size = accumulate(recvcounts, recvcounts+ep_size, 0);
596    int local_buf2_size = accumulate(recvcounts+ep_rank-ep_rank_loc, recvcounts+ep_rank-ep_rank_loc+num_ep, 0);
597
598    ::MPI_Aint datasize, lb;
599
600    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
601
602    if(ep_rank_loc == 0)
603    {
604      local_buf = new void*[local_buf_size*datasize];
605      local_buf2 = new void*[local_buf2_size*datasize];
606    }
607    MPI_Reduce_local(sendbuf, local_buf, local_buf_size, MPI_INT, op, comm);
608
609
610    if(ep_rank_loc == 0)
611    {
612      int local_recvcnt[mpi_size];
613      for(int i=0; i<mpi_size; i++)
614      {
615        local_recvcnt[i] = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
616      }
617
618      ::MPI_Reduce_scatter(local_buf, local_buf2, local_recvcnt, static_cast< ::MPI_Datatype>(datatype),
619                         static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm));
620    }
621
622
623    int displs[num_ep];
624    displs[0] = 0;
625    for(int i=1; i<num_ep; i++)
626    {
627      displs[i] = displs[i-1] + recvcounts[ep_rank-ep_rank_loc+i-1];
628    }
629
630    MPI_Scatterv_local(local_buf2, recvcounts+ep_rank-ep_rank_loc, displs, datatype, recvbuf, comm);
631
632    if(ep_rank_loc == 0)
633    {
634      if(datatype == MPI_INT)
635      {
636        delete[] static_cast<int*>(local_buf);
637        delete[] static_cast<int*>(local_buf2);
638      }
639      else if(datatype == MPI_FLOAT)
640      {
641        delete[] static_cast<float*>(local_buf);
642        delete[] static_cast<float*>(local_buf2);
643      }
644      else if(datatype == MPI_DOUBLE)
645      {
646        delete[] static_cast<double*>(local_buf);
647        delete[] static_cast<double*>(local_buf2);
648      }
649      else if(datatype == MPI_LONG)
650      {
651        delete[] static_cast<long*>(local_buf);
652        delete[] static_cast<long*>(local_buf2);
653      }
654      else if(datatype == MPI_UNSIGNED_LONG)
655      {
656        delete[] static_cast<unsigned long*>(local_buf);
657        delete[] static_cast<unsigned long*>(local_buf2);
658      }
659      else // if(datatype == MPI_DOUBLE)
660      {
661        delete[] static_cast<char*>(local_buf);
662        delete[] static_cast<char*>(local_buf2);
663      }
664    }
665
666    Message_Check(comm);
667    return 0;
668  }
669}
670
Note: See TracBrowser for help on using the repository browser.