source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_exscan.cpp @ 1134

Last change on this file since 1134 was 1134, checked in by yushan, 7 years ago

branch merged with trunk r1130

File size: 22.7 KB
Line 
1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16  template<typename T>
17  T max_op(T a, T b)
18  {
19    return max(a,b);
20  }
21
22  template<typename T>
23  T min_op(T a, T b)
24  {
25    return min(a,b);
26  }
27
28  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
29  {
30    if(datatype == MPI_INT)
31    {
32      return MPI_Exscan_local_int(sendbuf, recvbuf, count, op, comm);
33    }
34    else if(datatype == MPI_FLOAT)
35    {
36      return MPI_Exscan_local_float(sendbuf, recvbuf, count, op, comm);
37    }
38    else if(datatype == MPI_DOUBLE)
39    {
40      return MPI_Exscan_local_double(sendbuf, recvbuf, count, op, comm);
41    }
42    else if(datatype == MPI_LONG)
43    {
44      return MPI_Exscan_local_long(sendbuf, recvbuf, count, op, comm);
45    }
46    else if(datatype == MPI_UNSIGNED_LONG)
47    {
48      return MPI_Exscan_local_ulong(sendbuf, recvbuf, count, op, comm);
49    }
50    else if(datatype == MPI_CHAR)
51    {
52      return MPI_Exscan_local_char(sendbuf, recvbuf, count, op, comm);
53    }
54    else
55    {
56      printf("MPI_Exscan Datatype not supported!\n");
57      exit(0);
58    }
59  }
60
61
62
63
64  int MPI_Exscan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
65  {
66    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
67    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
68
69    int *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_int;
70    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
71    int *recv_buf = static_cast<int*>(recvbuf);
72
73    for(int j=0; j<count; j+=BUFFER_SIZE)
74    {
75
76      if(my_rank == 0)
77      {
78
79        #pragma omp critical (write_to_buffer)
80        {
81          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
82          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
83          #pragma omp flush
84        }
85      }
86
87      MPI_Barrier_local(comm);
88
89      for(int k=1; k<num_ep; k++)
90      {
91        #pragma omp critical (write_to_buffer)
92        {
93          if(my_rank == k)
94          {
95            #pragma omp flush
96            if(op == MPI_SUM)
97            {
98              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
99              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>());
100
101            }
102            else if(op == MPI_MAX)
103            {
104              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
105              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>);
106            }
107            else if(op == MPI_MIN)
108            {
109              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
110              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>);
111            }
112            else
113            {
114              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
115              exit(1);
116            }
117            #pragma omp flush
118          }
119        }
120
121        MPI_Barrier_local(comm);
122      }
123    }
124
125  }
126
127  int MPI_Exscan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
128  {
129    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
130    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
131
132    float *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_float;
133    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
134    float *recv_buf = static_cast<float*>(recvbuf);
135
136    for(int j=0; j<count; j+=BUFFER_SIZE)
137    {
138      if(my_rank == 0)
139      {
140
141        #pragma omp critical (write_to_buffer)
142        {
143          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
144          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
145          #pragma omp flush
146        }
147      }
148
149      MPI_Barrier_local(comm);
150
151      for(int k=1; k<num_ep; k++)
152      {
153        #pragma omp critical (write_to_buffer)
154        {
155          if(my_rank == k)
156          {
157            #pragma omp flush
158            if(op == MPI_SUM)
159            {
160              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
161              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>());
162            }
163            else if(op == MPI_MAX)
164            {
165              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
166              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>);
167            }
168            else if(op == MPI_MIN)
169            {
170              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
171              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>);
172            }
173            else
174            {
175              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
176              exit(1);
177            }
178            #pragma omp flush
179          }
180        }
181
182        MPI_Barrier_local(comm);
183      }
184    }
185  }
186
187  int MPI_Exscan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
188  {
189
190    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
191    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
192
193    double *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_double;
194    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
195    double *recv_buf = static_cast<double*>(recvbuf);
196
197    for(int j=0; j<count; j+=BUFFER_SIZE)
198    {
199      if(my_rank == 0)
200      {
201
202        #pragma omp critical (write_to_buffer)
203        {
204          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
205          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
206          #pragma omp flush
207        }
208      }
209
210      MPI_Barrier_local(comm);
211
212      for(int k=1; k<num_ep; k++)
213      {
214        #pragma omp critical (write_to_buffer)
215        {
216          if(my_rank == k)
217          {
218            #pragma omp flush
219            if(op == MPI_SUM)
220            {
221              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
222              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>());
223            }
224            else if(op == MPI_MAX)
225            {
226              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
227              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>);
228            }
229            else if(op == MPI_MIN)
230            {
231              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
232              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>);
233            }
234            else
235            {
236              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
237              exit(1);
238            }
239            #pragma omp flush
240          }
241        }
242
243        MPI_Barrier_local(comm);
244      }
245    }
246  }
247
248  int MPI_Exscan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
249  {
250
251    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
252    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
253
254    long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_long;
255    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
256    long *recv_buf = static_cast<long*>(recvbuf);
257
258    for(int j=0; j<count; j+=BUFFER_SIZE)
259    {
260      if(my_rank == 0)
261      {
262
263        #pragma omp critical (write_to_buffer)
264        {
265          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
266          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
267          #pragma omp flush
268        }
269      }
270
271      MPI_Barrier_local(comm);
272
273      for(int k=1; k<num_ep; k++)
274      {
275        #pragma omp critical (write_to_buffer)
276        {
277          if(my_rank == k)
278          {
279            #pragma omp flush
280            if(op == MPI_SUM)
281            {
282              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
283              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>());
284            }
285            else if(op == MPI_MAX)
286            {
287              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
288              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>);
289            }
290            else if(op == MPI_MIN)
291            {
292              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
293              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>);
294            }
295            else
296            {
297              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
298              exit(1);
299            }
300            #pragma omp flush
301          }
302        }
303
304        MPI_Barrier_local(comm);
305      }
306    }
307  }
308
309  int MPI_Exscan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
310  {
311
312    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
313    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
314
315    unsigned long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_ulong;
316    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
317    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
318
319    for(int j=0; j<count; j+=BUFFER_SIZE)
320    {
321      if(my_rank == 0)
322      {
323
324        #pragma omp critical (write_to_buffer)
325        {
326          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
327          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
328          #pragma omp flush
329        }
330      }
331
332      MPI_Barrier_local(comm);
333
334      for(int k=1; k<num_ep; k++)
335      {
336        #pragma omp critical (write_to_buffer)
337        {
338          if(my_rank == k)
339          {
340            #pragma omp flush
341            if(op == MPI_SUM)
342            {
343              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
344              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>());
345            }
346            else if(op == MPI_MAX)
347            {
348              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
349              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>);
350            }
351            else if(op == MPI_MIN)
352            {
353              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
354              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>);
355            }
356            else
357            {
358              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
359              exit(1);
360            }
361            #pragma omp flush
362          }
363        }
364
365        MPI_Barrier_local(comm);
366      }
367    }
368  }
369
370  int MPI_Exscan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm)
371  {
372
373    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
374    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
375
376    char *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_char;
377    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
378    char *recv_buf = static_cast<char*>(recvbuf);
379
380    for(int j=0; j<count; j+=BUFFER_SIZE)
381    {
382      if(my_rank == 0)
383      {
384
385        #pragma omp critical (write_to_buffer)
386        {
387          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
388          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED);
389          #pragma omp flush
390        }
391      }
392
393      MPI_Barrier_local(comm);
394
395      for(int k=1; k<num_ep; k++)
396      {
397        #pragma omp critical (write_to_buffer)
398        {
399          if(my_rank == k)
400          {
401            #pragma omp flush
402            if(op == MPI_SUM)
403            {
404              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
405              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>());
406            }
407            else if(op == MPI_MAX)
408            {
409              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
410              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>);
411            }
412            else if(op == MPI_MIN)
413            {
414              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
415              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>);
416            }
417            else
418            {
419              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n");
420              exit(1);
421            }
422            #pragma omp flush
423          }
424        }
425
426        MPI_Barrier_local(comm);
427      }
428    }
429  }
430
431
432  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
433  {
434
435    if(!comm.is_ep)
436    {
437      ::MPI_Exscan(const_cast<void*>(sendbuf), recvbuf, count, static_cast< ::MPI_Datatype>(datatype),
438                   static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm));
439      return 0;
440    }
441    if(!comm.mpi_comm) return 0;
442
443    int ep_rank, ep_rank_loc, mpi_rank;
444    int ep_size, num_ep, mpi_size;
445
446    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
447    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
448    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
449    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
450    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
451    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
452
453
454
455    ::MPI_Aint datasize, lb;
456   
457    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
458
459    void* local_scan_recvbuf;
460    local_scan_recvbuf = new void*[datasize * count];
461
462
463    // local scan
464    MPI_Exscan_local(sendbuf, recvbuf, count, datatype, op, comm);
465
466//     MPI_scan
467    void* local_sum;
468    void* mpi_scan_recvbuf;
469
470
471    mpi_scan_recvbuf = new void*[datasize*count];
472
473    if(ep_rank_loc == 0)
474    {
475      local_sum = new void*[datasize*count];
476    }
477
478
479    MPI_Reduce_local(sendbuf, local_sum, count, datatype, op, comm);
480
481    if(ep_rank_loc == 0)
482    {
483      ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm));
484    }
485
486
487    if(mpi_rank > 0)
488    {
489      MPI_Bcast_local(mpi_scan_recvbuf, count, datatype, comm);
490    }
491
492
493    if(datatype == MPI_DOUBLE)
494    {
495      double* sum_buf = static_cast<double*>(mpi_scan_recvbuf);
496      double* recv_buf = static_cast<double*>(recvbuf);
497
498      if(mpi_rank != 0)
499      {
500        if(op == MPI_SUM)
501        {
502          if(ep_rank_loc == 0)
503          {
504            copy(sum_buf, sum_buf+count, recv_buf);
505          }
506          else
507          {
508            for(int i=0; i<count; i++)
509            {
510              recv_buf[i] += sum_buf[i];
511            }
512          }
513        }
514        else if (op == MPI_MAX)
515        {
516          if(ep_rank_loc == 0)
517          {
518            copy(sum_buf, sum_buf+count, recv_buf);
519          }
520          else
521          {
522            for(int i=0; i<count; i++)
523            {
524              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
525            }
526          }
527        }
528        else if(op == MPI_MIN)
529        {
530          if(ep_rank_loc == 0)
531          {
532            copy(sum_buf, sum_buf+count, recv_buf);
533          }
534          else
535          {
536            for(int i=0; i<count; i++)
537            {
538              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
539            }
540          }
541        }
542        else
543        {
544          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
545          exit(1);
546        }
547      }
548
549      delete[] static_cast<double*>(mpi_scan_recvbuf);
550      if(ep_rank_loc == 0)
551      {
552        delete[] static_cast<double*>(local_sum);
553      }
554    }
555
556    else if(datatype == MPI_FLOAT)
557    {
558      float* sum_buf = static_cast<float*>(mpi_scan_recvbuf);
559      float* recv_buf = static_cast<float*>(recvbuf);
560
561      if(mpi_rank != 0)
562      {
563        if(op == MPI_SUM)
564        {
565          if(ep_rank_loc == 0)
566          {
567            copy(sum_buf, sum_buf+count, recv_buf);
568          }
569          else
570          {
571            for(int i=0; i<count; i++)
572            {
573              recv_buf[i] += sum_buf[i];
574            }
575          }
576        }
577        else if (op == MPI_MAX)
578        {
579          if(ep_rank_loc == 0)
580          {
581            copy(sum_buf, sum_buf+count, recv_buf);
582          }
583          else
584          {
585            for(int i=0; i<count; i++)
586            {
587              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
588            }
589          }
590        }
591        else if(op == MPI_MIN)
592        {
593          if(ep_rank_loc == 0)
594          {
595            copy(sum_buf, sum_buf+count, recv_buf);
596          }
597          else
598          {
599            for(int i=0; i<count; i++)
600            {
601              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
602            }
603          }
604        }
605        else
606        {
607          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
608          exit(1);
609        }
610      }
611
612      delete[] static_cast<float*>(mpi_scan_recvbuf);
613      if(ep_rank_loc == 0)
614      {
615        delete[] static_cast<float*>(local_sum);
616      }
617    }
618
619    else if(datatype == MPI_INT)
620    {
621      int* sum_buf = static_cast<int*>(mpi_scan_recvbuf);
622      int* recv_buf = static_cast<int*>(recvbuf);
623
624      if(mpi_rank != 0)
625      {
626        if(op == MPI_SUM)
627        {
628          if(ep_rank_loc == 0)
629          {
630            copy(sum_buf, sum_buf+count, recv_buf);
631          }
632          else
633          {
634            for(int i=0; i<count; i++)
635            {
636              recv_buf[i] += sum_buf[i];
637            }
638          }
639        }
640        else if (op == MPI_MAX)
641        {
642          if(ep_rank_loc == 0)
643          {
644            copy(sum_buf, sum_buf+count, recv_buf);
645          }
646          else
647          {
648            for(int i=0; i<count; i++)
649            {
650              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
651            }
652          }
653        }
654        else if(op == MPI_MIN)
655        {
656          if(ep_rank_loc == 0)
657          {
658            copy(sum_buf, sum_buf+count, recv_buf);
659          }
660          else
661          {
662            for(int i=0; i<count; i++)
663            {
664              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
665            }
666          }
667        }
668        else
669        {
670          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
671          exit(1);
672        }
673      }
674
675      delete[] static_cast<int*>(mpi_scan_recvbuf);
676      if(ep_rank_loc == 0)
677      {
678        delete[] static_cast<int*>(local_sum);
679      }
680    }
681
682    else if(datatype == MPI_CHAR)
683    {
684      char* sum_buf = static_cast<char*>(mpi_scan_recvbuf);
685      char* recv_buf = static_cast<char*>(recvbuf);
686
687      if(mpi_rank != 0)
688      {
689        if(op == MPI_SUM)
690        {
691          if(ep_rank_loc == 0)
692          {
693            copy(sum_buf, sum_buf+count, recv_buf);
694          }
695          else
696          {
697            for(int i=0; i<count; i++)
698            {
699              recv_buf[i] += sum_buf[i];
700            }
701          }
702        }
703        else if (op == MPI_MAX)
704        {
705          if(ep_rank_loc == 0)
706          {
707            copy(sum_buf, sum_buf+count, recv_buf);
708          }
709          else
710          {
711            for(int i=0; i<count; i++)
712            {
713              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
714            }
715          }
716        }
717        else if(op == MPI_MIN)
718        {
719          if(ep_rank_loc == 0)
720          {
721            copy(sum_buf, sum_buf+count, recv_buf);
722          }
723          else
724          {
725            for(int i=0; i<count; i++)
726            {
727              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
728            }
729          }
730        }
731        else
732        {
733          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
734          exit(1);
735        }
736      }
737
738      delete[] static_cast<char*>(mpi_scan_recvbuf);
739      if(ep_rank_loc == 0)
740      {
741        delete[] static_cast<char*>(local_sum);
742      }
743    }
744
745    else if(datatype == MPI_LONG)
746    {
747      long* sum_buf = static_cast<long*>(mpi_scan_recvbuf);
748      long* recv_buf = static_cast<long*>(recvbuf);
749
750      if(mpi_rank != 0)
751      {
752        if(op == MPI_SUM)
753        {
754          if(ep_rank_loc == 0)
755          {
756            copy(sum_buf, sum_buf+count, recv_buf);
757          }
758          else
759          {
760            for(int i=0; i<count; i++)
761            {
762              recv_buf[i] += sum_buf[i];
763            }
764          }
765        }
766        else if (op == MPI_MAX)
767        {
768          if(ep_rank_loc == 0)
769          {
770            copy(sum_buf, sum_buf+count, recv_buf);
771          }
772          else
773          {
774            for(int i=0; i<count; i++)
775            {
776              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
777            }
778          }
779        }
780        else if(op == MPI_MIN)
781        {
782          if(ep_rank_loc == 0)
783          {
784            copy(sum_buf, sum_buf+count, recv_buf);
785          }
786          else
787          {
788            for(int i=0; i<count; i++)
789            {
790              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
791            }
792          }
793        }
794        else
795        {
796          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
797          exit(1);
798        }
799      }
800
801      delete[] static_cast<long*>(mpi_scan_recvbuf);
802      if(ep_rank_loc == 0)
803      {
804        delete[] static_cast<long*>(local_sum);
805      }
806    }
807
808    else if(datatype == MPI_UNSIGNED_LONG)
809    {
810      unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf);
811      unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf);
812
813      if(mpi_rank != 0)
814      {
815        if(op == MPI_SUM)
816        {
817          if(ep_rank_loc == 0)
818          {
819            copy(sum_buf, sum_buf+count, recv_buf);
820          }
821          else
822          {
823            for(int i=0; i<count; i++)
824            {
825              recv_buf[i] += sum_buf[i];
826            }
827          }
828        }
829        else if (op == MPI_MAX)
830        {
831          if(ep_rank_loc == 0)
832          {
833            copy(sum_buf, sum_buf+count, recv_buf);
834          }
835          else
836          {
837            for(int i=0; i<count; i++)
838            {
839              recv_buf[i] = max(recv_buf[i], sum_buf[i]);
840            }
841          }
842        }
843        else if(op == MPI_MIN)
844        {
845          if(ep_rank_loc == 0)
846          {
847            copy(sum_buf, sum_buf+count, recv_buf);
848          }
849          else
850          {
851            for(int i=0; i<count; i++)
852            {
853              recv_buf[i] = min(recv_buf[i], sum_buf[i]);
854            }
855          }
856        }
857        else
858        {
859          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n");
860          exit(1);
861        }
862      }
863
864      delete[] static_cast<unsigned long*>(mpi_scan_recvbuf);
865      if(ep_rank_loc == 0)
866      {
867        delete[] static_cast<unsigned long*>(local_sum);
868      }
869    }
870
871
872  }
873
874
875
876}
Note: See TracBrowser for help on using the repository browser.