Ignore:
Timestamp:
10/04/17 11:45:14 (7 years ago)
Author:
yushan
Message:

EP updated

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce.cpp

    r1134 r1287  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
     11#include "ep_mpi.hpp" 
    1112 
    1213using namespace std; 
     
    2728  } 
    2829 
    29  
    30   int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    31   { 
    32     if(datatype == MPI_INT) 
    33     { 
    34       Debug("datatype is INT\n"); 
    35       return MPI_Reduce_local_int(sendbuf, recvbuf, count, op, comm); 
    36     } 
    37     else if(datatype == MPI_FLOAT) 
    38     { 
    39       Debug("datatype is FLOAT\n"); 
    40       return MPI_Reduce_local_float(sendbuf, recvbuf, count, op, comm); 
    41     } 
    42     else if(datatype == MPI_DOUBLE) 
    43     { 
    44       Debug("datatype is DOUBLE\n"); 
    45       return MPI_Reduce_local_double(sendbuf, recvbuf, count, op, comm); 
    46     } 
    47     else if(datatype == MPI_LONG) 
    48     { 
    49       Debug("datatype is DOUBLE\n"); 
    50       return MPI_Reduce_local_long(sendbuf, recvbuf, count, op, comm); 
    51     } 
    52     else if(datatype == MPI_UNSIGNED_LONG) 
    53     { 
    54       Debug("datatype is DOUBLE\n"); 
    55       return MPI_Reduce_local_ulong(sendbuf, recvbuf, count, op, comm); 
    56     } 
    57     else if(datatype == MPI_CHAR) 
    58     { 
    59       Debug("datatype is DOUBLE\n"); 
    60       return MPI_Reduce_local_char(sendbuf, recvbuf, count, op, comm); 
    61     } 
    62     else 
    63     { 
    64       printf("MPI_Reduce Datatype not supported!\n"); 
    65       exit(0); 
    66     } 
    67   } 
    68  
    69  
    70   int MPI_Reduce_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    71   { 
    72     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    73     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    74  
    75     int *buffer = comm.my_buffer->buf_int; 
    76     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    77     int *recv_buf = static_cast<int*>(const_cast<void*>(recvbuf)); 
    78  
    79     for(int j=0; j<count; j+=BUFFER_SIZE) 
    80     { 
    81       if( 0 == my_rank ) 
     30  template<typename T> 
     31  void reduce_max(const T * buffer, T* recvbuf, int count) 
     32  { 
     33    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
     34  } 
     35 
     36  template<typename T> 
     37  void reduce_min(const T * buffer, T* recvbuf, int count) 
     38  { 
     39    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
     40  } 
     41 
     42  template<typename T> 
     43  void reduce_sum(const T * buffer, T* recvbuf, int count) 
     44  { 
     45    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
     46  } 
     47 
     48  int MPI_Reduce_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int local_root, MPI_Comm comm) 
     49  { 
     50    assert(valid_type(datatype)); 
     51    assert(valid_op(op)); 
     52 
     53    ::MPI_Aint datasize, lb; 
     54    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     55 
     56    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     57    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     58 
     59    #pragma omp critical (_reduce) 
     60    comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf); 
     61 
     62    MPI_Barrier_local(comm); 
     63 
     64    if(ep_rank_loc == local_root) 
     65    { 
     66      memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize * count); 
     67 
     68      if(op == MPI_MAX) 
    8269      { 
    83         #pragma omp critical (write_to_buffer) 
    84         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    85         #pragma omp flush 
     70        if(datasize == sizeof(int)) 
     71        { 
     72          for(int i=1; i<num_ep; i++) 
     73            reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count); 
     74        } 
     75 
     76        else if(datasize == sizeof(float)) 
     77        { 
     78          for(int i=1; i<num_ep; i++) 
     79            reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count); 
     80        } 
     81 
     82        else if(datasize == sizeof(double)) 
     83        { 
     84          for(int i=1; i<num_ep; i++) 
     85            reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     86        } 
     87 
     88        else if(datasize == sizeof(char)) 
     89        { 
     90          for(int i=1; i<num_ep; i++) 
     91            reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     92        } 
     93 
     94        else if(datasize == sizeof(long)) 
     95        { 
     96          for(int i=1; i<num_ep; i++) 
     97            reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     98        } 
     99 
     100        else if(datasize == sizeof(unsigned long)) 
     101        { 
     102          for(int i=1; i<num_ep; i++) 
     103            reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count); 
     104        } 
     105 
     106        else printf("datatype Error\n"); 
     107 
    86108      } 
    87109 
    88       MPI_Barrier_local(comm); 
    89  
    90       if(my_rank !=0 ) 
     110      if(op == MPI_MIN) 
    91111      { 
    92         #pragma omp critical (write_to_buffer) 
    93         { 
    94           #pragma omp flush 
    95           if(op == MPI_SUM) 
    96           { 
    97             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
    98           } 
    99  
    100           else if (op == MPI_MAX) 
    101           { 
    102             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
    103           } 
    104  
    105           else if (op == MPI_MIN) 
    106           { 
    107             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
    108           } 
    109  
    110           else 
    111           { 
    112             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    113             exit(1); 
    114           } 
    115           #pragma omp flush 
    116         } 
     112        if(datasize == sizeof(int)) 
     113        { 
     114          for(int i=1; i<num_ep; i++) 
     115            reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count); 
     116        } 
     117 
     118        else if(datasize == sizeof(float)) 
     119        { 
     120          for(int i=1; i<num_ep; i++) 
     121            reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count); 
     122        } 
     123 
     124        else if(datasize == sizeof(double)) 
     125        { 
     126          for(int i=1; i<num_ep; i++) 
     127            reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     128        } 
     129 
     130        else if(datasize == sizeof(char)) 
     131        { 
     132          for(int i=1; i<num_ep; i++) 
     133            reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     134        } 
     135 
     136        else if(datasize == sizeof(long)) 
     137        { 
     138          for(int i=1; i<num_ep; i++) 
     139            reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     140        } 
     141 
     142        else if(datasize == sizeof(unsigned long)) 
     143        { 
     144          for(int i=1; i<num_ep; i++) 
     145            reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count); 
     146        } 
     147 
     148        else printf("datatype Error\n"); 
     149 
    117150      } 
    118151 
    119       MPI_Barrier_local(comm); 
    120  
    121       if(my_rank == 0) 
     152 
     153      if(op == MPI_SUM) 
    122154      { 
    123         #pragma omp flush 
    124         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     155        if(datasize == sizeof(int)) 
     156        { 
     157          for(int i=1; i<num_ep; i++) 
     158            reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count); 
     159        } 
     160 
     161        else if(datasize == sizeof(float)) 
     162        { 
     163          for(int i=1; i<num_ep; i++) 
     164            reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count); 
     165        } 
     166 
     167        else if(datasize == sizeof(double)) 
     168        { 
     169          for(int i=1; i<num_ep; i++) 
     170            reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
     171        } 
     172 
     173        else if(datasize == sizeof(char)) 
     174        { 
     175          for(int i=1; i<num_ep; i++) 
     176            reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
     177        } 
     178 
     179        else if(datasize == sizeof(long)) 
     180        { 
     181          for(int i=1; i<num_ep; i++) 
     182            reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
     183        } 
     184 
     185        else if(datasize == sizeof(unsigned long)) 
     186        { 
     187          for(int i=1; i<num_ep; i++) 
     188            reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count); 
     189        } 
     190 
     191        else printf("datatype Error\n"); 
     192 
    125193      } 
    126       MPI_Barrier_local(comm); 
    127     } 
    128   } 
    129  
    130  
    131   int MPI_Reduce_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    132   { 
    133     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    134     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    135  
    136     float *buffer = comm.my_buffer->buf_float; 
    137     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    138     float *recv_buf = static_cast<float*>(const_cast<void*>(recvbuf)); 
    139  
    140     for(int j=0; j<count; j+=BUFFER_SIZE) 
    141     { 
    142       if( 0 == my_rank ) 
    143       { 
    144         #pragma omp critical (write_to_buffer) 
    145         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    146         #pragma omp flush 
    147       } 
    148  
    149       MPI_Barrier_local(comm); 
    150  
    151       if(my_rank !=0 ) 
    152       { 
    153         #pragma omp critical (write_to_buffer) 
    154         { 
    155           #pragma omp flush 
    156  
    157           if(op == MPI_SUM) 
    158           { 
    159             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
    160           } 
    161  
    162           else if (op == MPI_MAX) 
    163           { 
    164             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
    165           } 
    166  
    167           else if (op == MPI_MIN) 
    168           { 
    169             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
    170           } 
    171  
    172           else 
    173           { 
    174             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    175             exit(1); 
    176           } 
    177           #pragma omp flush 
    178         } 
    179       } 
    180  
    181       MPI_Barrier_local(comm); 
    182  
    183       if(my_rank == 0) 
    184       { 
    185         #pragma omp flush 
    186         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    187       } 
    188       MPI_Barrier_local(comm); 
    189     } 
    190   } 
    191  
    192   int MPI_Reduce_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    193   { 
    194     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    195     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    196  
    197     double *buffer = comm.my_buffer->buf_double; 
    198     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    199     double *recv_buf = static_cast<double*>(const_cast<void*>(recvbuf)); 
    200  
    201     for(int j=0; j<count; j+=BUFFER_SIZE) 
    202     { 
    203       if( 0 == my_rank ) 
    204       { 
    205         #pragma omp critical (write_to_buffer) 
    206         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    207         #pragma omp flush 
    208       } 
    209  
    210       MPI_Barrier_local(comm); 
    211  
    212       if(my_rank !=0 ) 
    213       { 
    214         #pragma omp critical (write_to_buffer) 
    215         { 
    216           #pragma omp flush 
    217  
    218  
    219           if(op == MPI_SUM) 
    220           { 
    221             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
    222           } 
    223  
    224           else if (op == MPI_MAX) 
    225           { 
    226             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
    227           } 
    228  
    229  
    230           else if (op == MPI_MIN) 
    231           { 
    232             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
    233           } 
    234  
    235           else 
    236           { 
    237             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    238             exit(1); 
    239           } 
    240           #pragma omp flush 
    241         } 
    242       } 
    243  
    244       MPI_Barrier_local(comm); 
    245  
    246       if(my_rank == 0) 
    247       { 
    248         #pragma omp flush 
    249         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    250       } 
    251       MPI_Barrier_local(comm); 
    252     } 
    253   } 
    254  
    255   int MPI_Reduce_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    256   { 
    257     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    258     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    259  
    260     long *buffer = comm.my_buffer->buf_long; 
    261     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    262     long *recv_buf = static_cast<long*>(const_cast<void*>(recvbuf)); 
    263  
    264     for(int j=0; j<count; j+=BUFFER_SIZE) 
    265     { 
    266       if( 0 == my_rank ) 
    267       { 
    268         #pragma omp critical (write_to_buffer) 
    269         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    270         #pragma omp flush 
    271       } 
    272  
    273       MPI_Barrier_local(comm); 
    274  
    275       if(my_rank !=0 ) 
    276       { 
    277         #pragma omp critical (write_to_buffer) 
    278         { 
    279           #pragma omp flush 
    280  
    281  
    282           if(op == MPI_SUM) 
    283           { 
    284             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
    285           } 
    286  
    287           else if (op == MPI_MAX) 
    288           { 
    289             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
    290           } 
    291  
    292  
    293           else if (op == MPI_MIN) 
    294           { 
    295             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
    296           } 
    297  
    298           else 
    299           { 
    300             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    301             exit(1); 
    302           } 
    303           #pragma omp flush 
    304         } 
    305       } 
    306  
    307       MPI_Barrier_local(comm); 
    308  
    309       if(my_rank == 0) 
    310       { 
    311         #pragma omp flush 
    312         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    313       } 
    314       MPI_Barrier_local(comm); 
    315     } 
    316   } 
    317  
    318   int MPI_Reduce_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    319   { 
    320     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    321     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    322  
    323     unsigned long *buffer = comm.my_buffer->buf_ulong; 
    324     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    325     unsigned long *recv_buf = static_cast<unsigned long*>(const_cast<void*>(recvbuf)); 
    326  
    327     for(int j=0; j<count; j+=BUFFER_SIZE) 
    328     { 
    329       if( 0 == my_rank ) 
    330       { 
    331         #pragma omp critical (write_to_buffer) 
    332         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    333         #pragma omp flush 
    334       } 
    335  
    336       MPI_Barrier_local(comm); 
    337  
    338       if(my_rank !=0 ) 
    339       { 
    340         #pragma omp critical (write_to_buffer) 
    341         { 
    342           #pragma omp flush 
    343  
    344  
    345           if(op == MPI_SUM) 
    346           { 
    347             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
    348           } 
    349  
    350           else if (op == MPI_MAX) 
    351           { 
    352             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
    353           } 
    354  
    355  
    356           else if (op == MPI_MIN) 
    357           { 
    358             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
    359           } 
    360  
    361           else 
    362           { 
    363             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    364             exit(1); 
    365           } 
    366           #pragma omp flush 
    367         } 
    368       } 
    369  
    370       MPI_Barrier_local(comm); 
    371  
    372       if(my_rank == 0) 
    373       { 
    374         #pragma omp flush 
    375         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    376       } 
    377       MPI_Barrier_local(comm); 
    378     } 
    379   } 
    380  
    381   int MPI_Reduce_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
    382   { 
    383     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    384     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    385  
    386     char *buffer = comm.my_buffer->buf_char; 
    387     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    388     char *recv_buf = static_cast<char*>(const_cast<void*>(recvbuf)); 
    389  
    390     for(int j=0; j<count; j+=BUFFER_SIZE) 
    391     { 
    392       if( 0 == my_rank ) 
    393       { 
    394         #pragma omp critical (write_to_buffer) 
    395         copy(send_buf+j, send_buf+j + min(BUFFER_SIZE, count-j), buffer); 
    396         #pragma omp flush 
    397       } 
    398  
    399       MPI_Barrier_local(comm); 
    400  
    401       if(my_rank !=0 ) 
    402       { 
    403         #pragma omp critical (write_to_buffer) 
    404         { 
    405           #pragma omp flush 
    406  
    407  
    408           if(op == MPI_SUM) 
    409           { 
    410             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
    411           } 
    412  
    413           else if (op == MPI_MAX) 
    414           { 
    415             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
    416           } 
    417  
    418  
    419           else if (op == MPI_MIN) 
    420           { 
    421             transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
    422           } 
    423  
    424           else 
    425           { 
    426             printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
    427             exit(1); 
    428           } 
    429           #pragma omp flush 
    430         } 
    431       } 
    432  
    433       MPI_Barrier_local(comm); 
    434  
    435       if(my_rank == 0) 
    436       { 
    437         #pragma omp flush 
    438         copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    439       } 
    440       MPI_Barrier_local(comm); 
    441     } 
     194    } 
     195 
     196    MPI_Barrier_local(comm); 
     197 
    442198  } 
    443199 
     
    445201  int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm) 
    446202  { 
     203 
    447204    if(!comm.is_ep && comm.mpi_comm) 
    448205    { 
    449       ::MPI_Reduce(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), root, 
    450                    static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    451       return 0; 
    452     } 
    453  
    454  
    455     if(!comm.mpi_comm) return 0; 
     206      return ::MPI_Reduce(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root, to_mpi_comm(comm.mpi_comm)); 
     207    } 
     208 
     209 
     210 
     211    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     212    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     213    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     214    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     215    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     216    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    456217 
    457218    int root_mpi_rank = comm.rank_map->at(root).second; 
    458219    int root_ep_loc = comm.rank_map->at(root).first; 
    459220 
    460     int ep_rank, ep_rank_loc, mpi_rank; 
    461     int ep_size, num_ep, mpi_size; 
    462  
    463     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    464     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    465     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    466     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    467     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    468     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    469  
    470  
    471     ::MPI_Aint recvsize, lb; 
    472  
    473     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &recvsize); 
    474  
    475     void *local_recvbuf; 
    476     if(ep_rank_loc==0) 
    477     { 
    478       local_recvbuf = new void*[recvsize*count]; 
    479     } 
    480  
    481     MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, comm); 
    482  
    483  
    484     if(ep_rank_loc==0) 
    485     { 
    486       ::MPI_Reduce(local_recvbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    487     } 
    488  
    489     if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
    490     { 
    491       innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, count, datatype, comm); 
    492     } 
    493  
    494     if(ep_rank_loc==0) 
    495     { 
    496       if(datatype == MPI_INT) delete[] static_cast<int*>(local_recvbuf); 
    497       else if(datatype == MPI_FLOAT) delete[] static_cast<float*>(local_recvbuf); 
    498       else if(datatype == MPI_DOUBLE) delete[] static_cast<double*>(local_recvbuf); 
    499       else if(datatype == MPI_LONG) delete[] static_cast<long*>(local_recvbuf); 
    500       else if(datatype == MPI_UNSIGNED_LONG) delete[] static_cast<unsigned long*>(local_recvbuf); 
    501       else delete[] static_cast<char*>(local_recvbuf); 
    502     } 
    503  
    504     Message_Check(comm); 
    505  
    506     return 0; 
    507   } 
    508  
    509  
    510  
    511  
    512   int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    513   { 
    514     if(!comm.is_ep && comm.mpi_comm) 
    515     { 
    516       ::MPI_Allreduce(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), 
    517                       static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    518       return 0; 
    519     } 
    520  
    521     if(!comm.mpi_comm) return 0; 
    522  
    523  
    524     int ep_rank, ep_rank_loc, mpi_rank; 
    525     int ep_size, num_ep, mpi_size; 
    526  
    527     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    528     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    529     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    530     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    531     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    532     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    533  
    534  
    535     ::MPI_Aint recvsize, lb; 
    536  
    537     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &recvsize); 
    538  
    539     void *local_recvbuf; 
    540     if(ep_rank_loc==0) 
    541     { 
    542       local_recvbuf = new void*[recvsize*count]; 
    543     } 
    544  
    545     MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, comm); 
    546  
    547  
    548     if(ep_rank_loc==0) 
    549     { 
    550       ::MPI_Allreduce(local_recvbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    551     } 
    552  
    553     MPI_Bcast_local(recvbuf, count, datatype, comm); 
    554  
    555     if(ep_rank_loc==0) 
    556     { 
    557       if(datatype == MPI_INT) delete[] static_cast<int*>(local_recvbuf); 
    558       else if(datatype == MPI_FLOAT) delete[] static_cast<float*>(local_recvbuf); 
    559       else if(datatype == MPI_DOUBLE) delete[] static_cast<double*>(local_recvbuf); 
    560       else if(datatype == MPI_LONG) delete[] static_cast<long*>(local_recvbuf); 
    561       else if(datatype == MPI_UNSIGNED_LONG) delete[] static_cast<unsigned long*>(local_recvbuf); 
    562       else delete[] static_cast<char*>(local_recvbuf); 
    563     } 
    564  
    565     Message_Check(comm); 
    566  
    567     return 0; 
    568   } 
    569  
    570  
    571   int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    572   { 
    573  
    574     if(!comm.is_ep && comm.mpi_comm) 
    575     { 
    576       ::MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), 
    577                            static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    578       return 0; 
    579     } 
    580  
    581     if(!comm.mpi_comm) return 0; 
    582  
    583     int ep_rank, ep_rank_loc, mpi_rank; 
    584     int ep_size, num_ep, mpi_size; 
    585  
    586     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    587     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    588     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    589     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    590     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    591     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    592  
    593     void* local_buf; 
    594     void* local_buf2; 
    595     int local_buf_size = accumulate(recvcounts, recvcounts+ep_size, 0); 
    596     int local_buf2_size = accumulate(recvcounts+ep_rank-ep_rank_loc, recvcounts+ep_rank-ep_rank_loc+num_ep, 0); 
    597  
    598221    ::MPI_Aint datasize, lb; 
    599222 
    600223    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    601224 
    602     if(ep_rank_loc == 0) 
    603     { 
    604       local_buf = new void*[local_buf_size*datasize]; 
    605       local_buf2 = new void*[local_buf2_size*datasize]; 
    606     } 
    607     MPI_Reduce_local(sendbuf, local_buf, local_buf_size, MPI_INT, op, comm); 
    608  
    609  
    610     if(ep_rank_loc == 0) 
    611     { 
    612       int local_recvcnt[mpi_size]; 
    613       for(int i=0; i<mpi_size; i++) 
    614       { 
    615         local_recvcnt[i] = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 
    616       } 
    617  
    618       ::MPI_Reduce_scatter(local_buf, local_buf2, local_recvcnt, static_cast< ::MPI_Datatype>(datatype), 
    619                          static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    620     } 
    621  
    622  
    623     int displs[num_ep]; 
    624     displs[0] = 0; 
    625     for(int i=1; i<num_ep; i++) 
    626     { 
    627       displs[i] = displs[i-1] + recvcounts[ep_rank-ep_rank_loc+i-1]; 
    628     } 
    629  
    630     MPI_Scatterv_local(local_buf2, recvcounts+ep_rank-ep_rank_loc, displs, datatype, recvbuf, comm); 
    631  
    632     if(ep_rank_loc == 0) 
    633     { 
    634       if(datatype == MPI_INT) 
    635       { 
    636         delete[] static_cast<int*>(local_buf); 
    637         delete[] static_cast<int*>(local_buf2); 
    638       } 
    639       else if(datatype == MPI_FLOAT) 
    640       { 
    641         delete[] static_cast<float*>(local_buf); 
    642         delete[] static_cast<float*>(local_buf2); 
    643       } 
    644       else if(datatype == MPI_DOUBLE) 
    645       { 
    646         delete[] static_cast<double*>(local_buf); 
    647         delete[] static_cast<double*>(local_buf2); 
    648       } 
    649       else if(datatype == MPI_LONG) 
    650       { 
    651         delete[] static_cast<long*>(local_buf); 
    652         delete[] static_cast<long*>(local_buf2); 
    653       } 
    654       else if(datatype == MPI_UNSIGNED_LONG) 
    655       { 
    656         delete[] static_cast<unsigned long*>(local_buf); 
    657         delete[] static_cast<unsigned long*>(local_buf2); 
    658       } 
    659       else // if(datatype == MPI_DOUBLE) 
    660       { 
    661         delete[] static_cast<char*>(local_buf); 
    662         delete[] static_cast<char*>(local_buf2); 
    663       } 
    664     } 
    665  
    666     Message_Check(comm); 
    667     return 0; 
    668   } 
     225    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 
     226    bool is_root = ep_rank == root; 
     227 
     228    void* local_recvbuf; 
     229 
     230    if(is_master) 
     231    { 
     232      local_recvbuf = new void*[datasize * count]; 
     233    } 
     234 
     235    if(mpi_rank == root_mpi_rank) MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, root_ep_loc, comm); 
     236    else                          MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm); 
     237 
     238 
     239 
     240    if(is_master) 
     241    { 
     242      ::MPI_Reduce(local_recvbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
     243       
     244    } 
     245 
     246    if(is_master) 
     247    { 
     248      delete[] local_recvbuf; 
     249    } 
     250 
     251    MPI_Barrier_local(comm); 
     252  } 
     253 
     254 
    669255} 
    670256 
Note: See TracChangeset for help on using the changeset viewer.