source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_exscan.cpp @ 1501

Last change on this file since 1501 was 1295, checked in by yushan, 7 years ago

EP update all

File size: 9.4 KB
RevLine 
[1134]1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
[1295]11#include "ep_mpi.hpp"
[1134]12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
[1295]29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
[1134]31  {
[1295]32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
[1134]33  }
34
[1295]35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
[1134]40
[1295]41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
[1134]46
47
[1295]48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
[1134]49  {
[1295]50    valid_op(op);
[1134]51
[1295]52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
55   
[1134]56
[1295]57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
[1289]61    {
[1295]62      comm.my_buffer->void_buffer[0] = recvbuf;
63    }
64    if(ep_rank_loc == 0 && mpi_rank == 0)
65    {
66      comm.my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
67    } 
68     
[1287]69
[1295]70    MPI_Barrier_local(comm);
[1289]71
[1295]72    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count);
[1289]73
[1295]74    MPI_Barrier_local(comm);
[1289]75
[1295]76    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
77   
78    MPI_Barrier_local(comm);
79
80    if(op == MPI_SUM)
81    {
82      if(datatype == MPI_INT )
[1289]83      {
[1295]84        assert(datasize == sizeof(int));
85        for(int i=0; i<ep_rank_loc; i++)
86          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
87      }
88     
89      else if(datatype == MPI_FLOAT )
90      {
91        assert(datasize == sizeof(float));
92        for(int i=0; i<ep_rank_loc; i++)
93          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
94      }
95     
[1289]96
[1295]97      else if(datatype == MPI_DOUBLE )
98      {
99        assert(datasize == sizeof(double));
100        for(int i=0; i<ep_rank_loc; i++)
101          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
102      }
[1289]103
[1295]104      else if(datatype == MPI_CHAR )
105      {
106        assert(datasize == sizeof(char));
107        for(int i=0; i<ep_rank_loc; i++)
108          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1289]109      }
[1134]110
[1295]111      else if(datatype == MPI_LONG )
[1134]112      {
[1295]113        assert(datasize == sizeof(long));
114        for(int i=0; i<ep_rank_loc; i++)
115          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]116      }
[1289]117
[1295]118      else if(datatype == MPI_UNSIGNED_LONG )
[1134]119      {
[1295]120        assert(datasize == sizeof(unsigned long));
121        for(int i=0; i<ep_rank_loc; i++)
122          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1134]123      }
124
[1295]125      else printf("datatype Error\n");
[1289]126
[1295]127     
128    }
[1289]129
[1295]130    else if(op == MPI_MAX)
[1289]131    {
[1295]132      if(datatype == MPI_INT )
[1134]133      {
[1295]134        assert(datasize == sizeof(int));
135        for(int i=0; i<ep_rank_loc; i++)
136          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
137      }
[1289]138
[1295]139      else if(datatype == MPI_FLOAT )
140      {
141        assert(datasize == sizeof(float));
142        for(int i=0; i<ep_rank_loc; i++)
143          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1134]144      }
145
[1295]146      else if(datatype == MPI_DOUBLE )
147      {
148        assert(datasize == sizeof(double));
149        for(int i=0; i<ep_rank_loc; i++)
150          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
151      }
[1289]152
[1295]153      else if(datatype == MPI_CHAR )
[1134]154      {
[1295]155        assert(datasize == sizeof(char));
156        for(int i=0; i<ep_rank_loc; i++)
157          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1134]158      }
159
[1295]160      else if(datatype == MPI_LONG )
[1134]161      {
[1295]162        assert(datasize == sizeof(long));
163        for(int i=0; i<ep_rank_loc; i++)
164          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]165      }
166
[1295]167      else if(datatype == MPI_UNSIGNED_LONG )
[1134]168      {
[1295]169        assert(datasize == sizeof(unsigned long));
170        for(int i=0; i<ep_rank_loc; i++)
171          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1134]172      }
[1295]173     
174      else printf("datatype Error\n");
[1289]175    }
[1134]176
[1295]177    else //if(op == MPI_MIN)
[1134]178    {
[1295]179      if(datatype == MPI_INT )
[1289]180      {
[1295]181        assert(datasize == sizeof(int));
182        for(int i=0; i<ep_rank_loc; i++)
183          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
184      }
[1134]185
[1295]186      else if(datatype == MPI_FLOAT )
187      {
188        assert(datasize == sizeof(float));
189        for(int i=0; i<ep_rank_loc; i++)
190          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1289]191      }
[1134]192
[1295]193      else if(datatype == MPI_DOUBLE )
194      {
195        assert(datasize == sizeof(double));
196        for(int i=0; i<ep_rank_loc; i++)
197          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
198      }
[1134]199
[1295]200      else if(datatype == MPI_CHAR )
[1289]201      {
[1295]202        assert(datasize == sizeof(char));
203        for(int i=0; i<ep_rank_loc; i++)
204          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
205      }
[1134]206
[1295]207      else if(datatype == MPI_LONG )
208      {
209        assert(datasize == sizeof(long));
210        for(int i=0; i<ep_rank_loc; i++)
211          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1289]212      }
[1134]213
[1295]214      else if(datatype == MPI_UNSIGNED_LONG )
[1289]215      {
[1295]216        assert(datasize == sizeof(unsigned long));
217        for(int i=0; i<ep_rank_loc; i++)
218          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1289]219      }
[1134]220
[1295]221      else printf("datatype Error\n");
222    }
[1134]223
[1295]224    MPI_Barrier_local(comm);
[1289]225
226  }
[1134]227
228  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
229  {
230    if(!comm.is_ep)
231    {
[1295]232      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
[1134]233    }
[1295]234   
235    valid_type(datatype);
[1134]236
[1295]237    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
238    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
239    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
240    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
241    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
242    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
[1134]243
244    ::MPI_Aint datasize, lb;
[1295]245    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
[1134]246   
[1295]247    void* tmp_sendbuf;
248    tmp_sendbuf = new void*[datasize * count];
[1134]249
[1295]250    int my_src = 0;
251    int my_dst = ep_rank;
[1134]252
[1295]253    std::vector<int> my_map(mpi_size, 0);
[1134]254
[1295]255    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++;
[1134]256
[1295]257    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
258    my_src += ep_rank_loc;
[1134]259
[1295]260     
261    for(int i=0; i<mpi_size; i++)
[1134]262    {
[1295]263      if(my_dst < my_map[i])
264      {
265        my_dst = get_ep_rank(comm, my_dst, i); 
266        break;
267      }
268      else
269        my_dst -= my_map[i];
[1289]270    }
271
[1295]272    if(ep_rank != my_dst) 
[1289]273    {
[1295]274      MPI_Request request[2];
275      MPI_Status status[2];
[1289]276
[1295]277      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
278   
279      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
280   
281      MPI_Waitall(2, request, status);
[1289]282    }
283
[1295]284    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
285   
[1289]286
[1295]287    void* tmp_recvbuf;
288    tmp_recvbuf = new void*[datasize * count];   
[1289]289
[1295]290    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
[1289]291
[1295]292    if(ep_rank_loc == 0)
293      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
[1134]294
[1295]295    // printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
296   
297    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
[1134]298
[1295]299     // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
[1289]300
[1134]301
302
[1295]303    if(ep_rank != my_src) 
[1289]304    {
[1295]305      MPI_Request request[2];
306      MPI_Status status[2];
[1134]307
[1295]308      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
309   
310      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
311   
312      MPI_Waitall(2, request, status);
[1289]313    }
[1134]314
[1295]315    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
316   
[1134]317
318
[1289]319
[1295]320    delete[] tmp_sendbuf;
321    delete[] tmp_recvbuf;
[1134]322
[1289]323  }
[1134]324
[1295]325}
Note: See TracBrowser for help on using the repository browser.