source: XIOS/dev/branch_openmp/extern/ep_dev/ep_exscan.cpp @ 1501

Last change on this file since 1501 was 1500, checked in by yushan, 6 years ago

save dev

File size: 9.5 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      comm->my_buffer->void_buffer[0] = recvbuf;
63    }
64    if(ep_rank_loc == 0 && mpi_rank == 0)
65    {
66      comm->my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
67    } 
68     
69
70    MPI_Barrier_local(comm);
71
72    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
73
74    MPI_Barrier_local(comm);
75
76    comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
77   
78    MPI_Barrier_local(comm);
79
80    if(op == MPI_SUM)
81    {
82      if(datatype == MPI_INT )
83      {
84        assert(datasize == sizeof(int));
85        for(int i=0; i<ep_rank_loc; i++)
86          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
87      }
88     
89      else if(datatype == MPI_FLOAT )
90      {
91        assert(datasize == sizeof(float));
92        for(int i=0; i<ep_rank_loc; i++)
93          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
94      }
95     
96
97      else if(datatype == MPI_DOUBLE )
98      {
99        assert(datasize == sizeof(double));
100        for(int i=0; i<ep_rank_loc; i++)
101          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
102      }
103
104      else if(datatype == MPI_CHAR )
105      {
106        assert(datasize == sizeof(char));
107        for(int i=0; i<ep_rank_loc; i++)
108          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
109      }
110
111      else if(datatype == MPI_LONG )
112      {
113        assert(datasize == sizeof(long));
114        for(int i=0; i<ep_rank_loc; i++)
115          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
116      }
117
118      else if(datatype == MPI_UNSIGNED_LONG )
119      {
120        assert(datasize == sizeof(unsigned long));
121        for(int i=0; i<ep_rank_loc; i++)
122          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
123      }
124
125      else printf("datatype Error\n");
126
127     
128    }
129
130    else if(op == MPI_MAX)
131    {
132      if(datatype == MPI_INT )
133      {
134        assert(datasize == sizeof(int));
135        for(int i=0; i<ep_rank_loc; i++)
136          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
137      }
138
139      else if(datatype == MPI_FLOAT )
140      {
141        assert(datasize == sizeof(float));
142        for(int i=0; i<ep_rank_loc; i++)
143          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
144      }
145
146      else if(datatype == MPI_DOUBLE )
147      {
148        assert(datasize == sizeof(double));
149        for(int i=0; i<ep_rank_loc; i++)
150          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
151      }
152
153      else if(datatype == MPI_CHAR )
154      {
155        assert(datasize == sizeof(char));
156        for(int i=0; i<ep_rank_loc; i++)
157          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
158      }
159
160      else if(datatype == MPI_LONG )
161      {
162        assert(datasize == sizeof(long));
163        for(int i=0; i<ep_rank_loc; i++)
164          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
165      }
166
167      else if(datatype == MPI_UNSIGNED_LONG )
168      {
169        assert(datasize == sizeof(unsigned long));
170        for(int i=0; i<ep_rank_loc; i++)
171          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
172      }
173     
174      else printf("datatype Error\n");
175    }
176
177    else //if(op == MPI_MIN)
178    {
179      if(datatype == MPI_INT )
180      {
181        assert(datasize == sizeof(int));
182        for(int i=0; i<ep_rank_loc; i++)
183          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
184      }
185
186      else if(datatype == MPI_FLOAT )
187      {
188        assert(datasize == sizeof(float));
189        for(int i=0; i<ep_rank_loc; i++)
190          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
191      }
192
193      else if(datatype == MPI_DOUBLE )
194      {
195        assert(datasize == sizeof(double));
196        for(int i=0; i<ep_rank_loc; i++)
197          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
198      }
199
200      else if(datatype == MPI_CHAR )
201      {
202        assert(datasize == sizeof(char));
203        for(int i=0; i<ep_rank_loc; i++)
204          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
205      }
206
207      else if(datatype == MPI_LONG )
208      {
209        assert(datasize == sizeof(long));
210        for(int i=0; i<ep_rank_loc; i++)
211          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
212      }
213
214      else if(datatype == MPI_UNSIGNED_LONG )
215      {
216        assert(datasize == sizeof(unsigned long));
217        for(int i=0; i<ep_rank_loc; i++)
218          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
219      }
220
221      else printf("datatype Error\n");
222    }
223
224    MPI_Barrier_local(comm);
225
226  }
227
228  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
229  {
230    if(!comm->is_ep)
231    {
232      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
233    }
234   
235    valid_type(datatype);
236
237    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
238    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
239    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
240    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
241    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
242    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
243
244    ::MPI_Aint datasize, lb;
245    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
246   
247    void* tmp_sendbuf;
248    tmp_sendbuf = new void*[datasize * count];
249
250    int my_src = 0;
251    int my_dst = ep_rank;
252
253    std::vector<int> my_map(mpi_size, 0);
254
255    for(int i=0; i<comm->rank_map->size(); i++) my_map[comm->rank_map->at(i).second]++;
256
257    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
258    my_src += ep_rank_loc;
259
260     
261    for(int i=0; i<mpi_size; i++)
262    {
263      if(my_dst < my_map[i])
264      {
265        my_dst = get_ep_rank(comm, my_dst, i); 
266        break;
267      }
268      else
269        my_dst -= my_map[i];
270    }
271
272    if(ep_rank != my_dst) 
273    {
274      MPI_Request request[2];
275      MPI_Status status[2];
276
277      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
278   
279      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
280   
281      MPI_Waitall(2, request, status);
282    }
283
284    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
285   
286
287    void* tmp_recvbuf;
288    tmp_recvbuf = new void*[datasize * count];   
289
290    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
291
292    if(ep_rank_loc == 0)
293      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
294
295    // printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
296   
297    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
298
299     // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
300
301
302
303    if(ep_rank != my_src) 
304    {
305      MPI_Request request[2];
306      MPI_Status status[2];
307
308      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
309   
310      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
311   
312      MPI_Waitall(2, request, status);
313    }
314
315    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
316   
317
318
319
320    delete[] tmp_sendbuf;
321    delete[] tmp_recvbuf;
322
323  }
324
325}
Note: See TracBrowser for help on using the repository browser.