Changeset 1295 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp
- Timestamp:
- 10/06/17 13:56:33 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp
r1289 r1295 15 15 namespace ep_lib 16 16 { 17 int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 17 18 int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 18 19 { 19 20 assert(valid_type(datatype)); … … 132 133 } 133 134 134 135 136 135 for(int i=1; i<mpi_size; i++) 137 136 mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1]; 138 139 137 140 138 … … 146 144 if(is_root) 147 145 { 148 // printf("tmp_recvbuf =\n");149 // for(int i=0; i<ep_size*sendcount; i++) printf("%d\t", static_cast<int*>(tmp_recvbuf)[i]);150 // printf("\n");151 152 146 int offset; 153 147 for(int i=0; i<ep_size; i++) … … 164 158 165 159 memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize); 166 167 //printf("recvbuf[%d] = tmp_recvbuf[%d] \n", i, offset);168 160 169 161 } 170 171 // printf("recvbuf =\n");172 // for(int i=0; i<ep_size*sendcount; i++) printf("%d\t", static_cast<int*>(recvbuf)[i]);173 // printf("\n");174 162 175 163 } … … 185 173 } 186 174 187 // int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm)188 // {189 190 // if(!comm.is_ep && comm.mpi_comm)191 // {192 // ::MPI_Allgatherv(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcounts, displs, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm));193 // return 0;194 // }195 196 // if(!comm.mpi_comm) return 0;197 198 199 200 201 // assert(valid_type(sendtype) && valid_type(recvtype));202 203 // MPI_Datatype datatype = sendtype;204 // int count = sendcount;205 206 // ::MPI_Aint datasize, lb;207 208 // ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);209 210 211 // int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;212 // int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;213 // int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;214 // int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;215 // int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;216 // int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;217 218 219 // assert(sendcount == recvcounts[ep_rank]);220 221 // bool is_master = ep_rank_loc==0;222 223 // void* local_recvbuf;224 // void* tmp_recvbuf;225 226 // int recvbuf_size = 0;227 // for(int i=0; i<ep_size; i++)228 // recvbuf_size = max(recvbuf_size, displs[i]+recvcounts[i]);229 230 231 // vector<int>local_recvcounts(num_ep, 0);232 // vector<int>local_displs(num_ep, 0);233 234 // MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), 0, comm);235 // for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_recvcounts[i-1];236 237 238 // if(is_master)239 // {240 // local_recvbuf = new void*[datasize * std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0)];241 // tmp_recvbuf = new void*[datasize * std::accumulate(recvcounts, recvcounts+ep_size, 0)];242 // }243 244 // MPI_Gatherv_local(sendbuf, count, datatype, local_recvbuf, local_recvcounts.data(), local_displs.data(), 0, comm);245 246 247 // if(is_master)248 // {249 // std::vector<int>mpi_recvcounts(mpi_size, 0);250 // std::vector<int>mpi_displs(mpi_size, 0);251 252 // int local_sendcount = std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0);253 // MPI_Allgather(&local_sendcount, 1, MPI_INT, mpi_recvcounts.data(), 1, MPI_INT, to_mpi_comm(comm.mpi_comm));254 255 // for(int i=1; i<mpi_size; i++)256 // mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1];257 258 259 // ::MPI_Allgatherv(local_recvbuf, local_sendcount, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts.data(), mpi_displs.data(), to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm));260 261 262 263 // // reorder264 // int offset;265 // for(int i=0; i<ep_size; i++)266 // {267 // int extra = 0;268 // for(int j=0, k=0; j<ep_size, k<comm.rank_map->at(i).first; j++)269 // if(comm.rank_map->at(i).second == comm.rank_map->at(j).second)270 // {271 // extra += recvcounts[j];272 // k++;273 // }274 275 // offset = mpi_displs[comm.rank_map->at(i).second] + extra;276 277 // memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize);278 279 // }280 281 // }282 283 // MPI_Bcast_local(recvbuf, recvbuf_size, datatype, 0, comm);284 285 // if(is_master)286 // {287 // delete[] local_recvbuf;288 // delete[] tmp_recvbuf;289 // }290 291 // }292 293 294 int MPI_Gatherv_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)295 {296 if(datatype == MPI_INT)297 {298 Debug("datatype is INT\n");299 return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);300 }301 else if(datatype == MPI_FLOAT)302 {303 Debug("datatype is FLOAT\n");304 return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);305 }306 else if(datatype == MPI_DOUBLE)307 {308 Debug("datatype is DOUBLE\n");309 return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);310 }311 else if(datatype == MPI_LONG)312 {313 Debug("datatype is LONG\n");314 return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);315 }316 else if(datatype == MPI_UNSIGNED_LONG)317 {318 Debug("datatype is uLONG\n");319 return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);320 }321 else if(datatype == MPI_CHAR)322 {323 Debug("datatype is CHAR\n");324 return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);325 }326 else327 {328 printf("MPI_Gatherv Datatype not supported!\n");329 exit(0);330 }331 }332 333 int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)334 {335 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;336 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;337 338 int *buffer = comm.my_buffer->buf_int;339 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));340 int *recv_buf = static_cast<int*>(recvbuf);341 342 if(my_rank == 0)343 {344 assert(count == recvcounts[0]);345 copy(send_buf, send_buf+count, recv_buf + displs[0]);346 }347 348 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)349 {350 for(int k=1; k<num_ep; k++)351 {352 if(my_rank == k)353 {354 #pragma omp critical (write_to_buffer)355 {356 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);357 #pragma omp flush358 }359 }360 361 MPI_Barrier_local(comm);362 363 if(my_rank == 0)364 {365 #pragma omp flush366 #pragma omp critical (read_from_buffer)367 {368 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);369 }370 }371 372 MPI_Barrier_local(comm);373 }374 }375 }376 377 int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)378 {379 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;380 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;381 382 float *buffer = comm.my_buffer->buf_float;383 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));384 float *recv_buf = static_cast<float*>(recvbuf);385 386 if(my_rank == 0)387 {388 assert(count == recvcounts[0]);389 copy(send_buf, send_buf+count, recv_buf + displs[0]);390 }391 392 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)393 {394 for(int k=1; k<num_ep; k++)395 {396 if(my_rank == k)397 {398 #pragma omp critical (write_to_buffer)399 {400 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);401 #pragma omp flush402 }403 }404 405 MPI_Barrier_local(comm);406 407 if(my_rank == 0)408 {409 #pragma omp flush410 #pragma omp critical (read_from_buffer)411 {412 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);413 }414 }415 416 MPI_Barrier_local(comm);417 }418 }419 }420 421 int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)422 {423 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;424 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;425 426 double *buffer = comm.my_buffer->buf_double;427 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));428 double *recv_buf = static_cast<double*>(recvbuf);429 430 if(my_rank == 0)431 {432 assert(count == recvcounts[0]);433 copy(send_buf, send_buf+count, recv_buf + displs[0]);434 }435 436 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)437 {438 for(int k=1; k<num_ep; k++)439 {440 if(my_rank == k)441 {442 #pragma omp critical (write_to_buffer)443 {444 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);445 #pragma omp flush446 }447 }448 449 MPI_Barrier_local(comm);450 451 if(my_rank == 0)452 {453 #pragma omp flush454 #pragma omp critical (read_from_buffer)455 {456 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);457 }458 }459 460 MPI_Barrier_local(comm);461 }462 }463 }464 465 int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)466 {467 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;468 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;469 470 long *buffer = comm.my_buffer->buf_long;471 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));472 long *recv_buf = static_cast<long*>(recvbuf);473 474 if(my_rank == 0)475 {476 assert(count == recvcounts[0]);477 copy(send_buf, send_buf+count, recv_buf + displs[0]);478 }479 480 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)481 {482 for(int k=1; k<num_ep; k++)483 {484 if(my_rank == k)485 {486 #pragma omp critical (write_to_buffer)487 {488 if(count!=0)copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);489 #pragma omp flush490 }491 }492 493 MPI_Barrier_local(comm);494 495 if(my_rank == 0)496 {497 #pragma omp flush498 #pragma omp critical (read_from_buffer)499 {500 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);501 }502 }503 504 MPI_Barrier_local(comm);505 }506 }507 }508 509 int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)510 {511 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;512 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;513 514 unsigned long *buffer = comm.my_buffer->buf_ulong;515 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));516 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);517 518 if(my_rank == 0)519 {520 assert(count == recvcounts[0]);521 copy(send_buf, send_buf+count, recv_buf + displs[0]);522 }523 524 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)525 {526 for(int k=1; k<num_ep; k++)527 {528 if(my_rank == k)529 {530 #pragma omp critical (write_to_buffer)531 {532 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);533 #pragma omp flush534 }535 }536 537 MPI_Barrier_local(comm);538 539 if(my_rank == 0)540 {541 #pragma omp flush542 #pragma omp critical (read_from_buffer)543 {544 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);545 }546 }547 548 MPI_Barrier_local(comm);549 }550 }551 }552 553 int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)554 {555 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;556 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;557 558 char *buffer = comm.my_buffer->buf_char;559 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));560 char *recv_buf = static_cast<char*>(recvbuf);561 562 if(my_rank == 0)563 {564 assert(count == recvcounts[0]);565 copy(send_buf, send_buf+count, recv_buf + displs[0]);566 }567 568 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)569 {570 for(int k=1; k<num_ep; k++)571 {572 if(my_rank == k)573 {574 #pragma omp critical (write_to_buffer)575 {576 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);577 #pragma omp flush578 }579 }580 581 MPI_Barrier_local(comm);582 583 if(my_rank == 0)584 {585 #pragma omp flush586 #pragma omp critical (read_from_buffer)587 {588 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);589 }590 }591 592 MPI_Barrier_local(comm);593 }594 }595 }596 597 598 int MPI_Gatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],599 MPI_Datatype recvtype, int root, MPI_Comm comm)600 {601 602 if(!comm.is_ep && comm.mpi_comm)603 {604 ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),605 static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));606 return 0;607 }608 609 if(!comm.mpi_comm) return 0;610 611 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));612 613 MPI_Datatype datatype = sendtype;614 int count = sendcount;615 616 int ep_rank, ep_rank_loc, mpi_rank;617 int ep_size, num_ep, mpi_size;618 619 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;620 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;621 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;622 ep_size = comm.ep_comm_ptr->size_rank_info[0].second;623 num_ep = comm.ep_comm_ptr->size_rank_info[1].second;624 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;625 626 627 628 if(ep_size == mpi_size)629 return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,630 static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));631 632 if(ep_rank != root)633 {634 recvcounts = new int[ep_size];635 displs = new int[ep_size];636 }637 638 MPI_Bcast(const_cast< int* >(displs), ep_size, MPI_INT, root, comm);639 MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);640 641 642 int recv_plus_displs[ep_size];643 for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];644 645 for(int j=0; j<mpi_size; j++)646 {647 if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||648 recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])649 {650 Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n");651 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);652 }653 654 for(int i=1; i<num_ep-1; i++)655 {656 if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||657 recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])658 {659 Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n");660 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);661 }662 }663 }664 665 666 int root_mpi_rank = comm.rank_map->at(root).second;667 int root_ep_loc = comm.rank_map->at(root).first;668 669 670 ::MPI_Aint datasize, lb;671 672 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);673 674 void *local_gather_recvbuf;675 int buffer_size;676 void *master_recvbuf;677 678 if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0)679 {680 master_recvbuf = new void*[sizeof(recvbuf)];681 assert(root_ep_loc == 0);682 }683 684 if(ep_rank_loc==0)685 {686 buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);687 688 local_gather_recvbuf = new void*[datasize*buffer_size];689 }690 691 MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);692 693 //MPI_Gather694 if(ep_rank_loc == 0)695 {696 int *mpi_recvcnt= new int[mpi_size];697 int *mpi_displs= new int[mpi_size];698 699 int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;700 int buff_end = buffer_size;701 702 int mpi_sendcnt = buff_end - buff_start;703 704 705 ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));706 ::MPI_Gather(&buff_start, 1, MPI_INT, mpi_displs, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));707 708 if(root_ep_loc == 0)709 { ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,710 mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));711 }712 else // gatherv to master_recvbuf713 { ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt,714 mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));715 }716 717 delete[] mpi_recvcnt;718 delete[] mpi_displs;719 }720 721 int global_min_displs = *std::min_element(displs, displs+ep_size);722 int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);723 724 725 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master726 {727 innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);728 if(ep_rank_loc == 0) delete[] master_recvbuf;729 }730 731 732 733 if(ep_rank_loc==0)734 {735 if(datatype == MPI_INT)736 {737 delete[] static_cast<int*>(local_gather_recvbuf);738 }739 else if(datatype == MPI_FLOAT)740 {741 delete[] static_cast<float*>(local_gather_recvbuf);742 }743 else if(datatype == MPI_DOUBLE)744 {745 delete[] static_cast<double*>(local_gather_recvbuf);746 }747 else if(datatype == MPI_LONG)748 {749 delete[] static_cast<long*>(local_gather_recvbuf);750 }751 else if(datatype == MPI_UNSIGNED_LONG)752 {753 delete[] static_cast<unsigned long*>(local_gather_recvbuf);754 }755 else // if(datatype == MPI_CHAR)756 {757 delete[] static_cast<char*>(local_gather_recvbuf);758 }759 }760 else761 {762 delete[] recvcounts;763 delete[] displs;764 }765 return 0;766 }767 768 769 770 int MPI_Allgatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],771 MPI_Datatype recvtype, MPI_Comm comm)772 {773 774 if(!comm.is_ep && comm.mpi_comm)775 {776 ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,777 static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));778 return 0;779 }780 781 if(!comm.mpi_comm) return 0;782 783 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));784 785 786 MPI_Datatype datatype = sendtype;787 int count = sendcount;788 789 int ep_rank, ep_rank_loc, mpi_rank;790 int ep_size, num_ep, mpi_size;791 792 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;793 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;794 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;795 ep_size = comm.ep_comm_ptr->size_rank_info[0].second;796 num_ep = comm.ep_comm_ptr->size_rank_info[1].second;797 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;798 799 if(ep_size == mpi_size) // needed by servers800 return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,801 static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));802 803 int recv_plus_displs[ep_size];804 for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];805 806 807 for(int j=0; j<mpi_size; j++)808 {809 if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||810 recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2])811 {812 printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size);813 for(int k=0; k<ep_size; k++)814 printf("recv_plus_displs[%d] = %d\t displs[%d] = %d\n", k, recv_plus_displs[k], k, displs[k]);815 816 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);817 }818 819 for(int i=1; i<num_ep-1; i++)820 {821 if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] ||822 recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])823 {824 printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size);825 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);826 }827 }828 }829 830 ::MPI_Aint datasize, lb;831 832 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);833 834 void *local_gather_recvbuf;835 int buffer_size;836 837 if(ep_rank_loc==0)838 {839 buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);840 841 local_gather_recvbuf = new void*[datasize*buffer_size];842 }843 844 // local gather to master845 MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);846 847 //MPI_Gather848 if(ep_rank_loc == 0)849 {850 int *mpi_recvcnt= new int[mpi_size];851 int *mpi_displs= new int[mpi_size];852 853 int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;854 int buff_end = buffer_size;855 856 int mpi_sendcnt = buff_end - buff_start;857 858 859 ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm));860 ::MPI_Allgather(&buff_start, 1, MPI_INT, mpi_displs, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm));861 862 863 ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,864 mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));865 866 delete[] mpi_recvcnt;867 delete[] mpi_displs;868 }869 870 int global_min_displs = *std::min_element(displs, displs+ep_size);871 int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);872 873 MPI_Bcast_local2(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);874 875 if(ep_rank_loc==0)876 {877 if(datatype == MPI_INT)878 {879 delete[] static_cast<int*>(local_gather_recvbuf);880 }881 else if(datatype == MPI_FLOAT)882 {883 delete[] static_cast<float*>(local_gather_recvbuf);884 }885 else if(datatype == MPI_DOUBLE)886 {887 delete[] static_cast<double*>(local_gather_recvbuf);888 }889 else if(datatype == MPI_LONG)890 {891 delete[] static_cast<long*>(local_gather_recvbuf);892 }893 else if(datatype == MPI_UNSIGNED_LONG)894 {895 delete[] static_cast<unsigned long*>(local_gather_recvbuf);896 }897 else // if(datatype == MPI_CHAR)898 {899 delete[] static_cast<char*>(local_gather_recvbuf);900 }901 }902 }903 904 int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],905 MPI_Datatype recvtype, int root, MPI_Comm comm)906 {907 int ep_rank, ep_rank_loc, mpi_rank;908 int ep_size, num_ep, mpi_size;909 910 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;911 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;912 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;913 ep_size = comm.ep_comm_ptr->size_rank_info[0].second;914 num_ep = comm.ep_comm_ptr->size_rank_info[1].second;915 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;916 917 int root_mpi_rank = comm.rank_map->at(root).second;918 int root_ep_loc = comm.rank_map->at(root).first;919 920 ::MPI_Aint datasize, lb;921 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);922 923 void *local_gather_recvbuf;924 int buffer_size;925 926 int *local_displs = new int[num_ep];927 int *local_rvcnts = new int[num_ep];928 for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];929 local_displs[0] = 0;930 for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];931 932 if(ep_rank_loc==0)933 {934 buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];935 local_gather_recvbuf = new void*[datasize*buffer_size];936 }937 938 // local gather to master939 MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master940 941 int **mpi_recvcnts = new int*[num_ep];942 int **mpi_displs = new int*[num_ep];943 for(int i=0; i<num_ep; i++)944 {945 mpi_recvcnts[i] = new int[mpi_size];946 mpi_displs[i] = new int[mpi_size];947 for(int j=0; j<mpi_size; j++)948 {949 mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];950 mpi_displs[i][j] = displs[j*num_ep + i];951 }952 }953 954 void *master_recvbuf;955 if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];956 957 if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop958 for(int i=0; i<num_ep; i++)959 {960 ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],961 static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));962 }963 if(ep_rank_loc == 0 && root_ep_loc != 0)964 for(int i=0; i<num_ep; i++)965 {966 ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i],967 static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));968 }969 970 971 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master972 {973 for(int i=0; i<ep_size; i++)974 innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm);975 976 if(ep_rank_loc == 0) delete[] master_recvbuf;977 }978 979 980 delete[] local_displs;981 delete[] local_rvcnts;982 for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];983 delete[] mpi_displs[i]; }984 delete[] mpi_recvcnts;985 delete[] mpi_displs;986 if(ep_rank_loc==0)987 {988 if(sendtype == MPI_INT)989 {990 delete[] static_cast<int*>(local_gather_recvbuf);991 }992 else if(sendtype == MPI_FLOAT)993 {994 delete[] static_cast<float*>(local_gather_recvbuf);995 }996 else if(sendtype == MPI_DOUBLE)997 {998 delete[] static_cast<double*>(local_gather_recvbuf);999 }1000 else if(sendtype == MPI_LONG)1001 {1002 delete[] static_cast<long*>(local_gather_recvbuf);1003 }1004 else if(sendtype == MPI_UNSIGNED_LONG)1005 {1006 delete[] static_cast<unsigned long*>(local_gather_recvbuf);1007 }1008 else // if(sendtype == MPI_CHAR)1009 {1010 delete[] static_cast<char*>(local_gather_recvbuf);1011 }1012 }1013 }1014 1015 int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],1016 MPI_Datatype recvtype, MPI_Comm comm)1017 {1018 int ep_rank, ep_rank_loc, mpi_rank;1019 int ep_size, num_ep, mpi_size;1020 1021 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;1022 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;1023 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;1024 ep_size = comm.ep_comm_ptr->size_rank_info[0].second;1025 num_ep = comm.ep_comm_ptr->size_rank_info[1].second;1026 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;1027 1028 1029 ::MPI_Aint datasize, lb;1030 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);1031 1032 void *local_gather_recvbuf;1033 int buffer_size;1034 1035 int *local_displs = new int[num_ep];1036 int *local_rvcnts = new int[num_ep];1037 for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];1038 local_displs[0] = 0;1039 for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];1040 1041 if(ep_rank_loc==0)1042 {1043 buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];1044 local_gather_recvbuf = new void*[datasize*buffer_size];1045 }1046 1047 // local gather to master1048 MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master1049 1050 int **mpi_recvcnts = new int*[num_ep];1051 int **mpi_displs = new int*[num_ep];1052 for(int i=0; i<num_ep; i++)1053 {1054 mpi_recvcnts[i] = new int[mpi_size];1055 mpi_displs[i] = new int[mpi_size];1056 for(int j=0; j<mpi_size; j++)1057 {1058 mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];1059 mpi_displs[i][j] = displs[j*num_ep + i];1060 }1061 }1062 1063 if(ep_rank_loc == 0) // master in MPI_Allgatherv loop1064 for(int i=0; i<num_ep; i++)1065 {1066 ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],1067 static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));1068 }1069 1070 for(int i=0; i<ep_size; i++)1071 MPI_Bcast_local2(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm);1072 1073 1074 delete[] local_displs;1075 delete[] local_rvcnts;1076 for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i];1077 delete[] mpi_displs[i]; }1078 delete[] mpi_recvcnts;1079 delete[] mpi_displs;1080 if(ep_rank_loc==0)1081 {1082 if(sendtype == MPI_INT)1083 {1084 delete[] static_cast<int*>(local_gather_recvbuf);1085 }1086 else if(sendtype == MPI_FLOAT)1087 {1088 delete[] static_cast<float*>(local_gather_recvbuf);1089 }1090 else if(sendtype == MPI_DOUBLE)1091 {1092 delete[] static_cast<double*>(local_gather_recvbuf);1093 }1094 else if(sendtype == MPI_LONG)1095 {1096 delete[] static_cast<long*>(local_gather_recvbuf);1097 }1098 else if(sendtype == MPI_UNSIGNED_LONG)1099 {1100 delete[] static_cast<unsigned long*>(local_gather_recvbuf);1101 }1102 else // if(sendtype == MPI_CHAR)1103 {1104 delete[] static_cast<char*>(local_gather_recvbuf);1105 }1106 }1107 }1108 1109 1110 175 }
Note: See TracChangeset
for help on using the changeset viewer.