Ignore:
Timestamp:
10/04/17 17:02:13 (7 years ago)
Author:
yushan
Message:

EP update part 2

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scan.cpp

    r1287 r1289  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
    11 #include "ep_mpi.hpp" 
    1211 
    1312using namespace std; 
     
    2726  } 
    2827 
    29   template<typename T> 
    30   void reduce_max(const T * buffer, T* recvbuf, int count) 
    31   { 
    32     transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
    33   } 
    34  
    35   template<typename T> 
    36   void reduce_min(const T * buffer, T* recvbuf, int count) 
    37   { 
    38     transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
    39   } 
    40  
    41   template<typename T> 
    42   void reduce_sum(const T * buffer, T* recvbuf, int count) 
    43   { 
    44     transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
    45   } 
    46  
    47  
    48   int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    49   { 
    50     valid_op(op); 
    51  
    52     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    53     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    54     int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    55      
     28 
     29  int MPI_Scan_local2(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
     30  { 
     31    if(datatype == MPI_INT) 
     32    { 
     33      return MPI_Scan_local_int(sendbuf, recvbuf, count, op, comm); 
     34    } 
     35    else if(datatype == MPI_FLOAT) 
     36    { 
     37      return MPI_Scan_local_float(sendbuf, recvbuf, count, op, comm); 
     38    } 
     39    else if(datatype == MPI_DOUBLE) 
     40    { 
     41      return MPI_Scan_local_double(sendbuf, recvbuf, count, op, comm); 
     42    } 
     43    else if(datatype == MPI_LONG) 
     44    { 
     45      return MPI_Scan_local_long(sendbuf, recvbuf, count, op, comm); 
     46    } 
     47    else if(datatype == MPI_UNSIGNED_LONG) 
     48    { 
     49      return MPI_Scan_local_ulong(sendbuf, recvbuf, count, op, comm); 
     50    } 
     51    else if(datatype == MPI_CHAR) 
     52    { 
     53      return MPI_Scan_local_char(sendbuf, recvbuf, count, op, comm); 
     54    } 
     55    else 
     56    { 
     57      printf("MPI_Scan Datatype not supported!\n"); 
     58      exit(0); 
     59    } 
     60 
     61  } 
     62 
     63 
     64 
     65 
     66  int MPI_Scan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     67  { 
     68    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     69    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     70 
     71    int *buffer = comm.my_buffer->buf_int; 
     72    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
     73    int *recv_buf = static_cast<int*>(recvbuf); 
     74 
     75    for(int j=0; j<count; j+=BUFFER_SIZE) 
     76    { 
     77      if(my_rank == 0) 
     78      { 
     79 
     80        #pragma omp critical (write_to_buffer) 
     81        { 
     82          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     83          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
     84          #pragma omp flush 
     85        } 
     86      } 
     87 
     88      MPI_Barrier_local(comm); 
     89 
     90      for(int k=1; k<num_ep; k++) 
     91      { 
     92        #pragma omp critical (write_to_buffer) 
     93        { 
     94          if(my_rank == k) 
     95          { 
     96            #pragma omp flush 
     97            if(op == MPI_SUM) 
     98            { 
     99              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
     100              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     101            } 
     102            else if(op == MPI_MAX) 
     103            { 
     104              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
     105              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     106            } 
     107            else if(op == MPI_MIN) 
     108            { 
     109              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
     110              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     111            } 
     112            else 
     113            { 
     114              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     115              exit(1); 
     116            } 
     117            #pragma omp flush 
     118          } 
     119        } 
     120 
     121        MPI_Barrier_local(comm); 
     122      } 
     123    } 
     124 
     125  } 
     126 
     127  int MPI_Scan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     128  { 
     129    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     130    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     131 
     132    float *buffer = comm.my_buffer->buf_float; 
     133    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
     134    float *recv_buf = static_cast<float*>(recvbuf); 
     135 
     136    for(int j=0; j<count; j+=BUFFER_SIZE) 
     137    { 
     138      if(my_rank == 0) 
     139      { 
     140 
     141        #pragma omp critical (write_to_buffer) 
     142        { 
     143          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     144          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
     145          #pragma omp flush 
     146        } 
     147      } 
     148 
     149      MPI_Barrier_local(comm); 
     150 
     151      for(int k=1; k<num_ep; k++) 
     152      { 
     153        #pragma omp critical (write_to_buffer) 
     154        { 
     155          if(my_rank == k) 
     156          { 
     157            #pragma omp flush 
     158            if(op == MPI_SUM) 
     159            { 
     160              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
     161              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     162            } 
     163            else if(op == MPI_MAX) 
     164            { 
     165              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
     166              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     167 
     168            } 
     169            else if(op == MPI_MIN) 
     170            { 
     171              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
     172              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     173 
     174            } 
     175            else 
     176            { 
     177              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     178              exit(1); 
     179            } 
     180            #pragma omp flush 
     181          } 
     182        } 
     183 
     184        MPI_Barrier_local(comm); 
     185      } 
     186    } 
     187  } 
     188 
     189  int MPI_Scan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     190  { 
     191    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     192    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     193 
     194    double *buffer = comm.my_buffer->buf_double; 
     195    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
     196    double *recv_buf = static_cast<double*>(recvbuf); 
     197 
     198    for(int j=0; j<count; j+=BUFFER_SIZE) 
     199    { 
     200      if(my_rank == 0) 
     201      { 
     202 
     203        #pragma omp critical (write_to_buffer) 
     204        { 
     205          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     206          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
     207          #pragma omp flush 
     208        } 
     209      } 
     210 
     211      MPI_Barrier_local(comm); 
     212 
     213      for(int k=1; k<num_ep; k++) 
     214      { 
     215        #pragma omp critical (write_to_buffer) 
     216        { 
     217          if(my_rank == k) 
     218          { 
     219            #pragma omp flush 
     220            if(op == MPI_SUM) 
     221            { 
     222              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
     223              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     224            } 
     225            else if(op == MPI_MAX) 
     226            { 
     227              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
     228              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     229            } 
     230            else if(op == MPI_MIN) 
     231            { 
     232              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
     233              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     234            } 
     235            else 
     236            { 
     237              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     238              exit(1); 
     239            } 
     240            #pragma omp flush 
     241          } 
     242        } 
     243 
     244        MPI_Barrier_local(comm); 
     245      } 
     246    } 
     247  } 
     248 
     249  int MPI_Scan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     250  { 
     251    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     252    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     253 
     254    long *buffer = comm.my_buffer->buf_long; 
     255    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
     256    long *recv_buf = static_cast<long*>(recvbuf); 
     257 
     258    for(int j=0; j<count; j+=BUFFER_SIZE) 
     259    { 
     260      if(my_rank == 0) 
     261      { 
     262 
     263        #pragma omp critical (write_to_buffer) 
     264        { 
     265          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     266          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
     267          #pragma omp flush 
     268        } 
     269      } 
     270 
     271      MPI_Barrier_local(comm); 
     272 
     273      for(int k=1; k<num_ep; k++) 
     274      { 
     275        #pragma omp critical (write_to_buffer) 
     276        { 
     277          if(my_rank == k) 
     278          { 
     279            #pragma omp flush 
     280            if(op == MPI_SUM) 
     281            { 
     282              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
     283              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     284            } 
     285            else if(op == MPI_MAX) 
     286            { 
     287              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
     288              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     289            } 
     290            else if(op == MPI_MIN) 
     291            { 
     292              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
     293              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     294            } 
     295            else 
     296            { 
     297              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     298              exit(1); 
     299            } 
     300            #pragma omp flush 
     301          } 
     302        } 
     303 
     304        MPI_Barrier_local(comm); 
     305      } 
     306    } 
     307  } 
     308 
     309  int MPI_Scan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     310  { 
     311    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     312    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     313 
     314    unsigned long *buffer = comm.my_buffer->buf_ulong; 
     315    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
     316    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
     317 
     318    for(int j=0; j<count; j+=BUFFER_SIZE) 
     319    { 
     320      if(my_rank == 0) 
     321      { 
     322 
     323        #pragma omp critical (write_to_buffer) 
     324        { 
     325          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     326          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
     327          #pragma omp flush 
     328        } 
     329      } 
     330 
     331      MPI_Barrier_local(comm); 
     332 
     333      for(int k=1; k<num_ep; k++) 
     334      { 
     335        #pragma omp critical (write_to_buffer) 
     336        { 
     337          if(my_rank == k) 
     338          { 
     339            #pragma omp flush 
     340            if(op == MPI_SUM) 
     341            { 
     342              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
     343              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     344            } 
     345            else if(op == MPI_MAX) 
     346            { 
     347              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
     348              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     349            } 
     350            else if(op == MPI_MIN) 
     351            { 
     352              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
     353              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     354            } 
     355            else 
     356            { 
     357              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     358              exit(1); 
     359            } 
     360            #pragma omp flush 
     361          } 
     362        } 
     363 
     364        MPI_Barrier_local(comm); 
     365      } 
     366    } 
     367  } 
     368 
     369  int MPI_Scan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     370  { 
     371    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     372    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     373 
     374    char *buffer = comm.my_buffer->buf_char; 
     375    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
     376    char *recv_buf = static_cast<char*>(recvbuf); 
     377 
     378    for(int j=0; j<count; j+=BUFFER_SIZE) 
     379    { 
     380      if(my_rank == 0) 
     381      { 
     382 
     383        #pragma omp critical (write_to_buffer) 
     384        { 
     385          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     386          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), recv_buf+j); 
     387          #pragma omp flush 
     388        } 
     389      } 
     390 
     391      MPI_Barrier_local(comm); 
     392 
     393      for(int k=1; k<num_ep; k++) 
     394      { 
     395        #pragma omp critical (write_to_buffer) 
     396        { 
     397          if(my_rank == k) 
     398          { 
     399            #pragma omp flush 
     400            if(op == MPI_SUM) 
     401            { 
     402              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
     403              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     404            } 
     405            else if(op == MPI_MAX) 
     406            { 
     407              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
     408              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     409            } 
     410            else if(op == MPI_MIN) 
     411            { 
     412              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
     413              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     414            } 
     415            else 
     416            { 
     417              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     418              exit(1); 
     419            } 
     420            #pragma omp flush 
     421          } 
     422        } 
     423 
     424        MPI_Barrier_local(comm); 
     425      } 
     426    } 
     427  } 
     428 
     429 
     430  int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
     431  { 
     432    if(!comm.is_ep) 
     433    { 
     434 
     435      ::MPI_Scan(sendbuf, recvbuf, count, static_cast< ::MPI_Datatype>(datatype), 
     436                 static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     437      return 0; 
     438    } 
     439 
     440    if(!comm.mpi_comm) return 0; 
     441 
     442    int ep_rank, ep_rank_loc, mpi_rank; 
     443    int ep_size, num_ep, mpi_size; 
     444 
     445 
     446    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     447    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     448    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     449    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     450    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     451    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     452 
    56453 
    57454    ::MPI_Aint datasize, lb; 
    58     ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
    59  
    60     if(ep_rank_loc == 0 && mpi_rank != 0) 
    61     { 
    62       if(op == MPI_SUM) 
    63       { 
    64         if(datatype == MPI_INT && datasize == sizeof(int)) 
    65           reduce_sum<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
    66            
    67         else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    68           reduce_sum<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
    69               
    70         else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    71           reduce_sum<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
    72        
    73         else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    74           reduce_sum<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
    75        
    76         else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    77           reduce_sum<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
    78              
    79         else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    80           reduce_sum<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
    81              
    82         else printf("datatype Error\n"); 
    83       } 
    84  
    85       else if(op == MPI_MAX) 
    86       { 
    87         if(datatype == MPI_INT && datasize == sizeof(int)) 
    88           reduce_max<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
    89            
    90         else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    91           reduce_max<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
    92               
    93         else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    94           reduce_max<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
    95        
    96         else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    97           reduce_max<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
    98        
    99         else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    100           reduce_max<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
    101              
    102         else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    103           reduce_max<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
    104              
    105         else printf("datatype Error\n"); 
    106       } 
    107  
    108       else //(op == MPI_MIN) 
    109       { 
    110         if(datatype == MPI_INT && datasize == sizeof(int)) 
    111           reduce_min<int>(static_cast<int*>(const_cast<void*>(sendbuf)), static_cast<int*>(recvbuf), count);     
    112            
    113         else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    114           reduce_min<float>(static_cast<float*>(const_cast<void*>(sendbuf)), static_cast<float*>(recvbuf), count);     
    115               
    116         else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    117           reduce_min<double>(static_cast<double*>(const_cast<void*>(sendbuf)), static_cast<double*>(recvbuf), count); 
    118        
    119         else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    120           reduce_min<char>(static_cast<char*>(const_cast<void*>(sendbuf)), static_cast<char*>(recvbuf), count); 
    121        
    122         else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    123           reduce_min<long>(static_cast<long*>(const_cast<void*>(sendbuf)), static_cast<long*>(recvbuf), count); 
    124              
    125         else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    126           reduce_min<unsigned long>(static_cast<unsigned long*>(const_cast<void*>(sendbuf)), static_cast<unsigned long*>(recvbuf), count);     
    127              
    128         else printf("datatype Error\n"); 
    129       } 
    130  
    131       comm.my_buffer->void_buffer[0] = recvbuf; 
    132     } 
    133     else 
    134     { 
    135       comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf);   
    136       memcpy(recvbuf, sendbuf, datasize*count); 
    137     }  
    138        
    139  
    140  
    141     MPI_Barrier_local(comm); 
    142  
    143     memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count); 
    144  
    145  
    146     if(op == MPI_SUM) 
    147     { 
    148       if(datatype == MPI_INT && datasize == sizeof(int)) 
    149       { 
    150         for(int i=1; i<ep_rank_loc+1; i++) 
    151           reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
    152       } 
    153       
    154       else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    155       { 
    156         for(int i=1; i<ep_rank_loc+1; i++) 
    157           reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
    158       } 
    159        
    160  
    161       else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    162       { 
    163         for(int i=1; i<ep_rank_loc+1; i++) 
    164           reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
    165       } 
    166  
    167       else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    168       { 
    169         for(int i=1; i<ep_rank_loc+1; i++) 
    170           reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
    171       } 
    172  
    173       else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    174       { 
    175         for(int i=1; i<ep_rank_loc+1; i++) 
    176           reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
    177       } 
    178  
    179       else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    180       { 
    181         for(int i=1; i<ep_rank_loc+1; i++) 
    182           reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
    183       } 
    184  
    185       else printf("datatype Error\n"); 
    186  
    187        
    188     } 
    189  
    190     else if(op == MPI_MAX) 
    191     { 
    192       if(datatype == MPI_INT && datasize == sizeof(int)) 
    193         for(int i=1; i<ep_rank_loc+1; i++) 
    194           reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
    195  
    196       else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    197         for(int i=1; i<ep_rank_loc+1; i++) 
    198           reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
    199  
    200       else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    201         for(int i=1; i<ep_rank_loc+1; i++) 
    202           reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
    203  
    204       else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    205         for(int i=1; i<ep_rank_loc+1; i++) 
    206           reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
    207  
    208       else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    209         for(int i=1; i<ep_rank_loc+1; i++) 
    210           reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
    211  
    212       else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    213         for(int i=1; i<ep_rank_loc+1; i++) 
    214           reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
    215       
    216       else printf("datatype Error\n"); 
    217     } 
    218  
    219     else //if(op == MPI_MIN) 
    220     { 
    221       if(datatype == MPI_INT && datasize == sizeof(int)) 
    222         for(int i=1; i<ep_rank_loc+1; i++) 
    223           reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
    224  
    225       else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    226         for(int i=1; i<ep_rank_loc+1; i++) 
    227           reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
    228  
    229       else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    230         for(int i=1; i<ep_rank_loc+1; i++) 
    231           reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
    232  
    233       else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    234         for(int i=1; i<ep_rank_loc+1; i++) 
    235           reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
    236  
    237       else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    238         for(int i=1; i<ep_rank_loc+1; i++) 
    239           reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
    240  
    241       else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    242         for(int i=1; i<ep_rank_loc+1; i++) 
    243           reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
    244  
    245       else printf("datatype Error\n"); 
    246     } 
    247  
    248     MPI_Barrier_local(comm); 
    249  
    250   } 
    251  
    252  
    253   int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    254   { 
    255     if(!comm.is_ep) 
    256     { 
    257       return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
    258     } 
    259      
    260     valid_type(datatype); 
    261  
    262     int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    263     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    264     int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    265     int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    266     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    267     int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    268  
    269     ::MPI_Aint datasize, lb; 
    270     ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
    271      
    272     void* tmp_sendbuf; 
    273     tmp_sendbuf = new void*[datasize * count]; 
    274  
    275     int my_src = 0; 
    276     int my_dst = ep_rank; 
    277  
    278     std::vector<int> my_map(mpi_size, 0); 
    279  
    280     for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++; 
    281  
    282     for(int i=0; i<mpi_rank; i++) my_src += my_map[i]; 
    283     my_src += ep_rank_loc; 
    284  
    285       
    286     for(int i=0; i<mpi_size; i++) 
    287     { 
    288       if(my_dst < my_map[i]) 
    289       { 
    290         my_dst = get_ep_rank(comm, my_dst, i);  
    291         break; 
    292       } 
    293       else 
    294         my_dst -= my_map[i]; 
    295     } 
    296  
    297     //printf("ID = %d : send to %d, recv from %d\n", ep_rank, my_dst, my_src); 
    298     MPI_Barrier(comm); 
    299  
    300     if(my_dst == ep_rank && my_src == ep_rank) memcpy(tmp_sendbuf, sendbuf, datasize*count); 
    301  
    302     if(ep_rank != my_dst)  
    303     { 
    304       MPI_Request request[2]; 
    305       MPI_Status status[2]; 
    306  
    307       MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]); 
    308      
    309       MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]); 
    310      
    311       MPI_Waitall(2, request, status); 
    312     } 
    313      
    314  
    315     void* tmp_recvbuf; 
    316     tmp_recvbuf = new void*[datasize * count];     
    317  
    318     MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm); 
     455 
     456    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
     457 
     458    void* local_scan_recvbuf; 
     459    local_scan_recvbuf = new void*[datasize * count]; 
     460 
     461 
     462    // local scan 
     463    MPI_Scan_local2(sendbuf, recvbuf, count, datatype, op, comm); 
     464 
     465//     MPI_scan 
     466    void* local_sum; 
     467    void* mpi_scan_recvbuf; 
     468 
     469 
     470    mpi_scan_recvbuf = new void*[datasize*count]; 
    319471 
    320472    if(ep_rank_loc == 0) 
    321       ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
    322  
    323     //printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
    324      
    325     MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm); 
    326  
    327     // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
    328  
    329  
    330  
    331     if(ep_rank != my_src)  
    332     { 
    333       MPI_Request request[2]; 
    334       MPI_Status status[2]; 
    335  
    336       MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]); 
    337      
    338       MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]); 
    339      
    340       MPI_Waitall(2, request, status); 
    341     } 
    342  
    343     else memcpy(recvbuf, tmp_recvbuf, datasize*count); 
    344      
    345  
    346  
    347  
    348     delete[] tmp_sendbuf; 
    349     delete[] tmp_recvbuf; 
     473    { 
     474      local_sum = new void*[datasize*count]; 
     475    } 
     476 
     477 
     478    MPI_Reduce_local2(sendbuf, local_sum, count, datatype, op, comm); 
     479 
     480    if(ep_rank_loc == 0) 
     481    { 
     482      ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     483    } 
     484 
     485 
     486    if(mpi_rank > 0) 
     487    { 
     488      MPI_Bcast_local2(mpi_scan_recvbuf, count, datatype, comm); 
     489    } 
     490 
     491 
     492    if(datatype == MPI_DOUBLE) 
     493    { 
     494      double* sum_buf = static_cast<double*>(mpi_scan_recvbuf); 
     495      double* recv_buf = static_cast<double*>(recvbuf); 
     496 
     497      if(mpi_rank != 0) 
     498      { 
     499        if(op == MPI_SUM) 
     500        { 
     501          for(int i=0; i<count; i++) 
     502          { 
     503            recv_buf[i] += sum_buf[i]; 
     504          } 
     505        } 
     506        else if (op == MPI_MAX) 
     507        { 
     508          for(int i=0; i<count; i++) 
     509          { 
     510            recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     511          } 
     512        } 
     513        else if(op == MPI_MIN) 
     514        { 
     515          for(int i=0; i<count; i++) 
     516          { 
     517            recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     518          } 
     519        } 
     520        else 
     521        { 
     522          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     523          exit(1); 
     524        } 
     525      } 
     526 
     527      delete[] static_cast<double*>(mpi_scan_recvbuf); 
     528      if(ep_rank_loc == 0) 
     529      { 
     530        delete[] static_cast<double*>(local_sum); 
     531      } 
     532    } 
     533 
     534    else if(datatype == MPI_FLOAT) 
     535    { 
     536      float* sum_buf = static_cast<float*>(mpi_scan_recvbuf); 
     537      float* recv_buf = static_cast<float*>(recvbuf); 
     538 
     539      if(mpi_rank != 0) 
     540      { 
     541        if(op == MPI_SUM) 
     542        { 
     543          for(int i=0; i<count; i++) 
     544          { 
     545            recv_buf[i] += sum_buf[i]; 
     546          } 
     547        } 
     548        else if (op == MPI_MAX) 
     549        { 
     550          for(int i=0; i<count; i++) 
     551          { 
     552            recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     553          } 
     554        } 
     555        else if(op == MPI_MIN) 
     556        { 
     557          for(int i=0; i<count; i++) 
     558          { 
     559            recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     560          } 
     561        } 
     562        else 
     563        { 
     564          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     565          exit(1); 
     566        } 
     567      } 
     568 
     569      delete[] static_cast<float*>(mpi_scan_recvbuf); 
     570      if(ep_rank_loc == 0) 
     571      { 
     572        delete[] static_cast<float*>(local_sum); 
     573      } 
     574    } 
     575 
     576    else if(datatype == MPI_INT) 
     577    { 
     578      int* sum_buf = static_cast<int*>(mpi_scan_recvbuf); 
     579      int* recv_buf = static_cast<int*>(recvbuf); 
     580 
     581      if(mpi_rank != 0) 
     582      { 
     583        if(op == MPI_SUM) 
     584        { 
     585          for(int i=0; i<count; i++) 
     586          { 
     587            recv_buf[i] += sum_buf[i]; 
     588          } 
     589        } 
     590        else if (op == MPI_MAX) 
     591        { 
     592          for(int i=0; i<count; i++) 
     593          { 
     594            recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     595          } 
     596        } 
     597        else if(op == MPI_MIN) 
     598        { 
     599          for(int i=0; i<count; i++) 
     600          { 
     601            recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     602          } 
     603        } 
     604        else 
     605        { 
     606          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     607          exit(1); 
     608        } 
     609      } 
     610 
     611      delete[] static_cast<int*>(mpi_scan_recvbuf); 
     612      if(ep_rank_loc == 0) 
     613      { 
     614        delete[] static_cast<int*>(local_sum); 
     615      } 
     616    } 
     617 
     618    else if(datatype == MPI_LONG) 
     619    { 
     620      long* sum_buf = static_cast<long*>(mpi_scan_recvbuf); 
     621      long* recv_buf = static_cast<long*>(recvbuf); 
     622 
     623      if(mpi_rank != 0) 
     624      { 
     625        if(op == MPI_SUM) 
     626        { 
     627          for(int i=0; i<count; i++) 
     628          { 
     629            recv_buf[i] += sum_buf[i]; 
     630          } 
     631        } 
     632        else if (op == MPI_MAX) 
     633        { 
     634          for(int i=0; i<count; i++) 
     635          { 
     636            recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     637          } 
     638        } 
     639        else if(op == MPI_MIN) 
     640        { 
     641          for(int i=0; i<count; i++) 
     642          { 
     643            recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     644          } 
     645        } 
     646        else 
     647        { 
     648          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     649          exit(1); 
     650        } 
     651      } 
     652 
     653      delete[] static_cast<long*>(mpi_scan_recvbuf); 
     654      if(ep_rank_loc == 0) 
     655      { 
     656        delete[] static_cast<long*>(local_sum); 
     657      } 
     658    } 
     659 
     660    else if(datatype == MPI_UNSIGNED_LONG) 
     661    { 
     662      unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf); 
     663      unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf); 
     664 
     665      if(mpi_rank != 0) 
     666      { 
     667        if(op == MPI_SUM) 
     668        { 
     669          for(int i=0; i<count; i++) 
     670          { 
     671            recv_buf[i] += sum_buf[i]; 
     672          } 
     673        } 
     674        else if (op == MPI_MAX) 
     675        { 
     676          for(int i=0; i<count; i++) 
     677          { 
     678            recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     679          } 
     680        } 
     681        else if(op == MPI_MIN) 
     682        { 
     683          for(int i=0; i<count; i++) 
     684          { 
     685            recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     686          } 
     687        } 
     688        else 
     689        { 
     690          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     691          exit(1); 
     692        } 
     693      } 
     694 
     695      delete[] static_cast<unsigned long*>(mpi_scan_recvbuf); 
     696      if(ep_rank_loc == 0) 
     697      { 
     698        delete[] static_cast<unsigned long*>(local_sum); 
     699      } 
     700    } 
     701 
     702    else if(datatype == MPI_CHAR) 
     703    { 
     704      char* sum_buf = static_cast<char*>(mpi_scan_recvbuf); 
     705      char* recv_buf = static_cast<char*>(recvbuf); 
     706 
     707      if(mpi_rank != 0) 
     708      { 
     709        if(op == MPI_SUM) 
     710        { 
     711          for(int i=0; i<count; i++) 
     712          { 
     713            recv_buf[i] += sum_buf[i]; 
     714          } 
     715        } 
     716        else if (op == MPI_MAX) 
     717        { 
     718          for(int i=0; i<count; i++) 
     719          { 
     720            recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     721          } 
     722        } 
     723        else if(op == MPI_MIN) 
     724        { 
     725          for(int i=0; i<count; i++) 
     726          { 
     727            recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     728          } 
     729        } 
     730        else 
     731        { 
     732          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     733          exit(1); 
     734        } 
     735      } 
     736 
     737      delete[] static_cast<char*>(mpi_scan_recvbuf); 
     738      if(ep_rank_loc == 0) 
     739      { 
     740        delete[] static_cast<char*>(local_sum); 
     741      } 
     742    } 
     743 
    350744 
    351745  } 
Note: See TracChangeset for help on using the changeset viewer.