Ignore:
Timestamp:
05/24/17 13:09:23 (7 years ago)
Author:
yushan
Message:

bug fixed in MPI_(All)Gatherv

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp

    r1138 r1145  
    1515namespace ep_lib 
    1616{ 
     17 
    1718  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    1819  { 
     
    347348    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    348349     
     350    if(ep_size == mpi_size)  
     351      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 
     352                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     353 
     354    int recv_plus_displs[ep_size]; 
     355    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
     356     
     357    #pragma omp single nowait 
     358    { 
     359      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]); 
     360      for(int i=1; i<num_ep-1; i++) 
     361      { 
     362        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]); 
     363        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]); 
     364      } 
     365      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]); 
     366    } 
     367 
    349368    if(ep_rank != root) 
    350369    { 
     
    366385 
    367386    void *local_gather_recvbuf; 
     387    int buffer_size; 
    368388 
    369389    if(ep_rank_loc==0) 
    370390    { 
    371       int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 
     391      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 
     392 
    372393      local_gather_recvbuf = new void*[datasize*buffer_size]; 
    373394    } 
    374395 
    375     // local gather to master 
    376     int local_displs[num_ep]; 
    377     local_displs[0] = 0; 
    378     for(int i=1; i<num_ep; i++) 
    379     { 
    380       local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc]; 
    381     } 
    382     MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm); 
     396    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 
    383397 
    384398    //MPI_Gather 
    385399    if(ep_rank_loc == 0) 
    386400    { 
    387  
    388       int gatherv_recvcnt[mpi_size]; 
    389       int gatherv_displs[mpi_size]; 
    390       int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 
    391  
    392       //gatherv_recvcnt = new int[mpi_size]; 
    393       //gatherv_displs = new int[mpi_size]; 
    394  
    395  
    396       ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    397  
    398       gatherv_displs[0] = 0; 
    399       for(int i=1; i<mpi_size; i++) 
    400       { 
    401         gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1]; 
    402       } 
    403  
    404  
    405       ::MPI_Gatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 
    406                     gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    407  
    408       //delete[] gatherv_recvcnt; 
    409       //delete[] gatherv_displs; 
    410     } 
     401      int *mpi_recvcnt= new int[mpi_size]; 
     402      int *mpi_displs= new int[mpi_size]; 
     403 
     404      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 
     405      int buff_end = buffer_size; 
     406 
     407      int mpi_sendcnt = buff_end - buff_start; 
     408 
     409 
     410      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     411      ::MPI_Gather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     412 
     413 
     414      ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 
     415                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     416 
     417      delete[] mpi_recvcnt; 
     418      delete[] mpi_displs; 
     419    } 
     420 
     421    int global_min_displs = *std::min_element(displs, displs+ep_size); 
     422    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 
    411423 
    412424 
    413425    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
    414426    { 
    415       innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm); 
     427      innode_memcpy(0, recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 
    416428    } 
    417429 
     
    487499      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 
    488500                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    489      
    490  
    491     assert(accumulate(recvcounts, recvcounts+ep_size-1, 0) >= displs[ep_size-1]); // Only for continuous gather. 
     501    
     502 
     503    int recv_plus_displs[ep_size]; 
     504    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
     505 
     506    #pragma omp single nowait 
     507    { 
     508      assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]); 
     509      for(int i=1; i<num_ep-1; i++) 
     510      { 
     511        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]); 
     512        assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]); 
     513      } 
     514      assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]); 
     515    } 
    492516 
    493517 
     
    497521 
    498522    void *local_gather_recvbuf; 
     523    int buffer_size; 
    499524 
    500525    if(ep_rank_loc==0) 
    501526    { 
    502       int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 
     527      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 
     528 
    503529      local_gather_recvbuf = new void*[datasize*buffer_size]; 
    504530    } 
    505531 
    506532    // local gather to master 
    507     int local_displs[num_ep]; 
    508     local_displs[0] = 0; 
    509     for(int i=1; i<num_ep; i++) 
    510     { 
    511       local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc]; 
    512     } 
    513     MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm); 
     533    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 
    514534 
    515535    //MPI_Gather 
    516536    if(ep_rank_loc == 0) 
    517537    { 
    518       int *gatherv_recvcnt; 
    519       int *gatherv_displs; 
    520       int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 
    521  
    522       gatherv_recvcnt = new int[mpi_size]; 
    523       gatherv_displs = new int[mpi_size]; 
    524  
    525       ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    526       gatherv_displs[0] = displs[0]; 
    527       for(int i=1; i<mpi_size; i++) 
    528       { 
    529         gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1]; 
    530       } 
    531  
    532       ::MPI_Allgatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 
    533                     gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    534  
    535       delete[] gatherv_recvcnt; 
    536       delete[] gatherv_displs; 
    537     } 
    538  
    539     MPI_Bcast_local(recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm); 
     538      int *mpi_recvcnt= new int[mpi_size]; 
     539      int *mpi_displs= new int[mpi_size]; 
     540 
     541      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 
     542      int buff_end = buffer_size; 
     543 
     544      int mpi_sendcnt = buff_end - buff_start; 
     545 
     546 
     547      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     548      ::MPI_Allgather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     549 
     550 
     551      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 
     552                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     553 
     554      delete[] mpi_recvcnt; 
     555      delete[] mpi_displs; 
     556    } 
     557 
     558    int global_min_displs = *std::min_element(displs, displs+ep_size); 
     559    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 
     560 
     561    MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 
    540562 
    541563    if(ep_rank_loc==0) 
Note: See TracChangeset for help on using the changeset viewer.