source: XIOS/dev/branch_openmp/extern/ep_dev/ep_scan.cpp @ 1527

Last change on this file since 1527 was 1527, checked in by yushan, 3 years ago

save dev

File size: 14.2 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      if(op == MPI_SUM)
63      {
64        if(datatype == MPI_INT)
65        {
66          assert(datasize == sizeof(int));
67          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
68        }
69         
70        else if(datatype == MPI_FLOAT)
71        {
72          assert( datasize == sizeof(float));
73          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
74        } 
75             
76        else if(datatype == MPI_DOUBLE )
77        {
78          assert( datasize == sizeof(double));
79          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
80        }
81     
82        else if(datatype == MPI_CHAR)
83        {
84          assert( datasize == sizeof(char));
85          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
86        } 
87         
88        else if(datatype == MPI_LONG)
89        {
90          assert( datasize == sizeof(long));
91          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
92        } 
93         
94           
95        else if(datatype == MPI_UNSIGNED_LONG)
96        {
97          assert(datasize == sizeof(unsigned long));
98          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
99        }
100           
101        else printf("datatype Error\n");
102      }
103
104      else if(op == MPI_MAX)
105      {
106        if(datatype == MPI_INT)
107        {
108          assert( datasize == sizeof(int));
109          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
110        } 
111         
112        else if(datatype == MPI_FLOAT )
113        {
114          assert( datasize == sizeof(float));
115          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
116        }
117
118        else if(datatype == MPI_DOUBLE )
119        {
120          assert( datasize == sizeof(double));
121          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
122        }
123     
124        else if(datatype == MPI_CHAR )
125        {
126          assert(datasize == sizeof(char));
127          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
128        }
129     
130        else if(datatype == MPI_LONG)
131        {
132          assert( datasize == sizeof(long));
133          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
134        } 
135           
136        else if(datatype == MPI_UNSIGNED_LONG)
137        {
138          assert( datasize == sizeof(unsigned long));
139          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
140        } 
141           
142        else printf("datatype Error\n");
143      }
144
145      else //(op == MPI_MIN)
146      {
147        if(datatype == MPI_INT )
148        {
149          assert (datasize == sizeof(int));
150          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
151        }
152         
153        else if(datatype == MPI_FLOAT )
154        {
155          assert( datasize == sizeof(float));
156          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
157        }
158             
159        else if(datatype == MPI_DOUBLE )
160        {
161          assert( datasize == sizeof(double));
162          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
163        }
164     
165        else if(datatype == MPI_CHAR )
166        {
167          assert( datasize == sizeof(char));
168          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
169        }
170     
171        else if(datatype == MPI_LONG )
172        { 
173          assert( datasize == sizeof(long));
174          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
175        }
176           
177        else if(datatype == MPI_UNSIGNED_LONG )
178        {
179          assert( datasize == sizeof(unsigned long));
180          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
181        }
182           
183        else printf("datatype Error\n");
184      }
185
186      comm->my_buffer->void_buffer[0] = recvbuf;
187    }
188    else
189    {
190      comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
191      memcpy(recvbuf, sendbuf, datasize*count);
192    } 
193     
194
195
196    MPI_Barrier_local(comm);
197
198    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
199
200
201    if(op == MPI_SUM)
202    {
203      if(datatype == MPI_INT )
204      {
205        assert (datasize == sizeof(int));
206        for(int i=1; i<ep_rank_loc+1; i++)
207          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
208      }
209     
210      else if(datatype == MPI_FLOAT )
211      {
212        assert(datasize == sizeof(float));
213        for(int i=1; i<ep_rank_loc+1; i++)
214          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
215      }
216     
217
218      else if(datatype == MPI_DOUBLE )
219      {
220        assert(datasize == sizeof(double));
221        for(int i=1; i<ep_rank_loc+1; i++)
222          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
223      }
224
225      else if(datatype == MPI_CHAR )
226      {
227        assert(datasize == sizeof(char));
228        for(int i=1; i<ep_rank_loc+1; i++)
229          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
230      }
231
232      else if(datatype == MPI_LONG )
233      {
234        assert(datasize == sizeof(long));
235        for(int i=1; i<ep_rank_loc+1; i++)
236          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
237      }
238
239      else if(datatype == MPI_UNSIGNED_LONG )
240      {
241        assert(datasize == sizeof(unsigned long));
242        for(int i=1; i<ep_rank_loc+1; i++)
243          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
244      }
245
246      else printf("datatype Error\n");
247
248     
249    }
250
251    else if(op == MPI_MAX)
252    {
253      if(datatype == MPI_INT)
254      {
255        assert(datasize == sizeof(int));
256        for(int i=1; i<ep_rank_loc+1; i++)
257          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
258      }
259
260      else if(datatype == MPI_FLOAT )
261      {
262        assert(datasize == sizeof(float));
263        for(int i=1; i<ep_rank_loc+1; i++)
264          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
265      }
266
267      else if(datatype == MPI_DOUBLE )
268      {
269        assert(datasize == sizeof(double));
270        for(int i=1; i<ep_rank_loc+1; i++)
271          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
272      }
273
274      else if(datatype == MPI_CHAR )
275      {
276        assert(datasize == sizeof(char));
277        for(int i=1; i<ep_rank_loc+1; i++)
278          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
279      }
280
281      else if(datatype == MPI_LONG )
282      {
283        assert(datasize == sizeof(long));
284        for(int i=1; i<ep_rank_loc+1; i++)
285          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
286      }
287
288      else if(datatype == MPI_UNSIGNED_LONG )
289      {
290        assert(datasize == sizeof(unsigned long));
291        for(int i=1; i<ep_rank_loc+1; i++)
292          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
293      }
294     
295      else printf("datatype Error\n");
296    }
297
298    else //if(op == MPI_MIN)
299    {
300      if(datatype == MPI_INT )
301      {
302        assert(datasize == sizeof(int));
303        for(int i=1; i<ep_rank_loc+1; i++)
304          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
305      }
306
307      else if(datatype == MPI_FLOAT )
308      {
309        assert(datasize == sizeof(float));
310        for(int i=1; i<ep_rank_loc+1; i++)
311          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
312      }
313
314      else if(datatype == MPI_DOUBLE )
315      {
316        assert(datasize == sizeof(double));
317        for(int i=1; i<ep_rank_loc+1; i++)
318          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
319      }
320
321      else if(datatype == MPI_CHAR )
322      {
323        assert(datasize == sizeof(char));
324        for(int i=1; i<ep_rank_loc+1; i++)
325          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
326      }
327
328      else if(datatype == MPI_LONG )
329      {
330        assert(datasize == sizeof(long));
331        for(int i=1; i<ep_rank_loc+1; i++)
332          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
333      }
334
335      else if(datatype == MPI_UNSIGNED_LONG )
336      {
337        assert(datasize == sizeof(unsigned long));
338        for(int i=1; i<ep_rank_loc+1; i++)
339          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
340      }
341
342      else printf("datatype Error\n");
343    }
344
345    MPI_Barrier_local(comm);
346
347  }
348
349
350  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
351  {
352    if(!comm->is_ep) return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
353    if(comm->is_intercomm) return MPI_Scan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
354   
355    valid_type(datatype);
356
357    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
358    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
359    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
360    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
361    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
362    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
363
364    ::MPI_Aint datasize, lb;
365    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
366   
367    void* tmp_sendbuf;
368    tmp_sendbuf = new void*[datasize * count];
369
370    int my_src = 0;
371    int my_dst = ep_rank;
372
373    std::vector<int> my_map(mpi_size, 0);
374
375    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
376
377    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
378    my_src += ep_rank_loc;
379
380     
381    for(int i=0; i<mpi_size; i++)
382    {
383      if(my_dst < my_map[i])
384      {
385        my_dst = get_ep_rank(comm, my_dst, i); 
386        break;
387      }
388      else
389        my_dst -= my_map[i];
390    }
391
392    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
393    MPI_Barrier(comm);
394
395    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
396
397    if(ep_rank != my_dst) 
398    {
399      MPI_Request request[2];
400      MPI_Status status[2];
401
402      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
403   
404      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
405   
406      MPI_Waitall(2, request, status);
407    }
408   
409
410    void* tmp_recvbuf;
411    tmp_recvbuf = new void*[datasize * count];   
412
413    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
414
415    if(ep_rank_loc == 0)
416      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
417
418    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
419   
420    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
421
422    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
423
424
425
426    if(ep_rank != my_src) 
427    {
428      MPI_Request request[2];
429      MPI_Status status[2];
430
431      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
432   
433      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
434   
435      MPI_Waitall(2, request, status);
436    }
437
438    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
439   
440
441    delete[] tmp_sendbuf;
442    delete[] tmp_recvbuf;
443
444  }
445
446  int MPI_Scan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
447  {
448    printf("MPI_Scan_intercomm not yet implemented\n");
449    MPI_Abort(comm, 0);
450  }
451
452}
Note: See TracBrowser for help on using the repository browser.