Ignore:
Timestamp:
10/04/17 17:02:13 (7 years ago)
Author:
yushan
Message:

EP update part 2

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp

    r1287 r1289  
    1515namespace ep_lib 
    1616{ 
    17  
    18   int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 
     17   int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 
    1918  { 
    2019    assert(valid_type(datatype)); 
     
    186185  } 
    187186 
     187  // int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm) 
     188  // { 
     189 
     190  //   if(!comm.is_ep && comm.mpi_comm) 
     191  //   { 
     192  //     ::MPI_Allgatherv(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcounts, displs, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm)); 
     193  //     return 0; 
     194  //   } 
     195 
     196  //   if(!comm.mpi_comm) return 0; 
     197 
     198 
     199 
     200 
     201  //   assert(valid_type(sendtype) && valid_type(recvtype)); 
     202 
     203  //   MPI_Datatype datatype = sendtype; 
     204  //   int count = sendcount; 
     205 
     206  //   ::MPI_Aint datasize, lb; 
     207 
     208  //   ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     209 
     210 
     211  //   int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     212  //   int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     213  //   int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     214  //   int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     215  //   int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     216  //   int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     217 
     218 
     219  //   assert(sendcount == recvcounts[ep_rank]); 
     220 
     221  //   bool is_master = ep_rank_loc==0; 
     222 
     223  //   void* local_recvbuf; 
     224  //   void* tmp_recvbuf; 
     225 
     226  //   int recvbuf_size = 0; 
     227  //   for(int i=0; i<ep_size; i++) 
     228  //     recvbuf_size = max(recvbuf_size, displs[i]+recvcounts[i]); 
     229 
     230 
     231  //   vector<int>local_recvcounts(num_ep, 0); 
     232  //   vector<int>local_displs(num_ep, 0); 
     233 
     234  //   MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), 0, comm); 
     235  //   for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_recvcounts[i-1];  
     236 
     237 
     238  //   if(is_master) 
     239  //   { 
     240  //     local_recvbuf = new void*[datasize * std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0)]; 
     241  //     tmp_recvbuf = new void*[datasize * std::accumulate(recvcounts, recvcounts+ep_size, 0)]; 
     242  //   } 
     243 
     244  //   MPI_Gatherv_local(sendbuf, count, datatype, local_recvbuf, local_recvcounts.data(), local_displs.data(), 0, comm); 
     245 
     246 
     247  //   if(is_master) 
     248  //   { 
     249  //     std::vector<int>mpi_recvcounts(mpi_size, 0); 
     250  //     std::vector<int>mpi_displs(mpi_size, 0); 
     251 
     252  //     int local_sendcount = std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0); 
     253  //     MPI_Allgather(&local_sendcount, 1, MPI_INT, mpi_recvcounts.data(), 1, MPI_INT, to_mpi_comm(comm.mpi_comm)); 
     254 
     255  //     for(int i=1; i<mpi_size; i++) 
     256  //       mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1]; 
     257 
     258 
     259  //     ::MPI_Allgatherv(local_recvbuf, local_sendcount, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts.data(), mpi_displs.data(), to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm)); 
     260 
     261 
     262 
     263  //     // reorder  
     264  //     int offset; 
     265  //     for(int i=0; i<ep_size; i++) 
     266  //     { 
     267  //       int extra = 0; 
     268  //       for(int j=0, k=0; j<ep_size, k<comm.rank_map->at(i).first; j++) 
     269  //         if(comm.rank_map->at(i).second == comm.rank_map->at(j).second) 
     270  //         { 
     271  //           extra += recvcounts[j]; 
     272  //           k++; 
     273  //         }   
     274 
     275  //       offset = mpi_displs[comm.rank_map->at(i).second] +  extra; 
     276 
     277  //       memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize); 
     278         
     279  //     } 
     280 
     281  //   } 
     282 
     283  //   MPI_Bcast_local(recvbuf, recvbuf_size, datatype, 0, comm); 
     284 
     285  //   if(is_master) 
     286  //   { 
     287  //     delete[] local_recvbuf; 
     288  //     delete[] tmp_recvbuf; 
     289  //   } 
     290 
     291  // } 
     292 
     293 
     294  int MPI_Gatherv_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     295  { 
     296    if(datatype == MPI_INT) 
     297    { 
     298      Debug("datatype is INT\n"); 
     299      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm); 
     300    } 
     301    else if(datatype == MPI_FLOAT) 
     302    { 
     303      Debug("datatype is FLOAT\n"); 
     304      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm); 
     305    } 
     306    else if(datatype == MPI_DOUBLE) 
     307    { 
     308      Debug("datatype is DOUBLE\n"); 
     309      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm); 
     310    } 
     311    else if(datatype == MPI_LONG) 
     312    { 
     313      Debug("datatype is LONG\n"); 
     314      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm); 
     315    } 
     316    else if(datatype == MPI_UNSIGNED_LONG) 
     317    { 
     318      Debug("datatype is uLONG\n"); 
     319      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm); 
     320    } 
     321    else if(datatype == MPI_CHAR) 
     322    { 
     323      Debug("datatype is CHAR\n"); 
     324      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm); 
     325    } 
     326    else 
     327    { 
     328      printf("MPI_Gatherv Datatype not supported!\n"); 
     329      exit(0); 
     330    } 
     331  } 
     332 
     333  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     334  { 
     335    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     336    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     337 
     338    int *buffer = comm.my_buffer->buf_int; 
     339    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
     340    int *recv_buf = static_cast<int*>(recvbuf); 
     341 
     342    if(my_rank == 0) 
     343    { 
     344      assert(count == recvcounts[0]); 
     345      copy(send_buf, send_buf+count, recv_buf + displs[0]); 
     346    } 
     347 
     348    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
     349    { 
     350      for(int k=1; k<num_ep; k++) 
     351      { 
     352        if(my_rank == k) 
     353        { 
     354          #pragma omp critical (write_to_buffer) 
     355          { 
     356            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
     357            #pragma omp flush 
     358          } 
     359        } 
     360 
     361        MPI_Barrier_local(comm); 
     362 
     363        if(my_rank == 0) 
     364        { 
     365          #pragma omp flush 
     366          #pragma omp critical (read_from_buffer) 
     367          { 
     368            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
     369          } 
     370        } 
     371 
     372        MPI_Barrier_local(comm); 
     373      } 
     374    } 
     375  } 
     376 
     377  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     378  { 
     379    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     380    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     381 
     382    float *buffer = comm.my_buffer->buf_float; 
     383    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
     384    float *recv_buf = static_cast<float*>(recvbuf); 
     385 
     386    if(my_rank == 0) 
     387    { 
     388      assert(count == recvcounts[0]); 
     389      copy(send_buf, send_buf+count, recv_buf + displs[0]); 
     390    } 
     391 
     392    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
     393    { 
     394      for(int k=1; k<num_ep; k++) 
     395      { 
     396        if(my_rank == k) 
     397        { 
     398          #pragma omp critical (write_to_buffer) 
     399          { 
     400            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
     401            #pragma omp flush 
     402          } 
     403        } 
     404 
     405        MPI_Barrier_local(comm); 
     406 
     407        if(my_rank == 0) 
     408        { 
     409          #pragma omp flush 
     410          #pragma omp critical (read_from_buffer) 
     411          { 
     412            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
     413          } 
     414        } 
     415 
     416        MPI_Barrier_local(comm); 
     417      } 
     418    } 
     419  } 
     420 
     421  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     422  { 
     423    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     424    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     425 
     426    double *buffer = comm.my_buffer->buf_double; 
     427    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
     428    double *recv_buf = static_cast<double*>(recvbuf); 
     429 
     430    if(my_rank == 0) 
     431    { 
     432      assert(count == recvcounts[0]); 
     433      copy(send_buf, send_buf+count, recv_buf + displs[0]); 
     434    } 
     435 
     436    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
     437    { 
     438      for(int k=1; k<num_ep; k++) 
     439      { 
     440        if(my_rank == k) 
     441        { 
     442          #pragma omp critical (write_to_buffer) 
     443          { 
     444            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
     445            #pragma omp flush 
     446          } 
     447        } 
     448 
     449        MPI_Barrier_local(comm); 
     450 
     451        if(my_rank == 0) 
     452        { 
     453          #pragma omp flush 
     454          #pragma omp critical (read_from_buffer) 
     455          { 
     456            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
     457          } 
     458        } 
     459 
     460        MPI_Barrier_local(comm); 
     461      } 
     462    } 
     463  } 
     464 
     465  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     466  { 
     467    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     468    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     469 
     470    long *buffer = comm.my_buffer->buf_long; 
     471    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
     472    long *recv_buf = static_cast<long*>(recvbuf); 
     473 
     474    if(my_rank == 0) 
     475    { 
     476      assert(count == recvcounts[0]); 
     477      copy(send_buf, send_buf+count, recv_buf + displs[0]); 
     478    } 
     479 
     480    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
     481    { 
     482      for(int k=1; k<num_ep; k++) 
     483      { 
     484        if(my_rank == k) 
     485        { 
     486          #pragma omp critical (write_to_buffer) 
     487          { 
     488            if(count!=0)copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
     489            #pragma omp flush 
     490          } 
     491        } 
     492 
     493        MPI_Barrier_local(comm); 
     494 
     495        if(my_rank == 0) 
     496        { 
     497          #pragma omp flush 
     498          #pragma omp critical (read_from_buffer) 
     499          { 
     500            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
     501          } 
     502        } 
     503 
     504        MPI_Barrier_local(comm); 
     505      } 
     506    } 
     507  } 
     508 
     509  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     510  { 
     511    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     512    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     513 
     514    unsigned long *buffer = comm.my_buffer->buf_ulong; 
     515    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
     516    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
     517 
     518    if(my_rank == 0) 
     519    { 
     520      assert(count == recvcounts[0]); 
     521      copy(send_buf, send_buf+count, recv_buf + displs[0]); 
     522    } 
     523 
     524    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
     525    { 
     526      for(int k=1; k<num_ep; k++) 
     527      { 
     528        if(my_rank == k) 
     529        { 
     530          #pragma omp critical (write_to_buffer) 
     531          { 
     532            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
     533            #pragma omp flush 
     534          } 
     535        } 
     536 
     537        MPI_Barrier_local(comm); 
     538 
     539        if(my_rank == 0) 
     540        { 
     541          #pragma omp flush 
     542          #pragma omp critical (read_from_buffer) 
     543          { 
     544            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
     545          } 
     546        } 
     547 
     548        MPI_Barrier_local(comm); 
     549      } 
     550    } 
     551  } 
     552 
     553  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
     554  { 
     555    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
     556    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     557 
     558    char *buffer = comm.my_buffer->buf_char; 
     559    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
     560    char *recv_buf = static_cast<char*>(recvbuf); 
     561 
     562    if(my_rank == 0) 
     563    { 
     564      assert(count == recvcounts[0]); 
     565      copy(send_buf, send_buf+count, recv_buf + displs[0]); 
     566    } 
     567 
     568    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
     569    { 
     570      for(int k=1; k<num_ep; k++) 
     571      { 
     572        if(my_rank == k) 
     573        { 
     574          #pragma omp critical (write_to_buffer) 
     575          { 
     576            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
     577            #pragma omp flush 
     578          } 
     579        } 
     580 
     581        MPI_Barrier_local(comm); 
     582 
     583        if(my_rank == 0) 
     584        { 
     585          #pragma omp flush 
     586          #pragma omp critical (read_from_buffer) 
     587          { 
     588            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
     589          } 
     590        } 
     591 
     592        MPI_Barrier_local(comm); 
     593      } 
     594    } 
     595  } 
     596 
     597 
     598  int MPI_Gatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     599                  MPI_Datatype recvtype, int root, MPI_Comm comm) 
     600  { 
     601   
     602    if(!comm.is_ep && comm.mpi_comm) 
     603    { 
     604      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs), 
     605                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     606      return 0; 
     607    } 
     608 
     609    if(!comm.mpi_comm) return 0; 
     610 
     611    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 
     612 
     613    MPI_Datatype datatype = sendtype; 
     614    int count = sendcount; 
     615 
     616    int ep_rank, ep_rank_loc, mpi_rank; 
     617    int ep_size, num_ep, mpi_size; 
     618 
     619    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     620    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     621    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     622    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     623    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     624    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     625     
     626     
     627     
     628    if(ep_size == mpi_size)  
     629      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 
     630                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     631 
     632    if(ep_rank != root) 
     633    { 
     634      recvcounts = new int[ep_size]; 
     635      displs = new int[ep_size]; 
     636    } 
     637     
     638    MPI_Bcast(const_cast< int* >(displs),     ep_size, MPI_INT, root, comm); 
     639    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm); 
     640                               
     641 
     642    int recv_plus_displs[ep_size]; 
     643    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
     644 
     645    for(int j=0; j<mpi_size; j++) 
     646    { 
     647      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 
     648         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])   
     649      {   
     650        Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n"); 
     651        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     652      } 
     653 
     654      for(int i=1; i<num_ep-1; i++) 
     655      { 
     656        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||  
     657           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 
     658        { 
     659          Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n"); 
     660          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     661        } 
     662      } 
     663    } 
     664 
     665 
     666    int root_mpi_rank = comm.rank_map->at(root).second; 
     667    int root_ep_loc = comm.rank_map->at(root).first; 
     668 
     669 
     670    ::MPI_Aint datasize, lb; 
     671 
     672    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
     673 
     674    void *local_gather_recvbuf; 
     675    int buffer_size; 
     676    void *master_recvbuf; 
     677 
     678    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0)  
     679    { 
     680      master_recvbuf = new void*[sizeof(recvbuf)]; 
     681      assert(root_ep_loc == 0); 
     682    } 
     683 
     684    if(ep_rank_loc==0) 
     685    { 
     686      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 
     687 
     688      local_gather_recvbuf = new void*[datasize*buffer_size]; 
     689    } 
     690 
     691    MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 
     692 
     693    //MPI_Gather 
     694    if(ep_rank_loc == 0) 
     695    { 
     696      int *mpi_recvcnt= new int[mpi_size]; 
     697      int *mpi_displs= new int[mpi_size]; 
     698 
     699      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 
     700      int buff_end = buffer_size; 
     701 
     702      int mpi_sendcnt = buff_end - buff_start; 
     703 
     704 
     705      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     706      ::MPI_Gather(&buff_start,  1, MPI_INT, mpi_displs,  1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     707 
     708      if(root_ep_loc == 0) 
     709      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 
     710                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     711      } 
     712      else  // gatherv to master_recvbuf 
     713      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt, 
     714                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     715      } 
     716 
     717      delete[] mpi_recvcnt; 
     718      delete[] mpi_displs; 
     719    } 
     720 
     721    int global_min_displs = *std::min_element(displs, displs+ep_size); 
     722    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 
     723 
     724 
     725    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
     726    { 
     727      innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 
     728      if(ep_rank_loc == 0) delete[] master_recvbuf; 
     729    } 
     730 
     731 
     732 
     733    if(ep_rank_loc==0) 
     734    { 
     735      if(datatype == MPI_INT) 
     736      { 
     737        delete[] static_cast<int*>(local_gather_recvbuf); 
     738      } 
     739      else if(datatype == MPI_FLOAT) 
     740      { 
     741        delete[] static_cast<float*>(local_gather_recvbuf); 
     742      } 
     743      else if(datatype == MPI_DOUBLE) 
     744      { 
     745        delete[] static_cast<double*>(local_gather_recvbuf); 
     746      } 
     747      else if(datatype == MPI_LONG) 
     748      { 
     749        delete[] static_cast<long*>(local_gather_recvbuf); 
     750      } 
     751      else if(datatype == MPI_UNSIGNED_LONG) 
     752      { 
     753        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
     754      } 
     755      else // if(datatype == MPI_CHAR) 
     756      { 
     757        delete[] static_cast<char*>(local_gather_recvbuf); 
     758      } 
     759    } 
     760    else 
     761    { 
     762      delete[] recvcounts; 
     763      delete[] displs; 
     764    } 
     765    return 0; 
     766  } 
     767 
     768 
     769 
     770  int MPI_Allgatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     771                  MPI_Datatype recvtype, MPI_Comm comm) 
     772  { 
     773 
     774    if(!comm.is_ep && comm.mpi_comm) 
     775    { 
     776      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs, 
     777                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     778      return 0; 
     779    } 
     780 
     781    if(!comm.mpi_comm) return 0; 
     782 
     783    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 
     784 
     785 
     786    MPI_Datatype datatype = sendtype; 
     787    int count = sendcount; 
     788 
     789    int ep_rank, ep_rank_loc, mpi_rank; 
     790    int ep_size, num_ep, mpi_size; 
     791 
     792    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     793    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     794    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     795    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     796    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     797    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     798     
     799    if(ep_size == mpi_size)  // needed by servers 
     800      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 
     801                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     802 
     803    int recv_plus_displs[ep_size]; 
     804    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
     805 
     806 
     807    for(int j=0; j<mpi_size; j++) 
     808    { 
     809      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 
     810         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])   
     811      {   
     812        printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size); 
     813        for(int k=0; k<ep_size; k++) 
     814          printf("recv_plus_displs[%d] = %d\t displs[%d] = %d\n", k, recv_plus_displs[k], k, displs[k]); 
     815 
     816        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     817      } 
     818 
     819      for(int i=1; i<num_ep-1; i++) 
     820      { 
     821        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||  
     822           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 
     823        { 
     824          printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size); 
     825          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
     826        } 
     827      } 
     828    } 
     829 
     830    ::MPI_Aint datasize, lb; 
     831 
     832    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
     833 
     834    void *local_gather_recvbuf; 
     835    int buffer_size; 
     836 
     837    if(ep_rank_loc==0) 
     838    { 
     839      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 
     840 
     841      local_gather_recvbuf = new void*[datasize*buffer_size]; 
     842    } 
     843 
     844    // local gather to master 
     845    MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 
     846 
     847    //MPI_Gather 
     848    if(ep_rank_loc == 0) 
     849    { 
     850      int *mpi_recvcnt= new int[mpi_size]; 
     851      int *mpi_displs= new int[mpi_size]; 
     852 
     853      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 
     854      int buff_end = buffer_size; 
     855 
     856      int mpi_sendcnt = buff_end - buff_start; 
     857 
     858 
     859      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     860      ::MPI_Allgather(&buff_start,  1, MPI_INT, mpi_displs,  1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     861 
     862 
     863      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 
     864                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     865 
     866      delete[] mpi_recvcnt; 
     867      delete[] mpi_displs; 
     868    } 
     869 
     870    int global_min_displs = *std::min_element(displs, displs+ep_size); 
     871    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 
     872 
     873    MPI_Bcast_local2(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 
     874 
     875    if(ep_rank_loc==0) 
     876    { 
     877      if(datatype == MPI_INT) 
     878      { 
     879        delete[] static_cast<int*>(local_gather_recvbuf); 
     880      } 
     881      else if(datatype == MPI_FLOAT) 
     882      { 
     883        delete[] static_cast<float*>(local_gather_recvbuf); 
     884      } 
     885      else if(datatype == MPI_DOUBLE) 
     886      { 
     887        delete[] static_cast<double*>(local_gather_recvbuf); 
     888      } 
     889      else if(datatype == MPI_LONG) 
     890      { 
     891        delete[] static_cast<long*>(local_gather_recvbuf); 
     892      } 
     893      else if(datatype == MPI_UNSIGNED_LONG) 
     894      { 
     895        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
     896      } 
     897      else // if(datatype == MPI_CHAR) 
     898      { 
     899        delete[] static_cast<char*>(local_gather_recvbuf); 
     900      } 
     901    } 
     902  } 
     903 
     904  int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     905                          MPI_Datatype recvtype, int root, MPI_Comm comm) 
     906  { 
     907    int ep_rank, ep_rank_loc, mpi_rank; 
     908    int ep_size, num_ep, mpi_size; 
     909 
     910    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     911    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     912    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     913    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     914    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     915    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     916 
     917    int root_mpi_rank = comm.rank_map->at(root).second; 
     918    int root_ep_loc = comm.rank_map->at(root).first; 
     919 
     920    ::MPI_Aint datasize, lb; 
     921    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 
     922 
     923    void *local_gather_recvbuf; 
     924    int buffer_size; 
     925 
     926    int *local_displs = new int[num_ep]; 
     927    int *local_rvcnts = new int[num_ep]; 
     928    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 
     929    local_displs[0] = 0; 
     930    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 
     931 
     932    if(ep_rank_loc==0) 
     933    { 
     934      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 
     935      local_gather_recvbuf = new void*[datasize*buffer_size]; 
     936    } 
     937 
     938    // local gather to master 
     939    MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 
     940 
     941    int **mpi_recvcnts = new int*[num_ep]; 
     942    int **mpi_displs   = new int*[num_ep]; 
     943    for(int i=0; i<num_ep; i++)  
     944    { 
     945      mpi_recvcnts[i] = new int[mpi_size]; 
     946      mpi_displs[i]   = new int[mpi_size]; 
     947      for(int j=0; j<mpi_size; j++) 
     948      { 
     949        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 
     950        mpi_displs[i][j]   = displs[j*num_ep + i]; 
     951      } 
     952    }  
     953 
     954    void *master_recvbuf; 
     955    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)]; 
     956 
     957    if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop 
     958      for(int i=0; i<num_ep; i++) 
     959      { 
     960        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 
     961                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     962      } 
     963    if(ep_rank_loc == 0 && root_ep_loc != 0) 
     964      for(int i=0; i<num_ep; i++) 
     965      { 
     966        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i], 
     967                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     968      } 
     969 
     970 
     971    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
     972    { 
     973      for(int i=0; i<ep_size; i++) 
     974        innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm); 
     975 
     976      if(ep_rank_loc == 0) delete[] master_recvbuf; 
     977    } 
     978 
     979     
     980    delete[] local_displs; 
     981    delete[] local_rvcnts; 
     982    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];  
     983                                  delete[] mpi_displs[i]; } 
     984    delete[] mpi_recvcnts; 
     985    delete[] mpi_displs; 
     986    if(ep_rank_loc==0) 
     987    { 
     988      if(sendtype == MPI_INT) 
     989      { 
     990        delete[] static_cast<int*>(local_gather_recvbuf); 
     991      } 
     992      else if(sendtype == MPI_FLOAT) 
     993      { 
     994        delete[] static_cast<float*>(local_gather_recvbuf); 
     995      } 
     996      else if(sendtype == MPI_DOUBLE) 
     997      { 
     998        delete[] static_cast<double*>(local_gather_recvbuf); 
     999      } 
     1000      else if(sendtype == MPI_LONG) 
     1001      { 
     1002        delete[] static_cast<long*>(local_gather_recvbuf); 
     1003      } 
     1004      else if(sendtype == MPI_UNSIGNED_LONG) 
     1005      { 
     1006        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
     1007      } 
     1008      else // if(sendtype == MPI_CHAR) 
     1009      { 
     1010        delete[] static_cast<char*>(local_gather_recvbuf); 
     1011      } 
     1012    } 
     1013  } 
     1014 
     1015  int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
     1016                             MPI_Datatype recvtype, MPI_Comm comm) 
     1017  { 
     1018    int ep_rank, ep_rank_loc, mpi_rank; 
     1019    int ep_size, num_ep, mpi_size; 
     1020 
     1021    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     1022    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     1023    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     1024    ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     1025    num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     1026    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
     1027 
     1028 
     1029    ::MPI_Aint datasize, lb; 
     1030    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 
     1031 
     1032    void *local_gather_recvbuf; 
     1033    int buffer_size; 
     1034 
     1035    int *local_displs = new int[num_ep]; 
     1036    int *local_rvcnts = new int[num_ep]; 
     1037    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 
     1038    local_displs[0] = 0; 
     1039    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 
     1040 
     1041    if(ep_rank_loc==0) 
     1042    { 
     1043      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 
     1044      local_gather_recvbuf = new void*[datasize*buffer_size]; 
     1045    } 
     1046 
     1047    // local gather to master 
     1048    MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 
     1049 
     1050    int **mpi_recvcnts = new int*[num_ep]; 
     1051    int **mpi_displs   = new int*[num_ep]; 
     1052    for(int i=0; i<num_ep; i++)  
     1053    { 
     1054      mpi_recvcnts[i] = new int[mpi_size]; 
     1055      mpi_displs[i]   = new int[mpi_size]; 
     1056      for(int j=0; j<mpi_size; j++) 
     1057      { 
     1058        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 
     1059        mpi_displs[i][j]   = displs[j*num_ep + i]; 
     1060      } 
     1061    }  
     1062 
     1063    if(ep_rank_loc == 0) // master in MPI_Allgatherv loop 
     1064    for(int i=0; i<num_ep; i++) 
     1065    { 
     1066      ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 
     1067                  static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
     1068    } 
     1069 
     1070    for(int i=0; i<ep_size; i++) 
     1071      MPI_Bcast_local2(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm); 
     1072 
     1073     
     1074    delete[] local_displs; 
     1075    delete[] local_rvcnts; 
     1076    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];  
     1077                                  delete[] mpi_displs[i]; } 
     1078    delete[] mpi_recvcnts; 
     1079    delete[] mpi_displs; 
     1080    if(ep_rank_loc==0) 
     1081    { 
     1082      if(sendtype == MPI_INT) 
     1083      { 
     1084        delete[] static_cast<int*>(local_gather_recvbuf); 
     1085      } 
     1086      else if(sendtype == MPI_FLOAT) 
     1087      { 
     1088        delete[] static_cast<float*>(local_gather_recvbuf); 
     1089      } 
     1090      else if(sendtype == MPI_DOUBLE) 
     1091      { 
     1092        delete[] static_cast<double*>(local_gather_recvbuf); 
     1093      } 
     1094      else if(sendtype == MPI_LONG) 
     1095      { 
     1096        delete[] static_cast<long*>(local_gather_recvbuf); 
     1097      } 
     1098      else if(sendtype == MPI_UNSIGNED_LONG) 
     1099      { 
     1100        delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
     1101      } 
     1102      else // if(sendtype == MPI_CHAR) 
     1103      { 
     1104        delete[] static_cast<char*>(local_gather_recvbuf); 
     1105      } 
     1106    } 
     1107  } 
     1108 
     1109 
    1881110} 
Note: See TracChangeset for help on using the changeset viewer.