Ignore:
Timestamp:
10/04/17 11:45:14 (7 years ago)
Author:
yushan
Message:

EP updated

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scan.cpp

    r1134 r1287  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
     11#include "ep_mpi.hpp" 
    1112 
    1213using namespace std; 
     
    2627  } 
    2728 
     29  template<typename T> 
     30  void reduce_max(const T * buffer, T* recvbuf, int count) 
     31  { 
     32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
     33  } 
     34 
     35  template<typename T> 
     36  void reduce_min(const T * buffer, T* recvbuf, int count) 
     37  { 
     38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
     39  } 
     40 
     41  template<typename T> 
     42  void reduce_sum(const T * buffer, T* recvbuf, int count) 
     43  { 
     44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
     45  } 
     46 
    2847 
    2948  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    3049  { 
    31     if(datatype == MPI_INT) 
    32     { 
    33       return MPI_Scan_local_int(sendbuf, recvbuf, count, op, comm); 
    34     } 
    35     else if(datatype == MPI_FLOAT) 
    36     { 
    37       return MPI_Scan_local_float(sendbuf, recvbuf, count, op, comm); 
    38     } 
    39     else if(datatype == MPI_DOUBLE) 
    40     { 
    41       return MPI_Scan_local_double(sendbuf, recvbuf, count, op, comm); 
    42     } 
    43     else if(datatype == MPI_LONG) 
    44     { 
    45       return MPI_Scan_local_long(sendbuf, recvbuf, count, op, comm); 
    46     } 
    47     else if(datatype == MPI_UNSIGNED_LONG) 
    48     { 
    49       return MPI_Scan_local_ulong(sendbuf, recvbuf, count, op, comm); 
    50     } 
    51     else if(datatype == MPI_CHAR) 
    52     { 
    53       return MPI_Scan_local_char(sendbuf, recvbuf, count, op, comm); 
     50    valid_op(op); 
     51 
     52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     55     
     56 
     57    ::MPI_Aint datasize, lb; 
     58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     59 
     60    if(ep_rank_loc == 0 && mpi_rank != 0) 
     61    { 
     62      if(op == MPI_SUM) 
     63      { 
     64        if(datatype == MPI_INT && datasize == sizeof(int)) 
     65          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
     66           
     67        else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
     68          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
     69              
     70        else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
     71          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
     72       
     73        else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
     74          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
     75       
     76        else if(datatype == MPI_LONG && datasize == sizeof(long)) 
     77          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
     78             
     79        else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
     80          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
     81             
     82        else printf("datatype Error\n"); 
     83      } 
     84 
     85      else if(op == MPI_MAX) 
     86      { 
     87        if(datatype == MPI_INT && datasize == sizeof(int)) 
     88          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
     89           
     90        else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
     91          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
     92              
     93        else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
     94          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
     95       
     96        else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
     97          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
     98       
     99        else if(datatype == MPI_LONG && datasize == sizeof(long)) 
     100          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
     101             
     102        else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
     103          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
     104             
     105        else printf("datatype Error\n"); 
     106      } 
     107 
     108      else //(op == MPI_MIN) 
     109      { 
     110        if(datatype == MPI_INT && datasize == sizeof(int)) 
     111          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
     112           
     113        else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
     114          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
     115              
     116        else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
     117          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
     118       
     119        else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
     120          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
     121       
     122        else if(datatype == MPI_LONG && datasize == sizeof(long)) 
     123          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
     124             
     125        else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
     126          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
     127             
     128        else printf("datatype Error\n"); 
     129      } 
     130 
     131      comm.my_buffer->void_buffer[0] = recvbuf; 
    54132    } 
    55133    else 
    56134    { 
    57       printf("MPI_Scan Datatype not supported!\n"); 
    58       exit(0); 
    59     } 
    60  
    61   } 
    62  
    63  
    64  
    65  
    66   int MPI_Scan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    67   { 
    68     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    69     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    70  
    71     int *buffer = comm.my_buffer->buf_int; 
    72     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    73     int *recv_buf = static_cast<int*>(recvbuf); 
    74  
    75     for(int j=0; j<count; j+=BUFFER_SIZE) 
    76     { 
    77       if(my_rank == 0) 
    78       { 
    79  
    80         #pragma omp critical (write_to_buffer) 
    81         { 
    82           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    83           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    84           #pragma omp flush 
    85         } 
    86       } 
    87  
    88       MPI_Barrier_local(comm); 
    89  
    90       for(int k=1; k<num_ep; k++) 
    91       { 
    92         #pragma omp critical (write_to_buffer) 
    93         { 
    94           if(my_rank == k) 
    95           { 
    96             #pragma omp flush 
    97             if(op == MPI_SUM) 
    98             { 
    99               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
    100               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    101             } 
    102             else if(op == MPI_MAX) 
    103             { 
    104               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
    105               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    106             } 
    107             else if(op == MPI_MIN) 
    108             { 
    109               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
    110               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    111             } 
    112             else 
    113             { 
    114               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    115               exit(1); 
    116             } 
    117             #pragma omp flush 
    118           } 
    119         } 
    120  
    121         MPI_Barrier_local(comm); 
    122       } 
    123     } 
    124  
    125   } 
    126  
    127   int MPI_Scan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    128   { 
    129     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    130     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    131  
    132     float *buffer = comm.my_buffer->buf_float; 
    133     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    134     float *recv_buf = static_cast<float*>(recvbuf); 
    135  
    136     for(int j=0; j<count; j+=BUFFER_SIZE) 
    137     { 
    138       if(my_rank == 0) 
    139       { 
    140  
    141         #pragma omp critical (write_to_buffer) 
    142         { 
    143           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    144           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    145           #pragma omp flush 
    146         } 
    147       } 
    148  
    149       MPI_Barrier_local(comm); 
    150  
    151       for(int k=1; k<num_ep; k++) 
    152       { 
    153         #pragma omp critical (write_to_buffer) 
    154         { 
    155           if(my_rank == k) 
    156           { 
    157             #pragma omp flush 
    158             if(op == MPI_SUM) 
    159             { 
    160               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
    161               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    162             } 
    163             else if(op == MPI_MAX) 
    164             { 
    165               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
    166               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    167  
    168             } 
    169             else if(op == MPI_MIN) 
    170             { 
    171               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
    172               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    173  
    174             } 
    175             else 
    176             { 
    177               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    178               exit(1); 
    179             } 
    180             #pragma omp flush 
    181           } 
    182         } 
    183  
    184         MPI_Barrier_local(comm); 
    185       } 
    186     } 
    187   } 
    188  
    189   int MPI_Scan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    190   { 
    191     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    192     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    193  
    194     double *buffer = comm.my_buffer->buf_double; 
    195     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    196     double *recv_buf = static_cast<double*>(recvbuf); 
    197  
    198     for(int j=0; j<count; j+=BUFFER_SIZE) 
    199     { 
    200       if(my_rank == 0) 
    201       { 
    202  
    203         #pragma omp critical (write_to_buffer) 
    204         { 
    205           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    206           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    207           #pragma omp flush 
    208         } 
    209       } 
    210  
    211       MPI_Barrier_local(comm); 
    212  
    213       for(int k=1; k<num_ep; k++) 
    214       { 
    215         #pragma omp critical (write_to_buffer) 
    216         { 
    217           if(my_rank == k) 
    218           { 
    219             #pragma omp flush 
    220             if(op == MPI_SUM) 
    221             { 
    222               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
    223               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    224             } 
    225             else if(op == MPI_MAX) 
    226             { 
    227               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
    228               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    229             } 
    230             else if(op == MPI_MIN) 
    231             { 
    232               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
    233               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    234             } 
    235             else 
    236             { 
    237               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    238               exit(1); 
    239             } 
    240             #pragma omp flush 
    241           } 
    242         } 
    243  
    244         MPI_Barrier_local(comm); 
    245       } 
    246     } 
    247   } 
    248  
    249   int MPI_Scan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    250   { 
    251     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    252     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    253  
    254     long *buffer = comm.my_buffer->buf_long; 
    255     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    256     long *recv_buf = static_cast<long*>(recvbuf); 
    257  
    258     for(int j=0; j<count; j+=BUFFER_SIZE) 
    259     { 
    260       if(my_rank == 0) 
    261       { 
    262  
    263         #pragma omp critical (write_to_buffer) 
    264         { 
    265           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    266           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    267           #pragma omp flush 
    268         } 
    269       } 
    270  
    271       MPI_Barrier_local(comm); 
    272  
    273       for(int k=1; k<num_ep; k++) 
    274       { 
    275         #pragma omp critical (write_to_buffer) 
    276         { 
    277           if(my_rank == k) 
    278           { 
    279             #pragma omp flush 
    280             if(op == MPI_SUM) 
    281             { 
    282               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
    283               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    284             } 
    285             else if(op == MPI_MAX) 
    286             { 
    287               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
    288               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    289             } 
    290             else if(op == MPI_MIN) 
    291             { 
    292               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
    293               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    294             } 
    295             else 
    296             { 
    297               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    298               exit(1); 
    299             } 
    300             #pragma omp flush 
    301           } 
    302         } 
    303  
    304         MPI_Barrier_local(comm); 
    305       } 
    306     } 
    307   } 
    308  
    309   int MPI_Scan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    310   { 
    311     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    312     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    313  
    314     unsigned long *buffer = comm.my_buffer->buf_ulong; 
    315     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    316     unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
    317  
    318     for(int j=0; j<count; j+=BUFFER_SIZE) 
    319     { 
    320       if(my_rank == 0) 
    321       { 
    322  
    323         #pragma omp critical (write_to_buffer) 
    324         { 
    325           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    326           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    327           #pragma omp flush 
    328         } 
    329       } 
    330  
    331       MPI_Barrier_local(comm); 
    332  
    333       for(int k=1; k<num_ep; k++) 
    334       { 
    335         #pragma omp critical (write_to_buffer) 
    336         { 
    337           if(my_rank == k) 
    338           { 
    339             #pragma omp flush 
    340             if(op == MPI_SUM) 
    341             { 
    342               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
    343               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    344             } 
    345             else if(op == MPI_MAX) 
    346             { 
    347               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
    348               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    349             } 
    350             else if(op == MPI_MIN) 
    351             { 
    352               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
    353               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    354             } 
    355             else 
    356             { 
    357               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    358               exit(1); 
    359             } 
    360             #pragma omp flush 
    361           } 
    362         } 
    363  
    364         MPI_Barrier_local(comm); 
    365       } 
    366     } 
    367   } 
    368  
    369   int MPI_Scan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    370   { 
    371     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    372     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    373  
    374     char *buffer = comm.my_buffer->buf_char; 
    375     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    376     char *recv_buf = static_cast<char*>(recvbuf); 
    377  
    378     for(int j=0; j<count; j+=BUFFER_SIZE) 
    379     { 
    380       if(my_rank == 0) 
    381       { 
    382  
    383         #pragma omp critical (write_to_buffer) 
    384         { 
    385           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    386           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    387           #pragma omp flush 
    388         } 
    389       } 
    390  
    391       MPI_Barrier_local(comm); 
    392  
    393       for(int k=1; k<num_ep; k++) 
    394       { 
    395         #pragma omp critical (write_to_buffer) 
    396         { 
    397           if(my_rank == k) 
    398           { 
    399             #pragma omp flush 
    400             if(op == MPI_SUM) 
    401             { 
    402               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
    403               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    404             } 
    405             else if(op == MPI_MAX) 
    406             { 
    407               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
    408               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    409             } 
    410             else if(op == MPI_MIN) 
    411             { 
    412               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
    413               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    414             } 
    415             else 
    416             { 
    417               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    418               exit(1); 
    419             } 
    420             #pragma omp flush 
    421           } 
    422         } 
    423  
    424         MPI_Barrier_local(comm); 
    425       } 
    426     } 
     135      comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf);   
     136      memcpy(recvbuf, sendbuf, datasize*count); 
     137    }  
     138       
     139 
     140 
     141    MPI_Barrier_local(comm); 
     142 
     143    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count); 
     144 
     145 
     146    if(op == MPI_SUM) 
     147    { 
     148      if(datatype == MPI_INT && datasize == sizeof(int)) 
     149      { 
     150        for(int i=1; i<ep_rank_loc+1; i++) 
     151          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     152      } 
     153      
     154      else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
     155      { 
     156        for(int i=1; i<ep_rank_loc+1; i++) 
     157          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     158      } 
     159       
     160 
     161      else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
     162      { 
     163        for(int i=1; i<ep_rank_loc+1; i++) 
     164          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     165      } 
     166 
     167      else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
     168      { 
     169        for(int i=1; i<ep_rank_loc+1; i++) 
     170          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     171      } 
     172 
     173      else if(datatype == MPI_LONG && datasize == sizeof(long)) 
     174      { 
     175        for(int i=1; i<ep_rank_loc+1; i++) 
     176          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     177      } 
     178 
     179      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
     180      { 
     181        for(int i=1; i<ep_rank_loc+1; i++) 
     182          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     183      } 
     184 
     185      else printf("datatype Error\n"); 
     186 
     187       
     188    } 
     189 
     190    else if(op == MPI_MAX) 
     191    { 
     192      if(datatype == MPI_INT && datasize == sizeof(int)) 
     193        for(int i=1; i<ep_rank_loc+1; i++) 
     194          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     195 
     196      else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
     197        for(int i=1; i<ep_rank_loc+1; i++) 
     198          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     199 
     200      else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
     201        for(int i=1; i<ep_rank_loc+1; i++) 
     202          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     203 
     204      else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
     205        for(int i=1; i<ep_rank_loc+1; i++) 
     206          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     207 
     208      else if(datatype == MPI_LONG && datasize == sizeof(long)) 
     209        for(int i=1; i<ep_rank_loc+1; i++) 
     210          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     211 
     212      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
     213        for(int i=1; i<ep_rank_loc+1; i++) 
     214          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     215      
     216      else printf("datatype Error\n"); 
     217    } 
     218 
     219    else //if(op == MPI_MIN) 
     220    { 
     221      if(datatype == MPI_INT && datasize == sizeof(int)) 
     222        for(int i=1; i<ep_rank_loc+1; i++) 
     223          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     224 
     225      else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
     226        for(int i=1; i<ep_rank_loc+1; i++) 
     227          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     228 
     229      else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
     230        for(int i=1; i<ep_rank_loc+1; i++) 
     231          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     232 
     233      else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
     234        for(int i=1; i<ep_rank_loc+1; i++) 
     235          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     236 
     237      else if(datatype == MPI_LONG && datasize == sizeof(long)) 
     238        for(int i=1; i<ep_rank_loc+1; i++) 
     239          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     240 
     241      else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
     242        for(int i=1; i<ep_rank_loc+1; i++) 
     243          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     244 
     245      else printf("datatype Error\n"); 
     246    } 
     247 
     248    MPI_Barrier_local(comm); 
     249 
    427250  } 
    428251 
     
    432255    if(!comm.is_ep) 
    433256    { 
    434  
    435       ::MPI_Scan(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), 
    436                  static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    437       return 0; 
    438     } 
    439  
    440     if(!comm.mpi_comm) return 0; 
    441  
    442     int ep_rank, ep_rank_loc, mpi_rank; 
    443     int ep_size, num_ep, mpi_size; 
    444  
    445  
    446     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    447     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    448     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    449     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    450     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    451     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    452  
     257      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
     258    } 
     259     
     260    valid_type(datatype); 
     261 
     262    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     263    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     264    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     265    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     266    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     267    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    453268 
    454269    ::MPI_Aint datasize, lb; 
    455  
    456     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    457  
    458     void* local_scan_recvbuf; 
    459     local_scan_recvbuf = new void*[datasize * count]; 
    460  
    461  
    462     // local scan 
    463     MPI_Scan_local(sendbuf, recvbuf, count, datatype, op, comm); 
    464  
    465 //     MPI_scan 
    466     void* local_sum; 
    467     void* mpi_scan_recvbuf; 
    468  
    469  
    470     mpi_scan_recvbuf = new void*[datasize*count]; 
     270    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     271     
     272    void* tmp_sendbuf; 
     273    tmp_sendbuf = new void*[datasize * count]; 
     274 
     275    int my_src = 0; 
     276    int my_dst = ep_rank; 
     277 
     278    std::vector<int> my_map(mpi_size, 0); 
     279 
     280    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++; 
     281 
     282    for(int i=0; i<mpi_rank; i++) my_src += my_map[i]; 
     283    my_src += ep_rank_loc; 
     284 
     285      
     286    for(int i=0; i<mpi_size; i++) 
     287    { 
     288      if(my_dst < my_map[i]) 
     289      { 
     290        my_dst = get_ep_rank(comm, my_dst, i);  
     291        break; 
     292      } 
     293      else 
     294        my_dst -= my_map[i]; 
     295    } 
     296 
     297    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src); 
     298    MPI_Barrier(comm); 
     299 
     300    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count); 
     301 
     302    if(ep_rank != my_dst)  
     303    { 
     304      MPI_Request request[2]; 
     305      MPI_Status status[2]; 
     306 
     307      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]); 
     308     
     309      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]); 
     310     
     311      MPI_Waitall(2, request, status); 
     312    } 
     313     
     314 
     315    void* tmp_recvbuf; 
     316    tmp_recvbuf = new void*[datasize * count];     
     317 
     318    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm); 
    471319 
    472320    if(ep_rank_loc == 0) 
    473     { 
    474       local_sum = new void*[datasize*count]; 
    475     } 
    476  
    477  
    478     MPI_Reduce_local(sendbuf, local_sum, count, datatype, op, comm); 
    479  
    480     if(ep_rank_loc == 0) 
    481     { 
    482       ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    483     } 
    484  
    485  
    486     if(mpi_rank > 0) 
    487     { 
    488       MPI_Bcast_local(mpi_scan_recvbuf, count, datatype, comm); 
    489     } 
    490  
    491  
    492     if(datatype == MPI_DOUBLE) 
    493     { 
    494       double* sum_buf = static_cast<double*>(mpi_scan_recvbuf); 
    495       double* recv_buf = static_cast<double*>(recvbuf); 
    496  
    497       if(mpi_rank != 0) 
    498       { 
    499         if(op == MPI_SUM) 
    500         { 
    501           for(int i=0; i<count; i++) 
    502           { 
    503             recv_buf[i] += sum_buf[i]; 
    504           } 
    505         } 
    506         else if (op == MPI_MAX) 
    507         { 
    508           for(int i=0; i<count; i++) 
    509           { 
    510             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    511           } 
    512         } 
    513         else if(op == MPI_MIN) 
    514         { 
    515           for(int i=0; i<count; i++) 
    516           { 
    517             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    518           } 
    519         } 
    520         else 
    521         { 
    522           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    523           exit(1); 
    524         } 
    525       } 
    526  
    527       delete[] static_cast<double*>(mpi_scan_recvbuf); 
    528       if(ep_rank_loc == 0) 
    529       { 
    530         delete[] static_cast<double*>(local_sum); 
    531       } 
    532     } 
    533  
    534     else if(datatype == MPI_FLOAT) 
    535     { 
    536       float* sum_buf = static_cast<float*>(mpi_scan_recvbuf); 
    537       float* recv_buf = static_cast<float*>(recvbuf); 
    538  
    539       if(mpi_rank != 0) 
    540       { 
    541         if(op == MPI_SUM) 
    542         { 
    543           for(int i=0; i<count; i++) 
    544           { 
    545             recv_buf[i] += sum_buf[i]; 
    546           } 
    547         } 
    548         else if (op == MPI_MAX) 
    549         { 
    550           for(int i=0; i<count; i++) 
    551           { 
    552             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    553           } 
    554         } 
    555         else if(op == MPI_MIN) 
    556         { 
    557           for(int i=0; i<count; i++) 
    558           { 
    559             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    560           } 
    561         } 
    562         else 
    563         { 
    564           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    565           exit(1); 
    566         } 
    567       } 
    568  
    569       delete[] static_cast<float*>(mpi_scan_recvbuf); 
    570       if(ep_rank_loc == 0) 
    571       { 
    572         delete[] static_cast<float*>(local_sum); 
    573       } 
    574     } 
    575  
    576     else if(datatype == MPI_INT) 
    577     { 
    578       int* sum_buf = static_cast<int*>(mpi_scan_recvbuf); 
    579       int* recv_buf = static_cast<int*>(recvbuf); 
    580  
    581       if(mpi_rank != 0) 
    582       { 
    583         if(op == MPI_SUM) 
    584         { 
    585           for(int i=0; i<count; i++) 
    586           { 
    587             recv_buf[i] += sum_buf[i]; 
    588           } 
    589         } 
    590         else if (op == MPI_MAX) 
    591         { 
    592           for(int i=0; i<count; i++) 
    593           { 
    594             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    595           } 
    596         } 
    597         else if(op == MPI_MIN) 
    598         { 
    599           for(int i=0; i<count; i++) 
    600           { 
    601             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    602           } 
    603         } 
    604         else 
    605         { 
    606           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    607           exit(1); 
    608         } 
    609       } 
    610  
    611       delete[] static_cast<int*>(mpi_scan_recvbuf); 
    612       if(ep_rank_loc == 0) 
    613       { 
    614         delete[] static_cast<int*>(local_sum); 
    615       } 
    616     } 
    617  
    618     else if(datatype == MPI_LONG) 
    619     { 
    620       long* sum_buf = static_cast<long*>(mpi_scan_recvbuf); 
    621       long* recv_buf = static_cast<long*>(recvbuf); 
    622  
    623       if(mpi_rank != 0) 
    624       { 
    625         if(op == MPI_SUM) 
    626         { 
    627           for(int i=0; i<count; i++) 
    628           { 
    629             recv_buf[i] += sum_buf[i]; 
    630           } 
    631         } 
    632         else if (op == MPI_MAX) 
    633         { 
    634           for(int i=0; i<count; i++) 
    635           { 
    636             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    637           } 
    638         } 
    639         else if(op == MPI_MIN) 
    640         { 
    641           for(int i=0; i<count; i++) 
    642           { 
    643             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    644           } 
    645         } 
    646         else 
    647         { 
    648           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    649           exit(1); 
    650         } 
    651       } 
    652  
    653       delete[] static_cast<long*>(mpi_scan_recvbuf); 
    654       if(ep_rank_loc == 0) 
    655       { 
    656         delete[] static_cast<long*>(local_sum); 
    657       } 
    658     } 
    659  
    660     else if(datatype == MPI_UNSIGNED_LONG) 
    661     { 
    662       unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf); 
    663       unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf); 
    664  
    665       if(mpi_rank != 0) 
    666       { 
    667         if(op == MPI_SUM) 
    668         { 
    669           for(int i=0; i<count; i++) 
    670           { 
    671             recv_buf[i] += sum_buf[i]; 
    672           } 
    673         } 
    674         else if (op == MPI_MAX) 
    675         { 
    676           for(int i=0; i<count; i++) 
    677           { 
    678             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    679           } 
    680         } 
    681         else if(op == MPI_MIN) 
    682         { 
    683           for(int i=0; i<count; i++) 
    684           { 
    685             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    686           } 
    687         } 
    688         else 
    689         { 
    690           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    691           exit(1); 
    692         } 
    693       } 
    694  
    695       delete[] static_cast<unsigned long*>(mpi_scan_recvbuf); 
    696       if(ep_rank_loc == 0) 
    697       { 
    698         delete[] static_cast<unsigned long*>(local_sum); 
    699       } 
    700     } 
    701  
    702     else if(datatype == MPI_CHAR) 
    703     { 
    704       char* sum_buf = static_cast<char*>(mpi_scan_recvbuf); 
    705       char* recv_buf = static_cast<char*>(recvbuf); 
    706  
    707       if(mpi_rank != 0) 
    708       { 
    709         if(op == MPI_SUM) 
    710         { 
    711           for(int i=0; i<count; i++) 
    712           { 
    713             recv_buf[i] += sum_buf[i]; 
    714           } 
    715         } 
    716         else if (op == MPI_MAX) 
    717         { 
    718           for(int i=0; i<count; i++) 
    719           { 
    720             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    721           } 
    722         } 
    723         else if(op == MPI_MIN) 
    724         { 
    725           for(int i=0; i<count; i++) 
    726           { 
    727             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    728           } 
    729         } 
    730         else 
    731         { 
    732           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    733           exit(1); 
    734         } 
    735       } 
    736  
    737       delete[] static_cast<char*>(mpi_scan_recvbuf); 
    738       if(ep_rank_loc == 0) 
    739       { 
    740         delete[] static_cast<char*>(local_sum); 
    741       } 
    742     } 
    743  
     321      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
     322 
     323    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
     324     
     325    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm); 
     326 
     327    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
     328 
     329 
     330 
     331    if(ep_rank != my_src)  
     332    { 
     333      MPI_Request request[2]; 
     334      MPI_Status status[2]; 
     335 
     336      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]); 
     337     
     338      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]); 
     339     
     340      MPI_Waitall(2, request, status); 
     341    } 
     342 
     343    else memcpy(recvbuf, tmp_recvbuf, datasize*count); 
     344     
     345 
     346 
     347 
     348    delete[] tmp_sendbuf; 
     349    delete[] tmp_recvbuf; 
    744350 
    745351  } 
Note: See TracChangeset for help on using the changeset viewer.