Changeset 1289 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gather.cpp
- Timestamp:
- 10/04/17 17:02:13 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gather.cpp
r1287 r1289 15 15 namespace ep_lib 16 16 { 17 18 17 int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, int local_root, MPI_Comm comm) 19 18 { … … 128 127 } 129 128 129 // int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm) 130 // { 131 132 // if(!comm.is_ep && comm.mpi_comm) 133 // { 134 // ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 135 // static_cast< ::MPI_Comm>(comm.mpi_comm)); 136 // return 0; 137 // } 138 139 // if(!comm.mpi_comm) return 0; 140 141 // assert(sendcount == recvcount); 142 143 // assert(valid_type(sendtype) && valid_type(recvtype)); 144 145 // MPI_Datatype datatype = sendtype; 146 // int count = sendcount; 147 148 // ::MPI_Aint datasize, lb; 149 150 // ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 151 152 153 // int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 154 // int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 155 // int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 156 // int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 157 // int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 158 // int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 159 160 // bool is_master = ep_rank_loc==0; 161 162 // void* local_recvbuf; 163 // void* tmp_recvbuf; 164 165 166 // if(is_master) 167 // { 168 // local_recvbuf = new void*[datasize * num_ep * count]; 169 // tmp_recvbuf = new void*[datasize * count * ep_size]; 170 // } 171 172 // MPI_Gather_local(sendbuf, count, datatype, local_recvbuf, 0, comm); 173 174 175 // int* mpi_recvcounts; 176 // int *mpi_displs; 177 178 // if(is_master) 179 // { 180 181 // mpi_recvcounts = new int[mpi_size]; 182 // mpi_displs = new int[mpi_size]; 183 184 // int local_sendcount = num_ep * count; 185 186 // ::MPI_Allgather(&local_sendcount, 1, to_mpi_type(MPI_INT), mpi_recvcounts, 1, to_mpi_type(MPI_INT), to_mpi_comm(comm.mpi_comm)); 187 188 // mpi_displs[0] = 0; 189 // for(int i=1; i<mpi_size; i++) 190 // { 191 // mpi_displs[i] = mpi_displs[i-1] + mpi_recvcounts[i-1]; 192 // } 193 194 195 // ::MPI_Allgatherv(local_recvbuf, num_ep * count, to_mpi_type(datatype), tmp_recvbuf, mpi_recvcounts, mpi_displs, to_mpi_type(datatype), to_mpi_comm(comm.mpi_comm)); 196 197 198 // // reorder 199 // int offset; 200 // for(int i=0; i<ep_size; i++) 201 // { 202 // offset = mpi_displs[comm.rank_map->at(i).second] + comm.rank_map->at(i).first * sendcount; 203 // memcpy(recvbuf + i*sendcount*datasize, tmp_recvbuf+offset*datasize, sendcount*datasize); 204 // } 205 206 // delete[] mpi_recvcounts; 207 // delete[] mpi_displs; 208 // } 209 210 // MPI_Bcast_local(recvbuf, count*ep_size, datatype, 0, comm); 211 212 // MPI_Barrier(comm); 213 214 215 // if(is_master) 216 // { 217 // delete[] local_recvbuf; 218 // delete[] tmp_recvbuf; 219 220 // } 221 222 // } 223 224 int MPI_Gather_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm) 225 { 226 if(datatype == MPI_INT) 227 { 228 Debug("datatype is INT\n"); 229 return MPI_Gather_local_int(sendbuf, count, recvbuf, comm); 230 } 231 else if(datatype == MPI_FLOAT) 232 { 233 Debug("datatype is FLOAT\n"); 234 return MPI_Gather_local_float(sendbuf, count, recvbuf, comm); 235 } 236 else if(datatype == MPI_DOUBLE) 237 { 238 Debug("datatype is DOUBLE\n"); 239 return MPI_Gather_local_double(sendbuf, count, recvbuf, comm); 240 } 241 else if(datatype == MPI_LONG) 242 { 243 Debug("datatype is LONG\n"); 244 return MPI_Gather_local_long(sendbuf, count, recvbuf, comm); 245 } 246 else if(datatype == MPI_UNSIGNED_LONG) 247 { 248 Debug("datatype is uLONG\n"); 249 return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm); 250 } 251 else if(datatype == MPI_CHAR) 252 { 253 Debug("datatype is CHAR\n"); 254 return MPI_Gather_local_char(sendbuf, count, recvbuf, comm); 255 } 256 else 257 { 258 printf("MPI_Gather Datatype not supported!\n"); 259 exit(0); 260 } 261 } 262 263 int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 264 { 265 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 266 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 267 268 int *buffer = comm.my_buffer->buf_int; 269 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 270 int *recv_buf = static_cast<int*>(recvbuf); 271 272 if(my_rank == 0) 273 { 274 copy(send_buf, send_buf+count, recv_buf); 275 } 276 277 for(int j=0; j<count; j+=BUFFER_SIZE) 278 { 279 for(int k=1; k<num_ep; k++) 280 { 281 if(my_rank == k) 282 { 283 #pragma omp critical (write_to_buffer) 284 { 285 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 286 #pragma omp flush 287 } 288 } 289 290 MPI_Barrier_local(comm); 291 292 if(my_rank == 0) 293 { 294 #pragma omp flush 295 #pragma omp critical (read_from_buffer) 296 { 297 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 298 } 299 } 300 301 MPI_Barrier_local(comm); 302 } 303 } 304 } 305 306 int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 307 { 308 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 309 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 310 311 float *buffer = comm.my_buffer->buf_float; 312 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 313 float *recv_buf = static_cast<float*>(recvbuf); 314 315 if(my_rank == 0) 316 { 317 copy(send_buf, send_buf+count, recv_buf); 318 } 319 320 for(int j=0; j<count; j+=BUFFER_SIZE) 321 { 322 for(int k=1; k<num_ep; k++) 323 { 324 if(my_rank == k) 325 { 326 #pragma omp critical (write_to_buffer) 327 { 328 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 329 #pragma omp flush 330 } 331 } 332 333 MPI_Barrier_local(comm); 334 335 if(my_rank == 0) 336 { 337 #pragma omp flush 338 #pragma omp critical (read_from_buffer) 339 { 340 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 341 } 342 } 343 344 MPI_Barrier_local(comm); 345 } 346 } 347 } 348 349 int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 350 { 351 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 352 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 353 354 double *buffer = comm.my_buffer->buf_double; 355 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 356 double *recv_buf = static_cast<double*>(recvbuf); 357 358 if(my_rank == 0) 359 { 360 copy(send_buf, send_buf+count, recv_buf); 361 } 362 363 for(int j=0; j<count; j+=BUFFER_SIZE) 364 { 365 for(int k=1; k<num_ep; k++) 366 { 367 if(my_rank == k) 368 { 369 #pragma omp critical (write_to_buffer) 370 { 371 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 372 #pragma omp flush 373 } 374 } 375 376 MPI_Barrier_local(comm); 377 378 if(my_rank == 0) 379 { 380 #pragma omp flush 381 #pragma omp critical (read_from_buffer) 382 { 383 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 384 } 385 } 386 387 MPI_Barrier_local(comm); 388 } 389 } 390 } 391 392 int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 393 { 394 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 395 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 396 397 long *buffer = comm.my_buffer->buf_long; 398 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 399 long *recv_buf = static_cast<long*>(recvbuf); 400 401 if(my_rank == 0) 402 { 403 copy(send_buf, send_buf+count, recv_buf); 404 } 405 406 for(int j=0; j<count; j+=BUFFER_SIZE) 407 { 408 for(int k=1; k<num_ep; k++) 409 { 410 if(my_rank == k) 411 { 412 #pragma omp critical (write_to_buffer) 413 { 414 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 415 #pragma omp flush 416 } 417 } 418 419 MPI_Barrier_local(comm); 420 421 if(my_rank == 0) 422 { 423 #pragma omp flush 424 #pragma omp critical (read_from_buffer) 425 { 426 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 427 } 428 } 429 430 MPI_Barrier_local(comm); 431 } 432 } 433 } 434 435 int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 436 { 437 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 438 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 439 440 unsigned long *buffer = comm.my_buffer->buf_ulong; 441 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 442 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 443 444 if(my_rank == 0) 445 { 446 copy(send_buf, send_buf+count, recv_buf); 447 } 448 449 for(int j=0; j<count; j+=BUFFER_SIZE) 450 { 451 for(int k=1; k<num_ep; k++) 452 { 453 if(my_rank == k) 454 { 455 #pragma omp critical (write_to_buffer) 456 { 457 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 458 #pragma omp flush 459 } 460 } 461 462 MPI_Barrier_local(comm); 463 464 if(my_rank == 0) 465 { 466 #pragma omp flush 467 #pragma omp critical (read_from_buffer) 468 { 469 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 470 } 471 } 472 473 MPI_Barrier_local(comm); 474 } 475 } 476 } 477 478 479 int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 480 { 481 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 482 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 483 484 char *buffer = comm.my_buffer->buf_char; 485 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 486 char *recv_buf = static_cast<char*>(recvbuf); 487 488 if(my_rank == 0) 489 { 490 copy(send_buf, send_buf+count, recv_buf); 491 } 492 493 for(int j=0; j<count; j+=BUFFER_SIZE) 494 { 495 for(int k=1; k<num_ep; k++) 496 { 497 if(my_rank == k) 498 { 499 #pragma omp critical (write_to_buffer) 500 { 501 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 502 #pragma omp flush 503 } 504 } 505 506 MPI_Barrier_local(comm); 507 508 if(my_rank == 0) 509 { 510 #pragma omp flush 511 #pragma omp critical (read_from_buffer) 512 { 513 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 514 } 515 } 516 517 MPI_Barrier_local(comm); 518 } 519 } 520 } 521 522 523 524 int MPI_Gather2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) 525 { 526 if(!comm.is_ep && comm.mpi_comm) 527 { 528 ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 529 root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 530 return 0; 531 } 532 533 if(!comm.mpi_comm) return 0; 534 535 MPI_Bcast(&recvcount, 1, MPI_INT, root, comm); 536 537 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 538 539 MPI_Datatype datatype = sendtype; 540 int count = sendcount; 541 542 int ep_rank, ep_rank_loc, mpi_rank; 543 int ep_size, num_ep, mpi_size; 544 545 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 546 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 547 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 548 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 549 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 550 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 551 552 553 int root_mpi_rank = comm.rank_map->at(root).second; 554 int root_ep_loc = comm.rank_map->at(root).first; 555 556 557 ::MPI_Aint datasize, lb; 558 559 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 560 561 void *local_gather_recvbuf; 562 void *master_recvbuf; 563 if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 564 { 565 master_recvbuf = new void*[datasize*ep_size*count]; 566 } 567 568 if(ep_rank_loc==0) 569 { 570 local_gather_recvbuf = new void*[datasize*num_ep*count]; 571 } 572 573 // local gather to master 574 MPI_Gather_local2(sendbuf, count, datatype, local_gather_recvbuf, comm); 575 576 //MPI_Gather 577 578 if(ep_rank_loc == 0) 579 { 580 int *gatherv_recvcnt; 581 int *gatherv_displs; 582 int gatherv_cnt = count*num_ep; 583 584 gatherv_recvcnt = new int[mpi_size]; 585 gatherv_displs = new int[mpi_size]; 586 587 588 ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT, gatherv_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 589 590 gatherv_displs[0] = 0; 591 for(int i=1; i<mpi_size; i++) 592 { 593 gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1]; 594 } 595 596 if(root_ep_loc != 0) // gather to root_master 597 { 598 ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, gatherv_recvcnt, 599 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 600 } 601 else 602 { 603 ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 604 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 605 } 606 607 delete[] gatherv_recvcnt; 608 delete[] gatherv_displs; 609 } 610 611 612 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master 613 { 614 innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm); 615 } 616 617 618 619 if(ep_rank_loc==0) 620 { 621 if(datatype == MPI_INT) 622 { 623 delete[] static_cast<int*>(local_gather_recvbuf); 624 } 625 else if(datatype == MPI_FLOAT) 626 { 627 delete[] static_cast<float*>(local_gather_recvbuf); 628 } 629 else if(datatype == MPI_DOUBLE) 630 { 631 delete[] static_cast<double*>(local_gather_recvbuf); 632 } 633 else if(datatype == MPI_CHAR) 634 { 635 delete[] static_cast<char*>(local_gather_recvbuf); 636 } 637 else if(datatype == MPI_LONG) 638 { 639 delete[] static_cast<long*>(local_gather_recvbuf); 640 } 641 else// if(datatype == MPI_UNSIGNED_LONG) 642 { 643 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 644 } 645 646 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) delete[] master_recvbuf; 647 } 648 } 649 650 651 int MPI_Allgather2(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm) 652 { 653 if(!comm.is_ep && comm.mpi_comm) 654 { 655 ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 656 static_cast< ::MPI_Comm>(comm.mpi_comm)); 657 return 0; 658 } 659 660 if(!comm.mpi_comm) return 0; 661 662 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 663 664 MPI_Datatype datatype = sendtype; 665 int count = sendcount; 666 667 int ep_rank, ep_rank_loc, mpi_rank; 668 int ep_size, num_ep, mpi_size; 669 670 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 671 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 672 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 673 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 674 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 675 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 676 677 678 ::MPI_Aint datasize, lb; 679 680 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 681 682 void *local_gather_recvbuf; 683 684 if(ep_rank_loc==0) 685 { 686 local_gather_recvbuf = new void*[datasize*num_ep*count]; 687 } 688 689 // local gather to master 690 MPI_Gather_local2(sendbuf, count, datatype, local_gather_recvbuf, comm); 691 692 //MPI_Gather 693 694 if(ep_rank_loc == 0) 695 { 696 int *gatherv_recvcnt; 697 int *gatherv_displs; 698 int gatherv_cnt = count*num_ep; 699 700 gatherv_recvcnt = new int[mpi_size]; 701 gatherv_displs = new int[mpi_size]; 702 703 ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT, gatherv_recvcnt, 1, MPI_INT, static_cast< ::MPI_Comm>(comm.mpi_comm)); 704 705 gatherv_displs[0] = 0; 706 for(int i=1; i<mpi_size; i++) 707 { 708 gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1]; 709 } 710 711 ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 712 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm)); 713 714 delete[] gatherv_recvcnt; 715 delete[] gatherv_displs; 716 } 717 718 MPI_Bcast_local2(recvbuf, count*ep_size, datatype, comm); 719 720 721 if(ep_rank_loc==0) 722 { 723 if(datatype == MPI_INT) 724 { 725 delete[] static_cast<int*>(local_gather_recvbuf); 726 } 727 else if(datatype == MPI_FLOAT) 728 { 729 delete[] static_cast<float*>(local_gather_recvbuf); 730 } 731 else if(datatype == MPI_DOUBLE) 732 { 733 delete[] static_cast<double*>(local_gather_recvbuf); 734 } 735 else if(datatype == MPI_CHAR) 736 { 737 delete[] static_cast<char*>(local_gather_recvbuf); 738 } 739 else if(datatype == MPI_LONG) 740 { 741 delete[] static_cast<long*>(local_gather_recvbuf); 742 } 743 else// if(datatype == MPI_UNSIGNED_LONG) 744 { 745 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 746 } 747 } 748 } 749 750 130 751 }
Note: See TracChangeset
for help on using the changeset viewer.