Changeset 1295 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatterv.cpp
- Timestamp:
- 10/06/17 13:56:33 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatterv.cpp
r1289 r1295 9 9 #include <mpi.h> 10 10 #include "ep_declaration.hpp" 11 #include "ep_mpi.hpp" 11 12 12 13 using namespace std; … … 15 16 { 16 17 17 int MPI_Scatterv_local2(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype datatype, void *recvbuf, MPI_Comm comm) 18 int MPI_Scatterv_local(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype sendtype, void *recvbuf, int recvcount, 19 MPI_Datatype recvtype, int local_root, MPI_Comm comm) 18 20 { 19 if(datatype == MPI_INT) 20 { 21 Debug("datatype is INT\n"); 22 return MPI_Scatterv_local_int(sendbuf, sendcounts, displs, recvbuf, comm); 23 } 24 else if(datatype == MPI_FLOAT) 25 { 26 Debug("datatype is FLOAT\n"); 27 return MPI_Scatterv_local_float(sendbuf, sendcounts, displs, recvbuf, comm); 28 } 29 else if(datatype == MPI_DOUBLE) 30 { 31 Debug("datatype is DOUBLE\n"); 32 return MPI_Scatterv_local_double(sendbuf, sendcounts, displs, recvbuf, comm); 33 } 34 else if(datatype == MPI_LONG) 35 { 36 Debug("datatype is LONG\n"); 37 return MPI_Scatterv_local_long(sendbuf, sendcounts, displs, recvbuf, comm); 38 } 39 else if(datatype == MPI_UNSIGNED_LONG) 40 { 41 Debug("datatype is uLONG\n"); 42 return MPI_Scatterv_local_ulong(sendbuf, sendcounts, displs, recvbuf, comm); 43 } 44 else if(datatype == MPI_CHAR) 45 { 46 Debug("datatype is CHAR\n"); 47 return MPI_Scatterv_local_char(sendbuf, sendcounts, displs, recvbuf, comm); 48 } 49 else 50 { 51 printf("MPI_scatterv Datatype not supported!\n"); 52 exit(0); 53 } 21 22 assert(valid_type(sendtype) && valid_type(recvtype)); 23 24 ::MPI_Aint datasize, lb; 25 ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize); 26 27 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 28 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 29 30 assert(recvcount == sendcounts[ep_rank_loc]); 31 32 if(ep_rank_loc == local_root) 33 comm.my_buffer->void_buffer[local_root] = const_cast<void*>(sendbuf); 34 35 MPI_Barrier_local(comm); 36 37 #pragma omp critical (_scatterv) 38 memcpy(recvbuf, comm.my_buffer->void_buffer[local_root]+datasize*displs[ep_rank_loc], datasize * recvcount); 39 40 41 MPI_Barrier_local(comm); 54 42 } 55 56 int MPI_Scatterv_local_int(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)57 {58 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;59 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;60 61 62 int *buffer = comm.my_buffer->buf_int;63 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));64 int *recv_buf = static_cast<int*>(recvbuf);65 66 for(int k=0; k<num_ep; k++)67 {68 int count = sendcounts[k];69 for(int j=0; j<count; j+=BUFFER_SIZE)70 {71 if(my_rank == 0)72 {73 #pragma omp critical (write_to_buffer)74 {75 copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);76 #pragma omp flush77 }78 }79 80 MPI_Barrier_local(comm);81 82 if(my_rank == k)83 {84 #pragma omp critical (read_from_buffer)85 {86 #pragma omp flush87 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);88 }89 }90 MPI_Barrier_local(comm);91 }92 }93 }94 95 int MPI_Scatterv_local_float(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)96 {97 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;98 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;99 100 101 float *buffer = comm.my_buffer->buf_float;102 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));103 float *recv_buf = static_cast<float*>(recvbuf);104 105 for(int k=0; k<num_ep; k++)106 {107 int count = sendcounts[k];108 for(int j=0; j<count; j+=BUFFER_SIZE)109 {110 if(my_rank == 0)111 {112 #pragma omp critical (write_to_buffer)113 {114 copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);115 #pragma omp flush116 }117 }118 119 MPI_Barrier_local(comm);120 121 if(my_rank == k)122 {123 #pragma omp critical (read_from_buffer)124 {125 #pragma omp flush126 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);127 }128 }129 MPI_Barrier_local(comm);130 }131 }132 }133 134 int MPI_Scatterv_local_double(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)135 {136 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;137 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;138 139 140 double *buffer = comm.my_buffer->buf_double;141 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));142 double *recv_buf = static_cast<double*>(recvbuf);143 144 for(int k=0; k<num_ep; k++)145 {146 int count = sendcounts[k];147 for(int j=0; j<count; j+=BUFFER_SIZE)148 {149 if(my_rank == 0)150 {151 #pragma omp critical (write_to_buffer)152 {153 copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);154 #pragma omp flush155 }156 }157 158 MPI_Barrier_local(comm);159 160 if(my_rank == k)161 {162 #pragma omp critical (read_from_buffer)163 {164 #pragma omp flush165 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);166 }167 }168 MPI_Barrier_local(comm);169 }170 }171 }172 173 int MPI_Scatterv_local_long(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)174 {175 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;176 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;177 178 179 long *buffer = comm.my_buffer->buf_long;180 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));181 long *recv_buf = static_cast<long*>(recvbuf);182 183 for(int k=0; k<num_ep; k++)184 {185 int count = sendcounts[k];186 for(int j=0; j<count; j+=BUFFER_SIZE)187 {188 if(my_rank == 0)189 {190 #pragma omp critical (write_to_buffer)191 {192 copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);193 #pragma omp flush194 }195 }196 197 MPI_Barrier_local(comm);198 199 if(my_rank == k)200 {201 #pragma omp critical (read_from_buffer)202 {203 #pragma omp flush204 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);205 }206 }207 MPI_Barrier_local(comm);208 }209 }210 }211 212 213 int MPI_Scatterv_local_ulong(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)214 {215 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;216 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;217 218 219 unsigned long *buffer = comm.my_buffer->buf_ulong;220 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));221 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);222 223 for(int k=0; k<num_ep; k++)224 {225 int count = sendcounts[k];226 for(int j=0; j<count; j+=BUFFER_SIZE)227 {228 if(my_rank == 0)229 {230 #pragma omp critical (write_to_buffer)231 {232 copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);233 #pragma omp flush234 }235 }236 237 MPI_Barrier_local(comm);238 239 if(my_rank == k)240 {241 #pragma omp critical (read_from_buffer)242 {243 #pragma omp flush244 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);245 }246 }247 MPI_Barrier_local(comm);248 }249 }250 }251 252 253 int MPI_Scatterv_local_char(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)254 {255 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;256 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;257 258 259 char *buffer = comm.my_buffer->buf_char;260 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));261 char *recv_buf = static_cast<char*>(recvbuf);262 263 for(int k=0; k<num_ep; k++)264 {265 int count = sendcounts[k];266 for(int j=0; j<count; j+=BUFFER_SIZE)267 {268 if(my_rank == 0)269 {270 #pragma omp critical (write_to_buffer)271 {272 copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);273 #pragma omp flush274 }275 }276 277 MPI_Barrier_local(comm);278 279 if(my_rank == k)280 {281 #pragma omp critical (read_from_buffer)282 {283 #pragma omp flush284 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);285 }286 }287 MPI_Barrier_local(comm);288 }289 }290 }291 292 43 293 44 int MPI_Scatterv(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype sendtype, void *recvbuf, int recvcount, … … 296 47 if(!comm.is_ep) 297 48 { 298 ::MPI_Scatterv(sendbuf, sendcounts, displs, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, 299 static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 300 return 0; 49 return ::MPI_Scatterv(sendbuf, sendcounts, displs, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), root, to_mpi_comm(comm.mpi_comm)); 301 50 } 302 if(!comm.mpi_comm) return 0; 51 52 assert(sendtype == recvtype); 303 53 304 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 305 306 MPI_Datatype datatype = sendtype; 54 int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 55 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 56 int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 57 int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 58 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 59 int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 307 60 308 61 int root_mpi_rank = comm.rank_map->at(root).second; 309 62 int root_ep_loc = comm.rank_map->at(root).first; 310 63 311 int ep_rank, ep_rank_loc, mpi_rank;312 int ep_size, num_ep, mpi_size;64 bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 65 bool is_root = ep_rank == root; 313 66 314 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 315 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 316 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 317 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 318 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 319 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 320 321 if(ep_rank != root) 322 { 323 sendcounts = new int[ep_size]; 324 displs = new int[ep_size]; 325 } 326 327 MPI_Bcast(const_cast<int*>(sendcounts), ep_size, MPI_INT, root, comm); 328 MPI_Bcast(const_cast<int*>(displs), ep_size, MPI_INT, root, comm); 329 330 67 MPI_Datatype datatype = sendtype; 331 68 int count = recvcount; 332 69 333 70 ::MPI_Aint datasize, lb; 71 ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 72 73 void *tmp_sendbuf; 74 if(is_root) tmp_sendbuf = new void*[ep_size * count * datasize]; 334 75 335 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 76 // reorder tmp_sendbuf 77 std::vector<int>local_ranks(num_ep); 78 std::vector<int>ranks(ep_size); 336 79 337 assert(accumulate(sendcounts, sendcounts+ep_size-1, 0) == displs[ep_size-1]); // Only for contunuous gather. 80 if(mpi_rank == root_mpi_rank) MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), root_ep_loc, comm); 81 else MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), 0, comm); 338 82 339 83 340 void *master_sendbuf;341 void *local_recvbuf;84 std::vector<int> recvcounts(mpi_size, 0); 85 std::vector<int> my_displs(mpi_size, 0); 342 86 343 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) 87 88 if(is_master) 344 89 { 345 int count_sum = accumulate(sendcounts, sendcounts+ep_size, 0); 346 if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count_sum]; 90 for(int i=0; i<ep_size; i++) 91 { 92 recvcounts[comm.rank_map->at(i).second]++; 93 } 347 94 348 innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count_sum, datatype, comm); 95 for(int i=1; i<mpi_size; i++) 96 my_displs[i] = my_displs[i-1] + recvcounts[i-1]; 97 98 ::MPI_Gatherv(local_ranks.data(), num_ep, MPI_INT, ranks.data(), recvcounts.data(), my_displs.data(), MPI_INT, root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 349 99 } 350 100 351 101 102 if(is_root) 103 { 104 int local_displs = 0; 105 for(int i=0; i<ep_size; i++) 106 { 107 //printf("i=%d : start from %d, src displs = %d, count = %d\n ", i, local_displs/datasize, displs[ranks[i]], sendcounts[ranks[i]]); 108 memcpy(tmp_sendbuf+local_displs, sendbuf + displs[ranks[i]]*datasize, sendcounts[ranks[i]]*datasize); 109 local_displs += sendcounts[ranks[i]]*datasize; 110 } 111 112 //for(int i=0; i<ep_size*2; i++) printf("%d\t", static_cast<int*>(const_cast<void*>(tmp_sendbuf))[i]); 113 } 352 114 353 if(ep_rank_loc == 0) 115 // MPI_Scatterv from root to masters 116 117 void* local_sendbuf; 118 int local_sendcount; 119 if(mpi_rank == root_mpi_rank) MPI_Reduce_local(&recvcount, &local_sendcount, 1, MPI_INT, MPI_SUM, root_ep_loc, comm); 120 else MPI_Reduce_local(&recvcount, &local_sendcount, 1, MPI_INT, MPI_SUM, 0, comm); 121 122 if(is_master) 354 123 { 355 int mpi_sendcnt = accumulate(sendcounts+ep_rank, sendcounts+ep_rank+num_ep, 0);356 int mpi_scatterv_sendcnt[mpi_size];357 int mpi_displs[mpi_size];124 local_sendbuf = new void*[datasize * local_sendcount]; 125 126 ::MPI_Gather(&local_sendcount, 1, to_mpi_type(MPI_INT), recvcounts.data(), 1, to_mpi_type(MPI_INT), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 358 127 359 local_recvbuf = new void*[datasize*mpi_sendcnt];128 if(is_root) for(int i=1; i<mpi_size; i++) my_displs[i] = my_displs[i-1] + recvcounts[i-1]; 360 129 361 ::MPI_ Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));130 ::MPI_Scatterv(tmp_sendbuf, recvcounts.data(), my_displs.data(), to_mpi_type(sendtype), local_sendbuf, num_ep*count, to_mpi_type(recvtype), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 362 131 363 mpi_displs[0] = displs[0]; 364 for(int i=1; i<mpi_size; i++) 365 mpi_displs[i] = mpi_displs[i-1] + mpi_scatterv_sendcnt[i-1]; 132 // printf("my_displs = %d %d %d %d\n", my_displs[0], my_displs[1], my_displs[2], my_displs[3] ); 133 134 // printf("%d %d %d %d %d %d %d %d\n", static_cast<int*>(local_sendbuf)[0], static_cast<int*>(local_sendbuf)[1], static_cast<int*>(local_sendbuf)[2], static_cast<int*>(local_sendbuf)[3], 135 // static_cast<int*>(local_sendbuf)[4], static_cast<int*>(local_sendbuf)[5], static_cast<int*>(local_sendbuf)[6], static_cast<int*>(local_sendbuf)[7]); 136 } 137 138 std::vector<int>local_sendcounts(num_ep, 0); 139 std::vector<int>local_displs(num_ep, 0); 140 141 MPI_Gather_local(&recvcount, 1, MPI_INT, local_sendcounts.data(), 0, comm); 142 MPI_Bcast_local(local_sendcounts.data(), num_ep, MPI_INT, 0, comm); 143 for(int i=1; i<num_ep; i++) 144 local_displs[i] = local_displs[i-1] + local_sendcounts[i-1]; 145 146 147 if(mpi_rank == root_mpi_rank) MPI_Scatterv_local(local_sendbuf, local_sendcounts.data(), local_displs.data(), sendtype, recvbuf, recvcount, recvtype, root_ep_loc, comm); 148 else MPI_Scatterv_local(local_sendbuf, local_sendcounts.data(), local_displs.data(), sendtype, recvbuf, recvcount, recvtype, 0, comm); 366 149 367 150 368 if(root_ep_loc!=0) 369 { 370 ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype), 371 local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 372 } 373 else 374 { 375 ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype), 376 local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 377 } 378 } 151 if(is_root) delete[] tmp_sendbuf; 152 if(is_master) delete[] local_sendbuf; 153 } 379 154 380 int local_displs[num_ep];381 local_displs[0] = 0;382 for(int i=1; i<num_ep; i++)383 {384 local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];385 }386 155 387 MPI_Scatterv_local2(local_recvbuf, sendcounts+ep_rank-ep_rank_loc, local_displs, datatype, recvbuf, comm);388 389 if(ep_rank_loc == 0)390 {391 if(datatype == MPI_INT)392 {393 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf);394 delete[] static_cast<int*>(local_recvbuf);395 }396 else if(datatype == MPI_FLOAT)397 {398 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf);399 delete[] static_cast<float*>(local_recvbuf);400 }401 else if(datatype == MPI_DOUBLE)402 {403 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf);404 delete[] static_cast<double*>(local_recvbuf);405 }406 else if(datatype == MPI_LONG)407 {408 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf);409 delete[] static_cast<long*>(local_recvbuf);410 }411 else if(datatype == MPI_UNSIGNED_LONG)412 {413 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf);414 delete[] static_cast<unsigned long*>(local_recvbuf);415 }416 else // if(datatype == MPI_DOUBLE)417 {418 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf);419 delete[] static_cast<char*>(local_recvbuf);420 }421 }422 else423 {424 delete[] sendcounts;425 delete[] displs;426 }427 428 }429 156 }
Note: See TracChangeset
for help on using the changeset viewer.