Ignore:
Timestamp:
10/06/17 13:56:33 (7 years ago)
Author:
yushan
Message:

EP update all

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_exscan.cpp

    r1289 r1295  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
     11#include "ep_mpi.hpp" 
    1112 
    1213using namespace std; 
     
    2627  } 
    2728 
    28   int MPI_Exscan_local2(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    29   { 
    30     if(datatype == MPI_INT) 
    31     { 
    32       return MPI_Exscan_local_int(sendbuf, recvbuf, count, op, comm); 
    33     } 
    34     else if(datatype == MPI_FLOAT) 
    35     { 
    36       return MPI_Exscan_local_float(sendbuf, recvbuf, count, op, comm); 
    37     } 
    38     else if(datatype == MPI_DOUBLE) 
    39     { 
    40       return MPI_Exscan_local_double(sendbuf, recvbuf, count, op, comm); 
    41     } 
    42     else if(datatype == MPI_LONG) 
    43     { 
    44       return MPI_Exscan_local_long(sendbuf, recvbuf, count, op, comm); 
    45     } 
    46     else if(datatype == MPI_UNSIGNED_LONG) 
    47     { 
    48       return MPI_Exscan_local_ulong(sendbuf, recvbuf, count, op, comm); 
    49     } 
    50     else if(datatype == MPI_CHAR) 
    51     { 
    52       return MPI_Exscan_local_char(sendbuf, recvbuf, count, op, comm); 
    53     } 
    54     else 
    55     { 
    56       printf("MPI_Exscan Datatype not supported!\n"); 
    57       exit(0); 
    58     } 
    59   } 
    60  
    61  
    62  
    63  
    64   int MPI_Exscan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    65   { 
    66     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    67     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    68  
    69     int *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_int; 
    70     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    71     int *recv_buf = static_cast<int*>(recvbuf); 
    72  
    73     for(int j=0; j<count; j+=BUFFER_SIZE) 
    74     { 
    75  
    76       if(my_rank == 0) 
    77       { 
    78  
    79         #pragma omp critical (write_to_buffer) 
    80         { 
    81           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    82           fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
    83           #pragma omp flush 
    84         } 
    85       } 
    86  
    87       MPI_Barrier_local(comm); 
    88  
    89       for(int k=1; k<num_ep; k++) 
    90       { 
    91         #pragma omp critical (write_to_buffer) 
    92         { 
    93           if(my_rank == k) 
    94           { 
    95             #pragma omp flush 
    96             if(op == MPI_SUM) 
    97             { 
    98               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    99               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
    100  
    101             } 
    102             else if(op == MPI_MAX) 
    103             { 
    104               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    105               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
    106             } 
    107             else if(op == MPI_MIN) 
    108             { 
    109               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    110               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
    111             } 
    112             else 
    113             { 
    114               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    115               exit(1); 
    116             } 
    117             #pragma omp flush 
    118           } 
    119         } 
    120  
    121         MPI_Barrier_local(comm); 
    122       } 
    123     } 
    124  
    125   } 
    126  
    127   int MPI_Exscan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    128   { 
    129     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    130     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    131  
    132     float *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_float; 
    133     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    134     float *recv_buf = static_cast<float*>(recvbuf); 
    135  
    136     for(int j=0; j<count; j+=BUFFER_SIZE) 
    137     { 
    138       if(my_rank == 0) 
    139       { 
    140  
    141         #pragma omp critical (write_to_buffer) 
    142         { 
    143           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    144           fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
    145           #pragma omp flush 
    146         } 
    147       } 
    148  
    149       MPI_Barrier_local(comm); 
    150  
    151       for(int k=1; k<num_ep; k++) 
    152       { 
    153         #pragma omp critical (write_to_buffer) 
    154         { 
    155           if(my_rank == k) 
    156           { 
    157             #pragma omp flush 
    158             if(op == MPI_SUM) 
    159             { 
    160               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    161               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
    162             } 
    163             else if(op == MPI_MAX) 
    164             { 
    165               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    166               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
    167             } 
    168             else if(op == MPI_MIN) 
    169             { 
    170               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    171               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
    172             } 
    173             else 
    174             { 
    175               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    176               exit(1); 
    177             } 
    178             #pragma omp flush 
    179           } 
    180         } 
    181  
    182         MPI_Barrier_local(comm); 
    183       } 
    184     } 
    185   } 
    186  
    187   int MPI_Exscan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    188   { 
    189  
    190     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    191     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    192  
    193     double *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_double; 
    194     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    195     double *recv_buf = static_cast<double*>(recvbuf); 
    196  
    197     for(int j=0; j<count; j+=BUFFER_SIZE) 
    198     { 
    199       if(my_rank == 0) 
    200       { 
    201  
    202         #pragma omp critical (write_to_buffer) 
    203         { 
    204           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    205           fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
    206           #pragma omp flush 
    207         } 
    208       } 
    209  
    210       MPI_Barrier_local(comm); 
    211  
    212       for(int k=1; k<num_ep; k++) 
    213       { 
    214         #pragma omp critical (write_to_buffer) 
    215         { 
    216           if(my_rank == k) 
    217           { 
    218             #pragma omp flush 
    219             if(op == MPI_SUM) 
    220             { 
    221               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    222               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
    223             } 
    224             else if(op == MPI_MAX) 
    225             { 
    226               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    227               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
    228             } 
    229             else if(op == MPI_MIN) 
    230             { 
    231               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    232               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
    233             } 
    234             else 
    235             { 
    236               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    237               exit(1); 
    238             } 
    239             #pragma omp flush 
    240           } 
    241         } 
    242  
    243         MPI_Barrier_local(comm); 
    244       } 
    245     } 
    246   } 
    247  
    248   int MPI_Exscan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    249   { 
    250  
    251     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    252     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    253  
    254     long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_long; 
    255     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    256     long *recv_buf = static_cast<long*>(recvbuf); 
    257  
    258     for(int j=0; j<count; j+=BUFFER_SIZE) 
    259     { 
    260       if(my_rank == 0) 
    261       { 
    262  
    263         #pragma omp critical (write_to_buffer) 
    264         { 
    265           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    266           fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
    267           #pragma omp flush 
    268         } 
    269       } 
    270  
    271       MPI_Barrier_local(comm); 
    272  
    273       for(int k=1; k<num_ep; k++) 
    274       { 
    275         #pragma omp critical (write_to_buffer) 
    276         { 
    277           if(my_rank == k) 
    278           { 
    279             #pragma omp flush 
    280             if(op == MPI_SUM) 
    281             { 
    282               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    283               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
    284             } 
    285             else if(op == MPI_MAX) 
    286             { 
    287               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    288               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
    289             } 
    290             else if(op == MPI_MIN) 
    291             { 
    292               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    293               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
    294             } 
    295             else 
    296             { 
    297               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    298               exit(1); 
    299             } 
    300             #pragma omp flush 
    301           } 
    302         } 
    303  
    304         MPI_Barrier_local(comm); 
    305       } 
    306     } 
    307   } 
    308  
    309   int MPI_Exscan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    310   { 
    311  
    312     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    313     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    314  
    315     unsigned long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_ulong; 
    316     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    317     unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
    318  
    319     for(int j=0; j<count; j+=BUFFER_SIZE) 
    320     { 
    321       if(my_rank == 0) 
    322       { 
    323  
    324         #pragma omp critical (write_to_buffer) 
    325         { 
    326           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    327           fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
    328           #pragma omp flush 
    329         } 
    330       } 
    331  
    332       MPI_Barrier_local(comm); 
    333  
    334       for(int k=1; k<num_ep; k++) 
    335       { 
    336         #pragma omp critical (write_to_buffer) 
    337         { 
    338           if(my_rank == k) 
    339           { 
    340             #pragma omp flush 
    341             if(op == MPI_SUM) 
    342             { 
    343               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    344               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
    345             } 
    346             else if(op == MPI_MAX) 
    347             { 
    348               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    349               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
    350             } 
    351             else if(op == MPI_MIN) 
    352             { 
    353               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    354               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
    355             } 
    356             else 
    357             { 
    358               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    359               exit(1); 
    360             } 
    361             #pragma omp flush 
    362           } 
    363         } 
    364  
    365         MPI_Barrier_local(comm); 
    366       } 
    367     } 
    368   } 
    369  
    370   int MPI_Exscan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    371   { 
    372  
    373     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    374     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    375  
    376     char *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_char; 
    377     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    378     char *recv_buf = static_cast<char*>(recvbuf); 
    379  
    380     for(int j=0; j<count; j+=BUFFER_SIZE) 
    381     { 
    382       if(my_rank == 0) 
    383       { 
    384  
    385         #pragma omp critical (write_to_buffer) 
    386         { 
    387           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    388           fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
    389           #pragma omp flush 
    390         } 
    391       } 
    392  
    393       MPI_Barrier_local(comm); 
    394  
    395       for(int k=1; k<num_ep; k++) 
    396       { 
    397         #pragma omp critical (write_to_buffer) 
    398         { 
    399           if(my_rank == k) 
    400           { 
    401             #pragma omp flush 
    402             if(op == MPI_SUM) 
    403             { 
    404               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    405               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
    406             } 
    407             else if(op == MPI_MAX) 
    408             { 
    409               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    410               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
    411             } 
    412             else if(op == MPI_MIN) 
    413             { 
    414               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    415               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
    416             } 
    417             else 
    418             { 
    419               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    420               exit(1); 
    421             } 
    422             #pragma omp flush 
    423           } 
    424         } 
    425  
    426         MPI_Barrier_local(comm); 
    427       } 
    428     } 
    429   } 
    430  
     29  template<typename T> 
     30  void reduce_max(const T * buffer, T* recvbuf, int count) 
     31  { 
     32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
     33  } 
     34 
     35  template<typename T> 
     36  void reduce_min(const T * buffer, T* recvbuf, int count) 
     37  { 
     38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
     39  } 
     40 
     41  template<typename T> 
     42  void reduce_sum(const T * buffer, T* recvbuf, int count) 
     43  { 
     44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
     45  } 
     46 
     47 
     48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
     49  { 
     50    valid_op(op); 
     51 
     52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     55     
     56 
     57    ::MPI_Aint datasize, lb; 
     58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     59 
     60    if(ep_rank_loc == 0 && mpi_rank != 0) 
     61    { 
     62      comm.my_buffer->void_buffer[0] = recvbuf; 
     63    } 
     64    if(ep_rank_loc == 0 && mpi_rank == 0) 
     65    { 
     66      comm.my_buffer->void_buffer[0] = const_cast<void*>(sendbuf);   
     67    }  
     68       
     69 
     70    MPI_Barrier_local(comm); 
     71 
     72    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count); 
     73 
     74    MPI_Barrier_local(comm); 
     75 
     76    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf);   
     77     
     78    MPI_Barrier_local(comm); 
     79 
     80    if(op == MPI_SUM) 
     81    { 
     82      if(datatype == MPI_INT ) 
     83      { 
     84        assert(datasize == sizeof(int)); 
     85        for(int i=0; i<ep_rank_loc; i++) 
     86          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     87      } 
     88      
     89      else if(datatype == MPI_FLOAT ) 
     90      { 
     91        assert(datasize == sizeof(float)); 
     92        for(int i=0; i<ep_rank_loc; i++) 
     93          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     94      } 
     95       
     96 
     97      else if(datatype == MPI_DOUBLE ) 
     98      { 
     99        assert(datasize == sizeof(double)); 
     100        for(int i=0; i<ep_rank_loc; i++) 
     101          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     102      } 
     103 
     104      else if(datatype == MPI_CHAR ) 
     105      { 
     106        assert(datasize == sizeof(char)); 
     107        for(int i=0; i<ep_rank_loc; i++) 
     108          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     109      } 
     110 
     111      else if(datatype == MPI_LONG ) 
     112      { 
     113        assert(datasize == sizeof(long)); 
     114        for(int i=0; i<ep_rank_loc; i++) 
     115          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     116      } 
     117 
     118      else if(datatype == MPI_UNSIGNED_LONG ) 
     119      { 
     120        assert(datasize == sizeof(unsigned long)); 
     121        for(int i=0; i<ep_rank_loc; i++) 
     122          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     123      } 
     124 
     125      else printf("datatype Error\n"); 
     126 
     127       
     128    } 
     129 
     130    else if(op == MPI_MAX) 
     131    { 
     132      if(datatype == MPI_INT ) 
     133      { 
     134        assert(datasize == sizeof(int)); 
     135        for(int i=0; i<ep_rank_loc; i++) 
     136          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     137      } 
     138 
     139      else if(datatype == MPI_FLOAT ) 
     140      { 
     141        assert(datasize == sizeof(float)); 
     142        for(int i=0; i<ep_rank_loc; i++) 
     143          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     144      } 
     145 
     146      else if(datatype == MPI_DOUBLE ) 
     147      { 
     148        assert(datasize == sizeof(double)); 
     149        for(int i=0; i<ep_rank_loc; i++) 
     150          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     151      } 
     152 
     153      else if(datatype == MPI_CHAR ) 
     154      { 
     155        assert(datasize == sizeof(char)); 
     156        for(int i=0; i<ep_rank_loc; i++) 
     157          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     158      } 
     159 
     160      else if(datatype == MPI_LONG ) 
     161      { 
     162        assert(datasize == sizeof(long)); 
     163        for(int i=0; i<ep_rank_loc; i++) 
     164          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     165      } 
     166 
     167      else if(datatype == MPI_UNSIGNED_LONG ) 
     168      { 
     169        assert(datasize == sizeof(unsigned long)); 
     170        for(int i=0; i<ep_rank_loc; i++) 
     171          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     172      } 
     173      
     174      else printf("datatype Error\n"); 
     175    } 
     176 
     177    else //if(op == MPI_MIN) 
     178    { 
     179      if(datatype == MPI_INT ) 
     180      { 
     181        assert(datasize == sizeof(int)); 
     182        for(int i=0; i<ep_rank_loc; i++) 
     183          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     184      } 
     185 
     186      else if(datatype == MPI_FLOAT ) 
     187      { 
     188        assert(datasize == sizeof(float)); 
     189        for(int i=0; i<ep_rank_loc; i++) 
     190          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     191      } 
     192 
     193      else if(datatype == MPI_DOUBLE ) 
     194      { 
     195        assert(datasize == sizeof(double)); 
     196        for(int i=0; i<ep_rank_loc; i++) 
     197          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     198      } 
     199 
     200      else if(datatype == MPI_CHAR ) 
     201      { 
     202        assert(datasize == sizeof(char)); 
     203        for(int i=0; i<ep_rank_loc; i++) 
     204          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     205      } 
     206 
     207      else if(datatype == MPI_LONG ) 
     208      { 
     209        assert(datasize == sizeof(long)); 
     210        for(int i=0; i<ep_rank_loc; i++) 
     211          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     212      } 
     213 
     214      else if(datatype == MPI_UNSIGNED_LONG ) 
     215      { 
     216        assert(datasize == sizeof(unsigned long)); 
     217        for(int i=0; i<ep_rank_loc; i++) 
     218          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     219      } 
     220 
     221      else printf("datatype Error\n"); 
     222    } 
     223 
     224    MPI_Barrier_local(comm); 
     225 
     226  } 
    431227 
    432228  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    433229  { 
    434  
    435230    if(!comm.is_ep) 
    436231    { 
    437       ::MPI_Exscan(const_cast<void*>(sendbuf), recvbuf, count, static_cast< ::MPI_Datatype>(datatype), 
    438                    static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    439       return 0; 
    440     } 
    441     if(!comm.mpi_comm) return 0; 
    442  
    443     int ep_rank, ep_rank_loc, mpi_rank; 
    444     int ep_size, num_ep, mpi_size; 
    445  
    446     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    447     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    448     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    449     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    450     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    451     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    452  
    453  
     232      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
     233    } 
     234     
     235    valid_type(datatype); 
     236 
     237    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     238    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     239    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     240    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     241    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     242    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    454243 
    455244    ::MPI_Aint datasize, lb; 
    456      
    457     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    458  
    459     void* local_scan_recvbuf; 
    460     local_scan_recvbuf = new void*[datasize * count]; 
    461  
    462  
    463     // local scan 
    464     MPI_Exscan_local2(sendbuf, recvbuf, count, datatype, op, comm); 
    465  
    466 //     MPI_scan 
    467     void* local_sum; 
    468     void* mpi_scan_recvbuf; 
    469  
    470  
    471     mpi_scan_recvbuf = new void*[datasize*count]; 
     245    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     246     
     247    void* tmp_sendbuf; 
     248    tmp_sendbuf = new void*[datasize * count]; 
     249 
     250    int my_src = 0; 
     251    int my_dst = ep_rank; 
     252 
     253    std::vector<int> my_map(mpi_size, 0); 
     254 
     255    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++; 
     256 
     257    for(int i=0; i<mpi_rank; i++) my_src += my_map[i]; 
     258    my_src += ep_rank_loc; 
     259 
     260      
     261    for(int i=0; i<mpi_size; i++) 
     262    { 
     263      if(my_dst < my_map[i]) 
     264      { 
     265        my_dst = get_ep_rank(comm, my_dst, i);  
     266        break; 
     267      } 
     268      else 
     269        my_dst -= my_map[i]; 
     270    } 
     271 
     272    if(ep_rank != my_dst)  
     273    { 
     274      MPI_Request request[2]; 
     275      MPI_Status status[2]; 
     276 
     277      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]); 
     278     
     279      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]); 
     280     
     281      MPI_Waitall(2, request, status); 
     282    } 
     283 
     284    else memcpy(tmp_sendbuf, sendbuf, datasize*count); 
     285     
     286 
     287    void* tmp_recvbuf; 
     288    tmp_recvbuf = new void*[datasize * count];     
     289 
     290    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm); 
    472291 
    473292    if(ep_rank_loc == 0) 
    474     { 
    475       local_sum = new void*[datasize*count]; 
    476     } 
    477  
    478  
    479     MPI_Reduce_local2(sendbuf, local_sum, count, datatype, op, comm); 
    480  
    481     if(ep_rank_loc == 0) 
    482     { 
    483       ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    484     } 
    485  
    486  
    487     if(mpi_rank > 0) 
    488     { 
    489       MPI_Bcast_local2(mpi_scan_recvbuf, count, datatype, comm); 
    490     } 
    491  
    492  
    493     if(datatype == MPI_DOUBLE) 
    494     { 
    495       double* sum_buf = static_cast<double*>(mpi_scan_recvbuf); 
    496       double* recv_buf = static_cast<double*>(recvbuf); 
    497  
    498       if(mpi_rank != 0) 
    499       { 
    500         if(op == MPI_SUM) 
    501         { 
    502           if(ep_rank_loc == 0) 
    503           { 
    504             copy(sum_buf, sum_buf+count, recv_buf); 
    505           } 
    506           else 
    507           { 
    508             for(int i=0; i<count; i++) 
    509             { 
    510               recv_buf[i] += sum_buf[i]; 
    511             } 
    512           } 
    513         } 
    514         else if (op == MPI_MAX) 
    515         { 
    516           if(ep_rank_loc == 0) 
    517           { 
    518             copy(sum_buf, sum_buf+count, recv_buf); 
    519           } 
    520           else 
    521           { 
    522             for(int i=0; i<count; i++) 
    523             { 
    524               recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    525             } 
    526           } 
    527         } 
    528         else if(op == MPI_MIN) 
    529         { 
    530           if(ep_rank_loc == 0) 
    531           { 
    532             copy(sum_buf, sum_buf+count, recv_buf); 
    533           } 
    534           else 
    535           { 
    536             for(int i=0; i<count; i++) 
    537             { 
    538               recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    539             } 
    540           } 
    541         } 
    542         else 
    543         { 
    544           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    545           exit(1); 
    546         } 
    547       } 
    548  
    549       delete[] static_cast<double*>(mpi_scan_recvbuf); 
    550       if(ep_rank_loc == 0) 
    551       { 
    552         delete[] static_cast<double*>(local_sum); 
    553       } 
    554     } 
    555  
    556     else if(datatype == MPI_FLOAT) 
    557     { 
    558       float* sum_buf = static_cast<float*>(mpi_scan_recvbuf); 
    559       float* recv_buf = static_cast<float*>(recvbuf); 
    560  
    561       if(mpi_rank != 0) 
    562       { 
    563         if(op == MPI_SUM) 
    564         { 
    565           if(ep_rank_loc == 0) 
    566           { 
    567             copy(sum_buf, sum_buf+count, recv_buf); 
    568           } 
    569           else 
    570           { 
    571             for(int i=0; i<count; i++) 
    572             { 
    573               recv_buf[i] += sum_buf[i]; 
    574             } 
    575           } 
    576         } 
    577         else if (op == MPI_MAX) 
    578         { 
    579           if(ep_rank_loc == 0) 
    580           { 
    581             copy(sum_buf, sum_buf+count, recv_buf); 
    582           } 
    583           else 
    584           { 
    585             for(int i=0; i<count; i++) 
    586             { 
    587               recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    588             } 
    589           } 
    590         } 
    591         else if(op == MPI_MIN) 
    592         { 
    593           if(ep_rank_loc == 0) 
    594           { 
    595             copy(sum_buf, sum_buf+count, recv_buf); 
    596           } 
    597           else 
    598           { 
    599             for(int i=0; i<count; i++) 
    600             { 
    601               recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    602             } 
    603           } 
    604         } 
    605         else 
    606         { 
    607           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    608           exit(1); 
    609         } 
    610       } 
    611  
    612       delete[] static_cast<float*>(mpi_scan_recvbuf); 
    613       if(ep_rank_loc == 0) 
    614       { 
    615         delete[] static_cast<float*>(local_sum); 
    616       } 
    617     } 
    618  
    619     else if(datatype == MPI_INT) 
    620     { 
    621       int* sum_buf = static_cast<int*>(mpi_scan_recvbuf); 
    622       int* recv_buf = static_cast<int*>(recvbuf); 
    623  
    624       if(mpi_rank != 0) 
    625       { 
    626         if(op == MPI_SUM) 
    627         { 
    628           if(ep_rank_loc == 0) 
    629           { 
    630             copy(sum_buf, sum_buf+count, recv_buf); 
    631           } 
    632           else 
    633           { 
    634             for(int i=0; i<count; i++) 
    635             { 
    636               recv_buf[i] += sum_buf[i]; 
    637             } 
    638           } 
    639         } 
    640         else if (op == MPI_MAX) 
    641         { 
    642           if(ep_rank_loc == 0) 
    643           { 
    644             copy(sum_buf, sum_buf+count, recv_buf); 
    645           } 
    646           else 
    647           { 
    648             for(int i=0; i<count; i++) 
    649             { 
    650               recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    651             } 
    652           } 
    653         } 
    654         else if(op == MPI_MIN) 
    655         { 
    656           if(ep_rank_loc == 0) 
    657           { 
    658             copy(sum_buf, sum_buf+count, recv_buf); 
    659           } 
    660           else 
    661           { 
    662             for(int i=0; i<count; i++) 
    663             { 
    664               recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    665             } 
    666           } 
    667         } 
    668         else 
    669         { 
    670           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    671           exit(1); 
    672         } 
    673       } 
    674  
    675       delete[] static_cast<int*>(mpi_scan_recvbuf); 
    676       if(ep_rank_loc == 0) 
    677       { 
    678         delete[] static_cast<int*>(local_sum); 
    679       } 
    680     } 
    681  
    682     else if(datatype == MPI_CHAR) 
    683     { 
    684       char* sum_buf = static_cast<char*>(mpi_scan_recvbuf); 
    685       char* recv_buf = static_cast<char*>(recvbuf); 
    686  
    687       if(mpi_rank != 0) 
    688       { 
    689         if(op == MPI_SUM) 
    690         { 
    691           if(ep_rank_loc == 0) 
    692           { 
    693             copy(sum_buf, sum_buf+count, recv_buf); 
    694           } 
    695           else 
    696           { 
    697             for(int i=0; i<count; i++) 
    698             { 
    699               recv_buf[i] += sum_buf[i]; 
    700             } 
    701           } 
    702         } 
    703         else if (op == MPI_MAX) 
    704         { 
    705           if(ep_rank_loc == 0) 
    706           { 
    707             copy(sum_buf, sum_buf+count, recv_buf); 
    708           } 
    709           else 
    710           { 
    711             for(int i=0; i<count; i++) 
    712             { 
    713               recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    714             } 
    715           } 
    716         } 
    717         else if(op == MPI_MIN) 
    718         { 
    719           if(ep_rank_loc == 0) 
    720           { 
    721             copy(sum_buf, sum_buf+count, recv_buf); 
    722           } 
    723           else 
    724           { 
    725             for(int i=0; i<count; i++) 
    726             { 
    727               recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    728             } 
    729           } 
    730         } 
    731         else 
    732         { 
    733           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    734           exit(1); 
    735         } 
    736       } 
    737  
    738       delete[] static_cast<char*>(mpi_scan_recvbuf); 
    739       if(ep_rank_loc == 0) 
    740       { 
    741         delete[] static_cast<char*>(local_sum); 
    742       } 
    743     } 
    744  
    745     else if(datatype == MPI_LONG) 
    746     { 
    747       long* sum_buf = static_cast<long*>(mpi_scan_recvbuf); 
    748       long* recv_buf = static_cast<long*>(recvbuf); 
    749  
    750       if(mpi_rank != 0) 
    751       { 
    752         if(op == MPI_SUM) 
    753         { 
    754           if(ep_rank_loc == 0) 
    755           { 
    756             copy(sum_buf, sum_buf+count, recv_buf); 
    757           } 
    758           else 
    759           { 
    760             for(int i=0; i<count; i++) 
    761             { 
    762               recv_buf[i] += sum_buf[i]; 
    763             } 
    764           } 
    765         } 
    766         else if (op == MPI_MAX) 
    767         { 
    768           if(ep_rank_loc == 0) 
    769           { 
    770             copy(sum_buf, sum_buf+count, recv_buf); 
    771           } 
    772           else 
    773           { 
    774             for(int i=0; i<count; i++) 
    775             { 
    776               recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    777             } 
    778           } 
    779         } 
    780         else if(op == MPI_MIN) 
    781         { 
    782           if(ep_rank_loc == 0) 
    783           { 
    784             copy(sum_buf, sum_buf+count, recv_buf); 
    785           } 
    786           else 
    787           { 
    788             for(int i=0; i<count; i++) 
    789             { 
    790               recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    791             } 
    792           } 
    793         } 
    794         else 
    795         { 
    796           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    797           exit(1); 
    798         } 
    799       } 
    800  
    801       delete[] static_cast<long*>(mpi_scan_recvbuf); 
    802       if(ep_rank_loc == 0) 
    803       { 
    804         delete[] static_cast<long*>(local_sum); 
    805       } 
    806     } 
    807  
    808     else if(datatype == MPI_UNSIGNED_LONG) 
    809     { 
    810       unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf); 
    811       unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf); 
    812  
    813       if(mpi_rank != 0) 
    814       { 
    815         if(op == MPI_SUM) 
    816         { 
    817           if(ep_rank_loc == 0) 
    818           { 
    819             copy(sum_buf, sum_buf+count, recv_buf); 
    820           } 
    821           else 
    822           { 
    823             for(int i=0; i<count; i++) 
    824             { 
    825               recv_buf[i] += sum_buf[i]; 
    826             } 
    827           } 
    828         } 
    829         else if (op == MPI_MAX) 
    830         { 
    831           if(ep_rank_loc == 0) 
    832           { 
    833             copy(sum_buf, sum_buf+count, recv_buf); 
    834           } 
    835           else 
    836           { 
    837             for(int i=0; i<count; i++) 
    838             { 
    839               recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    840             } 
    841           } 
    842         } 
    843         else if(op == MPI_MIN) 
    844         { 
    845           if(ep_rank_loc == 0) 
    846           { 
    847             copy(sum_buf, sum_buf+count, recv_buf); 
    848           } 
    849           else 
    850           { 
    851             for(int i=0; i<count; i++) 
    852             { 
    853               recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    854             } 
    855           } 
    856         } 
    857         else 
    858         { 
    859           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    860           exit(1); 
    861         } 
    862       } 
    863  
    864       delete[] static_cast<unsigned long*>(mpi_scan_recvbuf); 
    865       if(ep_rank_loc == 0) 
    866       { 
    867         delete[] static_cast<unsigned long*>(local_sum); 
    868       } 
    869     } 
    870  
    871  
    872   } 
    873  
    874  
     293      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
     294 
     295    // printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
     296     
     297    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm); 
     298 
     299     // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
     300 
     301 
     302 
     303    if(ep_rank != my_src)  
     304    { 
     305      MPI_Request request[2]; 
     306      MPI_Status status[2]; 
     307 
     308      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]); 
     309     
     310      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]); 
     311     
     312      MPI_Waitall(2, request, status); 
     313    } 
     314 
     315    else memcpy(recvbuf, tmp_recvbuf, datasize*count); 
     316     
     317 
     318 
     319 
     320    delete[] tmp_sendbuf; 
     321    delete[] tmp_recvbuf; 
     322 
     323  } 
    875324 
    876325} 
Note: See TracChangeset for help on using the changeset viewer.