- Timestamp:
- 05/24/17 13:09:23 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp
r1138 r1145 15 15 namespace ep_lib 16 16 { 17 17 18 int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 18 19 { … … 347 348 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 348 349 350 if(ep_size == mpi_size) 351 return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 352 static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 353 354 int recv_plus_displs[ep_size]; 355 for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 356 357 #pragma omp single nowait 358 { 359 assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]); 360 for(int i=1; i<num_ep-1; i++) 361 { 362 assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]); 363 assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]); 364 } 365 assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]); 366 } 367 349 368 if(ep_rank != root) 350 369 { … … 366 385 367 386 void *local_gather_recvbuf; 387 int buffer_size; 368 388 369 389 if(ep_rank_loc==0) 370 390 { 371 int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 391 buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 392 372 393 local_gather_recvbuf = new void*[datasize*buffer_size]; 373 394 } 374 395 375 // local gather to master 376 int local_displs[num_ep]; 377 local_displs[0] = 0; 378 for(int i=1; i<num_ep; i++) 379 { 380 local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc]; 381 } 382 MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm); 396 MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 383 397 384 398 //MPI_Gather 385 399 if(ep_rank_loc == 0) 386 400 { 387 388 int gatherv_recvcnt[mpi_size]; 389 int gatherv_displs[mpi_size]; 390 int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 391 392 //gatherv_recvcnt = new int[mpi_size]; 393 //gatherv_displs = new int[mpi_size]; 394 395 396 ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 397 398 gatherv_displs[0] = 0; 399 for(int i=1; i<mpi_size; i++) 400 { 401 gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1]; 402 } 403 404 405 ::MPI_Gatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 406 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 407 408 //delete[] gatherv_recvcnt; 409 //delete[] gatherv_displs; 410 } 401 int *mpi_recvcnt= new int[mpi_size]; 402 int *mpi_displs= new int[mpi_size]; 403 404 int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 405 int buff_end = buffer_size; 406 407 int mpi_sendcnt = buff_end - buff_start; 408 409 410 ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 411 ::MPI_Gather(&buff_start, 1, MPI_INT_STD, mpi_displs, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 412 413 414 ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 415 mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 416 417 delete[] mpi_recvcnt; 418 delete[] mpi_displs; 419 } 420 421 int global_min_displs = *std::min_element(displs, displs+ep_size); 422 int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 411 423 412 424 413 425 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 414 426 { 415 innode_memcpy(0, recvbuf , root_ep_loc, recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm);427 innode_memcpy(0, recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 416 428 } 417 429 … … 487 499 return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 488 500 static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 489 490 491 assert(accumulate(recvcounts, recvcounts+ep_size-1, 0) >= displs[ep_size-1]); // Only for continuous gather. 501 502 503 int recv_plus_displs[ep_size]; 504 for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 505 506 #pragma omp single nowait 507 { 508 assert(recv_plus_displs[ep_rank-ep_rank_loc] >= displs[ep_rank-ep_rank_loc+1]); 509 for(int i=1; i<num_ep-1; i++) 510 { 511 assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i+1]); 512 assert(recv_plus_displs[ep_rank-ep_rank_loc+i] >= displs[ep_rank-ep_rank_loc+i-1]); 513 } 514 assert(recv_plus_displs[ep_rank-ep_rank_loc+num_ep-1] >= displs[ep_rank-ep_rank_loc+num_ep-2]); 515 } 492 516 493 517 … … 497 521 498 522 void *local_gather_recvbuf; 523 int buffer_size; 499 524 500 525 if(ep_rank_loc==0) 501 526 { 502 int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 527 buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 528 503 529 local_gather_recvbuf = new void*[datasize*buffer_size]; 504 530 } 505 531 506 532 // local gather to master 507 int local_displs[num_ep]; 508 local_displs[0] = 0; 509 for(int i=1; i<num_ep; i++) 510 { 511 local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc]; 512 } 513 MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm); 533 MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 514 534 515 535 //MPI_Gather 516 536 if(ep_rank_loc == 0) 517 537 { 518 int *gatherv_recvcnt; 519 int *gatherv_displs; 520 int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0); 521 522 gatherv_recvcnt = new int[mpi_size]; 523 gatherv_displs = new int[mpi_size]; 524 525 ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 526 gatherv_displs[0] = displs[0]; 527 for(int i=1; i<mpi_size; i++) 528 { 529 gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1]; 530 } 531 532 ::MPI_Allgatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 533 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 534 535 delete[] gatherv_recvcnt; 536 delete[] gatherv_displs; 537 } 538 539 MPI_Bcast_local(recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm); 538 int *mpi_recvcnt= new int[mpi_size]; 539 int *mpi_displs= new int[mpi_size]; 540 541 int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 542 int buff_end = buffer_size; 543 544 int mpi_sendcnt = buff_end - buff_start; 545 546 547 ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 548 ::MPI_Allgather(&buff_start, 1, MPI_INT_STD, mpi_displs, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 549 550 551 ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 552 mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 553 554 delete[] mpi_recvcnt; 555 delete[] mpi_displs; 556 } 557 558 int global_min_displs = *std::min_element(displs, displs+ep_size); 559 int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 560 561 MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 540 562 541 563 if(ep_rank_loc==0)
Note: See TracChangeset
for help on using the changeset viewer.