source: XIOS/dev/dev_trunk_omp/extern/ep_dev/ep_exscan.cpp @ 1604

Last change on this file since 1604 was 1604, checked in by yushan, 5 years ago

branch_openmp merged with trunk r1597

File size: 10.9 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      comm->my_buffer->void_buffer[0] = recvbuf;
63    }
64    if(ep_rank_loc == 0 && mpi_rank == 0)
65    {
66      comm->my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
67    } 
68     
69
70    MPI_Barrier_local(comm);
71
72    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
73
74    MPI_Barrier_local(comm);
75
76    comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
77   
78    MPI_Barrier_local(comm);
79
80    if(op == MPI_SUM)
81    {
82      if(datatype == MPI_INT )
83      {
84        assert(datasize == sizeof(int));
85        for(int i=0; i<ep_rank_loc; i++)
86          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
87      }
88     
89      else if(datatype == MPI_FLOAT )
90      {
91        assert(datasize == sizeof(float));
92        for(int i=0; i<ep_rank_loc; i++)
93          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
94      }
95     
96
97      else if(datatype == MPI_DOUBLE )
98      {
99        assert(datasize == sizeof(double));
100        for(int i=0; i<ep_rank_loc; i++)
101          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
102      }
103
104      else if(datatype == MPI_CHAR )
105      {
106        assert(datasize == sizeof(char));
107        for(int i=0; i<ep_rank_loc; i++)
108          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
109      }
110
111      else if(datatype == MPI_LONG )
112      {
113        assert(datasize == sizeof(long));
114        for(int i=0; i<ep_rank_loc; i++)
115          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
116      }
117
118      else if(datatype == MPI_UNSIGNED_LONG )
119      {
120        assert(datasize == sizeof(unsigned long));
121        for(int i=0; i<ep_rank_loc; i++)
122          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
123      }
124     
125      else if(datatype == MPI_LONG_LONG_INT )
126      {
127        assert(datasize == sizeof(long long int));
128        for(int i=0; i<ep_rank_loc; i++)
129          reduce_sum<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
130      }
131
132      else 
133      {
134        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
135        MPI_Abort(comm, 0);
136      }
137
138     
139    }
140
141    else if(op == MPI_MAX)
142    {
143      if(datatype == MPI_INT )
144      {
145        assert(datasize == sizeof(int));
146        for(int i=0; i<ep_rank_loc; i++)
147          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
148      }
149
150      else if(datatype == MPI_FLOAT )
151      {
152        assert(datasize == sizeof(float));
153        for(int i=0; i<ep_rank_loc; i++)
154          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
155      }
156
157      else if(datatype == MPI_DOUBLE )
158      {
159        assert(datasize == sizeof(double));
160        for(int i=0; i<ep_rank_loc; i++)
161          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
162      }
163
164      else if(datatype == MPI_CHAR )
165      {
166        assert(datasize == sizeof(char));
167        for(int i=0; i<ep_rank_loc; i++)
168          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
169      }
170
171      else if(datatype == MPI_LONG )
172      {
173        assert(datasize == sizeof(long));
174        for(int i=0; i<ep_rank_loc; i++)
175          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
176      }
177
178      else if(datatype == MPI_UNSIGNED_LONG )
179      {
180        assert(datasize == sizeof(unsigned long));
181        for(int i=0; i<ep_rank_loc; i++)
182          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
183      }
184     
185      else if(datatype == MPI_LONG_LONG_INT )
186      {
187        assert(datasize == sizeof(long long int));
188        for(int i=0; i<ep_rank_loc; i++)
189          reduce_max<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
190      }
191
192      else 
193      {
194        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
195        MPI_Abort(comm, 0);
196      }
197    }
198
199    else if(op == MPI_MIN)
200    {
201      if(datatype == MPI_INT )
202      {
203        assert(datasize == sizeof(int));
204        for(int i=0; i<ep_rank_loc; i++)
205          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
206      }
207
208      else if(datatype == MPI_FLOAT )
209      {
210        assert(datasize == sizeof(float));
211        for(int i=0; i<ep_rank_loc; i++)
212          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
213      }
214
215      else if(datatype == MPI_DOUBLE )
216      {
217        assert(datasize == sizeof(double));
218        for(int i=0; i<ep_rank_loc; i++)
219          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
220      }
221
222      else if(datatype == MPI_CHAR )
223      {
224        assert(datasize == sizeof(char));
225        for(int i=0; i<ep_rank_loc; i++)
226          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
227      }
228
229      else if(datatype == MPI_LONG )
230      {
231        assert(datasize == sizeof(long));
232        for(int i=0; i<ep_rank_loc; i++)
233          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
234      }
235
236      else if(datatype == MPI_UNSIGNED_LONG )
237      {
238        assert(datasize == sizeof(unsigned long));
239        for(int i=0; i<ep_rank_loc; i++)
240          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
241      }
242
243      else if(datatype == MPI_LONG_LONG_INT )
244      {
245        assert(datasize == sizeof(long long int));
246        for(int i=0; i<ep_rank_loc; i++)
247          reduce_min<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
248      }
249
250      else 
251      {
252        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
253        MPI_Abort(comm, 0);
254      }
255    }
256   
257    else
258    {
259      printf("op type Error in ep_exscan : MPI_MAX, MPI_MIN, MPI_SUM\n");
260      MPI_Abort(comm, 0);
261    }
262
263    MPI_Barrier_local(comm);
264
265  }
266
267  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
268  {
269    if(!comm->is_ep) return ::MPI_Exscan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
270    if(comm->is_intercomm) return MPI_Exscan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
271   
272    assert(valid_type(datatype));
273    assert(valid_op(op));
274
275    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
276    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
277    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
278    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
279    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
280    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
281
282    ::MPI_Aint datasize, lb;
283    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
284   
285    void* tmp_sendbuf;
286    tmp_sendbuf = new void*[datasize * count];
287
288    int my_src = 0;
289    int my_dst = ep_rank;
290
291    std::vector<int> my_map(mpi_size, 0);
292
293    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
294
295    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
296    my_src += ep_rank_loc;
297
298     
299    for(int i=0; i<mpi_size; i++)
300    {
301      if(my_dst < my_map[i])
302      {
303        my_dst = get_ep_rank(comm, my_dst, i); 
304        break;
305      }
306      else
307        my_dst -= my_map[i];
308    }
309
310    if(ep_rank != my_dst) 
311    {
312      MPI_Request request[2];
313      MPI_Status status[2];
314
315      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
316   
317      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
318   
319      MPI_Waitall(2, request, status);
320    }
321
322    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
323   
324
325    void* tmp_recvbuf;
326    tmp_recvbuf = new void*[datasize * count];   
327
328    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
329
330    if(ep_rank_loc == 0)
331    {
332      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
333    }
334   
335    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
336
337
338    if(ep_rank != my_src) 
339    {
340      MPI_Request request[2];
341      MPI_Status status[2];
342
343      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
344   
345      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
346   
347      MPI_Waitall(2, request, status);
348    }
349
350    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
351
352    delete[] tmp_sendbuf;
353    delete[] tmp_recvbuf;
354
355  }
356
357
358  int MPI_Exscan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
359  {
360    printf("MPI_Exscan_intercomm not yet implemented\n");
361    MPI_Abort(comm, 0);
362  }
363
364}
Note: See TracBrowser for help on using the repository browser.