source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scan.cpp @ 2022

Last change on this file since 2022 was 1642, checked in by yushan, 5 years ago

dev on ADA. add flag switch _usingEP/_usingMPI

File size: 16.9 KB
RevLine 
[1134]1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
[1295]11#include "ep_mpi.hpp"
[1134]12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
[1295]29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
[1134]34
[1295]35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
[1134]37  {
[1295]38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
[1134]40
[1295]41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
[1134]45  }
46
47
[1295]48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
[1289]51
[1520]52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
[1295]55   
[1289]56
[1295]57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
[1134]59
[1295]60    if(ep_rank_loc == 0 && mpi_rank != 0)
[1134]61    {
[1642]62      if(op == EP_SUM)
[1134]63      {
[1642]64        if(datatype == EP_INT)
[1289]65        {
[1295]66          assert(datasize == sizeof(int));
67          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
[1289]68        }
[1295]69         
[1642]70        else if(datatype == EP_FLOAT)
[1295]71        {
72          assert( datasize == sizeof(float));
73          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
74        } 
75             
[1642]76        else if(datatype == EP_DOUBLE )
[1295]77        {
78          assert( datasize == sizeof(double));
79          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
80        }
81     
[1642]82        else if(datatype == EP_CHAR)
[1295]83        {
84          assert( datasize == sizeof(char));
85          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
86        } 
87         
[1642]88        else if(datatype == EP_LONG)
[1295]89        {
90          assert( datasize == sizeof(long));
91          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
92        } 
93         
94           
[1642]95        else if(datatype == EP_UNSIGNED_LONG)
[1295]96        {
97          assert(datasize == sizeof(unsigned long));
98          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
99        }
[1540]100       
[1642]101        else if(datatype == EP_LONG_LONG_INT)
[1540]102        {
103          assert(datasize == sizeof(long long int));
104          reduce_sum<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
105        }
[1295]106           
[1540]107        else 
108        {
109          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
110          MPI_Abort(comm, 0);
111        }
[1287]112      }
[1134]113
[1642]114      else if(op == EP_MAX)
[1287]115      {
[1642]116        if(datatype == EP_INT)
[1289]117        {
[1295]118          assert( datasize == sizeof(int));
119          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
120        } 
121         
[1642]122        else if(datatype == EP_FLOAT )
[1295]123        {
124          assert( datasize == sizeof(float));
125          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
[1289]126        }
[1134]127
[1642]128        else if(datatype == EP_DOUBLE )
[1295]129        {
130          assert( datasize == sizeof(double));
131          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
132        }
133     
[1642]134        else if(datatype == EP_CHAR )
[1295]135        {
136          assert(datasize == sizeof(char));
137          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
138        }
139     
[1642]140        else if(datatype == EP_LONG)
[1295]141        {
142          assert( datasize == sizeof(long));
143          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
144        } 
145           
[1642]146        else if(datatype == EP_UNSIGNED_LONG)
[1295]147        {
148          assert( datasize == sizeof(unsigned long));
149          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
150        } 
151           
[1642]152        else if(datatype == EP_LONG_LONG_INT)
[1540]153        {
154          assert(datasize == sizeof(long long int));
155          reduce_max<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
156        }
157           
158        else 
159        {
160          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
161          MPI_Abort(comm, 0);
162        }
[1287]163      }
[1134]164
[1642]165      else if(op == EP_MIN)
[1134]166      {
[1642]167        if(datatype == EP_INT )
[1289]168        {
[1295]169          assert (datasize == sizeof(int));
170          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
[1289]171        }
[1295]172         
[1642]173        else if(datatype == EP_FLOAT )
[1295]174        {
175          assert( datasize == sizeof(float));
176          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
177        }
178             
[1642]179        else if(datatype == EP_DOUBLE )
[1295]180        {
181          assert( datasize == sizeof(double));
182          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
183        }
184     
[1642]185        else if(datatype == EP_CHAR )
[1295]186        {
187          assert( datasize == sizeof(char));
188          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
189        }
190     
[1642]191        else if(datatype == EP_LONG )
[1295]192        { 
193          assert( datasize == sizeof(long));
194          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
195        }
196           
[1642]197        else if(datatype == EP_UNSIGNED_LONG )
[1295]198        {
199          assert( datasize == sizeof(unsigned long));
200          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
201        }
202           
[1642]203        else if(datatype == EP_LONG_LONG_INT)
[1540]204        {
205          assert(datasize == sizeof(long long int));
206          reduce_min<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
207        }
208           
209        else 
210        {
211          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
212          MPI_Abort(comm, 0);
213        }
[1134]214      }
[1540]215     
216      else
217      {
218        printf("op type Error in ep_scan : MPI_MAX, MPI_MIN, MPI_SUM\n");
219        MPI_Abort(comm, 0);
220      }
[1289]221
[1520]222      comm->my_buffer->void_buffer[0] = recvbuf;
[1295]223    }
224    else
225    {
[1520]226      comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
[1295]227      memcpy(recvbuf, sendbuf, datasize*count);
228    } 
229     
[1289]230
231
[1295]232    MPI_Barrier_local(comm);
[1289]233
[1520]234    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
[1289]235
[1134]236
[1642]237    if(op == EP_SUM)
[1289]238    {
[1642]239      if(datatype == EP_INT )
[1134]240      {
[1295]241        assert (datasize == sizeof(int));
242        for(int i=1; i<ep_rank_loc+1; i++)
[1520]243          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
[1295]244      }
245     
[1642]246      else if(datatype == EP_FLOAT )
[1295]247      {
248        assert(datasize == sizeof(float));
249        for(int i=1; i<ep_rank_loc+1; i++)
[1520]250          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1295]251      }
252     
[1289]253
[1642]254      else if(datatype == EP_DOUBLE )
[1295]255      {
256        assert(datasize == sizeof(double));
257        for(int i=1; i<ep_rank_loc+1; i++)
[1520]258          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1134]259      }
260
[1642]261      else if(datatype == EP_CHAR )
[1295]262      {
263        assert(datasize == sizeof(char));
264        for(int i=1; i<ep_rank_loc+1; i++)
[1520]265          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]266      }
[1289]267
[1642]268      else if(datatype == EP_LONG )
[1134]269      {
[1295]270        assert(datasize == sizeof(long));
271        for(int i=1; i<ep_rank_loc+1; i++)
[1520]272          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1295]273      }
[1289]274
[1642]275      else if(datatype == EP_UNSIGNED_LONG )
[1295]276      {
277        assert(datasize == sizeof(unsigned long));
278        for(int i=1; i<ep_rank_loc+1; i++)
[1520]279          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1134]280      }
[1540]281     
[1642]282      else if(datatype == EP_LONG_LONG_INT )
[1540]283      {
284        assert(datasize == sizeof(long long int));
285        for(int i=1; i<ep_rank_loc+1; i++)
286          reduce_sum<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
287      }
[1134]288
[1540]289      else 
290      {
291        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
292        MPI_Abort(comm, 0);
293      }
[1289]294
[1295]295     
296    }
[1289]297
[1642]298    else if(op == EP_MAX)
[1289]299    {
[1642]300      if(datatype == EP_INT)
[1134]301      {
[1295]302        assert(datasize == sizeof(int));
303        for(int i=1; i<ep_rank_loc+1; i++)
[1520]304          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
[1295]305      }
[1289]306
[1642]307      else if(datatype == EP_FLOAT )
[1295]308      {
309        assert(datasize == sizeof(float));
310        for(int i=1; i<ep_rank_loc+1; i++)
[1520]311          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1134]312      }
313
[1642]314      else if(datatype == EP_DOUBLE )
[1295]315      {
316        assert(datasize == sizeof(double));
317        for(int i=1; i<ep_rank_loc+1; i++)
[1520]318          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]319      }
[1289]320
[1642]321      else if(datatype == EP_CHAR )
[1134]322      {
[1295]323        assert(datasize == sizeof(char));
324        for(int i=1; i<ep_rank_loc+1; i++)
[1520]325          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]326      }
[1289]327
[1642]328      else if(datatype == EP_LONG )
[1295]329      {
330        assert(datasize == sizeof(long));
331        for(int i=1; i<ep_rank_loc+1; i++)
[1520]332          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]333      }
[1295]334
[1642]335      else if(datatype == EP_UNSIGNED_LONG )
[1295]336      {
337        assert(datasize == sizeof(unsigned long));
338        for(int i=1; i<ep_rank_loc+1; i++)
[1520]339          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1295]340      }
341     
[1642]342      else if(datatype == EP_LONG_LONG_INT )
[1540]343      {
344        assert(datasize == sizeof(long long int));
345        for(int i=1; i<ep_rank_loc+1; i++)
346          reduce_max<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
347      }
348
349      else 
350      {
351        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
352        MPI_Abort(comm, 0);
353      }
354
[1289]355    }
[1134]356
[1642]357    else if(op == EP_MIN)
[1134]358    {
[1642]359      if(datatype == EP_INT )
[1289]360      {
[1295]361        assert(datasize == sizeof(int));
362        for(int i=1; i<ep_rank_loc+1; i++)
[1520]363          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
[1295]364      }
[1134]365
[1642]366      else if(datatype == EP_FLOAT )
[1295]367      {
368        assert(datasize == sizeof(float));
369        for(int i=1; i<ep_rank_loc+1; i++)
[1520]370          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1289]371      }
[1134]372
[1642]373      else if(datatype == EP_DOUBLE )
[1295]374      {
375        assert(datasize == sizeof(double));
376        for(int i=1; i<ep_rank_loc+1; i++)
[1520]377          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]378      }
[1134]379
[1642]380      else if(datatype == EP_CHAR )
[1289]381      {
[1295]382        assert(datasize == sizeof(char));
383        for(int i=1; i<ep_rank_loc+1; i++)
[1520]384          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]385      }
[1134]386
[1642]387      else if(datatype == EP_LONG )
[1295]388      {
389        assert(datasize == sizeof(long));
390        for(int i=1; i<ep_rank_loc+1; i++)
[1520]391          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1289]392      }
[1134]393
[1642]394      else if(datatype == EP_UNSIGNED_LONG )
[1289]395      {
[1295]396        assert(datasize == sizeof(unsigned long));
397        for(int i=1; i<ep_rank_loc+1; i++)
[1520]398          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1289]399      }
[1134]400
[1642]401      else if(datatype == EP_LONG_LONG_INT )
[1540]402      {
403        assert(datasize == sizeof(long long int));
404        for(int i=1; i<ep_rank_loc+1; i++)
405          reduce_min<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
406      }
407
408      else 
409      {
410        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
411        MPI_Abort(comm, 0);
412      }
413
[1295]414    }
[1540]415   
416    else
417    {
418      printf("op type Error in ep_scan : MPI_MAX, MPI_MIN, MPI_SUM\n");
419      MPI_Abort(comm, 0);
420    }
[1134]421
[1295]422    MPI_Barrier_local(comm);
[1134]423
424  }
425
426
427  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
428  {
[1539]429    if(!comm->is_ep) return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
430    if(comm->is_intercomm) return MPI_Scan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
[1295]431   
432    valid_type(datatype);
[1134]433
[1520]434    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
435    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
436    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
437    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
438    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
439    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
[1134]440
441    ::MPI_Aint datasize, lb;
[1295]442    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
443   
444    void* tmp_sendbuf;
445    tmp_sendbuf = new void*[datasize * count];
[1134]446
[1295]447    int my_src = 0;
448    int my_dst = ep_rank;
[1134]449
[1295]450    std::vector<int> my_map(mpi_size, 0);
[1134]451
[1520]452    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
[1134]453
[1295]454    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
455    my_src += ep_rank_loc;
[1134]456
[1295]457     
458    for(int i=0; i<mpi_size; i++)
[1134]459    {
[1295]460      if(my_dst < my_map[i])
461      {
462        my_dst = get_ep_rank(comm, my_dst, i); 
463        break;
464      }
465      else
466        my_dst -= my_map[i];
[1134]467    }
468
[1295]469    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
470    MPI_Barrier(comm);
[1134]471
[1295]472    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
[1134]473
[1295]474    if(ep_rank != my_dst) 
[1134]475    {
[1295]476      MPI_Request request[2];
477      MPI_Status status[2];
[1134]478
[1295]479      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
480   
481      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
482   
483      MPI_Waitall(2, request, status);
[1134]484    }
[1295]485   
[1134]486
[1295]487    void* tmp_recvbuf;
488    tmp_recvbuf = new void*[datasize * count];   
[1134]489
[1295]490    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
[1134]491
[1295]492    if(ep_rank_loc == 0)
[1520]493      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
[1134]494
[1295]495    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
496   
497    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
[1134]498
[1295]499    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
[1134]500
501
502
[1295]503    if(ep_rank != my_src) 
[1134]504    {
[1295]505      MPI_Request request[2];
506      MPI_Status status[2];
[1134]507
[1295]508      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
509   
510      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
511   
512      MPI_Waitall(2, request, status);
[1134]513    }
514
[1295]515    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
516   
[1134]517
[1295]518    delete[] tmp_sendbuf;
519    delete[] tmp_recvbuf;
[1134]520
521  }
522
[1539]523  int MPI_Scan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
524  {
525    printf("MPI_Scan_intercomm not yet implemented\n");
526    MPI_Abort(comm, 0);
527  }
528
[1520]529}
Note: See TracBrowser for help on using the repository browser.