Ignore:
Timestamp:
10/04/17 17:02:13 (7 years ago)
Author:
yushan
Message:

EP update part 2

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp

    r1287 r1289  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
    11 #include "ep_mpi.hpp" 
    1211 
    1312using namespace std; 
     
    1615{ 
    1716 
    18   int MPI_Scatter_local(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int local_root, MPI_Comm comm) 
    19   { 
    20     assert(valid_type(sendtype) && valid_type(recvtype)); 
    21     assert(recvcount == sendcount); 
    22  
    23     ::MPI_Aint datasize, lb; 
    24     ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize); 
    25  
    26     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    27     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    28  
    29  
    30     if(ep_rank_loc == local_root) 
    31       comm.my_buffer->void_buffer[local_root] = const_cast<void*>(sendbuf); 
    32  
    33     MPI_Barrier_local(comm); 
    34  
    35     #pragma omp critical (_scatter)       
    36     memcpy(recvbuf, comm.my_buffer->void_buffer[local_root]+datasize*ep_rank_loc*sendcount, datasize * recvcount); 
    37      
    38  
    39     MPI_Barrier_local(comm); 
    40   } 
     17  int MPI_Scatter_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm) 
     18  { 
     19    if(datatype == MPI_INT) 
     20    { 
     21      Debug("datatype is INT\n"); 
     22      return MPI_Scatter_local_int(sendbuf, count, recvbuf, comm); 
     23    } 
     24    else if(datatype == MPI_FLOAT) 
     25    { 
     26      Debug("datatype is FLOAT\n"); 
     27      return MPI_Scatter_local_float(sendbuf, count, recvbuf, comm); 
     28    } 
     29    else if(datatype == MPI_DOUBLE) 
     30    { 
     31      Debug("datatype is DOUBLE\n"); 
     32      return MPI_Scatter_local_double(sendbuf, count, recvbuf, comm); 
     33    } 
     34    else if(datatype == MPI_LONG) 
     35    { 
     36      Debug("datatype is LONG\n"); 
     37      return MPI_Scatter_local_long(sendbuf, count, recvbuf, comm); 
     38    } 
     39    else if(datatype == MPI_UNSIGNED_LONG) 
     40    { 
     41      Debug("datatype is uLONG\n"); 
     42      return MPI_Scatter_local_ulong(sendbuf, count, recvbuf, comm); 
     43    } 
     44    else if(datatype == MPI_CHAR) 
     45    { 
     46      Debug("datatype is CHAR\n"); 
     47      return MPI_Scatter_local_char(sendbuf, count, recvbuf, comm); 
     48    } 
     49    else 
     50    { 
     51      printf("MPI_Scatter Datatype not supported!\n"); 
     52      exit(0); 
     53    } 
     54  } 
     55 
     56  int MPI_Scatter_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
     57  { 
     58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     60 
     61 
     62    int *buffer = comm.my_buffer->buf_int; 
     63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
     64    int *recv_buf = static_cast<int*>(recvbuf); 
     65 
     66    for(int k=0; k<num_ep; k++) 
     67    { 
     68      for(int j=0; j<count; j+=BUFFER_SIZE) 
     69      { 
     70        if(my_rank == 0) 
     71        { 
     72          #pragma omp critical (write_to_buffer) 
     73          { 
     74            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
     75            #pragma omp flush 
     76          } 
     77        } 
     78 
     79        MPI_Barrier_local(comm); 
     80 
     81        if(my_rank == k) 
     82        { 
     83          #pragma omp critical (read_from_buffer) 
     84          { 
     85            #pragma omp flush 
     86            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     87          } 
     88        } 
     89        MPI_Barrier_local(comm); 
     90      } 
     91    } 
     92  } 
     93 
     94  int MPI_Scatter_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
     95  { 
     96    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     97    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     98 
     99    float *buffer = comm.my_buffer->buf_float; 
     100    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
     101    float *recv_buf = static_cast<float*>(recvbuf); 
     102 
     103    for(int k=0; k<num_ep; k++) 
     104    { 
     105      for(int j=0; j<count; j+=BUFFER_SIZE) 
     106      { 
     107        if(my_rank == 0) 
     108        { 
     109          #pragma omp critical (write_to_buffer) 
     110          { 
     111            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
     112            #pragma omp flush 
     113          } 
     114        } 
     115 
     116        MPI_Barrier_local(comm); 
     117 
     118        if(my_rank == k) 
     119        { 
     120          #pragma omp critical (read_from_buffer) 
     121          { 
     122            #pragma omp flush 
     123            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     124          } 
     125        } 
     126        MPI_Barrier_local(comm); 
     127      } 
     128    } 
     129  } 
     130 
     131  int MPI_Scatter_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
     132  { 
     133    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     134    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     135 
     136    double *buffer = comm.my_buffer->buf_double; 
     137    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
     138    double *recv_buf = static_cast<double*>(recvbuf); 
     139 
     140    for(int k=0; k<num_ep; k++) 
     141    { 
     142      for(int j=0; j<count; j+=BUFFER_SIZE) 
     143      { 
     144        if(my_rank == 0) 
     145        { 
     146          #pragma omp critical (write_to_buffer) 
     147          { 
     148            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
     149            #pragma omp flush 
     150          } 
     151        } 
     152 
     153        MPI_Barrier_local(comm); 
     154 
     155        if(my_rank == k) 
     156        { 
     157          #pragma omp critical (read_from_buffer) 
     158          { 
     159            #pragma omp flush 
     160            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     161          } 
     162        } 
     163        MPI_Barrier_local(comm); 
     164      } 
     165    } 
     166  } 
     167 
     168  int MPI_Scatter_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
     169  { 
     170    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     171    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     172 
     173    long *buffer = comm.my_buffer->buf_long; 
     174    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
     175    long *recv_buf = static_cast<long*>(recvbuf); 
     176 
     177    for(int k=0; k<num_ep; k++) 
     178    { 
     179      for(int j=0; j<count; j+=BUFFER_SIZE) 
     180      { 
     181        if(my_rank == 0) 
     182        { 
     183          #pragma omp critical (write_to_buffer) 
     184          { 
     185            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
     186            #pragma omp flush 
     187          } 
     188        } 
     189 
     190        MPI_Barrier_local(comm); 
     191 
     192        if(my_rank == k) 
     193        { 
     194          #pragma omp critical (read_from_buffer) 
     195          { 
     196            #pragma omp flush 
     197            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     198          } 
     199        } 
     200        MPI_Barrier_local(comm); 
     201      } 
     202    } 
     203  } 
     204 
     205 
     206  int MPI_Scatter_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
     207  { 
     208    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     209    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     210 
     211    unsigned long *buffer = comm.my_buffer->buf_ulong; 
     212    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
     213    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
     214 
     215    for(int k=0; k<num_ep; k++) 
     216    { 
     217      for(int j=0; j<count; j+=BUFFER_SIZE) 
     218      { 
     219        if(my_rank == 0) 
     220        { 
     221          #pragma omp critical (write_to_buffer) 
     222          { 
     223            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
     224            #pragma omp flush 
     225          } 
     226        } 
     227 
     228        MPI_Barrier_local(comm); 
     229 
     230        if(my_rank == k) 
     231        { 
     232          #pragma omp critical (read_from_buffer) 
     233          { 
     234            #pragma omp flush 
     235            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     236          } 
     237        } 
     238        MPI_Barrier_local(comm); 
     239      } 
     240    } 
     241  } 
     242 
     243 
     244  int MPI_Scatter_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
     245  { 
     246    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     247    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     248 
     249    char *buffer = comm.my_buffer->buf_char; 
     250    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
     251    char *recv_buf = static_cast<char*>(recvbuf); 
     252 
     253    for(int k=0; k<num_ep; k++) 
     254    { 
     255      for(int j=0; j<count; j+=BUFFER_SIZE) 
     256      { 
     257        if(my_rank == 0) 
     258        { 
     259          #pragma omp critical (write_to_buffer) 
     260          { 
     261            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
     262            #pragma omp flush 
     263          } 
     264        } 
     265 
     266        MPI_Barrier_local(comm); 
     267 
     268        if(my_rank == k) 
     269        { 
     270          #pragma omp critical (read_from_buffer) 
     271          { 
     272            #pragma omp flush 
     273            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
     274          } 
     275        } 
     276        MPI_Barrier_local(comm); 
     277      } 
     278    } 
     279  } 
     280 
     281 
     282 
     283 
    41284 
    42285  int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) 
     
    44287    if(!comm.is_ep) 
    45288    { 
    46       return ::MPI_Scatter(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), root, to_mpi_comm(comm.mpi_comm)); 
    47     } 
    48     
    49     assert(sendcount == recvcount); 
    50  
    51     int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    52     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    53     int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    54     int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    55     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    56     int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     289      ::MPI_Scatter(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 
     290                    root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     291      return 0; 
     292    } 
     293 
     294    if(!comm.mpi_comm) return 0; 
     295 
     296    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 
    57297 
    58298    int root_mpi_rank = comm.rank_map->at(root).second; 
    59299    int root_ep_loc = comm.rank_map->at(root).first; 
    60300 
    61     bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 
    62     bool is_root = ep_rank == root; 
     301    int ep_rank, ep_rank_loc, mpi_rank; 
     302    int ep_size, num_ep, mpi_size; 
     303 
     304    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     305    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     306    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     307    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     308    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     309    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     310 
    63311 
    64312    MPI_Datatype datatype = sendtype; 
     
    66314 
    67315    ::MPI_Aint datasize, lb; 
    68     ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
    69      
    70     void *tmp_sendbuf; 
    71     if(is_root) tmp_sendbuf = new void*[ep_size * count * datasize]; 
    72  
    73     // reorder tmp_sendbuf 
    74     std::vector<int>local_ranks(num_ep); 
    75     std::vector<int>ranks(ep_size); 
    76  
    77     if(mpi_rank == root_mpi_rank) MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), root_ep_loc, comm); 
    78     else                          MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), 0, comm); 
    79  
    80  
    81     std::vector<int> recvcounts(mpi_size, 0); 
    82     std::vector<int> displs(mpi_size, 0); 
    83  
    84  
    85     if(is_master) 
    86     { 
    87       for(int i=0; i<ep_size; i++) 
    88       { 
    89         recvcounts[comm.rank_map->at(i).second]++; 
    90       } 
    91  
     316 
     317    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
     318 
     319 
     320    void *master_sendbuf; 
     321    void *local_recvbuf; 
     322 
     323    if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) 
     324    { 
     325      if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count*ep_size]; 
     326 
     327      innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count*ep_size, datatype, comm); 
     328    } 
     329 
     330 
     331 
     332    if(ep_rank_loc == 0) 
     333    { 
     334      int mpi_sendcnt = count*num_ep; 
     335      int mpi_scatterv_sendcnt[mpi_size]; 
     336      int displs[mpi_size]; 
     337 
     338      local_recvbuf = new void*[datasize*mpi_sendcnt]; 
     339 
     340      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     341 
     342      displs[0] = 0; 
    92343      for(int i=1; i<mpi_size; i++) 
    93         displs[i] = displs[i-1] + recvcounts[i-1]; 
    94  
    95       ::MPI_Gatherv(local_ranks.data(), num_ep, MPI_INT, ranks.data(), recvcounts.data(), displs.data(), MPI_INT, root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
    96     } 
    97  
    98  
    99  
    100     // if(is_root) printf("\nranks = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", ranks[0], ranks[1], ranks[2], ranks[3], ranks[4], ranks[5], ranks[6], ranks[7],  
    101     //                                                                                   ranks[8], ranks[9], ranks[10], ranks[11], ranks[12], ranks[13], ranks[14], ranks[15]); 
    102  
    103     if(is_root) 
    104     for(int i=0; i<ep_size; i++) 
    105     { 
    106       memcpy(tmp_sendbuf + i*datasize*count, sendbuf + ranks[i]*datasize*count, count*datasize); 
    107     } 
    108  
    109     // if(is_root) printf("\ntmp_sendbuf = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[2], static_cast<int*>(tmp_sendbuf)[4], static_cast<int*>(tmp_sendbuf)[6],  
    110     //                                                                           static_cast<int*>(tmp_sendbuf)[8], static_cast<int*>(tmp_sendbuf)[10], static_cast<int*>(tmp_sendbuf)[12], static_cast<int*>(tmp_sendbuf)[14],  
    111     //                                                                           static_cast<int*>(tmp_sendbuf)[16], static_cast<int*>(tmp_sendbuf)[18], static_cast<int*>(tmp_sendbuf)[20], static_cast<int*>(tmp_sendbuf)[22],  
    112     //                                                                           static_cast<int*>(tmp_sendbuf)[24], static_cast<int*>(tmp_sendbuf)[26], static_cast<int*>(tmp_sendbuf)[28], static_cast<int*>(tmp_sendbuf)[30] ); 
    113  
    114  
    115     // MPI_Scatterv from root to masters 
    116     void* local_recvbuf; 
    117     if(is_master) local_recvbuf = new void*[datasize * num_ep * count]; 
    118  
    119  
    120     if(is_master) 
    121     { 
    122       int local_sendcount = num_ep * count; 
    123       ::MPI_Gather(&local_sendcount, 1, to_mpi_type(MPI_INT), recvcounts.data(), 1, to_mpi_type(MPI_INT), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
    124        
    125       if(is_root) for(int i=1; i<mpi_size; i++) displs[i] = displs[i-1] + recvcounts[i-1]; 
    126  
    127       ::MPI_Scatterv(tmp_sendbuf, recvcounts.data(), displs.data(), to_mpi_type(sendtype), local_recvbuf, num_ep*count, to_mpi_type(recvtype), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
    128  
    129       // printf("local_recvbuf = %d %d %d %d\n", static_cast<int*>(local_recvbuf)[0], static_cast<int*>(local_recvbuf)[1], static_cast<int*>(local_recvbuf)[2], static_cast<int*>(local_recvbuf)[3]); 
    130                                                           // static_cast<int*>(local_recvbuf)[4], static_cast<int*>(local_recvbuf)[5], static_cast<int*>(local_recvbuf)[6], static_cast<int*>(local_recvbuf)[7]); 
    131     } 
    132  
    133     if(mpi_rank == root_mpi_rank) MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, root_ep_loc, comm); 
    134     else                          MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, 0, comm); 
    135  
    136     if(is_root)   delete[] tmp_sendbuf; 
    137     if(is_master) delete[] local_recvbuf; 
    138   } 
    139  
     344        displs[i] = displs[i-1] + mpi_scatterv_sendcnt[i-1]; 
     345 
     346 
     347      if(root_ep_loc!=0) 
     348      { 
     349        ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 
     350                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     351      } 
     352      else 
     353      { 
     354        ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 
     355                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     356      } 
     357    } 
     358 
     359    MPI_Scatter_local2(local_recvbuf, count, datatype, recvbuf, comm); 
     360 
     361    if(ep_rank_loc == 0) 
     362    { 
     363      if(datatype == MPI_INT) 
     364      { 
     365        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf); 
     366        delete[] static_cast<int*>(local_recvbuf); 
     367      } 
     368      else if(datatype == MPI_FLOAT) 
     369      { 
     370        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf); 
     371        delete[] static_cast<float*>(local_recvbuf); 
     372      } 
     373      else if(datatype == MPI_DOUBLE) 
     374      { 
     375        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf); 
     376        delete[] static_cast<double*>(local_recvbuf); 
     377      } 
     378      else if(datatype == MPI_LONG) 
     379      { 
     380        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf); 
     381        delete[] static_cast<long*>(local_recvbuf); 
     382      } 
     383      else if(datatype == MPI_UNSIGNED_LONG) 
     384      { 
     385        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf); 
     386        delete[] static_cast<unsigned long*>(local_recvbuf); 
     387      } 
     388      else //if(datatype == MPI_DOUBLE) 
     389      { 
     390        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf); 
     391        delete[] static_cast<char*>(local_recvbuf); 
     392      } 
     393    } 
     394 
     395  } 
    140396} 
Note: See TracChangeset for help on using the changeset viewer.