source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_exscan.cpp @ 1287

Last change on this file since 1287 was 1287, checked in by yushan, 7 years ago

EP updated

File size: 9.0 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
31  {
32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
33  }
34
35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
40
41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
46
47
48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
49  {
50    valid_op(op);
51
52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
55   
56
57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
61    {
62      comm.my_buffer->void_buffer[0] = recvbuf;
63    }
64    if(ep_rank_loc == 0 && mpi_rank == 0)
65    {
66      comm.my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
67    } 
68     
69
70    MPI_Barrier_local(comm);
71
72    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count);
73
74    MPI_Barrier_local(comm);
75
76    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
77   
78    MPI_Barrier_local(comm);
79
80    if(op == MPI_SUM)
81    {
82      if(datatype == MPI_INT && datasize == sizeof(int))
83      {
84        for(int i=0; i<ep_rank_loc; i++)
85          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
86      }
87     
88      else if(datatype == MPI_FLOAT && datasize == sizeof(float))
89      {
90        for(int i=0; i<ep_rank_loc; i++)
91          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
92      }
93     
94
95      else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
96      {
97        for(int i=0; i<ep_rank_loc; i++)
98          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
99      }
100
101      else if(datatype == MPI_CHAR && datasize == sizeof(char))
102      {
103        for(int i=0; i<ep_rank_loc; i++)
104          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
105      }
106
107      else if(datatype == MPI_LONG && datasize == sizeof(long))
108      {
109        for(int i=0; i<ep_rank_loc; i++)
110          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
111      }
112
113      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
114      {
115        for(int i=0; i<ep_rank_loc; i++)
116          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
117      }
118
119      else printf("datatype Error\n");
120
121     
122    }
123
124    else if(op == MPI_MAX)
125    {
126      if(datatype == MPI_INT && datasize == sizeof(int))
127        for(int i=0; i<ep_rank_loc; i++)
128          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
129
130      else if(datatype == MPI_FLOAT && datasize == sizeof(float))
131        for(int i=0; i<ep_rank_loc; i++)
132          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
133
134      else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
135        for(int i=0; i<ep_rank_loc; i++)
136          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
137
138      else if(datatype == MPI_CHAR && datasize == sizeof(char))
139        for(int i=0; i<ep_rank_loc; i++)
140          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
141
142      else if(datatype == MPI_LONG && datasize == sizeof(long))
143        for(int i=0; i<ep_rank_loc; i++)
144          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
145
146      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
147        for(int i=0; i<ep_rank_loc; i++)
148          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
149     
150      else printf("datatype Error\n");
151    }
152
153    else //if(op == MPI_MIN)
154    {
155      if(datatype == MPI_INT && datasize == sizeof(int))
156        for(int i=0; i<ep_rank_loc; i++)
157          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
158
159      else if(datatype == MPI_FLOAT && datasize == sizeof(float))
160        for(int i=0; i<ep_rank_loc; i++)
161          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
162
163      else if(datatype == MPI_DOUBLE && datasize == sizeof(double))
164        for(int i=0; i<ep_rank_loc; i++)
165          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
166
167      else if(datatype == MPI_CHAR && datasize == sizeof(char))
168        for(int i=0; i<ep_rank_loc; i++)
169          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
170
171      else if(datatype == MPI_LONG && datasize == sizeof(long))
172        for(int i=0; i<ep_rank_loc; i++)
173          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
174
175      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long))
176        for(int i=0; i<ep_rank_loc; i++)
177          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
178
179      else printf("datatype Error\n");
180    }
181
182    MPI_Barrier_local(comm);
183
184  }
185
186  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
187  {
188    if(!comm.is_ep)
189    {
190      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
191    }
192   
193    valid_type(datatype);
194
195    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
196    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
197    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
198    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
199    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
200    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
201
202    ::MPI_Aint datasize, lb;
203    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
204   
205    void* tmp_sendbuf;
206    tmp_sendbuf = new void*[datasize * count];
207
208    int my_src = 0;
209    int my_dst = ep_rank;
210
211    std::vector<int> my_map(mpi_size, 0);
212
213    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++;
214
215    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
216    my_src += ep_rank_loc;
217
218     
219    for(int i=0; i<mpi_size; i++)
220    {
221      if(my_dst < my_map[i])
222      {
223        my_dst = get_ep_rank(comm, my_dst, i); 
224        break;
225      }
226      else
227        my_dst -= my_map[i];
228    }
229
230    if(ep_rank != my_dst) 
231    {
232      MPI_Request request[2];
233      MPI_Status status[2];
234
235      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
236   
237      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
238   
239      MPI_Waitall(2, request, status);
240    }
241
242    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
243   
244
245    void* tmp_recvbuf;
246    tmp_recvbuf = new void*[datasize * count];   
247
248    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
249
250    if(ep_rank_loc == 0)
251      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
252
253    // printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
254   
255    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
256
257     // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
258
259
260
261    if(ep_rank != my_src) 
262    {
263      MPI_Request request[2];
264      MPI_Status status[2];
265
266      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
267   
268      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
269   
270      MPI_Waitall(2, request, status);
271    }
272
273    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
274   
275
276
277
278    delete[] tmp_sendbuf;
279    delete[] tmp_recvbuf;
280
281  }
282
283}
Note: See TracBrowser for help on using the repository browser.