Ignore:
Timestamp:
10/06/17 13:56:33 (7 years ago)
Author:
yushan
Message:

EP update all

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scan.cpp

    r1289 r1295  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
     11#include "ep_mpi.hpp" 
    1112 
    1213using namespace std; 
     
    2627  } 
    2728 
    28  
    29   int MPI_Scan_local2(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    30   { 
    31     if(datatype == MPI_INT) 
    32     { 
    33       return MPI_Scan_local_int(sendbuf, recvbuf, count, op, comm); 
    34     } 
    35     else if(datatype == MPI_FLOAT) 
    36     { 
    37       return MPI_Scan_local_float(sendbuf, recvbuf, count, op, comm); 
    38     } 
    39     else if(datatype == MPI_DOUBLE) 
    40     { 
    41       return MPI_Scan_local_double(sendbuf, recvbuf, count, op, comm); 
    42     } 
    43     else if(datatype == MPI_LONG) 
    44     { 
    45       return MPI_Scan_local_long(sendbuf, recvbuf, count, op, comm); 
    46     } 
    47     else if(datatype == MPI_UNSIGNED_LONG) 
    48     { 
    49       return MPI_Scan_local_ulong(sendbuf, recvbuf, count, op, comm); 
    50     } 
    51     else if(datatype == MPI_CHAR) 
    52     { 
    53       return MPI_Scan_local_char(sendbuf, recvbuf, count, op, comm); 
     29  template<typename T> 
     30  void reduce_max(const T * buffer, T* recvbuf, int count) 
     31  { 
     32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
     33  } 
     34 
     35  template<typename T> 
     36  void reduce_min(const T * buffer, T* recvbuf, int count) 
     37  { 
     38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
     39  } 
     40 
     41  template<typename T> 
     42  void reduce_sum(const T * buffer, T* recvbuf, int count) 
     43  { 
     44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
     45  } 
     46 
     47 
     48  int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
     49  { 
     50    valid_op(op); 
     51 
     52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     53    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     54    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     55     
     56 
     57    ::MPI_Aint datasize, lb; 
     58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     59 
     60    if(ep_rank_loc == 0 && mpi_rank != 0) 
     61    { 
     62      if(op == MPI_SUM) 
     63      { 
     64        if(datatype == MPI_INT) 
     65        { 
     66          assert(datasize == sizeof(int)); 
     67          reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
     68        } 
     69           
     70        else if(datatype == MPI_FLOAT) 
     71        { 
     72          assert( datasize == sizeof(float)); 
     73          reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
     74        }  
     75              
     76        else if(datatype == MPI_DOUBLE ) 
     77        { 
     78          assert( datasize == sizeof(double)); 
     79          reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
     80        } 
     81       
     82        else if(datatype == MPI_CHAR) 
     83        { 
     84          assert( datasize == sizeof(char)); 
     85          reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
     86        }  
     87           
     88        else if(datatype == MPI_LONG) 
     89        { 
     90          assert( datasize == sizeof(long)); 
     91          reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
     92        }  
     93           
     94             
     95        else if(datatype == MPI_UNSIGNED_LONG) 
     96        { 
     97          assert(datasize == sizeof(unsigned long)); 
     98          reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
     99        } 
     100             
     101        else printf("datatype Error\n"); 
     102      } 
     103 
     104      else if(op == MPI_MAX) 
     105      { 
     106        if(datatype == MPI_INT) 
     107        { 
     108          assert( datasize == sizeof(int)); 
     109          reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
     110        }  
     111           
     112        else if(datatype == MPI_FLOAT ) 
     113        { 
     114          assert( datasize == sizeof(float)); 
     115          reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
     116        } 
     117 
     118        else if(datatype == MPI_DOUBLE ) 
     119        { 
     120          assert( datasize == sizeof(double)); 
     121          reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
     122        } 
     123       
     124        else if(datatype == MPI_CHAR ) 
     125        { 
     126          assert(datasize == sizeof(char)); 
     127          reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
     128        } 
     129       
     130        else if(datatype == MPI_LONG) 
     131        { 
     132          assert( datasize == sizeof(long)); 
     133          reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
     134        }  
     135             
     136        else if(datatype == MPI_UNSIGNED_LONG) 
     137        { 
     138          assert( datasize == sizeof(unsigned long)); 
     139          reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
     140        }  
     141             
     142        else printf("datatype Error\n"); 
     143      } 
     144 
     145      else //(op == MPI_MIN) 
     146      { 
     147        if(datatype == MPI_INT ) 
     148        { 
     149          assert (datasize == sizeof(int)); 
     150          reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
     151        } 
     152           
     153        else if(datatype == MPI_FLOAT ) 
     154        { 
     155          assert( datasize == sizeof(float)); 
     156          reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
     157        } 
     158              
     159        else if(datatype == MPI_DOUBLE ) 
     160        { 
     161          assert( datasize == sizeof(double)); 
     162          reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
     163        } 
     164       
     165        else if(datatype == MPI_CHAR ) 
     166        { 
     167          assert( datasize == sizeof(char)); 
     168          reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
     169        } 
     170       
     171        else if(datatype == MPI_LONG ) 
     172        {  
     173          assert( datasize == sizeof(long)); 
     174          reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
     175        } 
     176             
     177        else if(datatype == MPI_UNSIGNED_LONG ) 
     178        { 
     179          assert( datasize == sizeof(unsigned long)); 
     180          reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
     181        } 
     182             
     183        else printf("datatype Error\n"); 
     184      } 
     185 
     186      comm.my_buffer->void_buffer[0] = recvbuf; 
    54187    } 
    55188    else 
    56189    { 
    57       printf("MPI_Scan Datatype not supported!\n"); 
    58       exit(0); 
    59     } 
    60  
    61   } 
    62  
    63  
    64  
    65  
    66   int MPI_Scan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    67   { 
    68     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    69     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    70  
    71     int *buffer = comm.my_buffer->buf_int; 
    72     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    73     int *recv_buf = static_cast<int*>(recvbuf); 
    74  
    75     for(int j=0; j<count; j+=BUFFER_SIZE) 
    76     { 
    77       if(my_rank == 0) 
    78       { 
    79  
    80         #pragma omp critical (write_to_buffer) 
    81         { 
    82           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    83           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    84           #pragma omp flush 
    85         } 
    86       } 
    87  
    88       MPI_Barrier_local(comm); 
    89  
    90       for(int k=1; k<num_ep; k++) 
    91       { 
    92         #pragma omp critical (write_to_buffer) 
    93         { 
    94           if(my_rank == k) 
    95           { 
    96             #pragma omp flush 
    97             if(op == MPI_SUM) 
    98             { 
    99               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
    100               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    101             } 
    102             else if(op == MPI_MAX) 
    103             { 
    104               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
    105               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    106             } 
    107             else if(op == MPI_MIN) 
    108             { 
    109               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
    110               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    111             } 
    112             else 
    113             { 
    114               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    115               exit(1); 
    116             } 
    117             #pragma omp flush 
    118           } 
    119         } 
    120  
    121         MPI_Barrier_local(comm); 
    122       } 
    123     } 
    124  
    125   } 
    126  
    127   int MPI_Scan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    128   { 
    129     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    130     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    131  
    132     float *buffer = comm.my_buffer->buf_float; 
    133     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    134     float *recv_buf = static_cast<float*>(recvbuf); 
    135  
    136     for(int j=0; j<count; j+=BUFFER_SIZE) 
    137     { 
    138       if(my_rank == 0) 
    139       { 
    140  
    141         #pragma omp critical (write_to_buffer) 
    142         { 
    143           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    144           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    145           #pragma omp flush 
    146         } 
    147       } 
    148  
    149       MPI_Barrier_local(comm); 
    150  
    151       for(int k=1; k<num_ep; k++) 
    152       { 
    153         #pragma omp critical (write_to_buffer) 
    154         { 
    155           if(my_rank == k) 
    156           { 
    157             #pragma omp flush 
    158             if(op == MPI_SUM) 
    159             { 
    160               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
    161               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    162             } 
    163             else if(op == MPI_MAX) 
    164             { 
    165               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
    166               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    167  
    168             } 
    169             else if(op == MPI_MIN) 
    170             { 
    171               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
    172               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    173  
    174             } 
    175             else 
    176             { 
    177               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    178               exit(1); 
    179             } 
    180             #pragma omp flush 
    181           } 
    182         } 
    183  
    184         MPI_Barrier_local(comm); 
    185       } 
    186     } 
    187   } 
    188  
    189   int MPI_Scan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    190   { 
    191     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    192     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    193  
    194     double *buffer = comm.my_buffer->buf_double; 
    195     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    196     double *recv_buf = static_cast<double*>(recvbuf); 
    197  
    198     for(int j=0; j<count; j+=BUFFER_SIZE) 
    199     { 
    200       if(my_rank == 0) 
    201       { 
    202  
    203         #pragma omp critical (write_to_buffer) 
    204         { 
    205           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    206           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    207           #pragma omp flush 
    208         } 
    209       } 
    210  
    211       MPI_Barrier_local(comm); 
    212  
    213       for(int k=1; k<num_ep; k++) 
    214       { 
    215         #pragma omp critical (write_to_buffer) 
    216         { 
    217           if(my_rank == k) 
    218           { 
    219             #pragma omp flush 
    220             if(op == MPI_SUM) 
    221             { 
    222               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
    223               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    224             } 
    225             else if(op == MPI_MAX) 
    226             { 
    227               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
    228               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    229             } 
    230             else if(op == MPI_MIN) 
    231             { 
    232               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
    233               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    234             } 
    235             else 
    236             { 
    237               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    238               exit(1); 
    239             } 
    240             #pragma omp flush 
    241           } 
    242         } 
    243  
    244         MPI_Barrier_local(comm); 
    245       } 
    246     } 
    247   } 
    248  
    249   int MPI_Scan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    250   { 
    251     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    252     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    253  
    254     long *buffer = comm.my_buffer->buf_long; 
    255     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    256     long *recv_buf = static_cast<long*>(recvbuf); 
    257  
    258     for(int j=0; j<count; j+=BUFFER_SIZE) 
    259     { 
    260       if(my_rank == 0) 
    261       { 
    262  
    263         #pragma omp critical (write_to_buffer) 
    264         { 
    265           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    266           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    267           #pragma omp flush 
    268         } 
    269       } 
    270  
    271       MPI_Barrier_local(comm); 
    272  
    273       for(int k=1; k<num_ep; k++) 
    274       { 
    275         #pragma omp critical (write_to_buffer) 
    276         { 
    277           if(my_rank == k) 
    278           { 
    279             #pragma omp flush 
    280             if(op == MPI_SUM) 
    281             { 
    282               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
    283               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    284             } 
    285             else if(op == MPI_MAX) 
    286             { 
    287               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
    288               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    289             } 
    290             else if(op == MPI_MIN) 
    291             { 
    292               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
    293               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    294             } 
    295             else 
    296             { 
    297               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    298               exit(1); 
    299             } 
    300             #pragma omp flush 
    301           } 
    302         } 
    303  
    304         MPI_Barrier_local(comm); 
    305       } 
    306     } 
    307   } 
    308  
    309   int MPI_Scan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    310   { 
    311     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    312     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    313  
    314     unsigned long *buffer = comm.my_buffer->buf_ulong; 
    315     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    316     unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
    317  
    318     for(int j=0; j<count; j+=BUFFER_SIZE) 
    319     { 
    320       if(my_rank == 0) 
    321       { 
    322  
    323         #pragma omp critical (write_to_buffer) 
    324         { 
    325           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    326           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    327           #pragma omp flush 
    328         } 
    329       } 
    330  
    331       MPI_Barrier_local(comm); 
    332  
    333       for(int k=1; k<num_ep; k++) 
    334       { 
    335         #pragma omp critical (write_to_buffer) 
    336         { 
    337           if(my_rank == k) 
    338           { 
    339             #pragma omp flush 
    340             if(op == MPI_SUM) 
    341             { 
    342               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
    343               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    344             } 
    345             else if(op == MPI_MAX) 
    346             { 
    347               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
    348               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    349             } 
    350             else if(op == MPI_MIN) 
    351             { 
    352               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
    353               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    354             } 
    355             else 
    356             { 
    357               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    358               exit(1); 
    359             } 
    360             #pragma omp flush 
    361           } 
    362         } 
    363  
    364         MPI_Barrier_local(comm); 
    365       } 
    366     } 
    367   } 
    368  
    369   int MPI_Scan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    370   { 
    371     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    372     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    373  
    374     char *buffer = comm.my_buffer->buf_char; 
    375     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    376     char *recv_buf = static_cast<char*>(recvbuf); 
    377  
    378     for(int j=0; j<count; j+=BUFFER_SIZE) 
    379     { 
    380       if(my_rank == 0) 
    381       { 
    382  
    383         #pragma omp critical (write_to_buffer) 
    384         { 
    385           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
    386           copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
    387           #pragma omp flush 
    388         } 
    389       } 
    390  
    391       MPI_Barrier_local(comm); 
    392  
    393       for(int k=1; k<num_ep; k++) 
    394       { 
    395         #pragma omp critical (write_to_buffer) 
    396         { 
    397           if(my_rank == k) 
    398           { 
    399             #pragma omp flush 
    400             if(op == MPI_SUM) 
    401             { 
    402               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
    403               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    404             } 
    405             else if(op == MPI_MAX) 
    406             { 
    407               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
    408               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    409             } 
    410             else if(op == MPI_MIN) 
    411             { 
    412               transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
    413               copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    414             } 
    415             else 
    416             { 
    417               printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    418               exit(1); 
    419             } 
    420             #pragma omp flush 
    421           } 
    422         } 
    423  
    424         MPI_Barrier_local(comm); 
    425       } 
    426     } 
     190      comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf);   
     191      memcpy(recvbuf, sendbuf, datasize*count); 
     192    }  
     193       
     194 
     195 
     196    MPI_Barrier_local(comm); 
     197 
     198    memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count); 
     199 
     200 
     201    if(op == MPI_SUM) 
     202    { 
     203      if(datatype == MPI_INT ) 
     204      { 
     205        assert (datasize == sizeof(int)); 
     206        for(int i=1; i<ep_rank_loc+1; i++) 
     207          reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     208      } 
     209      
     210      else if(datatype == MPI_FLOAT ) 
     211      { 
     212        assert(datasize == sizeof(float)); 
     213        for(int i=1; i<ep_rank_loc+1; i++) 
     214          reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     215      } 
     216       
     217 
     218      else if(datatype == MPI_DOUBLE ) 
     219      { 
     220        assert(datasize == sizeof(double)); 
     221        for(int i=1; i<ep_rank_loc+1; i++) 
     222          reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     223      } 
     224 
     225      else if(datatype == MPI_CHAR ) 
     226      { 
     227        assert(datasize == sizeof(char)); 
     228        for(int i=1; i<ep_rank_loc+1; i++) 
     229          reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     230      } 
     231 
     232      else if(datatype == MPI_LONG ) 
     233      { 
     234        assert(datasize == sizeof(long)); 
     235        for(int i=1; i<ep_rank_loc+1; i++) 
     236          reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     237      } 
     238 
     239      else if(datatype == MPI_UNSIGNED_LONG ) 
     240      { 
     241        assert(datasize == sizeof(unsigned long)); 
     242        for(int i=1; i<ep_rank_loc+1; i++) 
     243          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     244      } 
     245 
     246      else printf("datatype Error\n"); 
     247 
     248       
     249    } 
     250 
     251    else if(op == MPI_MAX) 
     252    { 
     253      if(datatype == MPI_INT) 
     254      { 
     255        assert(datasize == sizeof(int)); 
     256        for(int i=1; i<ep_rank_loc+1; i++) 
     257          reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     258      } 
     259 
     260      else if(datatype == MPI_FLOAT ) 
     261      { 
     262        assert(datasize == sizeof(float)); 
     263        for(int i=1; i<ep_rank_loc+1; i++) 
     264          reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     265      } 
     266 
     267      else if(datatype == MPI_DOUBLE ) 
     268      { 
     269        assert(datasize == sizeof(double)); 
     270        for(int i=1; i<ep_rank_loc+1; i++) 
     271          reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     272      } 
     273 
     274      else if(datatype == MPI_CHAR ) 
     275      { 
     276        assert(datasize == sizeof(char)); 
     277        for(int i=1; i<ep_rank_loc+1; i++) 
     278          reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     279      } 
     280 
     281      else if(datatype == MPI_LONG ) 
     282      { 
     283        assert(datasize == sizeof(long)); 
     284        for(int i=1; i<ep_rank_loc+1; i++) 
     285          reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     286      } 
     287 
     288      else if(datatype == MPI_UNSIGNED_LONG ) 
     289      { 
     290        assert(datasize == sizeof(unsigned long)); 
     291        for(int i=1; i<ep_rank_loc+1; i++) 
     292          reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     293      } 
     294      
     295      else printf("datatype Error\n"); 
     296    } 
     297 
     298    else //if(op == MPI_MIN) 
     299    { 
     300      if(datatype == MPI_INT ) 
     301      { 
     302        assert(datasize == sizeof(int)); 
     303        for(int i=1; i<ep_rank_loc+1; i++) 
     304          reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
     305      } 
     306 
     307      else if(datatype == MPI_FLOAT ) 
     308      { 
     309        assert(datasize == sizeof(float)); 
     310        for(int i=1; i<ep_rank_loc+1; i++) 
     311          reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
     312      } 
     313 
     314      else if(datatype == MPI_DOUBLE ) 
     315      { 
     316        assert(datasize == sizeof(double)); 
     317        for(int i=1; i<ep_rank_loc+1; i++) 
     318          reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     319      } 
     320 
     321      else if(datatype == MPI_CHAR ) 
     322      { 
     323        assert(datasize == sizeof(char)); 
     324        for(int i=1; i<ep_rank_loc+1; i++) 
     325          reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     326      } 
     327 
     328      else if(datatype == MPI_LONG ) 
     329      { 
     330        assert(datasize == sizeof(long)); 
     331        for(int i=1; i<ep_rank_loc+1; i++) 
     332          reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     333      } 
     334 
     335      else if(datatype == MPI_UNSIGNED_LONG ) 
     336      { 
     337        assert(datasize == sizeof(unsigned long)); 
     338        for(int i=1; i<ep_rank_loc+1; i++) 
     339          reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
     340      } 
     341 
     342      else printf("datatype Error\n"); 
     343    } 
     344 
     345    MPI_Barrier_local(comm); 
     346 
    427347  } 
    428348 
     
    432352    if(!comm.is_ep) 
    433353    { 
    434  
    435       ::MPI_Scan(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), 
    436                  static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    437       return 0; 
    438     } 
    439  
    440     if(!comm.mpi_comm) return 0; 
    441  
    442     int ep_rank, ep_rank_loc, mpi_rank; 
    443     int ep_size, num_ep, mpi_size; 
    444  
    445  
    446     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    447     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    448     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    449     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    450     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    451     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    452  
     354      return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
     355    } 
     356     
     357    valid_type(datatype); 
     358 
     359    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     360    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     361    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     362    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     363    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     364    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    453365 
    454366    ::MPI_Aint datasize, lb; 
    455  
    456     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    457  
    458     void* local_scan_recvbuf; 
    459     local_scan_recvbuf = new void*[datasize * count]; 
    460  
    461  
    462     // local scan 
    463     MPI_Scan_local2(sendbuf, recvbuf, count, datatype, op, comm); 
    464  
    465 //     MPI_scan 
    466     void* local_sum; 
    467     void* mpi_scan_recvbuf; 
    468  
    469  
    470     mpi_scan_recvbuf = new void*[datasize*count]; 
     367    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     368     
     369    void* tmp_sendbuf; 
     370    tmp_sendbuf = new void*[datasize * count]; 
     371 
     372    int my_src = 0; 
     373    int my_dst = ep_rank; 
     374 
     375    std::vector<int> my_map(mpi_size, 0); 
     376 
     377    for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++; 
     378 
     379    for(int i=0; i<mpi_rank; i++) my_src += my_map[i]; 
     380    my_src += ep_rank_loc; 
     381 
     382      
     383    for(int i=0; i<mpi_size; i++) 
     384    { 
     385      if(my_dst < my_map[i]) 
     386      { 
     387        my_dst = get_ep_rank(comm, my_dst, i);  
     388        break; 
     389      } 
     390      else 
     391        my_dst -= my_map[i]; 
     392    } 
     393 
     394    //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src); 
     395    MPI_Barrier(comm); 
     396 
     397    if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count); 
     398 
     399    if(ep_rank != my_dst)  
     400    { 
     401      MPI_Request request[2]; 
     402      MPI_Status status[2]; 
     403 
     404      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]); 
     405     
     406      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]); 
     407     
     408      MPI_Waitall(2, request, status); 
     409    } 
     410     
     411 
     412    void* tmp_recvbuf; 
     413    tmp_recvbuf = new void*[datasize * count];     
     414 
     415    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm); 
    471416 
    472417    if(ep_rank_loc == 0) 
    473     { 
    474       local_sum = new void*[datasize*count]; 
    475     } 
    476  
    477  
    478     MPI_Reduce_local2(sendbuf, local_sum, count, datatype, op, comm); 
    479  
    480     if(ep_rank_loc == 0) 
    481     { 
    482       ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    483     } 
    484  
    485  
    486     if(mpi_rank > 0) 
    487     { 
    488       MPI_Bcast_local2(mpi_scan_recvbuf, count, datatype, comm); 
    489     } 
    490  
    491  
    492     if(datatype == MPI_DOUBLE) 
    493     { 
    494       double* sum_buf = static_cast<double*>(mpi_scan_recvbuf); 
    495       double* recv_buf = static_cast<double*>(recvbuf); 
    496  
    497       if(mpi_rank != 0) 
    498       { 
    499         if(op == MPI_SUM) 
    500         { 
    501           for(int i=0; i<count; i++) 
    502           { 
    503             recv_buf[i] += sum_buf[i]; 
    504           } 
    505         } 
    506         else if (op == MPI_MAX) 
    507         { 
    508           for(int i=0; i<count; i++) 
    509           { 
    510             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    511           } 
    512         } 
    513         else if(op == MPI_MIN) 
    514         { 
    515           for(int i=0; i<count; i++) 
    516           { 
    517             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    518           } 
    519         } 
    520         else 
    521         { 
    522           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    523           exit(1); 
    524         } 
    525       } 
    526  
    527       delete[] static_cast<double*>(mpi_scan_recvbuf); 
    528       if(ep_rank_loc == 0) 
    529       { 
    530         delete[] static_cast<double*>(local_sum); 
    531       } 
    532     } 
    533  
    534     else if(datatype == MPI_FLOAT) 
    535     { 
    536       float* sum_buf = static_cast<float*>(mpi_scan_recvbuf); 
    537       float* recv_buf = static_cast<float*>(recvbuf); 
    538  
    539       if(mpi_rank != 0) 
    540       { 
    541         if(op == MPI_SUM) 
    542         { 
    543           for(int i=0; i<count; i++) 
    544           { 
    545             recv_buf[i] += sum_buf[i]; 
    546           } 
    547         } 
    548         else if (op == MPI_MAX) 
    549         { 
    550           for(int i=0; i<count; i++) 
    551           { 
    552             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    553           } 
    554         } 
    555         else if(op == MPI_MIN) 
    556         { 
    557           for(int i=0; i<count; i++) 
    558           { 
    559             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    560           } 
    561         } 
    562         else 
    563         { 
    564           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    565           exit(1); 
    566         } 
    567       } 
    568  
    569       delete[] static_cast<float*>(mpi_scan_recvbuf); 
    570       if(ep_rank_loc == 0) 
    571       { 
    572         delete[] static_cast<float*>(local_sum); 
    573       } 
    574     } 
    575  
    576     else if(datatype == MPI_INT) 
    577     { 
    578       int* sum_buf = static_cast<int*>(mpi_scan_recvbuf); 
    579       int* recv_buf = static_cast<int*>(recvbuf); 
    580  
    581       if(mpi_rank != 0) 
    582       { 
    583         if(op == MPI_SUM) 
    584         { 
    585           for(int i=0; i<count; i++) 
    586           { 
    587             recv_buf[i] += sum_buf[i]; 
    588           } 
    589         } 
    590         else if (op == MPI_MAX) 
    591         { 
    592           for(int i=0; i<count; i++) 
    593           { 
    594             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    595           } 
    596         } 
    597         else if(op == MPI_MIN) 
    598         { 
    599           for(int i=0; i<count; i++) 
    600           { 
    601             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    602           } 
    603         } 
    604         else 
    605         { 
    606           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    607           exit(1); 
    608         } 
    609       } 
    610  
    611       delete[] static_cast<int*>(mpi_scan_recvbuf); 
    612       if(ep_rank_loc == 0) 
    613       { 
    614         delete[] static_cast<int*>(local_sum); 
    615       } 
    616     } 
    617  
    618     else if(datatype == MPI_LONG) 
    619     { 
    620       long* sum_buf = static_cast<long*>(mpi_scan_recvbuf); 
    621       long* recv_buf = static_cast<long*>(recvbuf); 
    622  
    623       if(mpi_rank != 0) 
    624       { 
    625         if(op == MPI_SUM) 
    626         { 
    627           for(int i=0; i<count; i++) 
    628           { 
    629             recv_buf[i] += sum_buf[i]; 
    630           } 
    631         } 
    632         else if (op == MPI_MAX) 
    633         { 
    634           for(int i=0; i<count; i++) 
    635           { 
    636             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    637           } 
    638         } 
    639         else if(op == MPI_MIN) 
    640         { 
    641           for(int i=0; i<count; i++) 
    642           { 
    643             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    644           } 
    645         } 
    646         else 
    647         { 
    648           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    649           exit(1); 
    650         } 
    651       } 
    652  
    653       delete[] static_cast<long*>(mpi_scan_recvbuf); 
    654       if(ep_rank_loc == 0) 
    655       { 
    656         delete[] static_cast<long*>(local_sum); 
    657       } 
    658     } 
    659  
    660     else if(datatype == MPI_UNSIGNED_LONG) 
    661     { 
    662       unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf); 
    663       unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf); 
    664  
    665       if(mpi_rank != 0) 
    666       { 
    667         if(op == MPI_SUM) 
    668         { 
    669           for(int i=0; i<count; i++) 
    670           { 
    671             recv_buf[i] += sum_buf[i]; 
    672           } 
    673         } 
    674         else if (op == MPI_MAX) 
    675         { 
    676           for(int i=0; i<count; i++) 
    677           { 
    678             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    679           } 
    680         } 
    681         else if(op == MPI_MIN) 
    682         { 
    683           for(int i=0; i<count; i++) 
    684           { 
    685             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    686           } 
    687         } 
    688         else 
    689         { 
    690           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    691           exit(1); 
    692         } 
    693       } 
    694  
    695       delete[] static_cast<unsigned long*>(mpi_scan_recvbuf); 
    696       if(ep_rank_loc == 0) 
    697       { 
    698         delete[] static_cast<unsigned long*>(local_sum); 
    699       } 
    700     } 
    701  
    702     else if(datatype == MPI_CHAR) 
    703     { 
    704       char* sum_buf = static_cast<char*>(mpi_scan_recvbuf); 
    705       char* recv_buf = static_cast<char*>(recvbuf); 
    706  
    707       if(mpi_rank != 0) 
    708       { 
    709         if(op == MPI_SUM) 
    710         { 
    711           for(int i=0; i<count; i++) 
    712           { 
    713             recv_buf[i] += sum_buf[i]; 
    714           } 
    715         } 
    716         else if (op == MPI_MAX) 
    717         { 
    718           for(int i=0; i<count; i++) 
    719           { 
    720             recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
    721           } 
    722         } 
    723         else if(op == MPI_MIN) 
    724         { 
    725           for(int i=0; i<count; i++) 
    726           { 
    727             recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
    728           } 
    729         } 
    730         else 
    731         { 
    732           printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
    733           exit(1); 
    734         } 
    735       } 
    736  
    737       delete[] static_cast<char*>(mpi_scan_recvbuf); 
    738       if(ep_rank_loc == 0) 
    739       { 
    740         delete[] static_cast<char*>(local_sum); 
    741       } 
    742     } 
    743  
     418      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
     419 
     420    //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
     421     
     422    MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm); 
     423 
     424    // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
     425 
     426 
     427 
     428    if(ep_rank != my_src)  
     429    { 
     430      MPI_Request request[2]; 
     431      MPI_Status status[2]; 
     432 
     433      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]); 
     434     
     435      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]); 
     436     
     437      MPI_Waitall(2, request, status); 
     438    } 
     439 
     440    else memcpy(recvbuf, tmp_recvbuf, datasize*count); 
     441     
     442 
     443    delete[] tmp_sendbuf; 
     444    delete[] tmp_recvbuf; 
    744445 
    745446  } 
Note: See TracChangeset for help on using the changeset viewer.