Ignore:
Timestamp:
10/06/17 13:56:33 (7 years ago)
Author:
yushan
Message:

EP update all

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce.cpp

    r1289 r1295  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
     11#include "ep_mpi.hpp" 
    1112 
    1213using namespace std; 
     
    2728  } 
    2829 
    29  
    30   int MPI_Reduce_local2(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    31   { 
    32     if(datatype == MPI_INT) 
    33     { 
    34       Debug("datatype is INT\n"); 
    35       return MPI_Reduce_local_int(sendbuf, recvbuf, count, op, comm); 
    36     } 
    37     else if(datatype == MPI_FLOAT) 
    38     { 
    39       Debug("datatype is FLOAT\n"); 
    40       return MPI_Reduce_local_float(sendbuf, recvbuf, count, op, comm); 
    41     } 
    42     else if(datatype == MPI_DOUBLE) 
    43     { 
    44       Debug("datatype is DOUBLE\n"); 
    45       return MPI_Reduce_local_double(sendbuf, recvbuf, count, op, comm); 
    46     } 
    47     else if(datatype == MPI_LONG) 
    48     { 
    49       Debug("datatype is DOUBLE\n"); 
    50       return MPI_Reduce_local_long(sendbuf, recvbuf, count, op, comm); 
    51     } 
    52     else if(datatype == MPI_UNSIGNED_LONG) 
    53     { 
    54       Debug("datatype is DOUBLE\n"); 
    55       return MPI_Reduce_local_ulong(sendbuf, recvbuf, count, op, comm); 
    56     } 
    57     else if(datatype == MPI_CHAR) 
    58     { 
    59       Debug("datatype is DOUBLE\n"); 
    60       return MPI_Reduce_local_char(sendbuf, recvbuf, count, op, comm); 
    61     } 
    62     else 
    63     { 
    64       printf("MPI_Reduce Datatype not supported!\n"); 
    65       exit(0); 
    66     } 
    67   } 
    68  
    69  
    70   int MPI_Reduce_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    71   { 
    72     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    73     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    74  
    75     int *buffer = comm.my_buffer->buf_int; 
    76     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    77     int *recv_buf = static_cast<int*>(const_cast<void*>(recvbuf)); 
    78  
    79     for(int j=0; j<count; j+=BUFFER_SIZE) 
    80     { 
    81       if( 0 == my_rank ) 
     30  template<typename T> 
     31  void reduce_max(const T * buffer, T* recvbuf, int count) 
     32  { 
     33    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
     34  } 
     35 
     36  template<typename T> 
     37  void reduce_min(const T * buffer, T* recvbuf, int count) 
     38  { 
     39    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
     40  } 
     41 
     42  template<typename T> 
     43  void reduce_sum(const T * buffer, T* recvbuf, int count) 
     44  { 
     45    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
     46  } 
     47 
     48  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int local_root, MPI_Comm comm) 
     49  { 
     50    assert(valid_type(datatype)); 
     51    assert(valid_op(op)); 
     52 
     53    ::MPI_Aint datasize, lb; 
     54    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     55 
     56    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     57    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     58    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     59 
     60    #pragma omp critical (_reduce) 
     61    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf); 
     62 
     63    MPI_Barrier_local(comm); 
     64 
     65    if(ep_rank_loc == local_root) 
     66    { 
     67 
     68      memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize * count); 
     69 
     70      if(op == MPI_MAX) 
    8271      { 
    83         #pragma omp critical (write_to_buffer) 
    84         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    85         #pragma omp flush 
     72        if(datatype == MPI_INT) 
     73        { 
     74          assert(datasize == sizeof(int)); 
     75          for(int i=1; i<num_ep; i++) 
     76            reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count); 
     77        } 
     78 
     79        else if(datatype == MPI_FLOAT) 
     80        { 
     81          assert(datasize == sizeof(float)); 
     82          for(int i=1; i<num_ep; i++) 
     83            reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count); 
     84        } 
     85 
     86        else if(datatype == MPI_DOUBLE) 
     87        { 
     88          assert(datasize == sizeof(double)); 
     89          for(int i=1; i<num_ep; i++) 
     90            reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     91        } 
     92 
     93        else if(datatype == MPI_CHAR) 
     94        { 
     95          assert(datasize == sizeof(char)); 
     96          for(int i=1; i<num_ep; i++) 
     97            reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     98        } 
     99 
     100        else if(datatype == MPI_LONG) 
     101        { 
     102          assert(datasize == sizeof(long)); 
     103          for(int i=1; i<num_ep; i++) 
     104            reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     105        } 
     106 
     107        else if(datatype == MPI_UNSIGNED_LONG) 
     108        { 
     109          assert(datasize == sizeof(unsigned long)); 
     110          for(int i=1; i<num_ep; i++) 
     111            reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count); 
     112        } 
     113 
     114        else printf("datatype Error\n"); 
     115 
    86116      } 
    87117 
    88       MPI_Barrier_local(comm); 
    89  
    90       if(my_rank !=0 ) 
     118      if(op == MPI_MIN) 
    91119      { 
    92         #pragma omp critical (write_to_buffer) 
    93         { 
    94           #pragma omp flush 
    95           if(op == MPI_SUM) 
    96           { 
    97             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
    98           } 
    99  
    100           else if (op == MPI_MAX) 
    101           { 
    102             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
    103           } 
    104  
    105           else if (op == MPI_MIN) 
    106           { 
    107             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
    108           } 
    109  
    110           else 
    111           { 
    112             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    113             exit(1); 
    114           } 
    115           #pragma omp flush 
    116         } 
     120        if(datatype ==MPI_INT) 
     121        { 
     122          assert(datasize == sizeof(int)); 
     123          for(int i=1; i<num_ep; i++) 
     124            reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count); 
     125        } 
     126 
     127        else if(datatype == MPI_FLOAT) 
     128        { 
     129          assert(datasize == sizeof(float)); 
     130          for(int i=1; i<num_ep; i++) 
     131            reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count); 
     132        } 
     133 
     134        else if(datatype == MPI_DOUBLE) 
     135        { 
     136          assert(datasize == sizeof(double)); 
     137          for(int i=1; i<num_ep; i++) 
     138            reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     139        } 
     140 
     141        else if(datatype == MPI_CHAR) 
     142        { 
     143          assert(datasize == sizeof(char)); 
     144          for(int i=1; i<num_ep; i++) 
     145            reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     146        } 
     147 
     148        else if(datatype == MPI_LONG) 
     149        { 
     150          assert(datasize == sizeof(long)); 
     151          for(int i=1; i<num_ep; i++) 
     152            reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     153        } 
     154 
     155        else if(datatype == MPI_UNSIGNED_LONG) 
     156        { 
     157          assert(datasize == sizeof(unsigned long)); 
     158          for(int i=1; i<num_ep; i++) 
     159            reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count); 
     160        } 
     161 
     162        else printf("datatype Error\n"); 
     163 
    117164      } 
    118165 
    119       MPI_Barrier_local(comm); 
    120  
    121       if(my_rank == 0) 
     166 
     167      if(op == MPI_SUM) 
    122168      { 
    123         #pragma omp flush 
    124         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     169        if(datatype==MPI_INT) 
     170        { 
     171          assert(datasize == sizeof(int)); 
     172          for(int i=1; i<num_ep; i++) 
     173            reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count); 
     174        } 
     175 
     176        else if(datatype == MPI_FLOAT) 
     177        { 
     178          assert(datasize == sizeof(float)); 
     179          for(int i=1; i<num_ep; i++) 
     180            reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count); 
     181        } 
     182 
     183        else if(datatype == MPI_DOUBLE) 
     184        { 
     185          assert(datasize == sizeof(double)); 
     186          for(int i=1; i<num_ep; i++) 
     187            reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     188        } 
     189 
     190        else if(datatype == MPI_CHAR) 
     191        { 
     192          assert(datasize == sizeof(char)); 
     193          for(int i=1; i<num_ep; i++) 
     194            reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     195        } 
     196 
     197        else if(datatype == MPI_LONG) 
     198        { 
     199          assert(datasize == sizeof(long)); 
     200          for(int i=1; i<num_ep; i++) 
     201            reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     202        } 
     203 
     204        else if(datatype ==MPI_UNSIGNED_LONG) 
     205        { 
     206          assert(datasize == sizeof(unsigned long)); 
     207          for(int i=1; i<num_ep; i++) 
     208            reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count); 
     209        } 
     210 
     211        else printf("datatype Error\n"); 
     212 
    125213      } 
    126       MPI_Barrier_local(comm); 
    127     } 
    128   } 
    129  
    130  
    131   int MPI_Reduce_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    132   { 
    133     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    134     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    135  
    136     float *buffer = comm.my_buffer->buf_float; 
    137     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    138     float *recv_buf = static_cast<float*>(const_cast<void*>(recvbuf)); 
    139  
    140     for(int j=0; j<count; j+=BUFFER_SIZE) 
    141     { 
    142       if( 0 == my_rank ) 
    143       { 
    144         #pragma omp critical (write_to_buffer) 
    145         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    146         #pragma omp flush 
    147       } 
    148  
    149       MPI_Barrier_local(comm); 
    150  
    151       if(my_rank !=0 ) 
    152       { 
    153         #pragma omp critical (write_to_buffer) 
    154         { 
    155           #pragma omp flush 
    156  
    157           if(op == MPI_SUM) 
    158           { 
    159             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
    160           } 
    161  
    162           else if (op == MPI_MAX) 
    163           { 
    164             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
    165           } 
    166  
    167           else if (op == MPI_MIN) 
    168           { 
    169             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
    170           } 
    171  
    172           else 
    173           { 
    174             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    175             exit(1); 
    176           } 
    177           #pragma omp flush 
    178         } 
    179       } 
    180  
    181       MPI_Barrier_local(comm); 
    182  
    183       if(my_rank == 0) 
    184       { 
    185         #pragma omp flush 
    186         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    187       } 
    188       MPI_Barrier_local(comm); 
    189     } 
    190   } 
    191  
    192   int MPI_Reduce_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    193   { 
    194     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    195     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    196  
    197     double *buffer = comm.my_buffer->buf_double; 
    198     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    199     double *recv_buf = static_cast<double*>(const_cast<void*>(recvbuf)); 
    200  
    201     for(int j=0; j<count; j+=BUFFER_SIZE) 
    202     { 
    203       if( 0 == my_rank ) 
    204       { 
    205         #pragma omp critical (write_to_buffer) 
    206         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    207         #pragma omp flush 
    208       } 
    209  
    210       MPI_Barrier_local(comm); 
    211  
    212       if(my_rank !=0 ) 
    213       { 
    214         #pragma omp critical (write_to_buffer) 
    215         { 
    216           #pragma omp flush 
    217  
    218  
    219           if(op == MPI_SUM) 
    220           { 
    221             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
    222           } 
    223  
    224           else if (op == MPI_MAX) 
    225           { 
    226             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
    227           } 
    228  
    229  
    230           else if (op == MPI_MIN) 
    231           { 
    232             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
    233           } 
    234  
    235           else 
    236           { 
    237             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    238             exit(1); 
    239           } 
    240           #pragma omp flush 
    241         } 
    242       } 
    243  
    244       MPI_Barrier_local(comm); 
    245  
    246       if(my_rank == 0) 
    247       { 
    248         #pragma omp flush 
    249         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    250       } 
    251       MPI_Barrier_local(comm); 
    252     } 
    253   } 
    254  
    255   int MPI_Reduce_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    256   { 
    257     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    258     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    259  
    260     long *buffer = comm.my_buffer->buf_long; 
    261     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    262     long *recv_buf = static_cast<long*>(const_cast<void*>(recvbuf)); 
    263  
    264     for(int j=0; j<count; j+=BUFFER_SIZE) 
    265     { 
    266       if( 0 == my_rank ) 
    267       { 
    268         #pragma omp critical (write_to_buffer) 
    269         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    270         #pragma omp flush 
    271       } 
    272  
    273       MPI_Barrier_local(comm); 
    274  
    275       if(my_rank !=0 ) 
    276       { 
    277         #pragma omp critical (write_to_buffer) 
    278         { 
    279           #pragma omp flush 
    280  
    281  
    282           if(op == MPI_SUM) 
    283           { 
    284             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
    285           } 
    286  
    287           else if (op == MPI_MAX) 
    288           { 
    289             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
    290           } 
    291  
    292  
    293           else if (op == MPI_MIN) 
    294           { 
    295             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
    296           } 
    297  
    298           else 
    299           { 
    300             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    301             exit(1); 
    302           } 
    303           #pragma omp flush 
    304         } 
    305       } 
    306  
    307       MPI_Barrier_local(comm); 
    308  
    309       if(my_rank == 0) 
    310       { 
    311         #pragma omp flush 
    312         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    313       } 
    314       MPI_Barrier_local(comm); 
    315     } 
    316   } 
    317  
    318   int MPI_Reduce_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    319   { 
    320     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    321     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    322  
    323     unsigned long *buffer = comm.my_buffer->buf_ulong; 
    324     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    325     unsigned long *recv_buf = static_cast<unsigned long*>(const_cast<void*>(recvbuf)); 
    326  
    327     for(int j=0; j<count; j+=BUFFER_SIZE) 
    328     { 
    329       if( 0 == my_rank ) 
    330       { 
    331         #pragma omp critical (write_to_buffer) 
    332         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    333         #pragma omp flush 
    334       } 
    335  
    336       MPI_Barrier_local(comm); 
    337  
    338       if(my_rank !=0 ) 
    339       { 
    340         #pragma omp critical (write_to_buffer) 
    341         { 
    342           #pragma omp flush 
    343  
    344  
    345           if(op == MPI_SUM) 
    346           { 
    347             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
    348           } 
    349  
    350           else if (op == MPI_MAX) 
    351           { 
    352             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
    353           } 
    354  
    355  
    356           else if (op == MPI_MIN) 
    357           { 
    358             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
    359           } 
    360  
    361           else 
    362           { 
    363             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    364             exit(1); 
    365           } 
    366           #pragma omp flush 
    367         } 
    368       } 
    369  
    370       MPI_Barrier_local(comm); 
    371  
    372       if(my_rank == 0) 
    373       { 
    374         #pragma omp flush 
    375         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    376       } 
    377       MPI_Barrier_local(comm); 
    378     } 
    379   } 
    380  
    381   int MPI_Reduce_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    382   { 
    383     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    384     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    385  
    386     char *buffer = comm.my_buffer->buf_char; 
    387     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    388     char *recv_buf = static_cast<char*>(const_cast<void*>(recvbuf)); 
    389  
    390     for(int j=0; j<count; j+=BUFFER_SIZE) 
    391     { 
    392       if( 0 == my_rank ) 
    393       { 
    394         #pragma omp critical (write_to_buffer) 
    395         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    396         #pragma omp flush 
    397       } 
    398  
    399       MPI_Barrier_local(comm); 
    400  
    401       if(my_rank !=0 ) 
    402       { 
    403         #pragma omp critical (write_to_buffer) 
    404         { 
    405           #pragma omp flush 
    406  
    407  
    408           if(op == MPI_SUM) 
    409           { 
    410             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
    411           } 
    412  
    413           else if (op == MPI_MAX) 
    414           { 
    415             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
    416           } 
    417  
    418  
    419           else if (op == MPI_MIN) 
    420           { 
    421             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
    422           } 
    423  
    424           else 
    425           { 
    426             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    427             exit(1); 
    428           } 
    429           #pragma omp flush 
    430         } 
    431       } 
    432  
    433       MPI_Barrier_local(comm); 
    434  
    435       if(my_rank == 0) 
    436       { 
    437         #pragma omp flush 
    438         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    439       } 
    440       MPI_Barrier_local(comm); 
    441     } 
     214    } 
     215 
     216    MPI_Barrier_local(comm); 
     217 
    442218  } 
    443219 
     
    445221  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm) 
    446222  { 
     223 
    447224    if(!comm.is_ep && comm.mpi_comm) 
    448225    { 
    449       ::MPI_Reduce(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), root, 
    450                    static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    451       return 0; 
    452     } 
    453  
    454  
    455     if(!comm.mpi_comm) return 0; 
     226      return ::MPI_Reduce(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root, to_mpi_comm(comm.mpi_comm)); 
     227    } 
     228 
     229 
     230 
     231    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     232    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     233    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     234    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     235    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     236    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    456237 
    457238    int root_mpi_rank = comm.rank_map->at(root).second; 
    458239    int root_ep_loc = comm.rank_map->at(root).first; 
    459240 
    460     int ep_rank, ep_rank_loc, mpi_rank; 
    461     int ep_size, num_ep, mpi_size; 
    462  
    463     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    464     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    465     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    466     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    467     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    468     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    469  
    470  
    471     ::MPI_Aint recvsize, lb; 
    472  
    473     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &recvsize); 
    474  
    475     void *local_recvbuf; 
    476     if(ep_rank_loc==0) 
    477     { 
    478       local_recvbuf = new void*[recvsize*count]; 
    479     } 
    480  
    481     MPI_Reduce_local2(sendbuf, local_recvbuf, count, datatype, op, comm); 
    482  
    483  
    484     if(ep_rank_loc==0) 
    485     { 
    486       ::MPI_Reduce(local_recvbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    487     } 
    488  
    489     if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
    490     { 
    491       innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, count, datatype, comm); 
    492     } 
    493  
    494     if(ep_rank_loc==0) 
    495     { 
    496       if(datatype == MPI_INT) delete[] static_cast<int*>(local_recvbuf); 
    497       else if(datatype == MPI_FLOAT) delete[] static_cast<float*>(local_recvbuf); 
    498       else if(datatype == MPI_DOUBLE) delete[] static_cast<double*>(local_recvbuf); 
    499       else if(datatype == MPI_LONG) delete[] static_cast<long*>(local_recvbuf); 
    500       else if(datatype == MPI_UNSIGNED_LONG) delete[] static_cast<unsigned long*>(local_recvbuf); 
    501       else delete[] static_cast<char*>(local_recvbuf); 
    502     } 
    503  
    504     Message_Check(comm); 
    505  
    506     return 0; 
    507   } 
    508  
    509  
    510  
    511  
    512   int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    513   { 
    514     if(!comm.is_ep && comm.mpi_comm) 
    515     { 
    516       ::MPI_Allreduce(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), 
    517                       static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    518       return 0; 
    519     } 
    520  
    521     if(!comm.mpi_comm) return 0; 
    522  
    523  
    524     int ep_rank, ep_rank_loc, mpi_rank; 
    525     int ep_size, num_ep, mpi_size; 
    526  
    527     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    528     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    529     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    530     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    531     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    532     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    533  
    534  
    535     ::MPI_Aint recvsize, lb; 
    536  
    537     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &recvsize); 
    538  
    539     void *local_recvbuf; 
    540     if(ep_rank_loc==0) 
    541     { 
    542       local_recvbuf = new void*[recvsize*count]; 
    543     } 
    544  
    545     MPI_Reduce_local2(sendbuf, local_recvbuf, count, datatype, op, comm); 
    546  
    547  
    548     if(ep_rank_loc==0) 
    549     { 
    550       ::MPI_Allreduce(local_recvbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    551     } 
    552  
    553     MPI_Bcast_local2(recvbuf, count, datatype, comm); 
    554  
    555     if(ep_rank_loc==0) 
    556     { 
    557       if(datatype == MPI_INT) delete[] static_cast<int*>(local_recvbuf); 
    558       else if(datatype == MPI_FLOAT) delete[] static_cast<float*>(local_recvbuf); 
    559       else if(datatype == MPI_DOUBLE) delete[] static_cast<double*>(local_recvbuf); 
    560       else if(datatype == MPI_LONG) delete[] static_cast<long*>(local_recvbuf); 
    561       else if(datatype == MPI_UNSIGNED_LONG) delete[] static_cast<unsigned long*>(local_recvbuf); 
    562       else delete[] static_cast<char*>(local_recvbuf); 
    563     } 
    564  
    565     Message_Check(comm); 
    566  
    567     return 0; 
    568   } 
    569  
    570  
    571   int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    572   { 
    573  
    574     if(!comm.is_ep && comm.mpi_comm) 
    575     { 
    576       ::MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), 
    577                            static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    578       return 0; 
    579     } 
    580  
    581     if(!comm.mpi_comm) return 0; 
    582  
    583     int ep_rank, ep_rank_loc, mpi_rank; 
    584     int ep_size, num_ep, mpi_size; 
    585  
    586     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    587     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    588     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    589     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    590     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    591     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    592  
    593     void* local_buf; 
    594     void* local_buf2; 
    595     int local_buf_size = accumulate(recvcounts, recvcounts+ep_size, 0); 
    596     int local_buf2_size = accumulate(recvcounts+ep_rank-ep_rank_loc, recvcounts+ep_rank-ep_rank_loc+num_ep, 0); 
    597  
    598241    ::MPI_Aint datasize, lb; 
    599242 
    600243    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    601244 
    602     if(ep_rank_loc == 0) 
    603     { 
    604       local_buf = new void*[local_buf_size*datasize]; 
    605       local_buf2 = new void*[local_buf2_size*datasize]; 
    606     } 
    607     MPI_Reduce_local2(sendbuf, local_buf, local_buf_size, MPI_INT, op, comm); 
    608  
    609  
    610     if(ep_rank_loc == 0) 
    611     { 
    612       int local_recvcnt[mpi_size]; 
    613       for(int i=0; i<mpi_size; i++) 
    614       { 
    615         local_recvcnt[i] = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 
    616       } 
    617  
    618       ::MPI_Reduce_scatter(local_buf, local_buf2, local_recvcnt, static_cast< ::MPI_Datatype>(datatype), 
    619                          static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    620     } 
    621  
    622  
    623     int displs[num_ep]; 
    624     displs[0] = 0; 
    625     for(int i=1; i<num_ep; i++) 
    626     { 
    627       displs[i] = displs[i-1] + recvcounts[ep_rank-ep_rank_loc+i-1]; 
    628     } 
    629  
    630     MPI_Scatterv_local2(local_buf2, recvcounts+ep_rank-ep_rank_loc, displs, datatype, recvbuf, comm); 
    631  
    632     if(ep_rank_loc == 0) 
    633     { 
    634       if(datatype == MPI_INT) 
    635       { 
    636         delete[] static_cast<int*>(local_buf); 
    637         delete[] static_cast<int*>(local_buf2); 
    638       } 
    639       else if(datatype == MPI_FLOAT) 
    640       { 
    641         delete[] static_cast<float*>(local_buf); 
    642         delete[] static_cast<float*>(local_buf2); 
    643       } 
    644       else if(datatype == MPI_DOUBLE) 
    645       { 
    646         delete[] static_cast<double*>(local_buf); 
    647         delete[] static_cast<double*>(local_buf2); 
    648       } 
    649       else if(datatype == MPI_LONG) 
    650       { 
    651         delete[] static_cast<long*>(local_buf); 
    652         delete[] static_cast<long*>(local_buf2); 
    653       } 
    654       else if(datatype == MPI_UNSIGNED_LONG) 
    655       { 
    656         delete[] static_cast<unsigned long*>(local_buf); 
    657         delete[] static_cast<unsigned long*>(local_buf2); 
    658       } 
    659       else // if(datatype == MPI_DOUBLE) 
    660       { 
    661         delete[] static_cast<char*>(local_buf); 
    662         delete[] static_cast<char*>(local_buf2); 
    663       } 
    664     } 
    665  
    666     Message_Check(comm); 
    667     return 0; 
    668   } 
     245    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 
     246    bool is_root = ep_rank == root; 
     247 
     248    void* local_recvbuf; 
     249 
     250    if(is_master) 
     251    { 
     252      local_recvbuf = new void*[datasize * count]; 
     253    } 
     254 
     255    if(mpi_rank == root_mpi_rank) MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, root_ep_loc, comm); 
     256    else                          MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm); 
     257 
     258 
     259 
     260    if(is_master) 
     261    { 
     262      ::MPI_Reduce(local_recvbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
     263       
     264    } 
     265 
     266    if(is_master) 
     267    { 
     268      delete[] local_recvbuf; 
     269    } 
     270 
     271    MPI_Barrier_local(comm); 
     272  } 
     273 
     274 
    669275} 
    670276 
Note: See TracChangeset for help on using the changeset viewer.