source: XIOS/dev/dev_trunk_omp/extern/ep_dev/ep_scan.cpp @ 1604

Last change on this file since 1604 was 1604, checked in by yushan, 5 years ago

branch_openmp merged with trunk r1597

File size: 17.0 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      if(op == MPI_SUM)
63      {
64        if(datatype == MPI_INT)
65        {
66          assert(datasize == sizeof(int));
67          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
68        }
69         
70        else if(datatype == MPI_FLOAT)
71        {
72          assert( datasize == sizeof(float));
73          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
74        } 
75             
76        else if(datatype == MPI_DOUBLE )
77        {
78          assert( datasize == sizeof(double));
79          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
80        }
81     
82        else if(datatype == MPI_CHAR)
83        {
84          assert( datasize == sizeof(char));
85          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
86        } 
87         
88        else if(datatype == MPI_LONG)
89        {
90          assert( datasize == sizeof(long));
91          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
92        } 
93         
94           
95        else if(datatype == MPI_UNSIGNED_LONG)
96        {
97          assert(datasize == sizeof(unsigned long));
98          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
99        }
100       
101        else if(datatype == MPI_LONG_LONG_INT)
102        {
103          assert(datasize == sizeof(long long int));
104          reduce_sum<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
105        }
106           
107        else 
108        {
109          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
110          MPI_Abort(comm, 0);
111        }
112      }
113
114      else if(op == MPI_MAX)
115      {
116        if(datatype == MPI_INT)
117        {
118          assert( datasize == sizeof(int));
119          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
120        } 
121         
122        else if(datatype == MPI_FLOAT )
123        {
124          assert( datasize == sizeof(float));
125          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
126        }
127
128        else if(datatype == MPI_DOUBLE )
129        {
130          assert( datasize == sizeof(double));
131          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
132        }
133     
134        else if(datatype == MPI_CHAR )
135        {
136          assert(datasize == sizeof(char));
137          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
138        }
139     
140        else if(datatype == MPI_LONG)
141        {
142          assert( datasize == sizeof(long));
143          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
144        } 
145           
146        else if(datatype == MPI_UNSIGNED_LONG)
147        {
148          assert( datasize == sizeof(unsigned long));
149          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
150        } 
151           
152        else if(datatype == MPI_LONG_LONG_INT)
153        {
154          assert(datasize == sizeof(long long int));
155          reduce_max<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
156        }
157           
158        else 
159        {
160          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
161          MPI_Abort(comm, 0);
162        }
163      }
164
165      else if(op == MPI_MIN)
166      {
167        if(datatype == MPI_INT )
168        {
169          assert (datasize == sizeof(int));
170          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
171        }
172         
173        else if(datatype == MPI_FLOAT )
174        {
175          assert( datasize == sizeof(float));
176          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
177        }
178             
179        else if(datatype == MPI_DOUBLE )
180        {
181          assert( datasize == sizeof(double));
182          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
183        }
184     
185        else if(datatype == MPI_CHAR )
186        {
187          assert( datasize == sizeof(char));
188          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
189        }
190     
191        else if(datatype == MPI_LONG )
192        { 
193          assert( datasize == sizeof(long));
194          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
195        }
196           
197        else if(datatype == MPI_UNSIGNED_LONG )
198        {
199          assert( datasize == sizeof(unsigned long));
200          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
201        }
202           
203        else if(datatype == MPI_LONG_LONG_INT)
204        {
205          assert(datasize == sizeof(long long int));
206          reduce_min<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
207        }
208           
209        else 
210        {
211          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
212          MPI_Abort(comm, 0);
213        }
214      }
215     
216      else
217      {
218        printf("op type Error in ep_scan : MPI_MAX, MPI_MIN, MPI_SUM\n");
219        MPI_Abort(comm, 0);
220      }
221
222      comm->my_buffer->void_buffer[0] = recvbuf;
223    }
224    else
225    {
226      comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
227      memcpy(recvbuf, sendbuf, datasize*count);
228    } 
229     
230
231
232    MPI_Barrier_local(comm);
233
234    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
235
236
237    if(op == MPI_SUM)
238    {
239      if(datatype == MPI_INT )
240      {
241        assert (datasize == sizeof(int));
242        for(int i=1; i<ep_rank_loc+1; i++)
243          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
244      }
245     
246      else if(datatype == MPI_FLOAT )
247      {
248        assert(datasize == sizeof(float));
249        for(int i=1; i<ep_rank_loc+1; i++)
250          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
251      }
252     
253
254      else if(datatype == MPI_DOUBLE )
255      {
256        assert(datasize == sizeof(double));
257        for(int i=1; i<ep_rank_loc+1; i++)
258          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
259      }
260
261      else if(datatype == MPI_CHAR )
262      {
263        assert(datasize == sizeof(char));
264        for(int i=1; i<ep_rank_loc+1; i++)
265          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
266      }
267
268      else if(datatype == MPI_LONG )
269      {
270        assert(datasize == sizeof(long));
271        for(int i=1; i<ep_rank_loc+1; i++)
272          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
273      }
274
275      else if(datatype == MPI_UNSIGNED_LONG )
276      {
277        assert(datasize == sizeof(unsigned long));
278        for(int i=1; i<ep_rank_loc+1; i++)
279          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
280      }
281     
282      else if(datatype == MPI_LONG_LONG_INT )
283      {
284        assert(datasize == sizeof(long long int));
285        for(int i=1; i<ep_rank_loc+1; i++)
286          reduce_sum<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
287      }
288
289      else 
290      {
291        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
292        MPI_Abort(comm, 0);
293      }
294
295     
296    }
297
298    else if(op == MPI_MAX)
299    {
300      if(datatype == MPI_INT)
301      {
302        assert(datasize == sizeof(int));
303        for(int i=1; i<ep_rank_loc+1; i++)
304          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
305      }
306
307      else if(datatype == MPI_FLOAT )
308      {
309        assert(datasize == sizeof(float));
310        for(int i=1; i<ep_rank_loc+1; i++)
311          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
312      }
313
314      else if(datatype == MPI_DOUBLE )
315      {
316        assert(datasize == sizeof(double));
317        for(int i=1; i<ep_rank_loc+1; i++)
318          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
319      }
320
321      else if(datatype == MPI_CHAR )
322      {
323        assert(datasize == sizeof(char));
324        for(int i=1; i<ep_rank_loc+1; i++)
325          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
326      }
327
328      else if(datatype == MPI_LONG )
329      {
330        assert(datasize == sizeof(long));
331        for(int i=1; i<ep_rank_loc+1; i++)
332          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
333      }
334
335      else if(datatype == MPI_UNSIGNED_LONG )
336      {
337        assert(datasize == sizeof(unsigned long));
338        for(int i=1; i<ep_rank_loc+1; i++)
339          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
340      }
341     
342      else if(datatype == MPI_LONG_LONG_INT )
343      {
344        assert(datasize == sizeof(long long int));
345        for(int i=1; i<ep_rank_loc+1; i++)
346          reduce_max<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
347      }
348
349      else 
350      {
351        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
352        MPI_Abort(comm, 0);
353      }
354
355    }
356
357    else if(op == MPI_MIN)
358    {
359      if(datatype == MPI_INT )
360      {
361        assert(datasize == sizeof(int));
362        for(int i=1; i<ep_rank_loc+1; i++)
363          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
364      }
365
366      else if(datatype == MPI_FLOAT )
367      {
368        assert(datasize == sizeof(float));
369        for(int i=1; i<ep_rank_loc+1; i++)
370          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
371      }
372
373      else if(datatype == MPI_DOUBLE )
374      {
375        assert(datasize == sizeof(double));
376        for(int i=1; i<ep_rank_loc+1; i++)
377          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
378      }
379
380      else if(datatype == MPI_CHAR )
381      {
382        assert(datasize == sizeof(char));
383        for(int i=1; i<ep_rank_loc+1; i++)
384          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
385      }
386
387      else if(datatype == MPI_LONG )
388      {
389        assert(datasize == sizeof(long));
390        for(int i=1; i<ep_rank_loc+1; i++)
391          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
392      }
393
394      else if(datatype == MPI_UNSIGNED_LONG )
395      {
396        assert(datasize == sizeof(unsigned long));
397        for(int i=1; i<ep_rank_loc+1; i++)
398          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
399      }
400
401      else if(datatype == MPI_LONG_LONG_INT )
402      {
403        assert(datasize == sizeof(long long int));
404        for(int i=1; i<ep_rank_loc+1; i++)
405          reduce_min<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
406      }
407
408      else 
409      {
410        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
411        MPI_Abort(comm, 0);
412      }
413
414    }
415   
416    else
417    {
418      printf("op type Error in ep_scan : MPI_MAX, MPI_MIN, MPI_SUM\n");
419      MPI_Abort(comm, 0);
420    }
421
422    MPI_Barrier_local(comm);
423
424  }
425
426
427  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
428  {
429    if(!comm->is_ep) return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
430    if(comm->is_intercomm) return MPI_Scan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
431   
432    valid_type(datatype);
433
434    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
435    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
436    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
437    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
438    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
439    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
440
441    ::MPI_Aint datasize, lb;
442    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
443   
444    void* tmp_sendbuf;
445    tmp_sendbuf = new void*[datasize * count];
446
447    int my_src = 0;
448    int my_dst = ep_rank;
449
450    std::vector<int> my_map(mpi_size, 0);
451
452    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
453
454    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
455    my_src += ep_rank_loc;
456
457     
458    for(int i=0; i<mpi_size; i++)
459    {
460      if(my_dst < my_map[i])
461      {
462        my_dst = get_ep_rank(comm, my_dst, i); 
463        break;
464      }
465      else
466        my_dst -= my_map[i];
467    }
468
469    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
470    MPI_Barrier(comm);
471
472    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
473
474    if(ep_rank != my_dst) 
475    {
476      MPI_Request request[2];
477      MPI_Status status[2];
478
479      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
480   
481      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
482   
483      MPI_Waitall(2, request, status);
484    }
485   
486
487    void* tmp_recvbuf;
488    tmp_recvbuf = new void*[datasize * count];   
489
490    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
491
492    if(ep_rank_loc == 0)
493      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
494
495    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
496   
497    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
498
499    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
500
501
502
503    if(ep_rank != my_src) 
504    {
505      MPI_Request request[2];
506      MPI_Status status[2];
507
508      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
509   
510      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
511   
512      MPI_Waitall(2, request, status);
513    }
514
515    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
516   
517
518    delete[] tmp_sendbuf;
519    delete[] tmp_recvbuf;
520
521  }
522
523  int MPI_Scan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
524  {
525    printf("MPI_Scan_intercomm not yet implemented\n");
526    MPI_Abort(comm, 0);
527  }
528
529}
Note: See TracBrowser for help on using the repository browser.