source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce.cpp @ 1482

Last change on this file since 1482 was 1482, checked in by yushan, 3 years ago

Branch EP merged with Dev_cmip6 @r1481

File size: 10.4 KB
Line 
1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15
16namespace ep_lib {
17
18  template<typename T>
19  T max_op(T a, T b)
20  {
21    return max(a,b);
22  }
23
24  template<typename T>
25  T min_op(T a, T b)
26  {
27    return min(a,b);
28  }
29
30  template<typename T>
31  T lor_op(T a, T b)
32  {
33    return a||b;
34  }
35
36  template<typename T>
37  void reduce_max(const T * buffer, T* recvbuf, int count)
38  {
39    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
40  }
41
42  template<typename T>
43  void reduce_min(const T * buffer, T* recvbuf, int count)
44  {
45    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
46  }
47
48  template<typename T>
49  void reduce_sum(const T * buffer, T* recvbuf, int count)
50  {
51    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
52  }
53
54  template<typename T>
55  void reduce_lor(const T * buffer, T* recvbuf, int count)
56  {
57    transform(buffer, buffer+count, recvbuf, recvbuf, lor_op<T>);
58  }
59
60  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int local_root, MPI_Comm comm)
61  {
62    assert(valid_type(datatype));
63    assert(valid_op(op));
64
65    ::MPI_Aint datasize, lb;
66    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
67
68    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
69    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
70    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
71
72    #pragma omp critical (_reduce)
73    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf);
74
75    MPI_Barrier_local(comm);
76
77    if(ep_rank_loc == local_root)
78    {
79
80      memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize * count);
81
82      if(op == MPI_MAX)
83      {
84        if(datatype == MPI_INT)
85        {
86          assert(datasize == sizeof(int));
87          for(int i=1; i<num_ep; i++)
88            reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
89        }
90
91        else if(datatype == MPI_FLOAT)
92        {
93          assert(datasize == sizeof(float));
94          for(int i=1; i<num_ep; i++)
95            reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
96        }
97
98        else if(datatype == MPI_DOUBLE)
99        {
100          assert(datasize == sizeof(double));
101          for(int i=1; i<num_ep; i++)
102            reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
103        }
104
105        else if(datatype == MPI_CHAR)
106        {
107          assert(datasize == sizeof(char));
108          for(int i=1; i<num_ep; i++)
109            reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
110        }
111
112        else if(datatype == MPI_LONG)
113        {
114          assert(datasize == sizeof(long));
115          for(int i=1; i<num_ep; i++)
116            reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
117        }
118
119        else if(datatype == MPI_UNSIGNED_LONG)
120        {
121          assert(datasize == sizeof(unsigned long));
122          for(int i=1; i<num_ep; i++)
123            reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
124        }
125
126        else if(datatype == MPI_UINT64_T)
127        {
128          assert(datasize == sizeof(uint64_t));
129          for(int i=1; i<num_ep; i++)
130            reduce_max<uint64_t>(static_cast<uint64_t*>(comm.my_buffer->void_buffer[i]), static_cast<uint64_t*>(recvbuf), count);
131        }
132       
133        else if(datatype == MPI_LONG_LONG_INT)
134        {
135          assert(datasize == sizeof(long long));
136          for(int i=1; i<num_ep; i++)
137            reduce_max<long long>(static_cast<long long*>(comm.my_buffer->void_buffer[i]), static_cast<long long*>(recvbuf), count);
138        }
139
140        else printf("datatype Error\n");
141
142      }
143
144      if(op == MPI_MIN)
145      {
146        if(datatype ==MPI_INT)
147        {
148          assert(datasize == sizeof(int));
149          for(int i=1; i<num_ep; i++)
150            reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
151        }
152
153        else if(datatype == MPI_FLOAT)
154        {
155          assert(datasize == sizeof(float));
156          for(int i=1; i<num_ep; i++)
157            reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
158        }
159
160        else if(datatype == MPI_DOUBLE)
161        {
162          assert(datasize == sizeof(double));
163          for(int i=1; i<num_ep; i++)
164            reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
165        }
166
167        else if(datatype == MPI_CHAR)
168        {
169          assert(datasize == sizeof(char));
170          for(int i=1; i<num_ep; i++)
171            reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
172        }
173
174        else if(datatype == MPI_LONG)
175        {
176          assert(datasize == sizeof(long));
177          for(int i=1; i<num_ep; i++)
178            reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
179        }
180
181        else if(datatype == MPI_UNSIGNED_LONG)
182        {
183          assert(datasize == sizeof(unsigned long));
184          for(int i=1; i<num_ep; i++)
185            reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
186        }
187
188        else if(datatype == MPI_UINT64_T)
189        {
190          assert(datasize == sizeof(uint64_t));
191          for(int i=1; i<num_ep; i++)
192            reduce_min<uint64_t>(static_cast<uint64_t*>(comm.my_buffer->void_buffer[i]), static_cast<uint64_t*>(recvbuf), count);
193        }
194       
195        else if(datatype == MPI_LONG_LONG_INT)
196        {
197          assert(datasize == sizeof(long long));
198          for(int i=1; i<num_ep; i++)
199            reduce_min<long long>(static_cast<long long*>(comm.my_buffer->void_buffer[i]), static_cast<long long*>(recvbuf), count);
200        }
201
202        else printf("datatype Error\n");
203
204      }
205
206
207      if(op == MPI_SUM)
208      {
209        if(datatype==MPI_INT)
210        {
211          assert(datasize == sizeof(int));
212          for(int i=1; i<num_ep; i++)
213            reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
214        }
215
216        else if(datatype == MPI_FLOAT)
217        {
218          assert(datasize == sizeof(float));
219          for(int i=1; i<num_ep; i++)
220            reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
221        }
222
223        else if(datatype == MPI_DOUBLE)
224        {
225          assert(datasize == sizeof(double));
226          for(int i=1; i<num_ep; i++)
227            reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
228        }
229
230        else if(datatype == MPI_CHAR)
231        {
232          assert(datasize == sizeof(char));
233          for(int i=1; i<num_ep; i++)
234            reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
235        }
236
237        else if(datatype == MPI_LONG)
238        {
239          assert(datasize == sizeof(long));
240          for(int i=1; i<num_ep; i++)
241            reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
242        }
243
244        else if(datatype ==MPI_UNSIGNED_LONG)
245        {
246          assert(datasize == sizeof(unsigned long));
247          for(int i=1; i<num_ep; i++)
248            reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
249        }
250
251        else if(datatype ==MPI_UINT64_T)
252        {
253          assert(datasize == sizeof(uint64_t));
254          for(int i=1; i<num_ep; i++)
255            reduce_sum<uint64_t>(static_cast<uint64_t*>(comm.my_buffer->void_buffer[i]), static_cast<uint64_t*>(recvbuf), count);
256        }
257       
258        else if(datatype ==MPI_LONG_LONG_INT)
259        {
260          assert(datasize == sizeof(long long));
261          for(int i=1; i<num_ep; i++)
262            reduce_sum<long long>(static_cast<long long*>(comm.my_buffer->void_buffer[i]), static_cast<long long*>(recvbuf), count);
263        }
264
265        else printf("datatype Error\n");
266
267      }
268
269      if(op == MPI_LOR)
270      {
271        if(datatype != MPI_INT)
272          printf("datatype Error, must be MPI_INT\n");
273        else
274        {
275          assert(datasize == sizeof(int));
276          for(int i=1; i<num_ep; i++)
277            reduce_lor<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
278        }
279      }
280    }
281
282    MPI_Barrier_local(comm);
283
284  }
285
286
287  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
288  {
289
290    if(!comm.is_ep && comm.mpi_comm)
291    {
292      return ::MPI_Reduce(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root, to_mpi_comm(comm.mpi_comm));
293    }
294
295
296
297    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
298    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
299    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
300    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
301    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
302    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
303
304    int root_mpi_rank = comm.rank_map->at(root).second;
305    int root_ep_loc = comm.rank_map->at(root).first;
306
307    ::MPI_Aint datasize, lb;
308
309    ::MPI_Type_get_extent(*(static_cast< ::MPI_Datatype*>(datatype)), &lb, &datasize);
310
311    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root;
312    bool is_root = ep_rank == root;
313
314    void* local_recvbuf;
315
316    if(is_master)
317    {
318      local_recvbuf = new void*[datasize * count];
319    }
320
321    if(mpi_rank == root_mpi_rank) MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, root_ep_loc, comm);
322    else                          MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
323
324
325
326    if(is_master)
327    {
328      ::MPI_Reduce(local_recvbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root_mpi_rank, to_mpi_comm(comm.mpi_comm));
329     
330    }
331
332    if(is_master)
333    {
334      delete[] local_recvbuf;
335    }
336
337    MPI_Barrier_local(comm);
338  }
339
340
341}
342
Note: See TracBrowser for help on using the repository browser.