source: XIOS/dev/branch_openmp/extern/ep_dev/ep_scan.cpp @ 1381

Last change on this file since 1381 was 1381, checked in by yushan, 5 years ago

add folder for MPI EP-RMA development. Current: MPI_Win, MPI_win_create, MPI_win_fence, MPI_win_free

File size: 13.9 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      if(op == MPI_SUM)
63      {
64        if(datatype == MPI_INT)
65        {
66          assert(datasize == sizeof(int));
67          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
68        }
69         
70        else if(datatype == MPI_FLOAT)
71        {
72          assert( datasize == sizeof(float));
73          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
74        } 
75             
76        else if(datatype == MPI_DOUBLE )
77        {
78          assert( datasize == sizeof(double));
79          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
80        }
81     
82        else if(datatype == MPI_CHAR)
83        {
84          assert( datasize == sizeof(char));
85          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
86        } 
87         
88        else if(datatype == MPI_LONG)
89        {
90          assert( datasize == sizeof(long));
91          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
92        } 
93         
94           
95        else if(datatype == MPI_UNSIGNED_LONG)
96        {
97          assert(datasize == sizeof(unsigned long));
98          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
99        }
100           
101        else printf("datatype Error\n");
102      }
103
104      else if(op == MPI_MAX)
105      {
106        if(datatype == MPI_INT)
107        {
108          assert( datasize == sizeof(int));
109          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
110        } 
111         
112        else if(datatype == MPI_FLOAT )
113        {
114          assert( datasize == sizeof(float));
115          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
116        }
117
118        else if(datatype == MPI_DOUBLE )
119        {
120          assert( datasize == sizeof(double));
121          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
122        }
123     
124        else if(datatype == MPI_CHAR )
125        {
126          assert(datasize == sizeof(char));
127          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
128        }
129     
130        else if(datatype == MPI_LONG)
131        {
132          assert( datasize == sizeof(long));
133          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
134        } 
135           
136        else if(datatype == MPI_UNSIGNED_LONG)
137        {
138          assert( datasize == sizeof(unsigned long));
139          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
140        } 
141           
142        else printf("datatype Error\n");
143      }
144
145      else //(op == MPI_MIN)
146      {
147        if(datatype == MPI_INT )
148        {
149          assert (datasize == sizeof(int));
150          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
151        }
152         
153        else if(datatype == MPI_FLOAT )
154        {
155          assert( datasize == sizeof(float));
156          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
157        }
158             
159        else if(datatype == MPI_DOUBLE )
160        {
161          assert( datasize == sizeof(double));
162          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
163        }
164     
165        else if(datatype == MPI_CHAR )
166        {
167          assert( datasize == sizeof(char));
168          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
169        }
170     
171        else if(datatype == MPI_LONG )
172        { 
173          assert( datasize == sizeof(long));
174          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
175        }
176           
177        else if(datatype == MPI_UNSIGNED_LONG )
178        {
179          assert( datasize == sizeof(unsigned long));
180          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
181        }
182           
183        else printf("datatype Error\n");
184      }
185
186      comm.my_buffer->void_buffer[0] = recvbuf;
187    }
188    else
189    {
190      comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
191      memcpy(recvbuf, sendbuf, datasize*count);
192    } 
193     
194
195
196    MPI_Barrier_local(comm);
197
198    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count);
199
200
201    if(op == MPI_SUM)
202    {
203      if(datatype == MPI_INT )
204      {
205        assert (datasize == sizeof(int));
206        for(int i=1; i<ep_rank_loc+1; i++)
207          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
208      }
209     
210      else if(datatype == MPI_FLOAT )
211      {
212        assert(datasize == sizeof(float));
213        for(int i=1; i<ep_rank_loc+1; i++)
214          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
215      }
216     
217
218      else if(datatype == MPI_DOUBLE )
219      {
220        assert(datasize == sizeof(double));
221        for(int i=1; i<ep_rank_loc+1; i++)
222          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
223      }
224
225      else if(datatype == MPI_CHAR )
226      {
227        assert(datasize == sizeof(char));
228        for(int i=1; i<ep_rank_loc+1; i++)
229          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
230      }
231
232      else if(datatype == MPI_LONG )
233      {
234        assert(datasize == sizeof(long));
235        for(int i=1; i<ep_rank_loc+1; i++)
236          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
237      }
238
239      else if(datatype == MPI_UNSIGNED_LONG )
240      {
241        assert(datasize == sizeof(unsigned long));
242        for(int i=1; i<ep_rank_loc+1; i++)
243          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
244      }
245
246      else printf("datatype Error\n");
247
248     
249    }
250
251    else if(op == MPI_MAX)
252    {
253      if(datatype == MPI_INT)
254      {
255        assert(datasize == sizeof(int));
256        for(int i=1; i<ep_rank_loc+1; i++)
257          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
258      }
259
260      else if(datatype == MPI_FLOAT )
261      {
262        assert(datasize == sizeof(float));
263        for(int i=1; i<ep_rank_loc+1; i++)
264          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
265      }
266
267      else if(datatype == MPI_DOUBLE )
268      {
269        assert(datasize == sizeof(double));
270        for(int i=1; i<ep_rank_loc+1; i++)
271          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
272      }
273
274      else if(datatype == MPI_CHAR )
275      {
276        assert(datasize == sizeof(char));
277        for(int i=1; i<ep_rank_loc+1; i++)
278          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
279      }
280
281      else if(datatype == MPI_LONG )
282      {
283        assert(datasize == sizeof(long));
284        for(int i=1; i<ep_rank_loc+1; i++)
285          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
286      }
287
288      else if(datatype == MPI_UNSIGNED_LONG )
289      {
290        assert(datasize == sizeof(unsigned long));
291        for(int i=1; i<ep_rank_loc+1; i++)
292          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
293      }
294     
295      else printf("datatype Error\n");
296    }
297
298    else //if(op == MPI_MIN)
299    {
300      if(datatype == MPI_INT )
301      {
302        assert(datasize == sizeof(int));
303        for(int i=1; i<ep_rank_loc+1; i++)
304          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
305      }
306
307      else if(datatype == MPI_FLOAT )
308      {
309        assert(datasize == sizeof(float));
310        for(int i=1; i<ep_rank_loc+1; i++)
311          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
312      }
313
314      else if(datatype == MPI_DOUBLE )
315      {
316        assert(datasize == sizeof(double));
317        for(int i=1; i<ep_rank_loc+1; i++)
318          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
319      }
320
321      else if(datatype == MPI_CHAR )
322      {
323        assert(datasize == sizeof(char));
324        for(int i=1; i<ep_rank_loc+1; i++)
325          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
326      }
327
328      else if(datatype == MPI_LONG )
329      {
330        assert(datasize == sizeof(long));
331        for(int i=1; i<ep_rank_loc+1; i++)
332          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
333      }
334
335      else if(datatype == MPI_UNSIGNED_LONG )
336      {
337        assert(datasize == sizeof(unsigned long));
338        for(int i=1; i<ep_rank_loc+1; i++)
339          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
340      }
341
342      else printf("datatype Error\n");
343    }
344
345    MPI_Barrier_local(comm);
346
347  }
348
349
350  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
351  {
352    if(!comm.is_ep)
353    {
354      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
355    }
356   
357    valid_type(datatype);
358
359    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
360    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
361    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
362    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
363    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
364    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
365
366    ::MPI_Aint datasize, lb;
367    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
368   
369    void* tmp_sendbuf;
370    tmp_sendbuf = new void*[datasize * count];
371
372    int my_src = 0;
373    int my_dst = ep_rank;
374
375    std::vector<int> my_map(mpi_size, 0);
376
377    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++;
378
379    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
380    my_src += ep_rank_loc;
381
382     
383    for(int i=0; i<mpi_size; i++)
384    {
385      if(my_dst < my_map[i])
386      {
387        my_dst = get_ep_rank(comm, my_dst, i); 
388        break;
389      }
390      else
391        my_dst -= my_map[i];
392    }
393
394    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
395    MPI_Barrier(comm);
396
397    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
398
399    if(ep_rank != my_dst) 
400    {
401      MPI_Request request[2];
402      MPI_Status status[2];
403
404      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
405   
406      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
407   
408      MPI_Waitall(2, request, status);
409    }
410   
411
412    void* tmp_recvbuf;
413    tmp_recvbuf = new void*[datasize * count];   
414
415    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
416
417    if(ep_rank_loc == 0)
418      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
419
420    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
421   
422    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
423
424    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
425
426
427
428    if(ep_rank != my_src) 
429    {
430      MPI_Request request[2];
431      MPI_Status status[2];
432
433      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
434   
435      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
436   
437      MPI_Waitall(2, request, status);
438    }
439
440    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
441   
442
443    delete[] tmp_sendbuf;
444    delete[] tmp_recvbuf;
445
446  }
447
448}
Note: See TracBrowser for help on using the repository browser.