source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce.cpp @ 1642

Last change on this file since 1642 was 1642, checked in by yushan, 5 years ago

dev on ADA. add flag switch _usingEP/_usingMPI

File size: 10.5 KB
RevLine 
[1134]1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
[1295]11#include "ep_mpi.hpp"
[1134]12
13using namespace std;
14
15
16namespace ep_lib {
17
18  template<typename T>
19  T max_op(T a, T b)
20  {
21    return max(a,b);
22  }
23
24  template<typename T>
25  T min_op(T a, T b)
26  {
27    return min(a,b);
28  }
29
[1295]30  template<typename T>
[1460]31  T lor_op(T a, T b)
32  {
33    return a||b;
34  }
35
36  template<typename T>
[1295]37  void reduce_max(const T * buffer, T* recvbuf, int count)
38  {
39    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
40  }
[1134]41
[1295]42  template<typename T>
43  void reduce_min(const T * buffer, T* recvbuf, int count)
[1134]44  {
[1295]45    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
[1134]46  }
47
[1295]48  template<typename T>
49  void reduce_sum(const T * buffer, T* recvbuf, int count)
50  {
51    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
52  }
[1134]53
[1460]54  template<typename T>
55  void reduce_lor(const T * buffer, T* recvbuf, int count)
56  {
57    transform(buffer, buffer+count, recvbuf, recvbuf, lor_op<T>);
58  }
59
[1295]60  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int local_root, MPI_Comm comm)
[1134]61  {
[1295]62    assert(valid_type(datatype));
63    assert(valid_op(op));
[1134]64
[1295]65    ::MPI_Aint datasize, lb;
66    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
[1134]67
[1520]68    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
69    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
70    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
[1134]71
[1295]72    #pragma omp critical (_reduce)
[1520]73    comm->my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf);
[1134]74
[1295]75    MPI_Barrier_local(comm);
[1134]76
[1295]77    if(ep_rank_loc == local_root)
[1134]78    {
79
[1520]80      memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize * count);
[1289]81
[1642]82      if(op == EP_MAX)
[1134]83      {
[1642]84        if(datatype == EP_INT)
[1134]85        {
[1295]86          assert(datasize == sizeof(int));
87          for(int i=1; i<num_ep; i++)
[1520]88            reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
[1295]89        }
[1134]90
[1295]91        else if(datatype == MPI_FLOAT)
92        {
93          assert(datasize == sizeof(float));
94          for(int i=1; i<num_ep; i++)
[1520]95            reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
[1295]96        }
[1134]97
[1642]98        else if(datatype == EP_DOUBLE)
[1295]99        {
100          assert(datasize == sizeof(double));
101          for(int i=1; i<num_ep; i++)
[1520]102            reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]103        }
[1134]104
[1642]105        else if(datatype == EP_CHAR)
[1295]106        {
107          assert(datasize == sizeof(char));
108          for(int i=1; i<num_ep; i++)
[1520]109            reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]110        }
[1134]111
[1642]112        else if(datatype == EP_LONG)
[1295]113        {
114          assert(datasize == sizeof(long));
115          for(int i=1; i<num_ep; i++)
[1520]116            reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]117        }
118
[1642]119        else if(datatype == EP_UNSIGNED_LONG)
[1295]120        {
121          assert(datasize == sizeof(unsigned long));
122          for(int i=1; i<num_ep; i++)
[1520]123            reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
[1295]124        }
[1134]125
[1642]126        else if(datatype == EP_LONG_LONG_INT)
[1482]127        {
128          assert(datasize == sizeof(long long));
129          for(int i=1; i<num_ep; i++)
[1520]130            reduce_max<long long>(static_cast<long long*>(comm->my_buffer->void_buffer[i]), static_cast<long long*>(recvbuf), count);
[1482]131        }
[1460]132
[1540]133        else 
134        {
135          printf("datatype Error in ep_reduce : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
136          MPI_Abort(comm, 0);
137        }
[1134]138
139      }
140
[1642]141      else if(op == EP_MIN)
[1134]142      {
[1642]143        if(datatype ==EP_INT)
[1134]144        {
[1295]145          assert(datasize == sizeof(int));
146          for(int i=1; i<num_ep; i++)
[1520]147            reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
[1295]148        }
[1134]149
[1295]150        else if(datatype == MPI_FLOAT)
151        {
152          assert(datasize == sizeof(float));
153          for(int i=1; i<num_ep; i++)
[1520]154            reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
[1295]155        }
[1134]156
[1642]157        else if(datatype == EP_DOUBLE)
[1295]158        {
159          assert(datasize == sizeof(double));
160          for(int i=1; i<num_ep; i++)
[1520]161            reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]162        }
[1134]163
[1642]164        else if(datatype == EP_CHAR)
[1295]165        {
166          assert(datasize == sizeof(char));
167          for(int i=1; i<num_ep; i++)
[1520]168            reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]169        }
[1134]170
[1642]171        else if(datatype == EP_LONG)
[1295]172        {
173          assert(datasize == sizeof(long));
174          for(int i=1; i<num_ep; i++)
[1520]175            reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1295]176        }
[1134]177
[1642]178        else if(datatype == EP_UNSIGNED_LONG)
[1295]179        {
180          assert(datasize == sizeof(unsigned long));
181          for(int i=1; i<num_ep; i++)
[1520]182            reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
[1134]183        }
184
[1642]185        else if(datatype == EP_LONG_LONG_INT)
[1482]186        {
187          assert(datasize == sizeof(long long));
188          for(int i=1; i<num_ep; i++)
[1520]189            reduce_min<long long>(static_cast<long long*>(comm->my_buffer->void_buffer[i]), static_cast<long long*>(recvbuf), count);
[1482]190        }
[1460]191
[1540]192        else 
193        {
194          printf("datatype Error in ep_reduce : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
195          MPI_Abort(comm, 0);
196        }
[1134]197
198      }
199
200
[1642]201      else if(op == EP_SUM)
[1134]202      {
[1642]203        if(datatype==EP_INT)
[1295]204        {
205          assert(datasize == sizeof(int));
206          for(int i=1; i<num_ep; i++)
[1520]207            reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
[1295]208        }
[1289]209
[1295]210        else if(datatype == MPI_FLOAT)
211        {
212          assert(datasize == sizeof(float));
213          for(int i=1; i<num_ep; i++)
[1520]214            reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
[1295]215        }
[1289]216
[1642]217        else if(datatype == EP_DOUBLE)
[1287]218        {
[1295]219          assert(datasize == sizeof(double));
220          for(int i=1; i<num_ep; i++)
[1520]221            reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]222        }
[1134]223
[1642]224        else if(datatype == EP_CHAR)
[1295]225        {
226          assert(datasize == sizeof(char));
227          for(int i=1; i<num_ep; i++)
[1520]228            reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]229        }
[1289]230
[1642]231        else if(datatype == EP_LONG)
[1295]232        {
233          assert(datasize == sizeof(long));
234          for(int i=1; i<num_ep; i++)
[1520]235            reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1287]236        }
[1134]237
[1642]238        else if(datatype ==EP_UNSIGNED_LONG)
[1134]239        {
[1295]240          assert(datasize == sizeof(unsigned long));
241          for(int i=1; i<num_ep; i++)
[1520]242            reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
[1287]243        }
[1482]244       
[1642]245        else if(datatype ==EP_LONG_LONG_INT)
[1482]246        {
247          assert(datasize == sizeof(long long));
248          for(int i=1; i<num_ep; i++)
[1520]249            reduce_sum<long long>(static_cast<long long*>(comm->my_buffer->void_buffer[i]), static_cast<long long*>(recvbuf), count);
[1482]250        }
[1460]251
[1540]252        else 
253        {
254          printf("datatype Error in ep_reduce : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
255          MPI_Abort(comm, 0);
256        }
[1289]257
258      }
[1460]259
[1642]260      else if(op == EP_LOR)
[1460]261      {
[1642]262        if(datatype != EP_INT)
[1460]263          printf("datatype Error, must be MPI_INT\n");
264        else
265        {
266          assert(datasize == sizeof(int));
267          for(int i=1; i<num_ep; i++)
[1520]268            reduce_lor<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
[1460]269        }
270      }
[1540]271     
272      else
273      {
274        printf("op type Error in ep_reduce : MPI_MAX, MPI_MIN, MPI_SUM, MPI_LOR\n");
275        MPI_Abort(comm, 0);
276      }
[1289]277    }
278
[1295]279    MPI_Barrier_local(comm);
[1289]280
[1134]281  }
282
283
284  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
285  {
[1295]286
[1539]287    if(!comm->is_ep) return ::MPI_Reduce(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root, to_mpi_comm(comm->mpi_comm));
288    if(comm->is_intercomm) return MPI_Reduce_intercomm(sendbuf, recvbuf, count, datatype, op, root, comm);
[1134]289
290
291
[1520]292    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
293    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
294    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
295    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
296    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
297    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
[1295]298
[1520]299    int root_mpi_rank = comm->ep_rank_map->at(root).second;
300    int root_ep_loc = comm->ep_rank_map->at(root).first;
[1134]301
[1295]302    ::MPI_Aint datasize, lb;
[1134]303
[1365]304    ::MPI_Type_get_extent(*(static_cast< ::MPI_Datatype*>(datatype)), &lb, &datasize);
[1134]305
[1295]306    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root;
307    bool is_root = ep_rank == root;
[1134]308
[1295]309    void* local_recvbuf;
[1134]310
[1295]311    if(is_master)
[1134]312    {
[1295]313      local_recvbuf = new void*[datasize * count];
[1134]314    }
315
[1295]316    if(mpi_rank == root_mpi_rank) MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, root_ep_loc, comm);
317    else                          MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
[1134]318
319
320
[1295]321    if(is_master)
[1134]322    {
[1520]323      ::MPI_Reduce(local_recvbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root_mpi_rank, to_mpi_comm(comm->mpi_comm));
[1295]324     
[1134]325    }
326
[1295]327    if(is_master)
[1134]328    {
[1295]329      delete[] local_recvbuf;
[1134]330    }
331
[1295]332    MPI_Barrier_local(comm);
[1134]333  }
334
335
[1539]336  int MPI_Reduce_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
337  {
338    printf("MPI_Reduce_intercomm not yet implemented\n");
339    MPI_Abort(comm, 0);
340  }
[1134]341}
342
Note: See TracBrowser for help on using the repository browser.