Ignore:
Timestamp:
06/08/17 17:31:50 (7 years ago)
Author:
yushan
Message:

Bug fixed in MPI_(All)Gatherv with displs

Location:
XIOS/dev/branch_yushan_merged/extern/src_ep_dev
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gather.cpp

    r1151 r1164  
    355355    void *local_gather_recvbuf; 
    356356    void *master_recvbuf; 
    357     if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)]; 
     357    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0)  
     358    { 
     359      master_recvbuf = new void*[datasize*ep_size*count]; 
     360    } 
    358361 
    359362    if(ep_rank_loc==0) 
     
    404407    { 
    405408      innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm); 
    406       if(ep_rank_loc == 0 ) delete[] master_recvbuf; 
    407409    } 
    408410 
     
    411413    if(ep_rank_loc==0) 
    412414    { 
    413  
    414415      if(datatype == MPI_INT) 
    415416      { 
     
    436437        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
    437438      } 
    438     } 
    439  
    440  
     439       
     440      if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) delete[] master_recvbuf; 
     441    } 
    441442  } 
    442443 
  • XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp

    r1151 r1164  
    366366    int recv_plus_displs[ep_size]; 
    367367    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
    368      
    369     #pragma omp single nowait 
    370     { 
    371       assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]); 
     368 
     369    for(int j=0; j<mpi_size; j++) 
     370    { 
     371      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 
     372         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])   
     373      {   
     374        Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n"); 
     375        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     376      } 
     377 
    372378      for(int i=1; i<num_ep-1; i++) 
    373379      { 
    374         assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]); 
    375         assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]); 
    376       } 
    377       assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]); 
     380        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||  
     381           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 
     382        { 
     383          Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n"); 
     384          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     385        } 
     386      } 
    378387    } 
    379388 
     
    391400    void *master_recvbuf; 
    392401 
    393     if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)]; 
     402    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0)  
     403    { 
     404      master_recvbuf = new void*[sizeof(recvbuf)]; 
     405      assert(root_ep_loc == 0); 
     406    } 
    394407 
    395408    if(ep_rank_loc==0) 
     
    507520    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    508521    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     522 
     523    //printf("size of recvbuf = %lu\n", sizeof(recvbuf)); 
     524    //printf("size of (char*)recvbuf = %lu\n", sizeof((char*)recvbuf)); 
    509525     
    510526    if(ep_size == mpi_size)  
     
    516532    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
    517533 
    518     #pragma omp single nowait 
    519     { 
    520       assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]); 
     534    for(int j=0; j<mpi_size; j++) 
     535    { 
     536      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 
     537         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])   
     538      {   
     539        Debug("Call special implementation of mpi_allgatherv.\n"); 
     540        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     541      } 
     542 
    521543      for(int i=1; i<num_ep-1; i++) 
    522544      { 
    523         assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]); 
    524         assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]); 
    525       } 
    526       assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]); 
    527     } 
    528  
     545        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||  
     546           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 
     547        { 
     548          Debug("Call special implementation of mpi_allgatherv.\n"); 
     549          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     550        } 
     551      } 
     552    } 
    529553 
    530554    ::MPI_Aint datasize, lb; 
     
    602626  } 
    603627 
     628  int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     629                          MPI_Datatype recvtype, int root, MPI_Comm comm) 
     630  { 
     631    int ep_rank, ep_rank_loc, mpi_rank; 
     632    int ep_size, num_ep, mpi_size; 
     633 
     634    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     635    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     636    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     637    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     638    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     639    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     640 
     641    int root_mpi_rank = comm.rank_map->at(root).second; 
     642    int root_ep_loc = comm.rank_map->at(root).first; 
     643 
     644    ::MPI_Aint datasize, lb; 
     645    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 
     646 
     647    void *local_gather_recvbuf; 
     648    int buffer_size; 
     649 
     650    int *local_displs = new int[num_ep]; 
     651    int *local_rvcnts = new int[num_ep]; 
     652    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 
     653    local_displs[0] = 0; 
     654    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 
     655 
     656    if(ep_rank_loc==0) 
     657    { 
     658      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 
     659      local_gather_recvbuf = new void*[datasize*buffer_size]; 
     660    } 
     661 
     662    // local gather to master 
     663    MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 
     664 
     665    int **mpi_recvcnts = new int*[num_ep]; 
     666    int **mpi_displs   = new int*[num_ep]; 
     667    for(int i=0; i<num_ep; i++)  
     668    { 
     669      mpi_recvcnts[i] = new int[mpi_size]; 
     670      mpi_displs[i]   = new int[mpi_size]; 
     671      for(int j=0; j<mpi_size; j++) 
     672      { 
     673        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 
     674        mpi_displs[i][j]   = displs[j*num_ep + i]; 
     675      } 
     676    }  
     677 
     678    void *master_recvbuf; 
     679    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)]; 
     680 
     681    if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop 
     682      for(int i=0; i<num_ep; i++) 
     683      { 
     684        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 
     685                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     686      } 
     687    if(ep_rank_loc == 0 && root_ep_loc != 0) 
     688      for(int i=0; i<num_ep; i++) 
     689      { 
     690        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i], 
     691                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     692      } 
     693 
     694 
     695    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
     696    { 
     697      for(int i=0; i<ep_size; i++) 
     698        innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm); 
     699 
     700      if(ep_rank_loc == 0) delete[] master_recvbuf; 
     701    } 
     702 
     703     
     704    delete[] local_displs; 
     705    delete[] local_rvcnts; 
     706    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];  
     707                                  delete[] mpi_displs[i]; } 
     708    delete[] mpi_recvcnts; 
     709    delete[] mpi_displs; 
     710    if(ep_rank_loc==0) 
     711    { 
     712      if(sendtype == MPI_INT) 
     713      { 
     714        delete[] static_cast<int*>(local_gather_recvbuf); 
     715      } 
     716      else if(sendtype == MPI_FLOAT) 
     717      { 
     718        delete[] static_cast<float*>(local_gather_recvbuf); 
     719      } 
     720      else if(sendtype == MPI_DOUBLE) 
     721      { 
     722        delete[] static_cast<double*>(local_gather_recvbuf); 
     723      } 
     724      else if(sendtype == MPI_LONG) 
     725      { 
     726        delete[] static_cast<long*>(local_gather_recvbuf); 
     727      } 
     728      else if(sendtype == MPI_UNSIGNED_LONG) 
     729      { 
     730        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
     731      } 
     732      else // if(sendtype == MPI_CHAR) 
     733      { 
     734        delete[] static_cast<char*>(local_gather_recvbuf); 
     735      } 
     736    } 
     737  } 
     738 
     739  int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     740                             MPI_Datatype recvtype, MPI_Comm comm) 
     741  { 
     742    int ep_rank, ep_rank_loc, mpi_rank; 
     743    int ep_size, num_ep, mpi_size; 
     744 
     745    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     746    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     747    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     748    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     749    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     750    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     751 
     752 
     753    ::MPI_Aint datasize, lb; 
     754    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 
     755 
     756    void *local_gather_recvbuf; 
     757    int buffer_size; 
     758 
     759    int *local_displs = new int[num_ep]; 
     760    int *local_rvcnts = new int[num_ep]; 
     761    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 
     762    local_displs[0] = 0; 
     763    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 
     764 
     765    if(ep_rank_loc==0) 
     766    { 
     767      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 
     768      local_gather_recvbuf = new void*[datasize*buffer_size]; 
     769    } 
     770 
     771    // local gather to master 
     772    MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 
     773 
     774    int **mpi_recvcnts = new int*[num_ep]; 
     775    int **mpi_displs   = new int*[num_ep]; 
     776    for(int i=0; i<num_ep; i++)  
     777    { 
     778      mpi_recvcnts[i] = new int[mpi_size]; 
     779      mpi_displs[i]   = new int[mpi_size]; 
     780      for(int j=0; j<mpi_size; j++) 
     781      { 
     782        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 
     783        mpi_displs[i][j]   = displs[j*num_ep + i]; 
     784      } 
     785    }  
     786 
     787    if(ep_rank_loc == 0) // master in MPI_Allgatherv loop 
     788    for(int i=0; i<num_ep; i++) 
     789    { 
     790      ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 
     791                  static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     792    } 
     793 
     794    for(int i=0; i<ep_size; i++) 
     795      MPI_Bcast_local(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm); 
     796 
     797     
     798    delete[] local_displs; 
     799    delete[] local_rvcnts; 
     800    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];  
     801                                  delete[] mpi_displs[i]; } 
     802    delete[] mpi_recvcnts; 
     803    delete[] mpi_displs; 
     804    if(ep_rank_loc==0) 
     805    { 
     806      if(sendtype == MPI_INT) 
     807      { 
     808        delete[] static_cast<int*>(local_gather_recvbuf); 
     809      } 
     810      else if(sendtype == MPI_FLOAT) 
     811      { 
     812        delete[] static_cast<float*>(local_gather_recvbuf); 
     813      } 
     814      else if(sendtype == MPI_DOUBLE) 
     815      { 
     816        delete[] static_cast<double*>(local_gather_recvbuf); 
     817      } 
     818      else if(sendtype == MPI_LONG) 
     819      { 
     820        delete[] static_cast<long*>(local_gather_recvbuf); 
     821      } 
     822      else if(sendtype == MPI_UNSIGNED_LONG) 
     823      { 
     824        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
     825      } 
     826      else // if(sendtype == MPI_CHAR) 
     827      { 
     828        delete[] static_cast<char*>(local_gather_recvbuf); 
     829      } 
     830    } 
     831  } 
     832 
    604833 
    605834} 
  • XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_lib_collective.hpp

    r1134 r1164  
    3131  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
    3232                  MPI_Datatype recvtype, int root, MPI_Comm comm); 
     33  int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     34                          MPI_Datatype recvtype, int root, MPI_Comm comm); 
    3335  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
    34                   MPI_Datatype recvtype, MPI_Comm comm); 
     36                     MPI_Datatype recvtype, MPI_Comm comm); 
     37  int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     38                             MPI_Datatype recvtype, MPI_Comm comm); 
     39 
    3540 
    3641  int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm); 
  • XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_wait.cpp

    r1153 r1164  
    2222    if(request->type == 1) 
    2323    { 
    24       ::MPI_Request mpi_request = static_cast< ::MPI_Request >(request->mpi_request); 
     24      ::MPI_Request *mpi_request = static_cast< ::MPI_Request* >(&(request->mpi_request)); 
    2525      ::MPI_Status mpi_status; 
    26       ::MPI_Wait(&mpi_request, &mpi_status); 
     26      ::MPI_Errhandler_set(MPI_COMM_WORLD_STD, MPI_ERRORS_RETURN); 
     27      int error_code = ::MPI_Wait(mpi_request, &mpi_status); 
     28      if (error_code != MPI_SUCCESS) { 
     29       
     30         char error_string[BUFSIZ]; 
     31         int length_of_error_string, error_class; 
     32       
     33         ::MPI_Error_class(error_code, &error_class); 
     34         ::MPI_Error_string(error_class, error_string, &length_of_error_string); 
     35         printf("%s\n", error_string); 
     36      } 
    2737 
    2838      status->mpi_status = &mpi_status; 
Note: See TracChangeset for help on using the changeset viewer.