source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce.cpp @ 1365

Last change on this file since 1365 was 1365, checked in by yushan, 3 years ago

unify type : MPI_Datatype MPI_Aint

File size: 8.2 KB
Line 
1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15
16namespace ep_lib {
17
18  template<typename T>
19  T max_op(T a, T b)
20  {
21    return max(a,b);
22  }
23
24  template<typename T>
25  T min_op(T a, T b)
26  {
27    return min(a,b);
28  }
29
30  template<typename T>
31  void reduce_max(const T * buffer, T* recvbuf, int count)
32  {
33    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
34  }
35
36  template<typename T>
37  void reduce_min(const T * buffer, T* recvbuf, int count)
38  {
39    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
40  }
41
42  template<typename T>
43  void reduce_sum(const T * buffer, T* recvbuf, int count)
44  {
45    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
46  }
47
48  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int local_root, MPI_Comm comm)
49  {
50    assert(valid_type(datatype));
51    assert(valid_op(op));
52
53    ::MPI_Aint datasize, lb;
54    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
55
56    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
57    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
58    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
59
60    #pragma omp critical (_reduce)
61    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf);
62
63    MPI_Barrier_local(comm);
64
65    if(ep_rank_loc == local_root)
66    {
67
68      memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize * count);
69
70      if(op == MPI_MAX)
71      {
72        if(datatype == MPI_INT)
73        {
74          assert(datasize == sizeof(int));
75          for(int i=1; i<num_ep; i++)
76            reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
77        }
78
79        else if(datatype == MPI_FLOAT)
80        {
81          assert(datasize == sizeof(float));
82          for(int i=1; i<num_ep; i++)
83            reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
84        }
85
86        else if(datatype == MPI_DOUBLE)
87        {
88          assert(datasize == sizeof(double));
89          for(int i=1; i<num_ep; i++)
90            reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
91        }
92
93        else if(datatype == MPI_CHAR)
94        {
95          assert(datasize == sizeof(char));
96          for(int i=1; i<num_ep; i++)
97            reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
98        }
99
100        else if(datatype == MPI_LONG)
101        {
102          assert(datasize == sizeof(long));
103          for(int i=1; i<num_ep; i++)
104            reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
105        }
106
107        else if(datatype == MPI_UNSIGNED_LONG)
108        {
109          assert(datasize == sizeof(unsigned long));
110          for(int i=1; i<num_ep; i++)
111            reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
112        }
113
114        else printf("datatype Error\n");
115
116      }
117
118      if(op == MPI_MIN)
119      {
120        if(datatype ==MPI_INT)
121        {
122          assert(datasize == sizeof(int));
123          for(int i=1; i<num_ep; i++)
124            reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
125        }
126
127        else if(datatype == MPI_FLOAT)
128        {
129          assert(datasize == sizeof(float));
130          for(int i=1; i<num_ep; i++)
131            reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
132        }
133
134        else if(datatype == MPI_DOUBLE)
135        {
136          assert(datasize == sizeof(double));
137          for(int i=1; i<num_ep; i++)
138            reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
139        }
140
141        else if(datatype == MPI_CHAR)
142        {
143          assert(datasize == sizeof(char));
144          for(int i=1; i<num_ep; i++)
145            reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
146        }
147
148        else if(datatype == MPI_LONG)
149        {
150          assert(datasize == sizeof(long));
151          for(int i=1; i<num_ep; i++)
152            reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
153        }
154
155        else if(datatype == MPI_UNSIGNED_LONG)
156        {
157          assert(datasize == sizeof(unsigned long));
158          for(int i=1; i<num_ep; i++)
159            reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
160        }
161
162        else printf("datatype Error\n");
163
164      }
165
166
167      if(op == MPI_SUM)
168      {
169        if(datatype==MPI_INT)
170        {
171          assert(datasize == sizeof(int));
172          for(int i=1; i<num_ep; i++)
173            reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);
174        }
175
176        else if(datatype == MPI_FLOAT)
177        {
178          assert(datasize == sizeof(float));
179          for(int i=1; i<num_ep; i++)
180            reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);
181        }
182
183        else if(datatype == MPI_DOUBLE)
184        {
185          assert(datasize == sizeof(double));
186          for(int i=1; i<num_ep; i++)
187            reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
188        }
189
190        else if(datatype == MPI_CHAR)
191        {
192          assert(datasize == sizeof(char));
193          for(int i=1; i<num_ep; i++)
194            reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
195        }
196
197        else if(datatype == MPI_LONG)
198        {
199          assert(datasize == sizeof(long));
200          for(int i=1; i<num_ep; i++)
201            reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
202        }
203
204        else if(datatype ==MPI_UNSIGNED_LONG)
205        {
206          assert(datasize == sizeof(unsigned long));
207          for(int i=1; i<num_ep; i++)
208            reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);
209        }
210
211        else printf("datatype Error\n");
212
213      }
214    }
215
216    MPI_Barrier_local(comm);
217
218  }
219
220
221  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
222  {
223
224    if(!comm.is_ep && comm.mpi_comm)
225    {
226      return ::MPI_Reduce(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root, to_mpi_comm(comm.mpi_comm));
227    }
228
229
230
231    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
232    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
233    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
234    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
235    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
236    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
237
238    int root_mpi_rank = comm.rank_map->at(root).second;
239    int root_ep_loc = comm.rank_map->at(root).first;
240
241    ::MPI_Aint datasize, lb;
242
243    ::MPI_Type_get_extent(*(static_cast< ::MPI_Datatype*>(datatype)), &lb, &datasize);
244
245    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root;
246    bool is_root = ep_rank == root;
247
248    void* local_recvbuf;
249
250    if(is_master)
251    {
252      local_recvbuf = new void*[datasize * count];
253    }
254
255    if(mpi_rank == root_mpi_rank) MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, root_ep_loc, comm);
256    else                          MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
257
258
259
260    if(is_master)
261    {
262      ::MPI_Reduce(local_recvbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root_mpi_rank, to_mpi_comm(comm.mpi_comm));
263     
264    }
265
266    if(is_master)
267    {
268      delete[] local_recvbuf;
269    }
270
271    MPI_Barrier_local(comm);
272  }
273
274
275}
276
Note: See TracBrowser for help on using the repository browser.