Ignore:
Timestamp:
10/04/17 17:02:13 (7 years ago)
Author:
yushan
Message:

EP update part 2

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_exscan.cpp

    r1287 r1289  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
    11 #include "ep_mpi.hpp" 
    1211 
    1312using namespace std; 
     
    2726  } 
    2827 
    29   template<typename T> 
    30   void reduce_max(const T * buffer, T* recvbuf, int count) 
    31   { 
    32     transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>); 
    33   } 
    34  
    35   template<typename T> 
    36   void reduce_min(const T * buffer, T* recvbuf, int count) 
    37   { 
    38     transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>); 
    39   } 
    40  
    41   template<typename T> 
    42   void reduce_sum(const T * buffer, T* recvbuf, int count) 
    43   { 
    44     transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>()); 
    45   } 
    46  
    47  
    48   int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    49   { 
    50     valid_op(op); 
    51  
    52     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    53     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    54     int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     28  int MPI_Exscan_local2(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
     29  { 
     30    if(datatype == MPI_INT) 
     31    { 
     32      return MPI_Exscan_local_int(sendbuf, recvbuf, count, op, comm); 
     33    } 
     34    else if(datatype == MPI_FLOAT) 
     35    { 
     36      return MPI_Exscan_local_float(sendbuf, recvbuf, count, op, comm); 
     37    } 
     38    else if(datatype == MPI_DOUBLE) 
     39    { 
     40      return MPI_Exscan_local_double(sendbuf, recvbuf, count, op, comm); 
     41    } 
     42    else if(datatype == MPI_LONG) 
     43    { 
     44      return MPI_Exscan_local_long(sendbuf, recvbuf, count, op, comm); 
     45    } 
     46    else if(datatype == MPI_UNSIGNED_LONG) 
     47    { 
     48      return MPI_Exscan_local_ulong(sendbuf, recvbuf, count, op, comm); 
     49    } 
     50    else if(datatype == MPI_CHAR) 
     51    { 
     52      return MPI_Exscan_local_char(sendbuf, recvbuf, count, op, comm); 
     53    } 
     54    else 
     55    { 
     56      printf("MPI_Exscan Datatype not supported!\n"); 
     57      exit(0); 
     58    } 
     59  } 
     60 
     61 
     62 
     63 
     64  int MPI_Exscan_local_int(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     65  { 
     66    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     67    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     68 
     69    int *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_int; 
     70    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
     71    int *recv_buf = static_cast<int*>(recvbuf); 
     72 
     73    for(int j=0; j<count; j+=BUFFER_SIZE) 
     74    { 
     75 
     76      if(my_rank == 0) 
     77      { 
     78 
     79        #pragma omp critical (write_to_buffer) 
     80        { 
     81          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     82          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
     83          #pragma omp flush 
     84        } 
     85      } 
     86 
     87      MPI_Barrier_local(comm); 
     88 
     89      for(int k=1; k<num_ep; k++) 
     90      { 
     91        #pragma omp critical (write_to_buffer) 
     92        { 
     93          if(my_rank == k) 
     94          { 
     95            #pragma omp flush 
     96            if(op == MPI_SUM) 
     97            { 
     98              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     99              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<int>()); 
     100 
     101            } 
     102            else if(op == MPI_MAX) 
     103            { 
     104              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     105              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<int>); 
     106            } 
     107            else if(op == MPI_MIN) 
     108            { 
     109              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     110              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<int>); 
     111            } 
     112            else 
     113            { 
     114              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     115              exit(1); 
     116            } 
     117            #pragma omp flush 
     118          } 
     119        } 
     120 
     121        MPI_Barrier_local(comm); 
     122      } 
     123    } 
     124 
     125  } 
     126 
     127  int MPI_Exscan_local_float(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     128  { 
     129    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     130    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     131 
     132    float *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_float; 
     133    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
     134    float *recv_buf = static_cast<float*>(recvbuf); 
     135 
     136    for(int j=0; j<count; j+=BUFFER_SIZE) 
     137    { 
     138      if(my_rank == 0) 
     139      { 
     140 
     141        #pragma omp critical (write_to_buffer) 
     142        { 
     143          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     144          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
     145          #pragma omp flush 
     146        } 
     147      } 
     148 
     149      MPI_Barrier_local(comm); 
     150 
     151      for(int k=1; k<num_ep; k++) 
     152      { 
     153        #pragma omp critical (write_to_buffer) 
     154        { 
     155          if(my_rank == k) 
     156          { 
     157            #pragma omp flush 
     158            if(op == MPI_SUM) 
     159            { 
     160              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     161              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<float>()); 
     162            } 
     163            else if(op == MPI_MAX) 
     164            { 
     165              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     166              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<float>); 
     167            } 
     168            else if(op == MPI_MIN) 
     169            { 
     170              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     171              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<float>); 
     172            } 
     173            else 
     174            { 
     175              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     176              exit(1); 
     177            } 
     178            #pragma omp flush 
     179          } 
     180        } 
     181 
     182        MPI_Barrier_local(comm); 
     183      } 
     184    } 
     185  } 
     186 
     187  int MPI_Exscan_local_double(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     188  { 
     189 
     190    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     191    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     192 
     193    double *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_double; 
     194    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
     195    double *recv_buf = static_cast<double*>(recvbuf); 
     196 
     197    for(int j=0; j<count; j+=BUFFER_SIZE) 
     198    { 
     199      if(my_rank == 0) 
     200      { 
     201 
     202        #pragma omp critical (write_to_buffer) 
     203        { 
     204          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     205          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
     206          #pragma omp flush 
     207        } 
     208      } 
     209 
     210      MPI_Barrier_local(comm); 
     211 
     212      for(int k=1; k<num_ep; k++) 
     213      { 
     214        #pragma omp critical (write_to_buffer) 
     215        { 
     216          if(my_rank == k) 
     217          { 
     218            #pragma omp flush 
     219            if(op == MPI_SUM) 
     220            { 
     221              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     222              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<double>()); 
     223            } 
     224            else if(op == MPI_MAX) 
     225            { 
     226              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     227              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<double>); 
     228            } 
     229            else if(op == MPI_MIN) 
     230            { 
     231              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     232              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<double>); 
     233            } 
     234            else 
     235            { 
     236              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     237              exit(1); 
     238            } 
     239            #pragma omp flush 
     240          } 
     241        } 
     242 
     243        MPI_Barrier_local(comm); 
     244      } 
     245    } 
     246  } 
     247 
     248  int MPI_Exscan_local_long(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     249  { 
     250 
     251    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     252    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     253 
     254    long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_long; 
     255    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
     256    long *recv_buf = static_cast<long*>(recvbuf); 
     257 
     258    for(int j=0; j<count; j+=BUFFER_SIZE) 
     259    { 
     260      if(my_rank == 0) 
     261      { 
     262 
     263        #pragma omp critical (write_to_buffer) 
     264        { 
     265          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     266          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
     267          #pragma omp flush 
     268        } 
     269      } 
     270 
     271      MPI_Barrier_local(comm); 
     272 
     273      for(int k=1; k<num_ep; k++) 
     274      { 
     275        #pragma omp critical (write_to_buffer) 
     276        { 
     277          if(my_rank == k) 
     278          { 
     279            #pragma omp flush 
     280            if(op == MPI_SUM) 
     281            { 
     282              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     283              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<long>()); 
     284            } 
     285            else if(op == MPI_MAX) 
     286            { 
     287              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     288              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<long>); 
     289            } 
     290            else if(op == MPI_MIN) 
     291            { 
     292              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     293              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<long>); 
     294            } 
     295            else 
     296            { 
     297              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     298              exit(1); 
     299            } 
     300            #pragma omp flush 
     301          } 
     302        } 
     303 
     304        MPI_Barrier_local(comm); 
     305      } 
     306    } 
     307  } 
     308 
     309  int MPI_Exscan_local_ulong(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     310  { 
     311 
     312    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     313    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     314 
     315    unsigned long *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_ulong; 
     316    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
     317    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
     318 
     319    for(int j=0; j<count; j+=BUFFER_SIZE) 
     320    { 
     321      if(my_rank == 0) 
     322      { 
     323 
     324        #pragma omp critical (write_to_buffer) 
     325        { 
     326          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     327          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
     328          #pragma omp flush 
     329        } 
     330      } 
     331 
     332      MPI_Barrier_local(comm); 
     333 
     334      for(int k=1; k<num_ep; k++) 
     335      { 
     336        #pragma omp critical (write_to_buffer) 
     337        { 
     338          if(my_rank == k) 
     339          { 
     340            #pragma omp flush 
     341            if(op == MPI_SUM) 
     342            { 
     343              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     344              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<unsigned long>()); 
     345            } 
     346            else if(op == MPI_MAX) 
     347            { 
     348              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     349              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<unsigned long>); 
     350            } 
     351            else if(op == MPI_MIN) 
     352            { 
     353              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     354              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<unsigned long>); 
     355            } 
     356            else 
     357            { 
     358              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     359              exit(1); 
     360            } 
     361            #pragma omp flush 
     362          } 
     363        } 
     364 
     365        MPI_Barrier_local(comm); 
     366      } 
     367    } 
     368  } 
     369 
     370  int MPI_Exscan_local_char(const void *sendbuf, void *recvbuf, int count, MPI_Op op, MPI_Comm comm) 
     371  { 
     372 
     373    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     374    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     375 
     376    char *buffer = comm.ep_comm_ptr->comm_list->my_buffer->buf_char; 
     377    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
     378    char *recv_buf = static_cast<char*>(recvbuf); 
     379 
     380    for(int j=0; j<count; j+=BUFFER_SIZE) 
     381    { 
     382      if(my_rank == 0) 
     383      { 
     384 
     385        #pragma omp critical (write_to_buffer) 
     386        { 
     387          copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 
     388          fill(recv_buf+j, recv_buf+j+min(BUFFER_SIZE, count-j), MPI_UNDEFINED); 
     389          #pragma omp flush 
     390        } 
     391      } 
     392 
     393      MPI_Barrier_local(comm); 
     394 
     395      for(int k=1; k<num_ep; k++) 
     396      { 
     397        #pragma omp critical (write_to_buffer) 
     398        { 
     399          if(my_rank == k) 
     400          { 
     401            #pragma omp flush 
     402            if(op == MPI_SUM) 
     403            { 
     404              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     405              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, std::plus<char>()); 
     406            } 
     407            else if(op == MPI_MAX) 
     408            { 
     409              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     410              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, max_op<char>); 
     411            } 
     412            else if(op == MPI_MIN) 
     413            { 
     414              copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     415              transform(buffer, buffer+min(BUFFER_SIZE, count-j), send_buf+j, buffer, min_op<char>); 
     416            } 
     417            else 
     418            { 
     419              printf("Supported operation: MPI_SUM, MPI_MAX, MPI_MIN\n"); 
     420              exit(1); 
     421            } 
     422            #pragma omp flush 
     423          } 
     424        } 
     425 
     426        MPI_Barrier_local(comm); 
     427      } 
     428    } 
     429  } 
     430 
     431 
     432  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
     433  { 
     434 
     435    if(!comm.is_ep) 
     436    { 
     437      ::MPI_Exscan(const_cast<void*>(sendbuf), recvbuf, count, static_cast< ::MPI_Datatype>(datatype), 
     438                   static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     439      return 0; 
     440    } 
     441    if(!comm.mpi_comm) return 0; 
     442 
     443    int ep_rank, ep_rank_loc, mpi_rank; 
     444    int ep_size, num_ep, mpi_size; 
     445 
     446    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     447    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     448    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     449    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     450    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     451    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     452 
     453 
     454 
     455    ::MPI_Aint datasize, lb; 
    55456     
    56  
    57     ::MPI_Aint datasize, lb; 
    58     ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
    59  
    60     if(ep_rank_loc == 0 && mpi_rank != 0) 
    61     { 
    62       comm.my_buffer->void_buffer[0] = recvbuf; 
    63     } 
    64     if(ep_rank_loc == 0 && mpi_rank == 0) 
    65     { 
    66       comm.my_buffer->void_buffer[0] = const_cast<void*>(sendbuf);   
    67     }  
    68        
    69  
    70     MPI_Barrier_local(comm); 
    71  
    72     memcpy(recvbuf, comm.my_buffer->void_buffer[0], datasize*count); 
    73  
    74     MPI_Barrier_local(comm); 
    75  
    76     comm.my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf);   
    77      
    78     MPI_Barrier_local(comm); 
    79  
    80     if(op == MPI_SUM) 
    81     { 
    82       if(datatype == MPI_INT && datasize == sizeof(int)) 
    83       { 
    84         for(int i=0; i<ep_rank_loc; i++) 
    85           reduce_sum<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
    86       } 
    87       
    88       else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    89       { 
    90         for(int i=0; i<ep_rank_loc; i++) 
    91           reduce_sum<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
    92       } 
    93        
    94  
    95       else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    96       { 
    97         for(int i=0; i<ep_rank_loc; i++) 
    98           reduce_sum<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
    99       } 
    100  
    101       else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    102       { 
    103         for(int i=0; i<ep_rank_loc; i++) 
    104           reduce_sum<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
    105       } 
    106  
    107       else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    108       { 
    109         for(int i=0; i<ep_rank_loc; i++) 
    110           reduce_sum<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
    111       } 
    112  
    113       else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    114       { 
    115         for(int i=0; i<ep_rank_loc; i++) 
    116           reduce_sum<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
    117       } 
    118  
    119       else printf("datatype Error\n"); 
    120  
    121        
    122     } 
    123  
    124     else if(op == MPI_MAX) 
    125     { 
    126       if(datatype == MPI_INT && datasize == sizeof(int)) 
    127         for(int i=0; i<ep_rank_loc; i++) 
    128           reduce_max<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
    129  
    130       else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    131         for(int i=0; i<ep_rank_loc; i++) 
    132           reduce_max<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
    133  
    134       else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    135         for(int i=0; i<ep_rank_loc; i++) 
    136           reduce_max<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
    137  
    138       else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    139         for(int i=0; i<ep_rank_loc; i++) 
    140           reduce_max<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
    141  
    142       else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    143         for(int i=0; i<ep_rank_loc; i++) 
    144           reduce_max<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
    145  
    146       else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    147         for(int i=0; i<ep_rank_loc; i++) 
    148           reduce_max<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
    149       
    150       else printf("datatype Error\n"); 
    151     } 
    152  
    153     else //if(op == MPI_MIN) 
    154     { 
    155       if(datatype == MPI_INT && datasize == sizeof(int)) 
    156         for(int i=0; i<ep_rank_loc; i++) 
    157           reduce_min<int>(static_cast<int*>(comm.my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);     
    158  
    159       else if(datatype == MPI_FLOAT && datasize == sizeof(float)) 
    160         for(int i=0; i<ep_rank_loc; i++) 
    161           reduce_min<float>(static_cast<float*>(comm.my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);     
    162  
    163       else if(datatype == MPI_DOUBLE && datasize == sizeof(double)) 
    164         for(int i=0; i<ep_rank_loc; i++) 
    165           reduce_min<double>(static_cast<double*>(comm.my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count); 
    166  
    167       else if(datatype == MPI_CHAR && datasize == sizeof(char)) 
    168         for(int i=0; i<ep_rank_loc; i++) 
    169           reduce_min<char>(static_cast<char*>(comm.my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count); 
    170  
    171       else if(datatype == MPI_LONG && datasize == sizeof(long)) 
    172         for(int i=0; i<ep_rank_loc; i++) 
    173           reduce_min<long>(static_cast<long*>(comm.my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count); 
    174  
    175       else if(datatype == MPI_UNSIGNED_LONG && datasize == sizeof(unsigned long)) 
    176         for(int i=0; i<ep_rank_loc; i++) 
    177           reduce_min<unsigned long>(static_cast<unsigned long*>(comm.my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);     
    178  
    179       else printf("datatype Error\n"); 
    180     } 
    181  
    182     MPI_Barrier_local(comm); 
    183  
    184   } 
    185  
    186   int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) 
    187   { 
    188     if(!comm.is_ep) 
    189     { 
    190       return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
    191     } 
    192      
    193     valid_type(datatype); 
    194  
    195     int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    196     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    197     int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    198     int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    199     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    200     int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    201  
    202     ::MPI_Aint datasize, lb; 
    203     ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
    204      
    205     void* tmp_sendbuf; 
    206     tmp_sendbuf = new void*[datasize * count]; 
    207  
    208     int my_src = 0; 
    209     int my_dst = ep_rank; 
    210  
    211     std::vector<int> my_map(mpi_size, 0); 
    212  
    213     for(int i=0; i<comm.rank_map->size(); i++) my_map[comm.rank_map->at(i).second]++; 
    214  
    215     for(int i=0; i<mpi_rank; i++) my_src += my_map[i]; 
    216     my_src += ep_rank_loc; 
    217  
    218       
    219     for(int i=0; i<mpi_size; i++) 
    220     { 
    221       if(my_dst < my_map[i]) 
    222       { 
    223         my_dst = get_ep_rank(comm, my_dst, i);  
    224         break; 
    225       } 
    226       else 
    227         my_dst -= my_map[i]; 
    228     } 
    229  
    230     if(ep_rank != my_dst)  
    231     { 
    232       MPI_Request request[2]; 
    233       MPI_Status status[2]; 
    234  
    235       MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]); 
    236      
    237       MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]); 
    238      
    239       MPI_Waitall(2, request, status); 
    240     } 
    241  
    242     else memcpy(tmp_sendbuf, sendbuf, datasize*count); 
    243      
    244  
    245     void* tmp_recvbuf; 
    246     tmp_recvbuf = new void*[datasize * count];     
    247  
    248     MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm); 
     457    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
     458 
     459    void* local_scan_recvbuf; 
     460    local_scan_recvbuf = new void*[datasize * count]; 
     461 
     462 
     463    // local scan 
     464    MPI_Exscan_local2(sendbuf, recvbuf, count, datatype, op, comm); 
     465 
     466//     MPI_scan 
     467    void* local_sum; 
     468    void* mpi_scan_recvbuf; 
     469 
     470 
     471    mpi_scan_recvbuf = new void*[datasize*count]; 
    249472 
    250473    if(ep_rank_loc == 0) 
    251       ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm)); 
    252  
    253     // printf(" ID=%d : %d  %d \n", ep_rank, static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
    254      
    255     MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm); 
    256  
    257      // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d  %d \n", ep_rank, static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[1], static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1]); 
    258  
    259  
    260  
    261     if(ep_rank != my_src)  
    262     { 
    263       MPI_Request request[2]; 
    264       MPI_Status status[2]; 
    265  
    266       MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]); 
    267      
    268       MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]); 
    269      
    270       MPI_Waitall(2, request, status); 
    271     } 
    272  
    273     else memcpy(recvbuf, tmp_recvbuf, datasize*count); 
    274      
    275  
    276  
    277  
    278     delete[] tmp_sendbuf; 
    279     delete[] tmp_recvbuf; 
    280  
    281   } 
     474    { 
     475      local_sum = new void*[datasize*count]; 
     476    } 
     477 
     478 
     479    MPI_Reduce_local2(sendbuf, local_sum, count, datatype, op, comm); 
     480 
     481    if(ep_rank_loc == 0) 
     482    { 
     483      ::MPI_Exscan(local_sum, mpi_scan_recvbuf, count, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Op>(op), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     484    } 
     485 
     486 
     487    if(mpi_rank > 0) 
     488    { 
     489      MPI_Bcast_local2(mpi_scan_recvbuf, count, datatype, comm); 
     490    } 
     491 
     492 
     493    if(datatype == MPI_DOUBLE) 
     494    { 
     495      double* sum_buf = static_cast<double*>(mpi_scan_recvbuf); 
     496      double* recv_buf = static_cast<double*>(recvbuf); 
     497 
     498      if(mpi_rank != 0) 
     499      { 
     500        if(op == MPI_SUM) 
     501        { 
     502          if(ep_rank_loc == 0) 
     503          { 
     504            copy(sum_buf, sum_buf+count, recv_buf); 
     505          } 
     506          else 
     507          { 
     508            for(int i=0; i<count; i++) 
     509            { 
     510              recv_buf[i] += sum_buf[i]; 
     511            } 
     512          } 
     513        } 
     514        else if (op == MPI_MAX) 
     515        { 
     516          if(ep_rank_loc == 0) 
     517          { 
     518            copy(sum_buf, sum_buf+count, recv_buf); 
     519          } 
     520          else 
     521          { 
     522            for(int i=0; i<count; i++) 
     523            { 
     524              recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     525            } 
     526          } 
     527        } 
     528        else if(op == MPI_MIN) 
     529        { 
     530          if(ep_rank_loc == 0) 
     531          { 
     532            copy(sum_buf, sum_buf+count, recv_buf); 
     533          } 
     534          else 
     535          { 
     536            for(int i=0; i<count; i++) 
     537            { 
     538              recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     539            } 
     540          } 
     541        } 
     542        else 
     543        { 
     544          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     545          exit(1); 
     546        } 
     547      } 
     548 
     549      delete[] static_cast<double*>(mpi_scan_recvbuf); 
     550      if(ep_rank_loc == 0) 
     551      { 
     552        delete[] static_cast<double*>(local_sum); 
     553      } 
     554    } 
     555 
     556    else if(datatype == MPI_FLOAT) 
     557    { 
     558      float* sum_buf = static_cast<float*>(mpi_scan_recvbuf); 
     559      float* recv_buf = static_cast<float*>(recvbuf); 
     560 
     561      if(mpi_rank != 0) 
     562      { 
     563        if(op == MPI_SUM) 
     564        { 
     565          if(ep_rank_loc == 0) 
     566          { 
     567            copy(sum_buf, sum_buf+count, recv_buf); 
     568          } 
     569          else 
     570          { 
     571            for(int i=0; i<count; i++) 
     572            { 
     573              recv_buf[i] += sum_buf[i]; 
     574            } 
     575          } 
     576        } 
     577        else if (op == MPI_MAX) 
     578        { 
     579          if(ep_rank_loc == 0) 
     580          { 
     581            copy(sum_buf, sum_buf+count, recv_buf); 
     582          } 
     583          else 
     584          { 
     585            for(int i=0; i<count; i++) 
     586            { 
     587              recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     588            } 
     589          } 
     590        } 
     591        else if(op == MPI_MIN) 
     592        { 
     593          if(ep_rank_loc == 0) 
     594          { 
     595            copy(sum_buf, sum_buf+count, recv_buf); 
     596          } 
     597          else 
     598          { 
     599            for(int i=0; i<count; i++) 
     600            { 
     601              recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     602            } 
     603          } 
     604        } 
     605        else 
     606        { 
     607          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     608          exit(1); 
     609        } 
     610      } 
     611 
     612      delete[] static_cast<float*>(mpi_scan_recvbuf); 
     613      if(ep_rank_loc == 0) 
     614      { 
     615        delete[] static_cast<float*>(local_sum); 
     616      } 
     617    } 
     618 
     619    else if(datatype == MPI_INT) 
     620    { 
     621      int* sum_buf = static_cast<int*>(mpi_scan_recvbuf); 
     622      int* recv_buf = static_cast<int*>(recvbuf); 
     623 
     624      if(mpi_rank != 0) 
     625      { 
     626        if(op == MPI_SUM) 
     627        { 
     628          if(ep_rank_loc == 0) 
     629          { 
     630            copy(sum_buf, sum_buf+count, recv_buf); 
     631          } 
     632          else 
     633          { 
     634            for(int i=0; i<count; i++) 
     635            { 
     636              recv_buf[i] += sum_buf[i]; 
     637            } 
     638          } 
     639        } 
     640        else if (op == MPI_MAX) 
     641        { 
     642          if(ep_rank_loc == 0) 
     643          { 
     644            copy(sum_buf, sum_buf+count, recv_buf); 
     645          } 
     646          else 
     647          { 
     648            for(int i=0; i<count; i++) 
     649            { 
     650              recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     651            } 
     652          } 
     653        } 
     654        else if(op == MPI_MIN) 
     655        { 
     656          if(ep_rank_loc == 0) 
     657          { 
     658            copy(sum_buf, sum_buf+count, recv_buf); 
     659          } 
     660          else 
     661          { 
     662            for(int i=0; i<count; i++) 
     663            { 
     664              recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     665            } 
     666          } 
     667        } 
     668        else 
     669        { 
     670          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     671          exit(1); 
     672        } 
     673      } 
     674 
     675      delete[] static_cast<int*>(mpi_scan_recvbuf); 
     676      if(ep_rank_loc == 0) 
     677      { 
     678        delete[] static_cast<int*>(local_sum); 
     679      } 
     680    } 
     681 
     682    else if(datatype == MPI_CHAR) 
     683    { 
     684      char* sum_buf = static_cast<char*>(mpi_scan_recvbuf); 
     685      char* recv_buf = static_cast<char*>(recvbuf); 
     686 
     687      if(mpi_rank != 0) 
     688      { 
     689        if(op == MPI_SUM) 
     690        { 
     691          if(ep_rank_loc == 0) 
     692          { 
     693            copy(sum_buf, sum_buf+count, recv_buf); 
     694          } 
     695          else 
     696          { 
     697            for(int i=0; i<count; i++) 
     698            { 
     699              recv_buf[i] += sum_buf[i]; 
     700            } 
     701          } 
     702        } 
     703        else if (op == MPI_MAX) 
     704        { 
     705          if(ep_rank_loc == 0) 
     706          { 
     707            copy(sum_buf, sum_buf+count, recv_buf); 
     708          } 
     709          else 
     710          { 
     711            for(int i=0; i<count; i++) 
     712            { 
     713              recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     714            } 
     715          } 
     716        } 
     717        else if(op == MPI_MIN) 
     718        { 
     719          if(ep_rank_loc == 0) 
     720          { 
     721            copy(sum_buf, sum_buf+count, recv_buf); 
     722          } 
     723          else 
     724          { 
     725            for(int i=0; i<count; i++) 
     726            { 
     727              recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     728            } 
     729          } 
     730        } 
     731        else 
     732        { 
     733          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     734          exit(1); 
     735        } 
     736      } 
     737 
     738      delete[] static_cast<char*>(mpi_scan_recvbuf); 
     739      if(ep_rank_loc == 0) 
     740      { 
     741        delete[] static_cast<char*>(local_sum); 
     742      } 
     743    } 
     744 
     745    else if(datatype == MPI_LONG) 
     746    { 
     747      long* sum_buf = static_cast<long*>(mpi_scan_recvbuf); 
     748      long* recv_buf = static_cast<long*>(recvbuf); 
     749 
     750      if(mpi_rank != 0) 
     751      { 
     752        if(op == MPI_SUM) 
     753        { 
     754          if(ep_rank_loc == 0) 
     755          { 
     756            copy(sum_buf, sum_buf+count, recv_buf); 
     757          } 
     758          else 
     759          { 
     760            for(int i=0; i<count; i++) 
     761            { 
     762              recv_buf[i] += sum_buf[i]; 
     763            } 
     764          } 
     765        } 
     766        else if (op == MPI_MAX) 
     767        { 
     768          if(ep_rank_loc == 0) 
     769          { 
     770            copy(sum_buf, sum_buf+count, recv_buf); 
     771          } 
     772          else 
     773          { 
     774            for(int i=0; i<count; i++) 
     775            { 
     776              recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     777            } 
     778          } 
     779        } 
     780        else if(op == MPI_MIN) 
     781        { 
     782          if(ep_rank_loc == 0) 
     783          { 
     784            copy(sum_buf, sum_buf+count, recv_buf); 
     785          } 
     786          else 
     787          { 
     788            for(int i=0; i<count; i++) 
     789            { 
     790              recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     791            } 
     792          } 
     793        } 
     794        else 
     795        { 
     796          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     797          exit(1); 
     798        } 
     799      } 
     800 
     801      delete[] static_cast<long*>(mpi_scan_recvbuf); 
     802      if(ep_rank_loc == 0) 
     803      { 
     804        delete[] static_cast<long*>(local_sum); 
     805      } 
     806    } 
     807 
     808    else if(datatype == MPI_UNSIGNED_LONG) 
     809    { 
     810      unsigned long* sum_buf = static_cast<unsigned long*>(mpi_scan_recvbuf); 
     811      unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf); 
     812 
     813      if(mpi_rank != 0) 
     814      { 
     815        if(op == MPI_SUM) 
     816        { 
     817          if(ep_rank_loc == 0) 
     818          { 
     819            copy(sum_buf, sum_buf+count, recv_buf); 
     820          } 
     821          else 
     822          { 
     823            for(int i=0; i<count; i++) 
     824            { 
     825              recv_buf[i] += sum_buf[i]; 
     826            } 
     827          } 
     828        } 
     829        else if (op == MPI_MAX) 
     830        { 
     831          if(ep_rank_loc == 0) 
     832          { 
     833            copy(sum_buf, sum_buf+count, recv_buf); 
     834          } 
     835          else 
     836          { 
     837            for(int i=0; i<count; i++) 
     838            { 
     839              recv_buf[i] = max(recv_buf[i], sum_buf[i]); 
     840            } 
     841          } 
     842        } 
     843        else if(op == MPI_MIN) 
     844        { 
     845          if(ep_rank_loc == 0) 
     846          { 
     847            copy(sum_buf, sum_buf+count, recv_buf); 
     848          } 
     849          else 
     850          { 
     851            for(int i=0; i<count; i++) 
     852            { 
     853              recv_buf[i] = min(recv_buf[i], sum_buf[i]); 
     854            } 
     855          } 
     856        } 
     857        else 
     858        { 
     859          printf("Support operator for MPI_Scan is MPI_SUM, MPI_MAX, and MPI_MIN\n"); 
     860          exit(1); 
     861        } 
     862      } 
     863 
     864      delete[] static_cast<unsigned long*>(mpi_scan_recvbuf); 
     865      if(ep_rank_loc == 0) 
     866      { 
     867        delete[] static_cast<unsigned long*>(local_sum); 
     868      } 
     869    } 
     870 
     871 
     872  } 
     873 
     874 
    282875 
    283876} 
Note: See TracChangeset for help on using the changeset viewer.