Changeset 1287 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gather.cpp
- Timestamp:
- 10/04/17 11:45:14 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gather.cpp
r1164 r1287 9 9 #include <mpi.h> 10 10 #include "ep_declaration.hpp" 11 11 #include "ep_mpi.hpp" 12 12 13 13 using namespace std; … … 16 16 { 17 17 18 int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)18 int MPI_Gather_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, int local_root, MPI_Comm comm) 19 19 { 20 if(datatype == MPI_INT) 20 assert(valid_type(datatype)); 21 22 ::MPI_Aint datasize, lb; 23 ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 24 25 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 26 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 27 28 #pragma omp critical (_gather) 29 comm.my_buffer->void_buffer[ep_rank_loc] = const_cast< void* >(sendbuf); 30 31 MPI_Barrier_local(comm); 32 33 if(ep_rank_loc == local_root) 21 34 { 22 Debug("datatype is INT\n"); 23 return MPI_Gather_local_int(sendbuf, count, recvbuf, comm); 24 } 25 else if(datatype == MPI_FLOAT) 26 { 27 Debug("datatype is FLOAT\n"); 28 return MPI_Gather_local_float(sendbuf, count, recvbuf, comm); 29 } 30 else if(datatype == MPI_DOUBLE) 31 { 32 Debug("datatype is DOUBLE\n"); 33 return MPI_Gather_local_double(sendbuf, count, recvbuf, comm); 34 } 35 else if(datatype == MPI_LONG) 36 { 37 Debug("datatype is LONG\n"); 38 return MPI_Gather_local_long(sendbuf, count, recvbuf, comm); 39 } 40 else if(datatype == MPI_UNSIGNED_LONG) 41 { 42 Debug("datatype is uLONG\n"); 43 return MPI_Gather_local_ulong(sendbuf, count, recvbuf, comm); 44 } 45 else if(datatype == MPI_CHAR) 46 { 47 Debug("datatype is CHAR\n"); 48 return MPI_Gather_local_char(sendbuf, count, recvbuf, comm); 49 } 50 else 51 { 52 printf("MPI_Gather Datatype not supported!\n"); 53 exit(0); 54 } 55 } 35 for(int i=0; i<num_ep; i++) 36 memcpy(recvbuf + datasize * i * count, comm.my_buffer->void_buffer[i], datasize * count); 56 37 57 int MPI_Gather_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 58 { 59 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 60 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 61 62 int *buffer = comm.my_buffer->buf_int; 63 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 64 int *recv_buf = static_cast<int*>(recvbuf); 65 66 if(my_rank == 0) 67 { 68 copy(send_buf, send_buf+count, recv_buf); 38 //printf("local_recvbuf = %d %d \n", static_cast<int*>(recvbuf)[0], static_cast<int*>(recvbuf)[1] ); 69 39 } 70 40 71 for(int j=0; j<count; j+=BUFFER_SIZE) 72 { 73 for(int k=1; k<num_ep; k++) 74 { 75 if(my_rank == k) 76 { 77 #pragma omp critical (write_to_buffer) 78 { 79 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer); 80 #pragma omp flush 81 } 82 } 83 84 MPI_Barrier_local(comm); 85 86 if(my_rank == 0) 87 { 88 #pragma omp flush 89 #pragma omp critical (read_from_buffer) 90 { 91 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count); 92 } 93 } 94 95 MPI_Barrier_local(comm); 96 } 97 } 41 MPI_Barrier_local(comm); 98 42 } 99 100 int MPI_Gather_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)101 {102 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;103 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;104 105 float *buffer = comm.my_buffer->buf_float;106 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));107 float *recv_buf = static_cast<float*>(recvbuf);108 109 if(my_rank == 0)110 {111 copy(send_buf, send_buf+count, recv_buf);112 }113 114 for(int j=0; j<count; j+=BUFFER_SIZE)115 {116 for(int k=1; k<num_ep; k++)117 {118 if(my_rank == k)119 {120 #pragma omp critical (write_to_buffer)121 {122 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);123 #pragma omp flush124 }125 }126 127 MPI_Barrier_local(comm);128 129 if(my_rank == 0)130 {131 #pragma omp flush132 #pragma omp critical (read_from_buffer)133 {134 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);135 }136 }137 138 MPI_Barrier_local(comm);139 }140 }141 }142 143 int MPI_Gather_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)144 {145 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;146 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;147 148 double *buffer = comm.my_buffer->buf_double;149 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));150 double *recv_buf = static_cast<double*>(recvbuf);151 152 if(my_rank == 0)153 {154 copy(send_buf, send_buf+count, recv_buf);155 }156 157 for(int j=0; j<count; j+=BUFFER_SIZE)158 {159 for(int k=1; k<num_ep; k++)160 {161 if(my_rank == k)162 {163 #pragma omp critical (write_to_buffer)164 {165 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);166 #pragma omp flush167 }168 }169 170 MPI_Barrier_local(comm);171 172 if(my_rank == 0)173 {174 #pragma omp flush175 #pragma omp critical (read_from_buffer)176 {177 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);178 }179 }180 181 MPI_Barrier_local(comm);182 }183 }184 }185 186 int MPI_Gather_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)187 {188 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;189 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;190 191 long *buffer = comm.my_buffer->buf_long;192 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));193 long *recv_buf = static_cast<long*>(recvbuf);194 195 if(my_rank == 0)196 {197 copy(send_buf, send_buf+count, recv_buf);198 }199 200 for(int j=0; j<count; j+=BUFFER_SIZE)201 {202 for(int k=1; k<num_ep; k++)203 {204 if(my_rank == k)205 {206 #pragma omp critical (write_to_buffer)207 {208 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);209 #pragma omp flush210 }211 }212 213 MPI_Barrier_local(comm);214 215 if(my_rank == 0)216 {217 #pragma omp flush218 #pragma omp critical (read_from_buffer)219 {220 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);221 }222 }223 224 MPI_Barrier_local(comm);225 }226 }227 }228 229 int MPI_Gather_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)230 {231 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;232 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;233 234 unsigned long *buffer = comm.my_buffer->buf_ulong;235 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));236 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);237 238 if(my_rank == 0)239 {240 copy(send_buf, send_buf+count, recv_buf);241 }242 243 for(int j=0; j<count; j+=BUFFER_SIZE)244 {245 for(int k=1; k<num_ep; k++)246 {247 if(my_rank == k)248 {249 #pragma omp critical (write_to_buffer)250 {251 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);252 #pragma omp flush253 }254 }255 256 MPI_Barrier_local(comm);257 258 if(my_rank == 0)259 {260 #pragma omp flush261 #pragma omp critical (read_from_buffer)262 {263 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);264 }265 }266 267 MPI_Barrier_local(comm);268 }269 }270 }271 272 273 int MPI_Gather_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)274 {275 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;276 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;277 278 char *buffer = comm.my_buffer->buf_char;279 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));280 char *recv_buf = static_cast<char*>(recvbuf);281 282 if(my_rank == 0)283 {284 copy(send_buf, send_buf+count, recv_buf);285 }286 287 for(int j=0; j<count; j+=BUFFER_SIZE)288 {289 for(int k=1; k<num_ep; k++)290 {291 if(my_rank == k)292 {293 #pragma omp critical (write_to_buffer)294 {295 copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);296 #pragma omp flush297 }298 }299 300 MPI_Barrier_local(comm);301 302 if(my_rank == 0)303 {304 #pragma omp flush305 #pragma omp critical (read_from_buffer)306 {307 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+k*count);308 }309 }310 311 MPI_Barrier_local(comm);312 }313 }314 }315 316 317 43 318 44 int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) 319 45 { 320 if(!comm.is_ep && comm.mpi_comm)46 if(!comm.is_ep) 321 47 { 322 ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 323 root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 324 return 0; 48 return ::MPI_Gather(const_cast<void*>(sendbuf), sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), 49 root, to_mpi_comm(comm.mpi_comm)); 325 50 } 326 51 327 if(!comm.mpi_comm) return 0; 328 329 MPI_Bcast(&recvcount, 1, MPI_INT, root, comm); 52 assert(sendcount == recvcount && sendtype == recvtype); 330 53 331 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 332 333 MPI_Datatype datatype = sendtype; 334 int count = sendcount; 335 336 int ep_rank, ep_rank_loc, mpi_rank; 337 int ep_size, num_ep, mpi_size; 338 339 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 340 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 341 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 342 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 343 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 344 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 345 54 int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 55 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 56 int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 57 int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 58 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 59 int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 346 60 347 61 int root_mpi_rank = comm.rank_map->at(root).second; 348 62 int root_ep_loc = comm.rank_map->at(root).first; 349 63 64 ::MPI_Aint datasize, lb; 65 ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize); 350 66 351 ::MPI_Aint datasize, lb; 67 bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 68 bool is_root = ep_rank == root; 352 69 353 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);70 void* local_recvbuf; 354 71 355 void *local_gather_recvbuf; 356 void *master_recvbuf; 357 if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 72 if(is_master) 358 73 { 359 master_recvbuf = new void*[datasize*ep_size*count];74 local_recvbuf = new void*[datasize * num_ep * sendcount]; 360 75 } 361 76 362 if(ep_rank_loc==0) 363 { 364 local_gather_recvbuf = new void*[datasize*num_ep*count]; 365 } 366 367 // local gather to master 368 MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm); 369 370 //MPI_Gather 371 372 if(ep_rank_loc == 0) 373 { 374 int *gatherv_recvcnt; 375 int *gatherv_displs; 376 int gatherv_cnt = count*num_ep; 377 378 gatherv_recvcnt = new int[mpi_size]; 379 gatherv_displs = new int[mpi_size]; 77 void* tmp_recvbuf; 78 if(is_root) tmp_recvbuf = new void*[datasize * recvcount * ep_size]; 380 79 381 80 382 ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm)); 81 if(mpi_rank == root_mpi_rank) MPI_Gather_local(sendbuf, sendcount, sendtype, local_recvbuf, root_ep_loc, comm); 82 else MPI_Gather_local(sendbuf, sendcount, sendtype, local_recvbuf, 0, comm); 383 83 384 gatherv_displs[0] = 0; 385 for(int i=1; i<mpi_size; i++) 84 std::vector<int> recvcounts(mpi_size, 0); 85 std::vector<int> displs(mpi_size, 0); 86 87 88 if(is_master) 89 { 90 for(int i=0; i<ep_size; i++) 386 91 { 387 gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];92 recvcounts[comm.rank_map->at(i).second]+=sendcount; 388 93 } 389 94 390 if(root_ep_loc != 0) // gather to root_master 95 for(int i=1; i<mpi_size; i++) 96 displs[i] = displs[i-1] + recvcounts[i-1]; 97 98 ::MPI_Gatherv(local_recvbuf, sendcount*num_ep, sendtype, tmp_recvbuf, recvcounts.data(), displs.data(), recvtype, root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 99 } 100 101 102 // reorder data 103 if(is_root) 104 { 105 // printf("tmp_recvbuf = %d %d %d %d %d %d %d %d\n", static_cast<int*>(tmp_recvbuf)[0], static_cast<int*>(tmp_recvbuf)[1], 106 // static_cast<int*>(tmp_recvbuf)[2], static_cast<int*>(tmp_recvbuf)[3], 107 // static_cast<int*>(tmp_recvbuf)[4], static_cast<int*>(tmp_recvbuf)[5], 108 // static_cast<int*>(tmp_recvbuf)[6], static_cast<int*>(tmp_recvbuf)[7] ); 109 110 int offset; 111 for(int i=0; i<ep_size; i++) 391 112 { 392 ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, gatherv_recvcnt, 393 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 394 } 395 else 396 { 397 ::MPI_Gatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt, 398 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 113 offset = displs[comm.rank_map->at(i).second] + comm.rank_map->at(i).first * sendcount; 114 memcpy(recvbuf + i*sendcount*datasize, tmp_recvbuf+offset*datasize, sendcount*datasize); 115 116 399 117 } 400 118 401 delete[] gatherv_recvcnt;402 delete[] gatherv_displs;403 119 } 404 120 405 121 406 if( root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master122 if(is_master) 407 123 { 408 innode_memcpy(0, master_recvbuf, root_ep_loc, recvbuf, count*ep_size, datatype, comm);124 delete[] local_recvbuf; 409 125 } 410 411 412 413 if(ep_rank_loc==0) 414 { 415 if(datatype == MPI_INT) 416 { 417 delete[] static_cast<int*>(local_gather_recvbuf); 418 } 419 else if(datatype == MPI_FLOAT) 420 { 421 delete[] static_cast<float*>(local_gather_recvbuf); 422 } 423 else if(datatype == MPI_DOUBLE) 424 { 425 delete[] static_cast<double*>(local_gather_recvbuf); 426 } 427 else if(datatype == MPI_CHAR) 428 { 429 delete[] static_cast<char*>(local_gather_recvbuf); 430 } 431 else if(datatype == MPI_LONG) 432 { 433 delete[] static_cast<long*>(local_gather_recvbuf); 434 } 435 else// if(datatype == MPI_UNSIGNED_LONG) 436 { 437 delete[] static_cast<unsigned long*>(local_gather_recvbuf); 438 } 439 440 if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) delete[] master_recvbuf; 441 } 126 if(is_root) delete[] tmp_recvbuf; 127 442 128 } 443 129 444 445 int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)446 {447 if(!comm.is_ep && comm.mpi_comm)448 {449 ::MPI_Allgather(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),450 static_cast< ::MPI_Comm>(comm.mpi_comm));451 return 0;452 }453 454 if(!comm.mpi_comm) return 0;455 456 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);457 458 MPI_Datatype datatype = sendtype;459 int count = sendcount;460 461 int ep_rank, ep_rank_loc, mpi_rank;462 int ep_size, num_ep, mpi_size;463 464 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;465 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;466 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;467 ep_size = comm.ep_comm_ptr->size_rank_info[0].second;468 num_ep = comm.ep_comm_ptr->size_rank_info[1].second;469 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;470 471 472 ::MPI_Aint datasize, lb;473 474 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);475 476 void *local_gather_recvbuf;477 478 if(ep_rank_loc==0)479 {480 local_gather_recvbuf = new void*[datasize*num_ep*count];481 }482 483 // local gather to master484 MPI_Gather_local(sendbuf, count, datatype, local_gather_recvbuf, comm);485 486 //MPI_Gather487 488 if(ep_rank_loc == 0)489 {490 int *gatherv_recvcnt;491 int *gatherv_displs;492 int gatherv_cnt = count*num_ep;493 494 gatherv_recvcnt = new int[mpi_size];495 gatherv_displs = new int[mpi_size];496 497 ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));498 499 gatherv_displs[0] = 0;500 for(int i=1; i<mpi_size; i++)501 {502 gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];503 }504 505 ::MPI_Allgatherv(local_gather_recvbuf, count*num_ep, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,506 gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));507 508 delete[] gatherv_recvcnt;509 delete[] gatherv_displs;510 }511 512 MPI_Bcast_local(recvbuf, count*ep_size, datatype, comm);513 514 515 if(ep_rank_loc==0)516 {517 if(datatype == MPI_INT)518 {519 delete[] static_cast<int*>(local_gather_recvbuf);520 }521 else if(datatype == MPI_FLOAT)522 {523 delete[] static_cast<float*>(local_gather_recvbuf);524 }525 else if(datatype == MPI_DOUBLE)526 {527 delete[] static_cast<double*>(local_gather_recvbuf);528 }529 else if(datatype == MPI_CHAR)530 {531 delete[] static_cast<char*>(local_gather_recvbuf);532 }533 else if(datatype == MPI_LONG)534 {535 delete[] static_cast<long*>(local_gather_recvbuf);536 }537 else// if(datatype == MPI_UNSIGNED_LONG)538 {539 delete[] static_cast<unsigned long*>(local_gather_recvbuf);540 }541 }542 }543 544 545 130 }
Note: See TracChangeset
for help on using the changeset viewer.