source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scan.cpp @ 1287

Last change on this file since 1287 was 1287, checked in by yushan, 7 years ago

EP updated

File size: 12.8 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      if(op == MPI_SUM)
63      {
64        if(datatype == MPI_INT && datasize == sizeof(int))
65          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
66         
67        else if(datatype == MPI_FLOAT && datasize == sizeof(float))
68          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
69             
70        else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
71          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
72     
73        else if(datatype == MPI_CHAR && datasize == sizeof(char))
74          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
75     
76        else if(datatype == MPI_LONG && datasize == sizeof(long))
77          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
78           
79        else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
80          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
81           
82        else printf("datatype Error\n");
83      }
84
85      else if(op == MPI_MAX)
86      {
87        if(datatype == MPI_INT && datasize == sizeof(int))
88          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
89         
90        else if(datatype == MPI_FLOAT && datasize == sizeof(float))
91          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
92             
93        else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
94          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
95     
96        else if(datatype == MPI_CHAR && datasize == sizeof(char))
97          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
98     
99        else if(datatype == MPI_LONG && datasize == sizeof(long))
100          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
101           
102        else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
103          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
104           
105        else printf("datatype Error\n");
106      }
107
108      else //(op == MPI_MIN)
109      {
110        if(datatype == MPI_INT && datasize == sizeof(int))
111          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
112         
113        else if(datatype == MPI_FLOAT && datasize == sizeof(float))
114          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
115             
116        else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
117          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
118     
119        else if(datatype == MPI_CHAR && datasize == sizeof(char))
120          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
121     
122        else if(datatype == MPI_LONG && datasize == sizeof(long))
123          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
124           
125        else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
126          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
127           
128        else printf("datatype Error\n");
129      }
130
131      comm.my_buffer->void_buffer[0] = recvbuf;
132    }
133    else
134    {
135      comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
136      memcpy(recvbuf, sendbuf, datasize*count);
137    } 
138     
139
140
141    MPI_Barrier_local(comm);
142
143    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count);
144
145
146    if(op == MPI_SUM)
147    {
148      if(datatype == MPI_INT && datasize == sizeof(int))
149      {
150        for(int i=1; i<ep_rank_loc+1; i++)
151          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
152      }
153     
154      else if(datatype == MPI_FLOAT && datasize == sizeof(float))
155      {
156        for(int i=1; i<ep_rank_loc+1; i++)
157          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
158      }
159     
160
161      else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
162      {
163        for(int i=1; i<ep_rank_loc+1; i++)
164          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
165      }
166
167      else if(datatype == MPI_CHAR && datasize == sizeof(char))
168      {
169        for(int i=1; i<ep_rank_loc+1; i++)
170          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
171      }
172
173      else if(datatype == MPI_LONG && datasize == sizeof(long))
174      {
175        for(int i=1; i<ep_rank_loc+1; i++)
176          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
177      }
178
179      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
180      {
181        for(int i=1; i<ep_rank_loc+1; i++)
182          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
183      }
184
185      else printf("datatype Error\n");
186
187     
188    }
189
190    else if(op == MPI_MAX)
191    {
192      if(datatype == MPI_INT && datasize == sizeof(int))
193        for(int i=1; i<ep_rank_loc+1; i++)
194          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
195
196      else if(datatype == MPI_FLOAT && datasize == sizeof(float))
197        for(int i=1; i<ep_rank_loc+1; i++)
198          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
199
200      else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
201        for(int i=1; i<ep_rank_loc+1; i++)
202          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
203
204      else if(datatype == MPI_CHAR && datasize == sizeof(char))
205        for(int i=1; i<ep_rank_loc+1; i++)
206          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
207
208      else if(datatype == MPI_LONG && datasize == sizeof(long))
209        for(int i=1; i<ep_rank_loc+1; i++)
210          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
211
212      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
213        for(int i=1; i<ep_rank_loc+1; i++)
214          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
215     
216      else printf("datatype Error\n");
217    }
218
219    else //if(op == MPI_MIN)
220    {
221      if(datatype == MPI_INT && datasize == sizeof(int))
222        for(int i=1; i<ep_rank_loc+1; i++)
223          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
224
225      else if(datatype == MPI_FLOAT && datasize == sizeof(float))
226        for(int i=1; i<ep_rank_loc+1; i++)
227          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
228
229      else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
230        for(int i=1; i<ep_rank_loc+1; i++)
231          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
232
233      else if(datatype == MPI_CHAR && datasize == sizeof(char))
234        for(int i=1; i<ep_rank_loc+1; i++)
235          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
236
237      else if(datatype == MPI_LONG && datasize == sizeof(long))
238        for(int i=1; i<ep_rank_loc+1; i++)
239          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
240
241      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
242        for(int i=1; i<ep_rank_loc+1; i++)
243          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
244
245      else printf("datatype Error\n");
246    }
247
248    MPI_Barrier_local(comm);
249
250  }
251
252
253  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
254  {
255    if(!comm.is_ep)
256    {
257      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
258    }
259   
260    valid_type(datatype);
261
262    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
263    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
264    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
265    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
266    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
267    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
268
269    ::MPI_Aint datasize, lb;
270    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
271   
272    void* tmp_sendbuf;
273    tmp_sendbuf = new void*[datasize * count];
274
275    int my_src = 0;
276    int my_dst = ep_rank;
277
278    std::vector<int> my_map(mpi_size, 0);
279
280    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++;
281
282    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
283    my_src += ep_rank_loc;
284
285     
286    for(int i=0; i<mpi_size; i++)
287    {
288      if(my_dst < my_map[i])
289      {
290        my_dst = get_ep_rank(comm, my_dst, i); 
291        break;
292      }
293      else
294        my_dst -= my_map[i];
295    }
296
297    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
298    MPI_Barrier(comm);
299
300    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
301
302    if(ep_rank != my_dst) 
303    {
304      MPI_Request request[2];
305      MPI_Status status[2];
306
307      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
308   
309      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
310   
311      MPI_Waitall(2, request, status);
312    }
313   
314
315    void* tmp_recvbuf;
316    tmp_recvbuf = new void*[datasize * count];   
317
318    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
319
320    if(ep_rank_loc == 0)
321      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
322
323    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
324   
325    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
326
327    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
328
329
330
331    if(ep_rank != my_src) 
332    {
333      MPI_Request request[2];
334      MPI_Status status[2];
335
336      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
337   
338      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
339   
340      MPI_Waitall(2, request, status);
341    }
342
343    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
344   
345
346
347
348    delete[] tmp_sendbuf;
349    delete[] tmp_recvbuf;
350
351  }
352
353}
Note: See TracBrowser for help on using the repository browser.