source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce.cpp @ 1365

Last change on this file since 1365 was 1365, checked in by yushan, 6 years ago

unify type : MPI_Datatype MPI_Aint

File size: 8.2 KB
RevLine 
[1134]1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
[1295]11#include "ep_mpi.hpp"
[1134]12
13using namespace std;
14
15
16namespace ep_lib {
17
18  template<typename T>
19  T max_op(T a, T b)
20  {
21    return max(a,b);
22  }
23
24  template<typename T>
25  T min_op(T a, T b)
26  {
27    return min(a,b);
28  }
29
[1295]30  template<typename T>
31  void reduce_max(const T * buffer, T* recvbuf, int count)
32  {
33    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
34  }
[1134]35
[1295]36  template<typename T>
37  void reduce_min(const T * buffer, T* recvbuf, int count)
[1134]38  {
[1295]39    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
[1134]40  }
41
[1295]42  template<typename T>
43  void reduce_sum(const T * buffer, T* recvbuf, int count)
44  {
45    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
46  }
[1134]47
[1295]48  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int local_root, MPI_Comm comm)
[1134]49  {
[1295]50    assert(valid_type(datatype));
51    assert(valid_op(op));
[1134]52
[1295]53    ::MPI_Aint datasize, lb;
54    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
[1134]55
[1295]56    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
57    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
58    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
[1134]59
[1295]60    #pragma omp critical (_reduce)
61    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf);
[1134]62
[1295]63    MPI_Barrier_local(comm);
[1134]64
[1295]65    if(ep_rank_loc == local_root)
[1134]66    {
67
[1295]68      memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize * count);
[1289]69
[1295]70      if(op == MPI_MAX)
[1134]71      {
[1295]72        if(datatype == MPI_INT)
[1134]73        {
[1295]74          assert(datasize == sizeof(int));
75          for(int i=1; i<num_ep; i++)
76            reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
77        }
[1134]78
[1295]79        else if(datatype == MPI_FLOAT)
80        {
81          assert(datasize == sizeof(float));
82          for(int i=1; i<num_ep; i++)
83            reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
84        }
[1134]85
[1295]86        else if(datatype == MPI_DOUBLE)
87        {
88          assert(datasize == sizeof(double));
89          for(int i=1; i<num_ep; i++)
90            reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
91        }
[1134]92
[1295]93        else if(datatype == MPI_CHAR)
94        {
95          assert(datasize == sizeof(char));
96          for(int i=1; i<num_ep; i++)
97            reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
98        }
[1134]99
[1295]100        else if(datatype == MPI_LONG)
101        {
102          assert(datasize == sizeof(long));
103          for(int i=1; i<num_ep; i++)
104            reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]105        }
106
[1295]107        else if(datatype == MPI_UNSIGNED_LONG)
108        {
109          assert(datasize == sizeof(unsigned long));
110          for(int i=1; i<num_ep; i++)
111            reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
112        }
[1134]113
[1295]114        else printf("datatype Error\n");
[1134]115
116      }
117
[1295]118      if(op == MPI_MIN)
[1134]119      {
[1295]120        if(datatype ==MPI_INT)
[1134]121        {
[1295]122          assert(datasize == sizeof(int));
123          for(int i=1; i<num_ep; i++)
124            reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
125        }
[1134]126
[1295]127        else if(datatype == MPI_FLOAT)
128        {
129          assert(datasize == sizeof(float));
130          for(int i=1; i<num_ep; i++)
131            reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
132        }
[1134]133
[1295]134        else if(datatype == MPI_DOUBLE)
135        {
136          assert(datasize == sizeof(double));
137          for(int i=1; i<num_ep; i++)
138            reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
139        }
[1134]140
[1295]141        else if(datatype == MPI_CHAR)
142        {
143          assert(datasize == sizeof(char));
144          for(int i=1; i<num_ep; i++)
145            reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
146        }
[1134]147
[1295]148        else if(datatype == MPI_LONG)
149        {
150          assert(datasize == sizeof(long));
151          for(int i=1; i<num_ep; i++)
152            reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
153        }
[1134]154
[1295]155        else if(datatype == MPI_UNSIGNED_LONG)
156        {
157          assert(datasize == sizeof(unsigned long));
158          for(int i=1; i<num_ep; i++)
159            reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
[1134]160        }
161
[1295]162        else printf("datatype Error\n");
[1134]163
164      }
165
166
[1295]167      if(op == MPI_SUM)
[1134]168      {
[1295]169        if(datatype==MPI_INT)
170        {
171          assert(datasize == sizeof(int));
172          for(int i=1; i<num_ep; i++)
173            reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
174        }
[1289]175
[1295]176        else if(datatype == MPI_FLOAT)
177        {
178          assert(datasize == sizeof(float));
179          for(int i=1; i<num_ep; i++)
180            reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
181        }
[1289]182
[1295]183        else if(datatype == MPI_DOUBLE)
[1287]184        {
[1295]185          assert(datasize == sizeof(double));
186          for(int i=1; i<num_ep; i++)
187            reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
188        }
[1134]189
[1295]190        else if(datatype == MPI_CHAR)
191        {
192          assert(datasize == sizeof(char));
193          for(int i=1; i<num_ep; i++)
194            reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
195        }
[1289]196
[1295]197        else if(datatype == MPI_LONG)
198        {
199          assert(datasize == sizeof(long));
200          for(int i=1; i<num_ep; i++)
201            reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1287]202        }
[1134]203
[1295]204        else if(datatype ==MPI_UNSIGNED_LONG)
[1134]205        {
[1295]206          assert(datasize == sizeof(unsigned long));
207          for(int i=1; i<num_ep; i++)
208            reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
[1287]209        }
[1134]210
[1295]211        else printf("datatype Error\n");
[1289]212
213      }
214    }
215
[1295]216    MPI_Barrier_local(comm);
[1289]217
[1134]218  }
219
220
221  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
222  {
[1295]223
[1134]224    if(!comm.is_ep && comm.mpi_comm)
225    {
[1295]226      return ::MPI_Reduce(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root, to_mpi_comm(comm.mpi_comm));
[1134]227    }
228
229
230
[1295]231    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
232    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
233    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
234    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
235    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
236    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
237
[1134]238    int root_mpi_rank = comm.rank_map->at(root).second;
239    int root_ep_loc = comm.rank_map->at(root).first;
240
[1295]241    ::MPI_Aint datasize, lb;
[1134]242
[1365]243    ::MPI_Type_get_extent(*(static_cast< ::MPI_Datatype*>(datatype)), &lb, &datasize);
[1134]244
[1295]245    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root;
246    bool is_root = ep_rank == root;
[1134]247
[1295]248    void* local_recvbuf;
[1134]249
[1295]250    if(is_master)
[1134]251    {
[1295]252      local_recvbuf = new void*[datasize * count];
[1134]253    }
254
[1295]255    if(mpi_rank == root_mpi_rank) MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, root_ep_loc, comm);
256    else                          MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
[1134]257
258
259
[1295]260    if(is_master)
[1134]261    {
[1295]262      ::MPI_Reduce(local_recvbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root_mpi_rank, to_mpi_comm(comm.mpi_comm));
263     
[1134]264    }
265
[1295]266    if(is_master)
[1134]267    {
[1295]268      delete[] local_recvbuf;
[1134]269    }
270
[1295]271    MPI_Barrier_local(comm);
[1134]272  }
273
274
275}
276
Note: See TracBrowser for help on using the repository browser.