source: XIOS/dev/branch_yushan/extern/src_ep_dev/ep_exscan.cpp @ 1037

Last change on this file since 1037 was 1037, checked in by yushan, 7 years ago

initialize the branch

File size: 22.8 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16  template<typename T>
17  T max_op(T a, T b)
18  {
19    return max(a,b);
20  }
21
22  template<typename T>
23  T min_op(T a, T b)
24  {
25    return min(a,b);
26  }
27
28  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
29  {
30    if(datatype == MPI_INT)
31    {
32      return MPI_Exscan_local_int(sendbuf, recvbuf, count, op, comm);
33    }
34    else if(datatype == MPI_FLOAT)
35    {
36      return MPI_Exscan_local_float(sendbuf, recvbuf, count, op, comm);
37    }
38    else if(datatype == MPI_DOUBLE)
39    {
40      return MPI_Exscan_local_double(sendbuf, recvbuf, count, op, comm);
41    }
42    else if(datatype == MPI_LONG)
43    {
44      return MPI_Exscan_local_long(sendbuf, recvbuf, count, op, comm);
45    }
46    else if(datatype == MPI_UNSIGNED_LONG)
47    {
48      return MPI_Exscan_local_ulong(sendbuf, recvbuf, count, op, comm);
49    }
50    else if(datatype == MPI_CHAR)
51    {
52      return MPI_Exscan_local_char(sendbuf, recvbuf, count, op, comm);
53    }
54    else
55    {
56      printf("MPI_Exscan Datatype not supported!\n");
57      exit(0);
58    }
59  }
60
61
62
63
64  int MPI_Exscan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
65  {
66    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
67    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
68
69    int *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_int;
70    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
71    int *recv_buf = static_cast<int*>(recvbuf);
72
73    for(int j=0; j<count; j+=BUFFER_SIZE)
74    {
75
76      if(my_rank == 0)
77      {
78
79        #pragma omp critical (write_to_buffer)
80        {
81          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
82          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
83          #pragma omp flush
84        }
85      }
86
87      MPI_Barrier_local(comm);
88
89      for(int k=1; k<num_ep; k++)
90      {
91        #pragma omp critical (write_to_buffer)
92        {
93          if(my_rank == k)
94          {
95            #pragma omp flush
96            if(op == MPI_SUM)
97            {
98              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
99              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>());
100
101            }
102            else if(op == MPI_MAX)
103            {
104              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
105              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>);
106            }
107            else if(op == MPI_MIN)
108            {
109              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
110              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>);
111            }
112            else
113            {
114              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
115              exit(1);
116            }
117            #pragma omp flush
118          }
119        }
120
121        MPI_Barrier_local(comm);
122      }
123    }
124
125  }
126
127  int MPI_Exscan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
128  {
129    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
130    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
131
132    float *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_float;
133    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
134    float *recv_buf = static_cast<float*>(recvbuf);
135
136    for(int j=0; j<count; j+=BUFFER_SIZE)
137    {
138      if(my_rank == 0)
139      {
140
141        #pragma omp critical (write_to_buffer)
142        {
143          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
144          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
145          #pragma omp flush
146        }
147      }
148
149      MPI_Barrier_local(comm);
150
151      for(int k=1; k<num_ep; k++)
152      {
153        #pragma omp critical (write_to_buffer)
154        {
155          if(my_rank == k)
156          {
157            #pragma omp flush
158            if(op == MPI_SUM)
159            {
160              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
161              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>());
162            }
163            else if(op == MPI_MAX)
164            {
165              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
166              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>);
167            }
168            else if(op == MPI_MIN)
169            {
170              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
171              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>);
172            }
173            else
174            {
175              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
176              exit(1);
177            }
178            #pragma omp flush
179          }
180        }
181
182        MPI_Barrier_local(comm);
183      }
184    }
185  }
186
187  int MPI_Exscan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
188  {
189
190    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
191    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
192
193    double *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_double;
194    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
195    double *recv_buf = static_cast<double*>(recvbuf);
196
197    for(int j=0; j<count; j+=BUFFER_SIZE)
198    {
199      if(my_rank == 0)
200      {
201
202        #pragma omp critical (write_to_buffer)
203        {
204          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
205          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
206          #pragma omp flush
207        }
208      }
209
210      MPI_Barrier_local(comm);
211
212      for(int k=1; k<num_ep; k++)
213      {
214        #pragma omp critical (write_to_buffer)
215        {
216          if(my_rank == k)
217          {
218            #pragma omp flush
219            if(op == MPI_SUM)
220            {
221              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
222              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>());
223            }
224            else if(op == MPI_MAX)
225            {
226              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
227              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>);
228            }
229            else if(op == MPI_MIN)
230            {
231              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
232              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>);
233            }
234            else
235            {
236              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
237              exit(1);
238            }
239            #pragma omp flush
240          }
241        }
242
243        MPI_Barrier_local(comm);
244      }
245    }
246  }
247
248  int MPI_Exscan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
249  {
250
251    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
252    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
253
254    long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_long;
255    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
256    long *recv_buf = static_cast<long*>(recvbuf);
257
258    for(int j=0; j<count; j+=BUFFER_SIZE)
259    {
260      if(my_rank == 0)
261      {
262
263        #pragma omp critical (write_to_buffer)
264        {
265          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
266          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
267          #pragma omp flush
268        }
269      }
270
271      MPI_Barrier_local(comm);
272
273      for(int k=1; k<num_ep; k++)
274      {
275        #pragma omp critical (write_to_buffer)
276        {
277          if(my_rank == k)
278          {
279            #pragma omp flush
280            if(op == MPI_SUM)
281            {
282              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
283              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>());
284            }
285            else if(op == MPI_MAX)
286            {
287              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
288              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>);
289            }
290            else if(op == MPI_MIN)
291            {
292              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
293              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>);
294            }
295            else
296            {
297              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
298              exit(1);
299            }
300            #pragma omp flush
301          }
302        }
303
304        MPI_Barrier_local(comm);
305      }
306    }
307  }
308
309  int MPI_Exscan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
310  {
311
312    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
313    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
314
315    unsigned long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_ulong;
316    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
317    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
318
319    for(int j=0; j<count; j+=BUFFER_SIZE)
320    {
321      if(my_rank == 0)
322      {
323
324        #pragma omp critical (write_to_buffer)
325        {
326          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
327          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
328          #pragma omp flush
329        }
330      }
331
332      MPI_Barrier_local(comm);
333
334      for(int k=1; k<num_ep; k++)
335      {
336        #pragma omp critical (write_to_buffer)
337        {
338          if(my_rank == k)
339          {
340            #pragma omp flush
341            if(op == MPI_SUM)
342            {
343              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
344              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>());
345            }
346            else if(op == MPI_MAX)
347            {
348              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
349              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>);
350            }
351            else if(op == MPI_MIN)
352            {
353              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
354              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>);
355            }
356            else
357            {
358              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
359              exit(1);
360            }
361            #pragma omp flush
362          }
363        }
364
365        MPI_Barrier_local(comm);
366      }
367    }
368  }
369
370  int MPI_Exscan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
371  {
372
373    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
374    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
375
376    char *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_char;
377    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
378    char *recv_buf = static_cast<char*>(recvbuf);
379
380    for(int j=0; j<count; j+=BUFFER_SIZE)
381    {
382      if(my_rank == 0)
383      {
384
385        #pragma omp critical (write_to_buffer)
386        {
387          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
388          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
389          #pragma omp flush
390        }
391      }
392
393      MPI_Barrier_local(comm);
394
395      for(int k=1; k<num_ep; k++)
396      {
397        #pragma omp critical (write_to_buffer)
398        {
399          if(my_rank == k)
400          {
401            #pragma omp flush
402            if(op == MPI_SUM)
403            {
404              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
405              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>());
406            }
407            else if(op == MPI_MAX)
408            {
409              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
410              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>);
411            }
412            else if(op == MPI_MIN)
413            {
414              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
415              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>);
416            }
417            else
418            {
419              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
420              exit(1);
421            }
422            #pragma omp flush
423          }
424        }
425
426        MPI_Barrier_local(comm);
427      }
428    }
429  }
430
431
432  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
433  {
434
435    if(!comm.is_ep)
436    {
437      ::MPI_Exscan(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype),
438                   static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm));
439      return 0;
440    }
441    if(!comm.mpi_comm) return 0;
442
443    int ep_rank, ep_rank_loc, mpi_rank;
444    int ep_size, num_ep, mpi_size;
445
446    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
447    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
448    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
449    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
450    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
451    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
452
453
454
455    ::MPI_Aint datasize, lb;
456   
457    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
458
459    void* local_scan_recvbuf;
460    local_scan_recvbuf = new void*[datasize * count];
461
462
463    // local scan
464    MPI_Exscan_local(sendbuf, recvbuf, count, datatype, op, comm);
465
466//     MPI_scan
467    void* local_sum;
468    void* mpi_scan_recvbuf;
469
470
471    mpi_scan_recvbuf = new void*[datasize*count];
472
473    if(ep_rank_loc == 0)
474    {
475      local_sum = new void*[datasize*count];
476    }
477
478
479    MPI_Reduce_local(sendbuf, local_sum, count, datatype, op, comm);
480
481    if(ep_rank_loc == 0)
482    {
483      #ifdef _serialized
484      #pragma omp critical (_mpi_call)
485      #endif // _serialized
486      ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm));
487    }
488
489
490    if(mpi_rank > 0)
491    {
492      MPI_Bcast_local(mpi_scan_recvbuf, count, datatype, comm);
493    }
494
495
496    if(datatype == MPI_DOUBLE)
497    {
498      double* sum_buf = static_cast<double*>(mpi_scan_recvbuf);
499      double* recv_buf = static_cast<double*>(recvbuf);
500
501      if(mpi_rank != 0)
502      {
503        if(op == MPI_SUM)
504        {
505          if(ep_rank_loc == 0)
506          {
507            copy(sum_buf, sum_buf+count, recv_buf);
508          }
509          else
510          {
511            for(int i=0; i<count; i++)
512            {
513              recv_buf[i] += sum_buf[i];
514            }
515          }
516        }
517        else if (op == MPI_MAX)
518        {
519          if(ep_rank_loc == 0)
520          {
521            copy(sum_buf, sum_buf+count, recv_buf);
522          }
523          else
524          {
525            for(int i=0; i<count; i++)
526            {
527              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
528            }
529          }
530        }
531        else if(op == MPI_MIN)
532        {
533          if(ep_rank_loc == 0)
534          {
535            copy(sum_buf, sum_buf+count, recv_buf);
536          }
537          else
538          {
539            for(int i=0; i<count; i++)
540            {
541              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
542            }
543          }
544        }
545        else
546        {
547          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
548          exit(1);
549        }
550      }
551
552      delete[] static_cast<double*>(mpi_scan_recvbuf);
553      if(ep_rank_loc == 0)
554      {
555        delete[] static_cast<double*>(local_sum);
556      }
557    }
558
559    else if(datatype == MPI_FLOAT)
560    {
561      float* sum_buf = static_cast<float*>(mpi_scan_recvbuf);
562      float* recv_buf = static_cast<float*>(recvbuf);
563
564      if(mpi_rank != 0)
565      {
566        if(op == MPI_SUM)
567        {
568          if(ep_rank_loc == 0)
569          {
570            copy(sum_buf, sum_buf+count, recv_buf);
571          }
572          else
573          {
574            for(int i=0; i<count; i++)
575            {
576              recv_buf[i] += sum_buf[i];
577            }
578          }
579        }
580        else if (op == MPI_MAX)
581        {
582          if(ep_rank_loc == 0)
583          {
584            copy(sum_buf, sum_buf+count, recv_buf);
585          }
586          else
587          {
588            for(int i=0; i<count; i++)
589            {
590              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
591            }
592          }
593        }
594        else if(op == MPI_MIN)
595        {
596          if(ep_rank_loc == 0)
597          {
598            copy(sum_buf, sum_buf+count, recv_buf);
599          }
600          else
601          {
602            for(int i=0; i<count; i++)
603            {
604              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
605            }
606          }
607        }
608        else
609        {
610          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
611          exit(1);
612        }
613      }
614
615      delete[] static_cast<float*>(mpi_scan_recvbuf);
616      if(ep_rank_loc == 0)
617      {
618        delete[] static_cast<float*>(local_sum);
619      }
620    }
621
622    else if(datatype == MPI_INT)
623    {
624      int* sum_buf = static_cast<int*>(mpi_scan_recvbuf);
625      int* recv_buf = static_cast<int*>(recvbuf);
626
627      if(mpi_rank != 0)
628      {
629        if(op == MPI_SUM)
630        {
631          if(ep_rank_loc == 0)
632          {
633            copy(sum_buf, sum_buf+count, recv_buf);
634          }
635          else
636          {
637            for(int i=0; i<count; i++)
638            {
639              recv_buf[i] += sum_buf[i];
640            }
641          }
642        }
643        else if (op == MPI_MAX)
644        {
645          if(ep_rank_loc == 0)
646          {
647            copy(sum_buf, sum_buf+count, recv_buf);
648          }
649          else
650          {
651            for(int i=0; i<count; i++)
652            {
653              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
654            }
655          }
656        }
657        else if(op == MPI_MIN)
658        {
659          if(ep_rank_loc == 0)
660          {
661            copy(sum_buf, sum_buf+count, recv_buf);
662          }
663          else
664          {
665            for(int i=0; i<count; i++)
666            {
667              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
668            }
669          }
670        }
671        else
672        {
673          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
674          exit(1);
675        }
676      }
677
678      delete[] static_cast<int*>(mpi_scan_recvbuf);
679      if(ep_rank_loc == 0)
680      {
681        delete[] static_cast<int*>(local_sum);
682      }
683    }
684
685    else if(datatype == MPI_CHAR)
686    {
687      char* sum_buf = static_cast<char*>(mpi_scan_recvbuf);
688      char* recv_buf = static_cast<char*>(recvbuf);
689
690      if(mpi_rank != 0)
691      {
692        if(op == MPI_SUM)
693        {
694          if(ep_rank_loc == 0)
695          {
696            copy(sum_buf, sum_buf+count, recv_buf);
697          }
698          else
699          {
700            for(int i=0; i<count; i++)
701            {
702              recv_buf[i] += sum_buf[i];
703            }
704          }
705        }
706        else if (op == MPI_MAX)
707        {
708          if(ep_rank_loc == 0)
709          {
710            copy(sum_buf, sum_buf+count, recv_buf);
711          }
712          else
713          {
714            for(int i=0; i<count; i++)
715            {
716              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
717            }
718          }
719        }
720        else if(op == MPI_MIN)
721        {
722          if(ep_rank_loc == 0)
723          {
724            copy(sum_buf, sum_buf+count, recv_buf);
725          }
726          else
727          {
728            for(int i=0; i<count; i++)
729            {
730              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
731            }
732          }
733        }
734        else
735        {
736          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
737          exit(1);
738        }
739      }
740
741      delete[] static_cast<char*>(mpi_scan_recvbuf);
742      if(ep_rank_loc == 0)
743      {
744        delete[] static_cast<char*>(local_sum);
745      }
746    }
747
748    else if(datatype == MPI_LONG)
749    {
750      long* sum_buf = static_cast<long*>(mpi_scan_recvbuf);
751      long* recv_buf = static_cast<long*>(recvbuf);
752
753      if(mpi_rank != 0)
754      {
755        if(op == MPI_SUM)
756        {
757          if(ep_rank_loc == 0)
758          {
759            copy(sum_buf, sum_buf+count, recv_buf);
760          }
761          else
762          {
763            for(int i=0; i<count; i++)
764            {
765              recv_buf[i] += sum_buf[i];
766            }
767          }
768        }
769        else if (op == MPI_MAX)
770        {
771          if(ep_rank_loc == 0)
772          {
773            copy(sum_buf, sum_buf+count, recv_buf);
774          }
775          else
776          {
777            for(int i=0; i<count; i++)
778            {
779              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
780            }
781          }
782        }
783        else if(op == MPI_MIN)
784        {
785          if(ep_rank_loc == 0)
786          {
787            copy(sum_buf, sum_buf+count, recv_buf);
788          }
789          else
790          {
791            for(int i=0; i<count; i++)
792            {
793              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
794            }
795          }
796        }
797        else
798        {
799          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
800          exit(1);
801        }
802      }
803
804      delete[] static_cast<long*>(mpi_scan_recvbuf);
805      if(ep_rank_loc == 0)
806      {
807        delete[] static_cast<long*>(local_sum);
808      }
809    }
810
811    else if(datatype == MPI_UNSIGNED_LONG)
812    {
813      unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf);
814      unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf);
815
816      if(mpi_rank != 0)
817      {
818        if(op == MPI_SUM)
819        {
820          if(ep_rank_loc == 0)
821          {
822            copy(sum_buf, sum_buf+count, recv_buf);
823          }
824          else
825          {
826            for(int i=0; i<count; i++)
827            {
828              recv_buf[i] += sum_buf[i];
829            }
830          }
831        }
832        else if (op == MPI_MAX)
833        {
834          if(ep_rank_loc == 0)
835          {
836            copy(sum_buf, sum_buf+count, recv_buf);
837          }
838          else
839          {
840            for(int i=0; i<count; i++)
841            {
842              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
843            }
844          }
845        }
846        else if(op == MPI_MIN)
847        {
848          if(ep_rank_loc == 0)
849          {
850            copy(sum_buf, sum_buf+count, recv_buf);
851          }
852          else
853          {
854            for(int i=0; i<count; i++)
855            {
856              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
857            }
858          }
859        }
860        else
861        {
862          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
863          exit(1);
864        }
865      }
866
867      delete[] static_cast<unsigned long*>(mpi_scan_recvbuf);
868      if(ep_rank_loc == 0)
869      {
870        delete[] static_cast<unsigned long*>(local_sum);
871      }
872    }
873
874
875  }
876
877
878
879}
Note: See TracBrowser for help on using the repository browser.