Ignore:
Timestamp:
10/06/17 13:56:33 (7 years ago)
Author:
yushan
Message:

EP update all

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp

    r1289 r1295  
    1515namespace ep_lib 
    1616{ 
    17    int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 
     17 
     18  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 
    1819  { 
    1920    assert(valid_type(datatype)); 
     
    132133      } 
    133134 
    134  
    135  
    136135      for(int i=1; i<mpi_size; i++) 
    137136        mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1]; 
    138  
    139137 
    140138 
     
    146144    if(is_root) 
    147145    { 
    148       // printf("tmp_recvbuf =\n"); 
    149       // for(int i=0; i<ep_size*sendcount; i++) printf("%d\t", static_cast<int*>(tmp_recvbuf)[i]); 
    150       // printf("\n"); 
    151  
    152146      int offset; 
    153147      for(int i=0; i<ep_size; i++) 
     
    164158 
    165159        memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize); 
    166  
    167         //printf("recvbuf[%d] = tmp_recvbuf[%d] \n", i, offset); 
    168160         
    169161      } 
    170  
    171       // printf("recvbuf =\n"); 
    172       // for(int i=0; i<ep_size*sendcount; i++) printf("%d\t", static_cast<int*>(recvbuf)[i]); 
    173       // printf("\n"); 
    174162 
    175163    } 
     
    185173  } 
    186174 
    187   // int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm) 
    188   // { 
    189  
    190   //   if(!comm.is_ep && comm.mpi_comm) 
    191   //   { 
    192   //     ::MPI_Allgatherv(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcounts, displs, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm)); 
    193   //     return 0; 
    194   //   } 
    195  
    196   //   if(!comm.mpi_comm) return 0; 
    197  
    198  
    199  
    200  
    201   //   assert(valid_type(sendtype) && valid_type(recvtype)); 
    202  
    203   //   MPI_Datatype datatype = sendtype; 
    204   //   int count = sendcount; 
    205  
    206   //   ::MPI_Aint datasize, lb; 
    207  
    208   //   ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
    209  
    210  
    211   //   int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    212   //   int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    213   //   int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    214   //   int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    215   //   int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    216   //   int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    217  
    218  
    219   //   assert(sendcount == recvcounts[ep_rank]); 
    220  
    221   //   bool is_master = ep_rank_loc==0; 
    222  
    223   //   void* local_recvbuf; 
    224   //   void* tmp_recvbuf; 
    225  
    226   //   int recvbuf_size = 0; 
    227   //   for(int i=0; i<ep_size; i++) 
    228   //     recvbuf_size = max(recvbuf_size, displs[i]+recvcounts[i]); 
    229  
    230  
    231   //   vector<int>local_recvcounts(num_ep, 0); 
    232   //   vector<int>local_displs(num_ep, 0); 
    233  
    234   //   MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), 0, comm); 
    235   //   for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_recvcounts[i-1];  
    236  
    237  
    238   //   if(is_master) 
    239   //   { 
    240   //     local_recvbuf = new void*[datasize * std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0)]; 
    241   //     tmp_recvbuf = new void*[datasize * std::accumulate(recvcounts, recvcounts+ep_size, 0)]; 
    242   //   } 
    243  
    244   //   MPI_Gatherv_local(sendbuf, count, datatype, local_recvbuf, local_recvcounts.data(), local_displs.data(), 0, comm); 
    245  
    246  
    247   //   if(is_master) 
    248   //   { 
    249   //     std::vector<int>mpi_recvcounts(mpi_size, 0); 
    250   //     std::vector<int>mpi_displs(mpi_size, 0); 
    251  
    252   //     int local_sendcount = std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0); 
    253   //     MPI_Allgather(&local_sendcount, 1, MPI_INT, mpi_recvcounts.data(), 1, MPI_INT, to_mpi_comm(comm.mpi_comm)); 
    254  
    255   //     for(int i=1; i<mpi_size; i++) 
    256   //       mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1]; 
    257  
    258  
    259   //     ::MPI_Allgatherv(local_recvbuf, local_sendcount, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts.data(), mpi_displs.data(), to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm)); 
    260  
    261  
    262  
    263   //     // reorder  
    264   //     int offset; 
    265   //     for(int i=0; i<ep_size; i++) 
    266   //     { 
    267   //       int extra = 0; 
    268   //       for(int j=0, k=0; j<ep_size, k<comm.rank_map->at(i).first; j++) 
    269   //         if(comm.rank_map->at(i).second == comm.rank_map->at(j).second) 
    270   //         { 
    271   //           extra += recvcounts[j]; 
    272   //           k++; 
    273   //         }   
    274  
    275   //       offset = mpi_displs[comm.rank_map->at(i).second] +  extra; 
    276  
    277   //       memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize); 
    278          
    279   //     } 
    280  
    281   //   } 
    282  
    283   //   MPI_Bcast_local(recvbuf, recvbuf_size, datatype, 0, comm); 
    284  
    285   //   if(is_master) 
    286   //   { 
    287   //     delete[] local_recvbuf; 
    288   //     delete[] tmp_recvbuf; 
    289   //   } 
    290  
    291   // } 
    292  
    293  
    294   int MPI_Gatherv_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    295   { 
    296     if(datatype == MPI_INT) 
    297     { 
    298       Debug("datatype is INT\n"); 
    299       return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm); 
    300     } 
    301     else if(datatype == MPI_FLOAT) 
    302     { 
    303       Debug("datatype is FLOAT\n"); 
    304       return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm); 
    305     } 
    306     else if(datatype == MPI_DOUBLE) 
    307     { 
    308       Debug("datatype is DOUBLE\n"); 
    309       return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm); 
    310     } 
    311     else if(datatype == MPI_LONG) 
    312     { 
    313       Debug("datatype is LONG\n"); 
    314       return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm); 
    315     } 
    316     else if(datatype == MPI_UNSIGNED_LONG) 
    317     { 
    318       Debug("datatype is uLONG\n"); 
    319       return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm); 
    320     } 
    321     else if(datatype == MPI_CHAR) 
    322     { 
    323       Debug("datatype is CHAR\n"); 
    324       return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm); 
    325     } 
    326     else 
    327     { 
    328       printf("MPI_Gatherv Datatype not supported!\n"); 
    329       exit(0); 
    330     } 
    331   } 
    332  
    333   int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    334   { 
    335     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    336     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    337  
    338     int *buffer = comm.my_buffer->buf_int; 
    339     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    340     int *recv_buf = static_cast<int*>(recvbuf); 
    341  
    342     if(my_rank == 0) 
    343     { 
    344       assert(count == recvcounts[0]); 
    345       copy(send_buf, send_buf+count, recv_buf + displs[0]); 
    346     } 
    347  
    348     for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
    349     { 
    350       for(int k=1; k<num_ep; k++) 
    351       { 
    352         if(my_rank == k) 
    353         { 
    354           #pragma omp critical (write_to_buffer) 
    355           { 
    356             if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
    357             #pragma omp flush 
    358           } 
    359         } 
    360  
    361         MPI_Barrier_local(comm); 
    362  
    363         if(my_rank == 0) 
    364         { 
    365           #pragma omp flush 
    366           #pragma omp critical (read_from_buffer) 
    367           { 
    368             copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
    369           } 
    370         } 
    371  
    372         MPI_Barrier_local(comm); 
    373       } 
    374     } 
    375   } 
    376  
    377   int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    378   { 
    379     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    380     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    381  
    382     float *buffer = comm.my_buffer->buf_float; 
    383     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    384     float *recv_buf = static_cast<float*>(recvbuf); 
    385  
    386     if(my_rank == 0) 
    387     { 
    388       assert(count == recvcounts[0]); 
    389       copy(send_buf, send_buf+count, recv_buf + displs[0]); 
    390     } 
    391  
    392     for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
    393     { 
    394       for(int k=1; k<num_ep; k++) 
    395       { 
    396         if(my_rank == k) 
    397         { 
    398           #pragma omp critical (write_to_buffer) 
    399           { 
    400             if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
    401             #pragma omp flush 
    402           } 
    403         } 
    404  
    405         MPI_Barrier_local(comm); 
    406  
    407         if(my_rank == 0) 
    408         { 
    409           #pragma omp flush 
    410           #pragma omp critical (read_from_buffer) 
    411           { 
    412             copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
    413           } 
    414         } 
    415  
    416         MPI_Barrier_local(comm); 
    417       } 
    418     } 
    419   } 
    420  
    421   int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    422   { 
    423     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    424     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    425  
    426     double *buffer = comm.my_buffer->buf_double; 
    427     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    428     double *recv_buf = static_cast<double*>(recvbuf); 
    429  
    430     if(my_rank == 0) 
    431     { 
    432       assert(count == recvcounts[0]); 
    433       copy(send_buf, send_buf+count, recv_buf + displs[0]); 
    434     } 
    435  
    436     for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
    437     { 
    438       for(int k=1; k<num_ep; k++) 
    439       { 
    440         if(my_rank == k) 
    441         { 
    442           #pragma omp critical (write_to_buffer) 
    443           { 
    444             if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
    445             #pragma omp flush 
    446           } 
    447         } 
    448  
    449         MPI_Barrier_local(comm); 
    450  
    451         if(my_rank == 0) 
    452         { 
    453           #pragma omp flush 
    454           #pragma omp critical (read_from_buffer) 
    455           { 
    456             copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
    457           } 
    458         } 
    459  
    460         MPI_Barrier_local(comm); 
    461       } 
    462     } 
    463   } 
    464  
    465   int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    466   { 
    467     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    468     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    469  
    470     long *buffer = comm.my_buffer->buf_long; 
    471     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    472     long *recv_buf = static_cast<long*>(recvbuf); 
    473  
    474     if(my_rank == 0) 
    475     { 
    476       assert(count == recvcounts[0]); 
    477       copy(send_buf, send_buf+count, recv_buf + displs[0]); 
    478     } 
    479  
    480     for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
    481     { 
    482       for(int k=1; k<num_ep; k++) 
    483       { 
    484         if(my_rank == k) 
    485         { 
    486           #pragma omp critical (write_to_buffer) 
    487           { 
    488             if(count!=0)copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
    489             #pragma omp flush 
    490           } 
    491         } 
    492  
    493         MPI_Barrier_local(comm); 
    494  
    495         if(my_rank == 0) 
    496         { 
    497           #pragma omp flush 
    498           #pragma omp critical (read_from_buffer) 
    499           { 
    500             copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
    501           } 
    502         } 
    503  
    504         MPI_Barrier_local(comm); 
    505       } 
    506     } 
    507   } 
    508  
    509   int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    510   { 
    511     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    512     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    513  
    514     unsigned long *buffer = comm.my_buffer->buf_ulong; 
    515     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    516     unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
    517  
    518     if(my_rank == 0) 
    519     { 
    520       assert(count == recvcounts[0]); 
    521       copy(send_buf, send_buf+count, recv_buf + displs[0]); 
    522     } 
    523  
    524     for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
    525     { 
    526       for(int k=1; k<num_ep; k++) 
    527       { 
    528         if(my_rank == k) 
    529         { 
    530           #pragma omp critical (write_to_buffer) 
    531           { 
    532             if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
    533             #pragma omp flush 
    534           } 
    535         } 
    536  
    537         MPI_Barrier_local(comm); 
    538  
    539         if(my_rank == 0) 
    540         { 
    541           #pragma omp flush 
    542           #pragma omp critical (read_from_buffer) 
    543           { 
    544             copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
    545           } 
    546         } 
    547  
    548         MPI_Barrier_local(comm); 
    549       } 
    550     } 
    551   } 
    552  
    553   int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 
    554   { 
    555     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    556     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    557  
    558     char *buffer = comm.my_buffer->buf_char; 
    559     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    560     char *recv_buf = static_cast<char*>(recvbuf); 
    561  
    562     if(my_rank == 0) 
    563     { 
    564       assert(count == recvcounts[0]); 
    565       copy(send_buf, send_buf+count, recv_buf + displs[0]); 
    566     } 
    567  
    568     for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 
    569     { 
    570       for(int k=1; k<num_ep; k++) 
    571       { 
    572         if(my_rank == k) 
    573         { 
    574           #pragma omp critical (write_to_buffer) 
    575           { 
    576             if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 
    577             #pragma omp flush 
    578           } 
    579         } 
    580  
    581         MPI_Barrier_local(comm); 
    582  
    583         if(my_rank == 0) 
    584         { 
    585           #pragma omp flush 
    586           #pragma omp critical (read_from_buffer) 
    587           { 
    588             copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 
    589           } 
    590         } 
    591  
    592         MPI_Barrier_local(comm); 
    593       } 
    594     } 
    595   } 
    596  
    597  
    598   int MPI_Gatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
    599                   MPI_Datatype recvtype, int root, MPI_Comm comm) 
    600   { 
    601    
    602     if(!comm.is_ep && comm.mpi_comm) 
    603     { 
    604       ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs), 
    605                     static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    606       return 0; 
    607     } 
    608  
    609     if(!comm.mpi_comm) return 0; 
    610  
    611     assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 
    612  
    613     MPI_Datatype datatype = sendtype; 
    614     int count = sendcount; 
    615  
    616     int ep_rank, ep_rank_loc, mpi_rank; 
    617     int ep_size, num_ep, mpi_size; 
    618  
    619     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    620     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    621     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    622     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    623     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    624     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    625      
    626      
    627      
    628     if(ep_size == mpi_size)  
    629       return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 
    630                               static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    631  
    632     if(ep_rank != root) 
    633     { 
    634       recvcounts = new int[ep_size]; 
    635       displs = new int[ep_size]; 
    636     } 
    637      
    638     MPI_Bcast(const_cast< int* >(displs),     ep_size, MPI_INT, root, comm); 
    639     MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm); 
    640                                
    641  
    642     int recv_plus_displs[ep_size]; 
    643     for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
    644  
    645     for(int j=0; j<mpi_size; j++) 
    646     { 
    647       if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 
    648          recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])   
    649       {   
    650         Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n"); 
    651         return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
    652       } 
    653  
    654       for(int i=1; i<num_ep-1; i++) 
    655       { 
    656         if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||  
    657            recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 
    658         { 
    659           Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n"); 
    660           return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
    661         } 
    662       } 
    663     } 
    664  
    665  
    666     int root_mpi_rank = comm.rank_map->at(root).second; 
    667     int root_ep_loc = comm.rank_map->at(root).first; 
    668  
    669  
    670     ::MPI_Aint datasize, lb; 
    671  
    672     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    673  
    674     void *local_gather_recvbuf; 
    675     int buffer_size; 
    676     void *master_recvbuf; 
    677  
    678     if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0)  
    679     { 
    680       master_recvbuf = new void*[sizeof(recvbuf)]; 
    681       assert(root_ep_loc == 0); 
    682     } 
    683  
    684     if(ep_rank_loc==0) 
    685     { 
    686       buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 
    687  
    688       local_gather_recvbuf = new void*[datasize*buffer_size]; 
    689     } 
    690  
    691     MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 
    692  
    693     //MPI_Gather 
    694     if(ep_rank_loc == 0) 
    695     { 
    696       int *mpi_recvcnt= new int[mpi_size]; 
    697       int *mpi_displs= new int[mpi_size]; 
    698  
    699       int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 
    700       int buff_end = buffer_size; 
    701  
    702       int mpi_sendcnt = buff_end - buff_start; 
    703  
    704  
    705       ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    706       ::MPI_Gather(&buff_start,  1, MPI_INT, mpi_displs,  1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    707  
    708       if(root_ep_loc == 0) 
    709       {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 
    710                        mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    711       } 
    712       else  // gatherv to master_recvbuf 
    713       {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt, 
    714                        mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    715       } 
    716  
    717       delete[] mpi_recvcnt; 
    718       delete[] mpi_displs; 
    719     } 
    720  
    721     int global_min_displs = *std::min_element(displs, displs+ep_size); 
    722     int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 
    723  
    724  
    725     if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
    726     { 
    727       innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 
    728       if(ep_rank_loc == 0) delete[] master_recvbuf; 
    729     } 
    730  
    731  
    732  
    733     if(ep_rank_loc==0) 
    734     { 
    735       if(datatype == MPI_INT) 
    736       { 
    737         delete[] static_cast<int*>(local_gather_recvbuf); 
    738       } 
    739       else if(datatype == MPI_FLOAT) 
    740       { 
    741         delete[] static_cast<float*>(local_gather_recvbuf); 
    742       } 
    743       else if(datatype == MPI_DOUBLE) 
    744       { 
    745         delete[] static_cast<double*>(local_gather_recvbuf); 
    746       } 
    747       else if(datatype == MPI_LONG) 
    748       { 
    749         delete[] static_cast<long*>(local_gather_recvbuf); 
    750       } 
    751       else if(datatype == MPI_UNSIGNED_LONG) 
    752       { 
    753         delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
    754       } 
    755       else // if(datatype == MPI_CHAR) 
    756       { 
    757         delete[] static_cast<char*>(local_gather_recvbuf); 
    758       } 
    759     } 
    760     else 
    761     { 
    762       delete[] recvcounts; 
    763       delete[] displs; 
    764     } 
    765     return 0; 
    766   } 
    767  
    768  
    769  
    770   int MPI_Allgatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
    771                   MPI_Datatype recvtype, MPI_Comm comm) 
    772   { 
    773  
    774     if(!comm.is_ep && comm.mpi_comm) 
    775     { 
    776       ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs, 
    777                        static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    778       return 0; 
    779     } 
    780  
    781     if(!comm.mpi_comm) return 0; 
    782  
    783     assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 
    784  
    785  
    786     MPI_Datatype datatype = sendtype; 
    787     int count = sendcount; 
    788  
    789     int ep_rank, ep_rank_loc, mpi_rank; 
    790     int ep_size, num_ep, mpi_size; 
    791  
    792     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    793     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    794     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    795     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    796     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    797     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    798      
    799     if(ep_size == mpi_size)  // needed by servers 
    800       return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 
    801                               static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    802  
    803     int recv_plus_displs[ep_size]; 
    804     for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 
    805  
    806  
    807     for(int j=0; j<mpi_size; j++) 
    808     { 
    809       if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 
    810          recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])   
    811       {   
    812         printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size); 
    813         for(int k=0; k<ep_size; k++) 
    814           printf("recv_plus_displs[%d] = %d\t displs[%d] = %d\n", k, recv_plus_displs[k], k, displs[k]); 
    815  
    816         return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
    817       } 
    818  
    819       for(int i=1; i<num_ep-1; i++) 
    820       { 
    821         if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||  
    822            recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 
    823         { 
    824           printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size); 
    825           return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 
    826         } 
    827       } 
    828     } 
    829  
    830     ::MPI_Aint datasize, lb; 
    831  
    832     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
    833  
    834     void *local_gather_recvbuf; 
    835     int buffer_size; 
    836  
    837     if(ep_rank_loc==0) 
    838     { 
    839       buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 
    840  
    841       local_gather_recvbuf = new void*[datasize*buffer_size]; 
    842     } 
    843  
    844     // local gather to master 
    845     MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 
    846  
    847     //MPI_Gather 
    848     if(ep_rank_loc == 0) 
    849     { 
    850       int *mpi_recvcnt= new int[mpi_size]; 
    851       int *mpi_displs= new int[mpi_size]; 
    852  
    853       int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 
    854       int buff_end = buffer_size; 
    855  
    856       int mpi_sendcnt = buff_end - buff_start; 
    857  
    858  
    859       ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    860       ::MPI_Allgather(&buff_start,  1, MPI_INT, mpi_displs,  1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    861  
    862  
    863       ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 
    864                        mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    865  
    866       delete[] mpi_recvcnt; 
    867       delete[] mpi_displs; 
    868     } 
    869  
    870     int global_min_displs = *std::min_element(displs, displs+ep_size); 
    871     int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 
    872  
    873     MPI_Bcast_local2(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 
    874  
    875     if(ep_rank_loc==0) 
    876     { 
    877       if(datatype == MPI_INT) 
    878       { 
    879         delete[] static_cast<int*>(local_gather_recvbuf); 
    880       } 
    881       else if(datatype == MPI_FLOAT) 
    882       { 
    883         delete[] static_cast<float*>(local_gather_recvbuf); 
    884       } 
    885       else if(datatype == MPI_DOUBLE) 
    886       { 
    887         delete[] static_cast<double*>(local_gather_recvbuf); 
    888       } 
    889       else if(datatype == MPI_LONG) 
    890       { 
    891         delete[] static_cast<long*>(local_gather_recvbuf); 
    892       } 
    893       else if(datatype == MPI_UNSIGNED_LONG) 
    894       { 
    895         delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
    896       } 
    897       else // if(datatype == MPI_CHAR) 
    898       { 
    899         delete[] static_cast<char*>(local_gather_recvbuf); 
    900       } 
    901     } 
    902   } 
    903  
    904   int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
    905                           MPI_Datatype recvtype, int root, MPI_Comm comm) 
    906   { 
    907     int ep_rank, ep_rank_loc, mpi_rank; 
    908     int ep_size, num_ep, mpi_size; 
    909  
    910     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    911     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    912     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    913     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    914     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    915     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    916  
    917     int root_mpi_rank = comm.rank_map->at(root).second; 
    918     int root_ep_loc = comm.rank_map->at(root).first; 
    919  
    920     ::MPI_Aint datasize, lb; 
    921     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 
    922  
    923     void *local_gather_recvbuf; 
    924     int buffer_size; 
    925  
    926     int *local_displs = new int[num_ep]; 
    927     int *local_rvcnts = new int[num_ep]; 
    928     for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 
    929     local_displs[0] = 0; 
    930     for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 
    931  
    932     if(ep_rank_loc==0) 
    933     { 
    934       buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 
    935       local_gather_recvbuf = new void*[datasize*buffer_size]; 
    936     } 
    937  
    938     // local gather to master 
    939     MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 
    940  
    941     int **mpi_recvcnts = new int*[num_ep]; 
    942     int **mpi_displs   = new int*[num_ep]; 
    943     for(int i=0; i<num_ep; i++)  
    944     { 
    945       mpi_recvcnts[i] = new int[mpi_size]; 
    946       mpi_displs[i]   = new int[mpi_size]; 
    947       for(int j=0; j<mpi_size; j++) 
    948       { 
    949         mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 
    950         mpi_displs[i][j]   = displs[j*num_ep + i]; 
    951       } 
    952     }  
    953  
    954     void *master_recvbuf; 
    955     if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)]; 
    956  
    957     if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop 
    958       for(int i=0; i<num_ep; i++) 
    959       { 
    960         ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 
    961                     static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    962       } 
    963     if(ep_rank_loc == 0 && root_ep_loc != 0) 
    964       for(int i=0; i<num_ep; i++) 
    965       { 
    966         ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i], 
    967                     static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    968       } 
    969  
    970  
    971     if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 
    972     { 
    973       for(int i=0; i<ep_size; i++) 
    974         innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm); 
    975  
    976       if(ep_rank_loc == 0) delete[] master_recvbuf; 
    977     } 
    978  
    979      
    980     delete[] local_displs; 
    981     delete[] local_rvcnts; 
    982     for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];  
    983                                   delete[] mpi_displs[i]; } 
    984     delete[] mpi_recvcnts; 
    985     delete[] mpi_displs; 
    986     if(ep_rank_loc==0) 
    987     { 
    988       if(sendtype == MPI_INT) 
    989       { 
    990         delete[] static_cast<int*>(local_gather_recvbuf); 
    991       } 
    992       else if(sendtype == MPI_FLOAT) 
    993       { 
    994         delete[] static_cast<float*>(local_gather_recvbuf); 
    995       } 
    996       else if(sendtype == MPI_DOUBLE) 
    997       { 
    998         delete[] static_cast<double*>(local_gather_recvbuf); 
    999       } 
    1000       else if(sendtype == MPI_LONG) 
    1001       { 
    1002         delete[] static_cast<long*>(local_gather_recvbuf); 
    1003       } 
    1004       else if(sendtype == MPI_UNSIGNED_LONG) 
    1005       { 
    1006         delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
    1007       } 
    1008       else // if(sendtype == MPI_CHAR) 
    1009       { 
    1010         delete[] static_cast<char*>(local_gather_recvbuf); 
    1011       } 
    1012     } 
    1013   } 
    1014  
    1015   int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 
    1016                              MPI_Datatype recvtype, MPI_Comm comm) 
    1017   { 
    1018     int ep_rank, ep_rank_loc, mpi_rank; 
    1019     int ep_size, num_ep, mpi_size; 
    1020  
    1021     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    1022     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    1023     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    1024     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    1025     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    1026     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    1027  
    1028  
    1029     ::MPI_Aint datasize, lb; 
    1030     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 
    1031  
    1032     void *local_gather_recvbuf; 
    1033     int buffer_size; 
    1034  
    1035     int *local_displs = new int[num_ep]; 
    1036     int *local_rvcnts = new int[num_ep]; 
    1037     for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 
    1038     local_displs[0] = 0; 
    1039     for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 
    1040  
    1041     if(ep_rank_loc==0) 
    1042     { 
    1043       buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 
    1044       local_gather_recvbuf = new void*[datasize*buffer_size]; 
    1045     } 
    1046  
    1047     // local gather to master 
    1048     MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 
    1049  
    1050     int **mpi_recvcnts = new int*[num_ep]; 
    1051     int **mpi_displs   = new int*[num_ep]; 
    1052     for(int i=0; i<num_ep; i++)  
    1053     { 
    1054       mpi_recvcnts[i] = new int[mpi_size]; 
    1055       mpi_displs[i]   = new int[mpi_size]; 
    1056       for(int j=0; j<mpi_size; j++) 
    1057       { 
    1058         mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 
    1059         mpi_displs[i][j]   = displs[j*num_ep + i]; 
    1060       } 
    1061     }  
    1062  
    1063     if(ep_rank_loc == 0) // master in MPI_Allgatherv loop 
    1064     for(int i=0; i<num_ep; i++) 
    1065     { 
    1066       ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 
    1067                   static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    1068     } 
    1069  
    1070     for(int i=0; i<ep_size; i++) 
    1071       MPI_Bcast_local2(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm); 
    1072  
    1073      
    1074     delete[] local_displs; 
    1075     delete[] local_rvcnts; 
    1076     for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];  
    1077                                   delete[] mpi_displs[i]; } 
    1078     delete[] mpi_recvcnts; 
    1079     delete[] mpi_displs; 
    1080     if(ep_rank_loc==0) 
    1081     { 
    1082       if(sendtype == MPI_INT) 
    1083       { 
    1084         delete[] static_cast<int*>(local_gather_recvbuf); 
    1085       } 
    1086       else if(sendtype == MPI_FLOAT) 
    1087       { 
    1088         delete[] static_cast<float*>(local_gather_recvbuf); 
    1089       } 
    1090       else if(sendtype == MPI_DOUBLE) 
    1091       { 
    1092         delete[] static_cast<double*>(local_gather_recvbuf); 
    1093       } 
    1094       else if(sendtype == MPI_LONG) 
    1095       { 
    1096         delete[] static_cast<long*>(local_gather_recvbuf); 
    1097       } 
    1098       else if(sendtype == MPI_UNSIGNED_LONG) 
    1099       { 
    1100         delete[] static_cast<unsigned long*>(local_gather_recvbuf); 
    1101       } 
    1102       else // if(sendtype == MPI_CHAR) 
    1103       { 
    1104         delete[] static_cast<char*>(local_gather_recvbuf); 
    1105       } 
    1106     } 
    1107   } 
    1108  
    1109  
    1110175} 
Note: See TracChangeset for help on using the changeset viewer.