source: XIOS/dev/branch_openmp/extern/ep_dev/ep_exscan.cpp @ 1527

Last change on this file since 1527 was 1527, checked in by yushan, 3 years ago

save dev

File size: 9.4 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      comm->my_buffer->void_buffer[0] = recvbuf;
63    }
64    if(ep_rank_loc == 0 && mpi_rank == 0)
65    {
66      comm->my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
67    } 
68     
69
70    MPI_Barrier_local(comm);
71
72    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
73
74    MPI_Barrier_local(comm);
75
76    comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
77   
78    MPI_Barrier_local(comm);
79
80    if(op == MPI_SUM)
81    {
82      if(datatype == MPI_INT )
83      {
84        assert(datasize == sizeof(int));
85        for(int i=0; i<ep_rank_loc; i++)
86          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
87      }
88     
89      else if(datatype == MPI_FLOAT )
90      {
91        assert(datasize == sizeof(float));
92        for(int i=0; i<ep_rank_loc; i++)
93          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
94      }
95     
96
97      else if(datatype == MPI_DOUBLE )
98      {
99        assert(datasize == sizeof(double));
100        for(int i=0; i<ep_rank_loc; i++)
101          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
102      }
103
104      else if(datatype == MPI_CHAR )
105      {
106        assert(datasize == sizeof(char));
107        for(int i=0; i<ep_rank_loc; i++)
108          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
109      }
110
111      else if(datatype == MPI_LONG )
112      {
113        assert(datasize == sizeof(long));
114        for(int i=0; i<ep_rank_loc; i++)
115          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
116      }
117
118      else if(datatype == MPI_UNSIGNED_LONG )
119      {
120        assert(datasize == sizeof(unsigned long));
121        for(int i=0; i<ep_rank_loc; i++)
122          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
123      }
124
125      else printf("datatype Error\n");
126
127     
128    }
129
130    else if(op == MPI_MAX)
131    {
132      if(datatype == MPI_INT )
133      {
134        assert(datasize == sizeof(int));
135        for(int i=0; i<ep_rank_loc; i++)
136          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
137      }
138
139      else if(datatype == MPI_FLOAT )
140      {
141        assert(datasize == sizeof(float));
142        for(int i=0; i<ep_rank_loc; i++)
143          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
144      }
145
146      else if(datatype == MPI_DOUBLE )
147      {
148        assert(datasize == sizeof(double));
149        for(int i=0; i<ep_rank_loc; i++)
150          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
151      }
152
153      else if(datatype == MPI_CHAR )
154      {
155        assert(datasize == sizeof(char));
156        for(int i=0; i<ep_rank_loc; i++)
157          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
158      }
159
160      else if(datatype == MPI_LONG )
161      {
162        assert(datasize == sizeof(long));
163        for(int i=0; i<ep_rank_loc; i++)
164          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
165      }
166
167      else if(datatype == MPI_UNSIGNED_LONG )
168      {
169        assert(datasize == sizeof(unsigned long));
170        for(int i=0; i<ep_rank_loc; i++)
171          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
172      }
173     
174      else printf("datatype Error\n");
175    }
176
177    else //if(op == MPI_MIN)
178    {
179      if(datatype == MPI_INT )
180      {
181        assert(datasize == sizeof(int));
182        for(int i=0; i<ep_rank_loc; i++)
183          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
184      }
185
186      else if(datatype == MPI_FLOAT )
187      {
188        assert(datasize == sizeof(float));
189        for(int i=0; i<ep_rank_loc; i++)
190          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
191      }
192
193      else if(datatype == MPI_DOUBLE )
194      {
195        assert(datasize == sizeof(double));
196        for(int i=0; i<ep_rank_loc; i++)
197          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
198      }
199
200      else if(datatype == MPI_CHAR )
201      {
202        assert(datasize == sizeof(char));
203        for(int i=0; i<ep_rank_loc; i++)
204          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
205      }
206
207      else if(datatype == MPI_LONG )
208      {
209        assert(datasize == sizeof(long));
210        for(int i=0; i<ep_rank_loc; i++)
211          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
212      }
213
214      else if(datatype == MPI_UNSIGNED_LONG )
215      {
216        assert(datasize == sizeof(unsigned long));
217        for(int i=0; i<ep_rank_loc; i++)
218          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
219      }
220
221      else printf("datatype Error\n");
222    }
223
224    MPI_Barrier_local(comm);
225
226  }
227
228  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
229  {
230    if(!comm->is_ep) return ::MPI_Exscan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
231    if(comm->is_intercomm) return MPI_Exscan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
232   
233    valid_type(datatype);
234
235    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
236    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
237    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
238    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
239    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
240    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
241
242    ::MPI_Aint datasize, lb;
243    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
244   
245    void* tmp_sendbuf;
246    tmp_sendbuf = new void*[datasize * count];
247
248    int my_src = 0;
249    int my_dst = ep_rank;
250
251    std::vector<int> my_map(mpi_size, 0);
252
253    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
254
255    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
256    my_src += ep_rank_loc;
257
258     
259    for(int i=0; i<mpi_size; i++)
260    {
261      if(my_dst < my_map[i])
262      {
263        my_dst = get_ep_rank(comm, my_dst, i); 
264        break;
265      }
266      else
267        my_dst -= my_map[i];
268    }
269
270    if(ep_rank != my_dst) 
271    {
272      MPI_Request request[2];
273      MPI_Status status[2];
274
275      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
276   
277      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
278   
279      MPI_Waitall(2, request, status);
280    }
281
282    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
283   
284
285    void* tmp_recvbuf;
286    tmp_recvbuf = new void*[datasize * count];   
287
288    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
289
290    if(ep_rank_loc == 0)
291    {
292      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
293    }
294   
295    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
296
297
298    if(ep_rank != my_src) 
299    {
300      MPI_Request request[2];
301      MPI_Status status[2];
302
303      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
304   
305      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
306   
307      MPI_Waitall(2, request, status);
308    }
309
310    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
311
312    delete[] tmp_sendbuf;
313    delete[] tmp_recvbuf;
314
315  }
316
317
318  int MPI_Exscan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
319  {
320    printf("MPI_Exscan_intercomm not yet implemented\n");
321    MPI_Abort(comm, 0);
322  }
323
324}
Note: See TracBrowser for help on using the repository browser.