source: XIOS/dev/dev_trunk_omp/extern/src_ep_dev/ep_scan.cpp @ 1646

Last change on this file since 1646 was 1646, checked in by yushan, 5 years ago

branch merged with trunk @1645. arch file (ep&mpi) added for ADA

File size: 17.0 KB
Line 
1#ifdef _usingEP
2/*!
3   \file ep_scan.cpp
4   \since 2 may 2016
5
6   \brief Definitions of MPI collective function: MPI_Scan
7 */
8
9#include "ep_lib.hpp"
10#include <mpi.h>
11#include "ep_declaration.hpp"
12#include "ep_mpi.hpp"
13
14using namespace std;
15
16namespace ep_lib
17{
18  template<typename T>
19  T max_op(T a, T b)
20  {
21    return max(a,b);
22  }
23
24  template<typename T>
25  T min_op(T a, T b)
26  {
27    return min(a,b);
28  }
29
30  template<typename T>
31  void reduce_max(const T * buffer, T* recvbuf, int count)
32  {
33    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
34  }
35
36  template<typename T>
37  void reduce_min(const T * buffer, T* recvbuf, int count)
38  {
39    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
40  }
41
42  template<typename T>
43  void reduce_sum(const T * buffer, T* recvbuf, int count)
44  {
45    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
46  }
47
48
49  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
50  {
51    valid_op(op);
52
53    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
54    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
55    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
56   
57
58    ::MPI_Aint datasize, lb;
59    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
60
61    if(ep_rank_loc == 0 && mpi_rank != 0)
62    {
63      if(op == MPI_SUM)
64      {
65        if(datatype == MPI_INT)
66        {
67          assert(datasize == sizeof(int));
68          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
69        }
70         
71        else if(datatype == MPI_FLOAT)
72        {
73          assert( datasize == sizeof(float));
74          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
75        } 
76             
77        else if(datatype == MPI_DOUBLE )
78        {
79          assert( datasize == sizeof(double));
80          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
81        }
82     
83        else if(datatype == MPI_CHAR)
84        {
85          assert( datasize == sizeof(char));
86          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
87        } 
88         
89        else if(datatype == MPI_LONG)
90        {
91          assert( datasize == sizeof(long));
92          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
93        } 
94         
95           
96        else if(datatype == MPI_UNSIGNED_LONG)
97        {
98          assert(datasize == sizeof(unsigned long));
99          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
100        }
101       
102        else if(datatype == MPI_LONG_LONG_INT)
103        {
104          assert(datasize == sizeof(long long int));
105          reduce_sum<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
106        }
107           
108        else 
109        {
110          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
111          MPI_Abort(comm, 0);
112        }
113      }
114
115      else if(op == MPI_MAX)
116      {
117        if(datatype == MPI_INT)
118        {
119          assert( datasize == sizeof(int));
120          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
121        } 
122         
123        else if(datatype == MPI_FLOAT )
124        {
125          assert( datasize == sizeof(float));
126          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
127        }
128
129        else if(datatype == MPI_DOUBLE )
130        {
131          assert( datasize == sizeof(double));
132          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
133        }
134     
135        else if(datatype == MPI_CHAR )
136        {
137          assert(datasize == sizeof(char));
138          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
139        }
140     
141        else if(datatype == MPI_LONG)
142        {
143          assert( datasize == sizeof(long));
144          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
145        } 
146           
147        else if(datatype == MPI_UNSIGNED_LONG)
148        {
149          assert( datasize == sizeof(unsigned long));
150          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
151        } 
152           
153        else if(datatype == MPI_LONG_LONG_INT)
154        {
155          assert(datasize == sizeof(long long int));
156          reduce_max<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
157        }
158           
159        else 
160        {
161          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
162          MPI_Abort(comm, 0);
163        }
164      }
165
166      else if(op == MPI_MIN)
167      {
168        if(datatype == MPI_INT )
169        {
170          assert (datasize == sizeof(int));
171          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);   
172        }
173         
174        else if(datatype == MPI_FLOAT )
175        {
176          assert( datasize == sizeof(float));
177          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);   
178        }
179             
180        else if(datatype == MPI_DOUBLE )
181        {
182          assert( datasize == sizeof(double));
183          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count);
184        }
185     
186        else if(datatype == MPI_CHAR )
187        {
188          assert( datasize == sizeof(char));
189          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count);
190        }
191     
192        else if(datatype == MPI_LONG )
193        { 
194          assert( datasize == sizeof(long));
195          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count);
196        }
197           
198        else if(datatype == MPI_UNSIGNED_LONG )
199        {
200          assert( datasize == sizeof(unsigned long));
201          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);   
202        }
203           
204        else if(datatype == MPI_LONG_LONG_INT)
205        {
206          assert(datasize == sizeof(long long int));
207          reduce_min<long long int>(static_cast<long long int*>(const_cast<void*>(sendbuf)), static_cast<long long int*>(recvbuf), count);   
208        }
209           
210        else 
211        {
212          printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
213          MPI_Abort(comm, 0);
214        }
215      }
216     
217      else
218      {
219        printf("op type Error in ep_scan : MPI_MAX, MPI_MIN, MPI_SUM\n");
220        MPI_Abort(comm, 0);
221      }
222
223      comm->my_buffer->void_buffer[0] = recvbuf;
224    }
225    else
226    {
227      comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
228      memcpy(recvbuf, sendbuf, datasize*count);
229    } 
230     
231
232
233    MPI_Barrier_local(comm);
234
235    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
236
237
238    if(op == MPI_SUM)
239    {
240      if(datatype == MPI_INT )
241      {
242        assert (datasize == sizeof(int));
243        for(int i=1; i<ep_rank_loc+1; i++)
244          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
245      }
246     
247      else if(datatype == MPI_FLOAT )
248      {
249        assert(datasize == sizeof(float));
250        for(int i=1; i<ep_rank_loc+1; i++)
251          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
252      }
253     
254
255      else if(datatype == MPI_DOUBLE )
256      {
257        assert(datasize == sizeof(double));
258        for(int i=1; i<ep_rank_loc+1; i++)
259          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
260      }
261
262      else if(datatype == MPI_CHAR )
263      {
264        assert(datasize == sizeof(char));
265        for(int i=1; i<ep_rank_loc+1; i++)
266          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
267      }
268
269      else if(datatype == MPI_LONG )
270      {
271        assert(datasize == sizeof(long));
272        for(int i=1; i<ep_rank_loc+1; i++)
273          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
274      }
275
276      else if(datatype == MPI_UNSIGNED_LONG )
277      {
278        assert(datasize == sizeof(unsigned long));
279        for(int i=1; i<ep_rank_loc+1; i++)
280          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
281      }
282     
283      else if(datatype == MPI_LONG_LONG_INT )
284      {
285        assert(datasize == sizeof(long long int));
286        for(int i=1; i<ep_rank_loc+1; i++)
287          reduce_sum<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
288      }
289
290      else 
291      {
292        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
293        MPI_Abort(comm, 0);
294      }
295
296     
297    }
298
299    else if(op == MPI_MAX)
300    {
301      if(datatype == MPI_INT)
302      {
303        assert(datasize == sizeof(int));
304        for(int i=1; i<ep_rank_loc+1; i++)
305          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
306      }
307
308      else if(datatype == MPI_FLOAT )
309      {
310        assert(datasize == sizeof(float));
311        for(int i=1; i<ep_rank_loc+1; i++)
312          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
313      }
314
315      else if(datatype == MPI_DOUBLE )
316      {
317        assert(datasize == sizeof(double));
318        for(int i=1; i<ep_rank_loc+1; i++)
319          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
320      }
321
322      else if(datatype == MPI_CHAR )
323      {
324        assert(datasize == sizeof(char));
325        for(int i=1; i<ep_rank_loc+1; i++)
326          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
327      }
328
329      else if(datatype == MPI_LONG )
330      {
331        assert(datasize == sizeof(long));
332        for(int i=1; i<ep_rank_loc+1; i++)
333          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
334      }
335
336      else if(datatype == MPI_UNSIGNED_LONG )
337      {
338        assert(datasize == sizeof(unsigned long));
339        for(int i=1; i<ep_rank_loc+1; i++)
340          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
341      }
342     
343      else if(datatype == MPI_LONG_LONG_INT )
344      {
345        assert(datasize == sizeof(long long int));
346        for(int i=1; i<ep_rank_loc+1; i++)
347          reduce_max<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
348      }
349
350      else 
351      {
352        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
353        MPI_Abort(comm, 0);
354      }
355
356    }
357
358    else if(op == MPI_MIN)
359    {
360      if(datatype == MPI_INT )
361      {
362        assert(datasize == sizeof(int));
363        for(int i=1; i<ep_rank_loc+1; i++)
364          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
365      }
366
367      else if(datatype == MPI_FLOAT )
368      {
369        assert(datasize == sizeof(float));
370        for(int i=1; i<ep_rank_loc+1; i++)
371          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
372      }
373
374      else if(datatype == MPI_DOUBLE )
375      {
376        assert(datasize == sizeof(double));
377        for(int i=1; i<ep_rank_loc+1; i++)
378          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
379      }
380
381      else if(datatype == MPI_CHAR )
382      {
383        assert(datasize == sizeof(char));
384        for(int i=1; i<ep_rank_loc+1; i++)
385          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
386      }
387
388      else if(datatype == MPI_LONG )
389      {
390        assert(datasize == sizeof(long));
391        for(int i=1; i<ep_rank_loc+1; i++)
392          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
393      }
394
395      else if(datatype == MPI_UNSIGNED_LONG )
396      {
397        assert(datasize == sizeof(unsigned long));
398        for(int i=1; i<ep_rank_loc+1; i++)
399          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
400      }
401
402      else if(datatype == MPI_LONG_LONG_INT )
403      {
404        assert(datasize == sizeof(long long int));
405        for(int i=1; i<ep_rank_loc+1; i++)
406          reduce_min<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
407      }
408
409      else 
410      {
411        printf("datatype Error in ep_scan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
412        MPI_Abort(comm, 0);
413      }
414
415    }
416   
417    else
418    {
419      printf("op type Error in ep_scan : MPI_MAX, MPI_MIN, MPI_SUM\n");
420      MPI_Abort(comm, 0);
421    }
422
423    MPI_Barrier_local(comm);
424
425  }
426
427
428  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
429  {
430    if(!comm->is_ep) return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
431    if(comm->is_intercomm) return MPI_Scan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
432   
433    valid_type(datatype);
434
435    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
436    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
437    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
438    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
439    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
440    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
441
442    ::MPI_Aint datasize, lb;
443    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
444   
445    void* tmp_sendbuf;
446    tmp_sendbuf = new void*[datasize * count];
447
448    int my_src = 0;
449    int my_dst = ep_rank;
450
451    std::vector<int> my_map(mpi_size, 0);
452
453    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
454
455    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
456    my_src += ep_rank_loc;
457
458     
459    for(int i=0; i<mpi_size; i++)
460    {
461      if(my_dst < my_map[i])
462      {
463        my_dst = get_ep_rank(comm, my_dst, i); 
464        break;
465      }
466      else
467        my_dst -= my_map[i];
468    }
469
470    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src);
471    MPI_Barrier(comm);
472
473    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count);
474
475    if(ep_rank != my_dst) 
476    {
477      MPI_Request request[2];
478      MPI_Status status[2];
479
480      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
481   
482      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
483   
484      MPI_Waitall(2, request, status);
485    }
486   
487
488    void* tmp_recvbuf;
489    tmp_recvbuf = new void*[datasize * count];   
490
491    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
492
493    if(ep_rank_loc == 0)
494      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
495
496    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
497   
498    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
499
500    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]);
501
502
503
504    if(ep_rank != my_src) 
505    {
506      MPI_Request request[2];
507      MPI_Status status[2];
508
509      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
510   
511      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
512   
513      MPI_Waitall(2, request, status);
514    }
515
516    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
517   
518
519    delete[] tmp_sendbuf;
520    delete[] tmp_recvbuf;
521
522  }
523
524  int MPI_Scan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
525  {
526    printf("MPI_Scan_intercomm not yet implemented\n");
527    MPI_Abort(comm, 0);
528  }
529
530}
531#endif
Note: See TracBrowser for help on using the repository browser.