Changeset 1289 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp
- Timestamp:
- 10/04/17 17:02:13 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp
r1287 r1289 15 15 namespace ep_lib 16 16 { 17 18 int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 17 int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], int local_root, MPI_Comm comm) 19 18 { 20 19 assert(valid_type(datatype)); … … 186 185 } 187 186 187 // int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm) 188 // { 189 190 // if(!comm.is_ep && comm.mpi_comm) 191 // { 192 // ::MPI_Allgatherv(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcounts, displs, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm)); 193 // return 0; 194 // } 195 196 // if(!comm.mpi_comm) return 0; 197 198 199 200 201 // assert(valid_type(sendtype) && valid_type(recvtype)); 202 203 // MPI_Datatype datatype = sendtype; 204 // int count = sendcount; 205 206 // ::MPI_Aint datasize, lb; 207 208 // ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 209 210 211 // int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 212 // int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 213 // int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 214 // int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 215 // int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 216 // int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 217 218 219 // assert(sendcount == recvcounts[ep_rank]); 220 221 // bool is_master = ep_rank_loc==0; 222 223 // void* local_recvbuf; 224 // void* tmp_recvbuf; 225 226 // int recvbuf_size = 0; 227 // for(int i=0; i<ep_size; i++) 228 // recvbuf_size = max(recvbuf_size, displs[i]+recvcounts[i]); 229 230 231 // vector<int>local_recvcounts(num_ep, 0); 232 // vector<int>local_displs(num_ep, 0); 233 234 // MPI_Gather_local(&sendcount, 1, MPI_INT, local_recvcounts.data(), 0, comm); 235 // for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_recvcounts[i-1]; 236 237 238 // if(is_master) 239 // { 240 // local_recvbuf = new void*[datasize * std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0)]; 241 // tmp_recvbuf = new void*[datasize * std::accumulate(recvcounts, recvcounts+ep_size, 0)]; 242 // } 243 244 // MPI_Gatherv_local(sendbuf, count, datatype, local_recvbuf, local_recvcounts.data(), local_displs.data(), 0, comm); 245 246 247 // if(is_master) 248 // { 249 // std::vector<int>mpi_recvcounts(mpi_size, 0); 250 // std::vector<int>mpi_displs(mpi_size, 0); 251 252 // int local_sendcount = std::accumulate(local_recvcounts.begin(), local_recvcounts.begin()+num_ep, 0); 253 // MPI_Allgather(&local_sendcount, 1, MPI_INT, mpi_recvcounts.data(), 1, MPI_INT, to_mpi_comm(comm.mpi_comm)); 254 255 // for(int i=1; i<mpi_size; i++) 256 // mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1]; 257 258 259 // ::MPI_Allgatherv(local_recvbuf, local_sendcount, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts.data(), mpi_displs.data(), to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm)); 260 261 262 263 // // reorder 264 // int offset; 265 // for(int i=0; i<ep_size; i++) 266 // { 267 // int extra = 0; 268 // for(int j=0, k=0; j<ep_size, k<comm.rank_map->at(i).first; j++) 269 // if(comm.rank_map->at(i).second == comm.rank_map->at(j).second) 270 // { 271 // extra += recvcounts[j]; 272 // k++; 273 // } 274 275 // offset = mpi_displs[comm.rank_map->at(i).second] + extra; 276 277 // memcpy(recvbuf+displs[i]*datasize, tmp_recvbuf+offset*datasize, recvcounts[i]*datasize); 278 279 // } 280 281 // } 282 283 // MPI_Bcast_local(recvbuf, recvbuf_size, datatype, 0, comm); 284 285 // if(is_master) 286 // { 287 // delete[] local_recvbuf; 288 // delete[] tmp_recvbuf; 289 // } 290 291 // } 292 293 294 int MPI_Gatherv_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 295 { 296 if(datatype == MPI_INT) 297 { 298 Debug("datatype is INT\n"); 299 return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm); 300 } 301 else if(datatype == MPI_FLOAT) 302 { 303 Debug("datatype is FLOAT\n"); 304 return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm); 305 } 306 else if(datatype == MPI_DOUBLE) 307 { 308 Debug("datatype is DOUBLE\n"); 309 return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm); 310 } 311 else if(datatype == MPI_LONG) 312 { 313 Debug("datatype is LONG\n"); 314 return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm); 315 } 316 else if(datatype == MPI_UNSIGNED_LONG) 317 { 318 Debug("datatype is uLONG\n"); 319 return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm); 320 } 321 else if(datatype == MPI_CHAR) 322 { 323 Debug("datatype is CHAR\n"); 324 return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm); 325 } 326 else 327 { 328 printf("MPI_Gatherv Datatype not supported!\n"); 329 exit(0); 330 } 331 } 332 333 int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 334 { 335 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 336 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 337 338 int *buffer = comm.my_buffer->buf_int; 339 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 340 int *recv_buf = static_cast<int*>(recvbuf); 341 342 if(my_rank == 0) 343 { 344 assert(count == recvcounts[0]); 345 copy(send_buf, send_buf+count, recv_buf + displs[0]); 346 } 347 348 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 349 { 350 for(int k=1; k<num_ep; k++) 351 { 352 if(my_rank == k) 353 { 354 #pragma omp critical (write_to_buffer) 355 { 356 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 357 #pragma omp flush 358 } 359 } 360 361 MPI_Barrier_local(comm); 362 363 if(my_rank == 0) 364 { 365 #pragma omp flush 366 #pragma omp critical (read_from_buffer) 367 { 368 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 369 } 370 } 371 372 MPI_Barrier_local(comm); 373 } 374 } 375 } 376 377 int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 378 { 379 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 380 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 381 382 float *buffer = comm.my_buffer->buf_float; 383 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 384 float *recv_buf = static_cast<float*>(recvbuf); 385 386 if(my_rank == 0) 387 { 388 assert(count == recvcounts[0]); 389 copy(send_buf, send_buf+count, recv_buf + displs[0]); 390 } 391 392 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 393 { 394 for(int k=1; k<num_ep; k++) 395 { 396 if(my_rank == k) 397 { 398 #pragma omp critical (write_to_buffer) 399 { 400 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 401 #pragma omp flush 402 } 403 } 404 405 MPI_Barrier_local(comm); 406 407 if(my_rank == 0) 408 { 409 #pragma omp flush 410 #pragma omp critical (read_from_buffer) 411 { 412 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 413 } 414 } 415 416 MPI_Barrier_local(comm); 417 } 418 } 419 } 420 421 int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 422 { 423 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 424 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 425 426 double *buffer = comm.my_buffer->buf_double; 427 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 428 double *recv_buf = static_cast<double*>(recvbuf); 429 430 if(my_rank == 0) 431 { 432 assert(count == recvcounts[0]); 433 copy(send_buf, send_buf+count, recv_buf + displs[0]); 434 } 435 436 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 437 { 438 for(int k=1; k<num_ep; k++) 439 { 440 if(my_rank == k) 441 { 442 #pragma omp critical (write_to_buffer) 443 { 444 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 445 #pragma omp flush 446 } 447 } 448 449 MPI_Barrier_local(comm); 450 451 if(my_rank == 0) 452 { 453 #pragma omp flush 454 #pragma omp critical (read_from_buffer) 455 { 456 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 457 } 458 } 459 460 MPI_Barrier_local(comm); 461 } 462 } 463 } 464 465 int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 466 { 467 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 468 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 469 470 long *buffer = comm.my_buffer->buf_long; 471 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 472 long *recv_buf = static_cast<long*>(recvbuf); 473 474 if(my_rank == 0) 475 { 476 assert(count == recvcounts[0]); 477 copy(send_buf, send_buf+count, recv_buf + displs[0]); 478 } 479 480 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 481 { 482 for(int k=1; k<num_ep; k++) 483 { 484 if(my_rank == k) 485 { 486 #pragma omp critical (write_to_buffer) 487 { 488 if(count!=0)copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 489 #pragma omp flush 490 } 491 } 492 493 MPI_Barrier_local(comm); 494 495 if(my_rank == 0) 496 { 497 #pragma omp flush 498 #pragma omp critical (read_from_buffer) 499 { 500 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 501 } 502 } 503 504 MPI_Barrier_local(comm); 505 } 506 } 507 } 508 509 int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 510 { 511 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 512 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 513 514 unsigned long *buffer = comm.my_buffer->buf_ulong; 515 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 516 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 517 518 if(my_rank == 0) 519 { 520 assert(count == recvcounts[0]); 521 copy(send_buf, send_buf+count, recv_buf + displs[0]); 522 } 523 524 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 525 { 526 for(int k=1; k<num_ep; k++) 527 { 528 if(my_rank == k) 529 { 530 #pragma omp critical (write_to_buffer) 531 { 532 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 533 #pragma omp flush 534 } 535 } 536 537 MPI_Barrier_local(comm); 538 539 if(my_rank == 0) 540 { 541 #pragma omp flush 542 #pragma omp critical (read_from_buffer) 543 { 544 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 545 } 546 } 547 548 MPI_Barrier_local(comm); 549 } 550 } 551 } 552 553 int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm) 554 { 555 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 556 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 557 558 char *buffer = comm.my_buffer->buf_char; 559 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 560 char *recv_buf = static_cast<char*>(recvbuf); 561 562 if(my_rank == 0) 563 { 564 assert(count == recvcounts[0]); 565 copy(send_buf, send_buf+count, recv_buf + displs[0]); 566 } 567 568 for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE) 569 { 570 for(int k=1; k<num_ep; k++) 571 { 572 if(my_rank == k) 573 { 574 #pragma omp critical (write_to_buffer) 575 { 576 if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer); 577 #pragma omp flush 578 } 579 } 580 581 MPI_Barrier_local(comm); 582 583 if(my_rank == 0) 584 { 585 #pragma omp flush 586 #pragma omp critical (read_from_buffer) 587 { 588 copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]); 589 } 590 } 591 592 MPI_Barrier_local(comm); 593 } 594 } 595 } 596 597 598 int MPI_Gatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 599 MPI_Datatype recvtype, int root, MPI_Comm comm) 600 { 601 602 if(!comm.is_ep && comm.mpi_comm) 603 { 604 ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs), 605 static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 606 return 0; 607 } 608 609 if(!comm.mpi_comm) return 0; 610 611 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 612 613 MPI_Datatype datatype = sendtype; 614 int count = sendcount; 615 616 int ep_rank, ep_rank_loc, mpi_rank; 617 int ep_size, num_ep, mpi_size; 618 619 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 620 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 621 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 622 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 623 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 624 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 625 626 627 628 if(ep_size == mpi_size) 629 return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 630 static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 631 632 if(ep_rank != root) 633 { 634 recvcounts = new int[ep_size]; 635 displs = new int[ep_size]; 636 } 637 638 MPI_Bcast(const_cast< int* >(displs), ep_size, MPI_INT, root, comm); 639 MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm); 640 641 642 int recv_plus_displs[ep_size]; 643 for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 644 645 for(int j=0; j<mpi_size; j++) 646 { 647 if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 648 recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 649 { 650 Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n"); 651 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 652 } 653 654 for(int i=1; i<num_ep-1; i++) 655 { 656 if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 657 recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 658 { 659 Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n"); 660 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 661 } 662 } 663 } 664 665 666 int root_mpi_rank = comm.rank_map->at(root).second; 667 int root_ep_loc = comm.rank_map->at(root).first; 668 669 670 ::MPI_Aint datasize, lb; 671 672 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 673 674 void *local_gather_recvbuf; 675 int buffer_size; 676 void *master_recvbuf; 677 678 if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 679 { 680 master_recvbuf = new void*[sizeof(recvbuf)]; 681 assert(root_ep_loc == 0); 682 } 683 684 if(ep_rank_loc==0) 685 { 686 buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 687 688 local_gather_recvbuf = new void*[datasize*buffer_size]; 689 } 690 691 MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 692 693 //MPI_Gather 694 if(ep_rank_loc == 0) 695 { 696 int *mpi_recvcnt= new int[mpi_size]; 697 int *mpi_displs= new int[mpi_size]; 698 699 int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 700 int buff_end = buffer_size; 701 702 int mpi_sendcnt = buff_end - buff_start; 703 704 705 ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 706 ::MPI_Gather(&buff_start, 1, MPI_INT, mpi_displs, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 707 708 if(root_ep_loc == 0) 709 { ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 710 mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 711 } 712 else // gatherv to master_recvbuf 713 { ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt, 714 mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 715 } 716 717 delete[] mpi_recvcnt; 718 delete[] mpi_displs; 719 } 720 721 int global_min_displs = *std::min_element(displs, displs+ep_size); 722 int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 723 724 725 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 726 { 727 innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 728 if(ep_rank_loc == 0) delete[] master_recvbuf; 729 } 730 731 732 733 if(ep_rank_loc==0) 734 { 735 if(datatype == MPI_INT) 736 { 737 delete[] static_cast<int*>(local_gather_recvbuf); 738 } 739 else if(datatype == MPI_FLOAT) 740 { 741 delete[] static_cast<float*>(local_gather_recvbuf); 742 } 743 else if(datatype == MPI_DOUBLE) 744 { 745 delete[] static_cast<double*>(local_gather_recvbuf); 746 } 747 else if(datatype == MPI_LONG) 748 { 749 delete[] static_cast<long*>(local_gather_recvbuf); 750 } 751 else if(datatype == MPI_UNSIGNED_LONG) 752 { 753 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 754 } 755 else // if(datatype == MPI_CHAR) 756 { 757 delete[] static_cast<char*>(local_gather_recvbuf); 758 } 759 } 760 else 761 { 762 delete[] recvcounts; 763 delete[] displs; 764 } 765 return 0; 766 } 767 768 769 770 int MPI_Allgatherv2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 771 MPI_Datatype recvtype, MPI_Comm comm) 772 { 773 774 if(!comm.is_ep && comm.mpi_comm) 775 { 776 ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs, 777 static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 778 return 0; 779 } 780 781 if(!comm.mpi_comm) return 0; 782 783 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 784 785 786 MPI_Datatype datatype = sendtype; 787 int count = sendcount; 788 789 int ep_rank, ep_rank_loc, mpi_rank; 790 int ep_size, num_ep, mpi_size; 791 792 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 793 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 794 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 795 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 796 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 797 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 798 799 if(ep_size == mpi_size) // needed by servers 800 return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs, 801 static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 802 803 int recv_plus_displs[ep_size]; 804 for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i]; 805 806 807 for(int j=0; j<mpi_size; j++) 808 { 809 if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] || 810 recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 811 { 812 printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size); 813 for(int k=0; k<ep_size; k++) 814 printf("recv_plus_displs[%d] = %d\t displs[%d] = %d\n", k, recv_plus_displs[k], k, displs[k]); 815 816 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 817 } 818 819 for(int i=1; i<num_ep-1; i++) 820 { 821 if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 822 recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1]) 823 { 824 printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size); 825 return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); 826 } 827 } 828 } 829 830 ::MPI_Aint datasize, lb; 831 832 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 833 834 void *local_gather_recvbuf; 835 int buffer_size; 836 837 if(ep_rank_loc==0) 838 { 839 buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep); 840 841 local_gather_recvbuf = new void*[datasize*buffer_size]; 842 } 843 844 // local gather to master 845 MPI_Gatherv_local2(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm); 846 847 //MPI_Gather 848 if(ep_rank_loc == 0) 849 { 850 int *mpi_recvcnt= new int[mpi_size]; 851 int *mpi_displs= new int[mpi_size]; 852 853 int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);; 854 int buff_end = buffer_size; 855 856 int mpi_sendcnt = buff_end - buff_start; 857 858 859 ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT, mpi_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 860 ::MPI_Allgather(&buff_start, 1, MPI_INT, mpi_displs, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 861 862 863 ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt, 864 mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 865 866 delete[] mpi_recvcnt; 867 delete[] mpi_displs; 868 } 869 870 int global_min_displs = *std::min_element(displs, displs+ep_size); 871 int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size); 872 873 MPI_Bcast_local2(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm); 874 875 if(ep_rank_loc==0) 876 { 877 if(datatype == MPI_INT) 878 { 879 delete[] static_cast<int*>(local_gather_recvbuf); 880 } 881 else if(datatype == MPI_FLOAT) 882 { 883 delete[] static_cast<float*>(local_gather_recvbuf); 884 } 885 else if(datatype == MPI_DOUBLE) 886 { 887 delete[] static_cast<double*>(local_gather_recvbuf); 888 } 889 else if(datatype == MPI_LONG) 890 { 891 delete[] static_cast<long*>(local_gather_recvbuf); 892 } 893 else if(datatype == MPI_UNSIGNED_LONG) 894 { 895 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 896 } 897 else // if(datatype == MPI_CHAR) 898 { 899 delete[] static_cast<char*>(local_gather_recvbuf); 900 } 901 } 902 } 903 904 int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 905 MPI_Datatype recvtype, int root, MPI_Comm comm) 906 { 907 int ep_rank, ep_rank_loc, mpi_rank; 908 int ep_size, num_ep, mpi_size; 909 910 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 911 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 912 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 913 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 914 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 915 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 916 917 int root_mpi_rank = comm.rank_map->at(root).second; 918 int root_ep_loc = comm.rank_map->at(root).first; 919 920 ::MPI_Aint datasize, lb; 921 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 922 923 void *local_gather_recvbuf; 924 int buffer_size; 925 926 int *local_displs = new int[num_ep]; 927 int *local_rvcnts = new int[num_ep]; 928 for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 929 local_displs[0] = 0; 930 for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 931 932 if(ep_rank_loc==0) 933 { 934 buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 935 local_gather_recvbuf = new void*[datasize*buffer_size]; 936 } 937 938 // local gather to master 939 MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 940 941 int **mpi_recvcnts = new int*[num_ep]; 942 int **mpi_displs = new int*[num_ep]; 943 for(int i=0; i<num_ep; i++) 944 { 945 mpi_recvcnts[i] = new int[mpi_size]; 946 mpi_displs[i] = new int[mpi_size]; 947 for(int j=0; j<mpi_size; j++) 948 { 949 mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 950 mpi_displs[i][j] = displs[j*num_ep + i]; 951 } 952 } 953 954 void *master_recvbuf; 955 if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)]; 956 957 if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop 958 for(int i=0; i<num_ep; i++) 959 { 960 ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 961 static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 962 } 963 if(ep_rank_loc == 0 && root_ep_loc != 0) 964 for(int i=0; i<num_ep; i++) 965 { 966 ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i], 967 static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 968 } 969 970 971 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 972 { 973 for(int i=0; i<ep_size; i++) 974 innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm); 975 976 if(ep_rank_loc == 0) delete[] master_recvbuf; 977 } 978 979 980 delete[] local_displs; 981 delete[] local_rvcnts; 982 for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 983 delete[] mpi_displs[i]; } 984 delete[] mpi_recvcnts; 985 delete[] mpi_displs; 986 if(ep_rank_loc==0) 987 { 988 if(sendtype == MPI_INT) 989 { 990 delete[] static_cast<int*>(local_gather_recvbuf); 991 } 992 else if(sendtype == MPI_FLOAT) 993 { 994 delete[] static_cast<float*>(local_gather_recvbuf); 995 } 996 else if(sendtype == MPI_DOUBLE) 997 { 998 delete[] static_cast<double*>(local_gather_recvbuf); 999 } 1000 else if(sendtype == MPI_LONG) 1001 { 1002 delete[] static_cast<long*>(local_gather_recvbuf); 1003 } 1004 else if(sendtype == MPI_UNSIGNED_LONG) 1005 { 1006 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 1007 } 1008 else // if(sendtype == MPI_CHAR) 1009 { 1010 delete[] static_cast<char*>(local_gather_recvbuf); 1011 } 1012 } 1013 } 1014 1015 int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], 1016 MPI_Datatype recvtype, MPI_Comm comm) 1017 { 1018 int ep_rank, ep_rank_loc, mpi_rank; 1019 int ep_size, num_ep, mpi_size; 1020 1021 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 1022 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 1023 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 1024 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 1025 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 1026 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 1027 1028 1029 ::MPI_Aint datasize, lb; 1030 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize); 1031 1032 void *local_gather_recvbuf; 1033 int buffer_size; 1034 1035 int *local_displs = new int[num_ep]; 1036 int *local_rvcnts = new int[num_ep]; 1037 for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i]; 1038 local_displs[0] = 0; 1039 for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1]; 1040 1041 if(ep_rank_loc==0) 1042 { 1043 buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1]; 1044 local_gather_recvbuf = new void*[datasize*buffer_size]; 1045 } 1046 1047 // local gather to master 1048 MPI_Gatherv_local2(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master 1049 1050 int **mpi_recvcnts = new int*[num_ep]; 1051 int **mpi_displs = new int*[num_ep]; 1052 for(int i=0; i<num_ep; i++) 1053 { 1054 mpi_recvcnts[i] = new int[mpi_size]; 1055 mpi_displs[i] = new int[mpi_size]; 1056 for(int j=0; j<mpi_size; j++) 1057 { 1058 mpi_recvcnts[i][j] = recvcounts[j*num_ep + i]; 1059 mpi_displs[i][j] = displs[j*num_ep + i]; 1060 } 1061 } 1062 1063 if(ep_rank_loc == 0) // master in MPI_Allgatherv loop 1064 for(int i=0; i<num_ep; i++) 1065 { 1066 ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i], 1067 static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 1068 } 1069 1070 for(int i=0; i<ep_size; i++) 1071 MPI_Bcast_local2(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm); 1072 1073 1074 delete[] local_displs; 1075 delete[] local_rvcnts; 1076 for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 1077 delete[] mpi_displs[i]; } 1078 delete[] mpi_recvcnts; 1079 delete[] mpi_displs; 1080 if(ep_rank_loc==0) 1081 { 1082 if(sendtype == MPI_INT) 1083 { 1084 delete[] static_cast<int*>(local_gather_recvbuf); 1085 } 1086 else if(sendtype == MPI_FLOAT) 1087 { 1088 delete[] static_cast<float*>(local_gather_recvbuf); 1089 } 1090 else if(sendtype == MPI_DOUBLE) 1091 { 1092 delete[] static_cast<double*>(local_gather_recvbuf); 1093 } 1094 else if(sendtype == MPI_LONG) 1095 { 1096 delete[] static_cast<long*>(local_gather_recvbuf); 1097 } 1098 else if(sendtype == MPI_UNSIGNED_LONG) 1099 { 1100 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 1101 } 1102 else // if(sendtype == MPI_CHAR) 1103 { 1104 delete[] static_cast<char*>(local_gather_recvbuf); 1105 } 1106 } 1107 } 1108 1109 188 1110 }
Note: See TracChangeset
for help on using the changeset viewer.