source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scan.cpp @ 1354

Last change on this file since 1354 was 1295, checked in by yushan, 7 years ago

EP update all

File size: 13.9 KB
RevLine 
[1134]1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
[1295]11#include "ep_mpi.hpp"
[1134]12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
[1295]29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
[1134]34
[1295]35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
[1134]37  {
[1295]38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
[1134]40
[1295]41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
[1134]45  }
46
47
[1295]48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
[1289]51
[1295]52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
55   
[1289]56
[1295]57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
[1134]59
[1295]60    if(ep_rank_loc == 0 && mpi_rank != 0)
[1134]61    {
[1295]62      if(op == MPI_SUM)
[1134]63      {
[1295]64        if(datatype == MPI_INT)
[1289]65        {
[1295]66          assert(datasize == sizeof(int));
67          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
[1289]68        }
[1295]69         
70        else if(datatype == MPI_FLOAT)
71        {
72          assert( datasize == sizeof(float));
73          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
74        } 
75             
76        else if(datatype == MPI_DOUBLE )
77        {
78          assert( datasize == sizeof(double));
79          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
80        }
81     
82        else if(datatype == MPI_CHAR)
83        {
84          assert( datasize == sizeof(char));
85          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
86        } 
87         
88        else if(datatype == MPI_LONG)
89        {
90          assert( datasize == sizeof(long));
91          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
92        } 
93         
94           
95        else if(datatype == MPI_UNSIGNED_LONG)
96        {
97          assert(datasize == sizeof(unsigned long));
98          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
99        }
100           
101        else printf("datatype Error\n");
[1287]102      }
[1134]103
[1295]104      else if(op == MPI_MAX)
[1287]105      {
[1295]106        if(datatype == MPI_INT)
[1289]107        {
[1295]108          assert( datasize == sizeof(int));
109          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
110        } 
111         
112        else if(datatype == MPI_FLOAT )
113        {
114          assert( datasize == sizeof(float));
115          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
[1289]116        }
[1134]117
[1295]118        else if(datatype == MPI_DOUBLE )
119        {
120          assert( datasize == sizeof(double));
121          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
122        }
123     
124        else if(datatype == MPI_CHAR )
125        {
126          assert(datasize == sizeof(char));
127          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
128        }
129     
130        else if(datatype == MPI_LONG)
131        {
132          assert( datasize == sizeof(long));
133          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
134        } 
135           
136        else if(datatype == MPI_UNSIGNED_LONG)
137        {
138          assert( datasize == sizeof(unsigned long));
139          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
140        } 
141           
142        else printf("datatype Error\n");
[1287]143      }
[1134]144
[1295]145      else //(op == MPI_MIN)
[1134]146      {
[1295]147        if(datatype == MPI_INT )
[1289]148        {
[1295]149          assert (datasize == sizeof(int));
150          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
[1289]151        }
[1295]152         
153        else if(datatype == MPI_FLOAT )
154        {
155          assert( datasize == sizeof(float));
156          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
157        }
158             
159        else if(datatype == MPI_DOUBLE )
160        {
161          assert( datasize == sizeof(double));
162          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
163        }
164     
165        else if(datatype == MPI_CHAR )
166        {
167          assert( datasize == sizeof(char));
168          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
169        }
170     
171        else if(datatype == MPI_LONG )
172        { 
173          assert( datasize == sizeof(long));
174          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
175        }
176           
177        else if(datatype == MPI_UNSIGNED_LONG )
178        {
179          assert( datasize == sizeof(unsigned long));
180          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
181        }
182           
183        else printf("datatype Error\n");
[1134]184      }
[1289]185
[1295]186      comm.my_buffer->void_buffer[0] = recvbuf;
187    }
188    else
189    {
190      comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
191      memcpy(recvbuf, sendbuf, datasize*count);
192    } 
193     
[1289]194
195
[1295]196    MPI_Barrier_local(comm);
[1289]197
[1295]198    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count);
[1289]199
[1134]200
[1295]201    if(op == MPI_SUM)
[1289]202    {
[1295]203      if(datatype == MPI_INT )
[1134]204      {
[1295]205        assert (datasize == sizeof(int));
206        for(int i=1; i<ep_rank_loc+1; i++)
207          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
208      }
209     
210      else if(datatype == MPI_FLOAT )
211      {
212        assert(datasize == sizeof(float));
213        for(int i=1; i<ep_rank_loc+1; i++)
214          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
215      }
216     
[1289]217
[1295]218      else if(datatype == MPI_DOUBLE )
219      {
220        assert(datasize == sizeof(double));
221        for(int i=1; i<ep_rank_loc+1; i++)
222          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1134]223      }
224
[1295]225      else if(datatype == MPI_CHAR )
226      {
227        assert(datasize == sizeof(char));
228        for(int i=1; i<ep_rank_loc+1; i++)
229          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
230      }
[1289]231
[1295]232      else if(datatype == MPI_LONG )
[1134]233      {
[1295]234        assert(datasize == sizeof(long));
235        for(int i=1; i<ep_rank_loc+1; i++)
236          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
237      }
[1289]238
[1295]239      else if(datatype == MPI_UNSIGNED_LONG )
240      {
241        assert(datasize == sizeof(unsigned long));
242        for(int i=1; i<ep_rank_loc+1; i++)
243          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1134]244      }
245
[1295]246      else printf("datatype Error\n");
[1289]247
[1295]248     
249    }
[1289]250
[1295]251    else if(op == MPI_MAX)
[1289]252    {
[1295]253      if(datatype == MPI_INT)
[1134]254      {
[1295]255        assert(datasize == sizeof(int));
256        for(int i=1; i<ep_rank_loc+1; i++)
257          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
258      }
[1289]259
[1295]260      else if(datatype == MPI_FLOAT )
261      {
262        assert(datasize == sizeof(float));
263        for(int i=1; i<ep_rank_loc+1; i++)
264          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1134]265      }
266
[1295]267      else if(datatype == MPI_DOUBLE )
268      {
269        assert(datasize == sizeof(double));
270        for(int i=1; i<ep_rank_loc+1; i++)
271          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
272      }
[1289]273
[1295]274      else if(datatype == MPI_CHAR )
[1134]275      {
[1295]276        assert(datasize == sizeof(char));
277        for(int i=1; i<ep_rank_loc+1; i++)
278          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
279      }
[1289]280
[1295]281      else if(datatype == MPI_LONG )
282      {
283        assert(datasize == sizeof(long));
284        for(int i=1; i<ep_rank_loc+1; i++)
285          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]286      }
[1295]287
288      else if(datatype == MPI_UNSIGNED_LONG )
289      {
290        assert(datasize == sizeof(unsigned long));
291        for(int i=1; i<ep_rank_loc+1; i++)
292          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
293      }
294     
295      else printf("datatype Error\n");
[1289]296    }
[1134]297
[1295]298    else //if(op == MPI_MIN)
[1134]299    {
[1295]300      if(datatype == MPI_INT )
[1289]301      {
[1295]302        assert(datasize == sizeof(int));
303        for(int i=1; i<ep_rank_loc+1; i++)
304          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
305      }
[1134]306
[1295]307      else if(datatype == MPI_FLOAT )
308      {
309        assert(datasize == sizeof(float));
310        for(int i=1; i<ep_rank_loc+1; i++)
311          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1289]312      }
[1134]313
[1295]314      else if(datatype == MPI_DOUBLE )
315      {
316        assert(datasize == sizeof(double));
317        for(int i=1; i<ep_rank_loc+1; i++)
318          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
319      }
[1134]320
[1295]321      else if(datatype == MPI_CHAR )
[1289]322      {
[1295]323        assert(datasize == sizeof(char));
324        for(int i=1; i<ep_rank_loc+1; i++)
325          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
326      }
[1134]327
[1295]328      else if(datatype == MPI_LONG )
329      {
330        assert(datasize == sizeof(long));
331        for(int i=1; i<ep_rank_loc+1; i++)
332          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1289]333      }
[1134]334
[1295]335      else if(datatype == MPI_UNSIGNED_LONG )
[1289]336      {
[1295]337        assert(datasize == sizeof(unsigned long));
338        for(int i=1; i<ep_rank_loc+1; i++)
339          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1289]340      }
[1134]341
[1295]342      else printf("datatype Error\n");
343    }
[1134]344
[1295]345    MPI_Barrier_local(comm);
[1134]346
347  }
348
349
350  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
351  {
352    if(!comm.is_ep)
353    {
[1295]354      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
[1134]355    }
[1295]356   
357    valid_type(datatype);
[1134]358
[1295]359    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
360    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
361    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
362    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
363    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
364    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
[1134]365
366    ::MPI_Aint datasize, lb;
[1295]367    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
368   
369    void* tmp_sendbuf;
370    tmp_sendbuf = new void*[datasize * count];
[1134]371
[1295]372    int my_src = 0;
373    int my_dst = ep_rank;
[1134]374
[1295]375    std::vector<int> my_map(mpi_size, 0);
[1134]376
[1295]377    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++;
[1134]378
[1295]379    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
380    my_src += ep_rank_loc;
[1134]381
[1295]382     
383    for(int i=0; i<mpi_size; i++)
[1134]384    {
[1295]385      if(my_dst < my_map[i])
386      {
387        my_dst = get_ep_rank(comm, my_dst, i); 
388        break;
389      }
390      else
391        my_dst -= my_map[i];
[1134]392    }
393
[1295]394    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
395    MPI_Barrier(comm);
[1134]396
[1295]397    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
[1134]398
[1295]399    if(ep_rank != my_dst) 
[1134]400    {
[1295]401      MPI_Request request[2];
402      MPI_Status status[2];
[1134]403
[1295]404      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
405   
406      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
407   
408      MPI_Waitall(2, request, status);
[1134]409    }
[1295]410   
[1134]411
[1295]412    void* tmp_recvbuf;
413    tmp_recvbuf = new void*[datasize * count];   
[1134]414
[1295]415    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
[1134]416
[1295]417    if(ep_rank_loc == 0)
418      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
[1134]419
[1295]420    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
421   
422    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
[1134]423
[1295]424    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
[1134]425
426
427
[1295]428    if(ep_rank != my_src) 
[1134]429    {
[1295]430      MPI_Request request[2];
431      MPI_Status status[2];
[1134]432
[1295]433      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
434   
435      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
436   
437      MPI_Waitall(2, request, status);
[1134]438    }
439
[1295]440    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
441   
[1134]442
[1295]443    delete[] tmp_sendbuf;
444    delete[] tmp_recvbuf;
[1134]445
446  }
447
[1295]448}
Note: See TracBrowser for help on using the repository browser.