Changeset 1289 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp
- Timestamp:
- 10/04/17 17:02:13 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp
r1287 r1289 9 9 #include <mpi.h> 10 10 #include "ep_declaration.hpp" 11 #include "ep_mpi.hpp"12 11 13 12 using namespace std; … … 16 15 { 17 16 18 int MPI_Scatter_local(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int local_root, MPI_Comm comm) 19 { 20 assert(valid_type(sendtype) && valid_type(recvtype)); 21 assert(recvcount == sendcount); 22 23 ::MPI_Aint datasize, lb; 24 ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize); 25 26 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 27 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 28 29 30 if(ep_rank_loc == local_root) 31 comm.my_buffer->void_buffer[local_root] = const_cast<void*>(sendbuf); 32 33 MPI_Barrier_local(comm); 34 35 #pragma omp critical (_scatter) 36 memcpy(recvbuf, comm.my_buffer->void_buffer[local_root]+datasize*ep_rank_loc*sendcount, datasize * recvcount); 37 38 39 MPI_Barrier_local(comm); 40 } 17 int MPI_Scatter_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm) 18 { 19 if(datatype == MPI_INT) 20 { 21 Debug("datatype is INT\n"); 22 return MPI_Scatter_local_int(sendbuf, count, recvbuf, comm); 23 } 24 else if(datatype == MPI_FLOAT) 25 { 26 Debug("datatype is FLOAT\n"); 27 return MPI_Scatter_local_float(sendbuf, count, recvbuf, comm); 28 } 29 else if(datatype == MPI_DOUBLE) 30 { 31 Debug("datatype is DOUBLE\n"); 32 return MPI_Scatter_local_double(sendbuf, count, recvbuf, comm); 33 } 34 else if(datatype == MPI_LONG) 35 { 36 Debug("datatype is LONG\n"); 37 return MPI_Scatter_local_long(sendbuf, count, recvbuf, comm); 38 } 39 else if(datatype == MPI_UNSIGNED_LONG) 40 { 41 Debug("datatype is uLONG\n"); 42 return MPI_Scatter_local_ulong(sendbuf, count, recvbuf, comm); 43 } 44 else if(datatype == MPI_CHAR) 45 { 46 Debug("datatype is CHAR\n"); 47 return MPI_Scatter_local_char(sendbuf, count, recvbuf, comm); 48 } 49 else 50 { 51 printf("MPI_Scatter Datatype not supported!\n"); 52 exit(0); 53 } 54 } 55 56 int MPI_Scatter_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 57 { 58 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 59 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 60 61 62 int *buffer = comm.my_buffer->buf_int; 63 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 64 int *recv_buf = static_cast<int*>(recvbuf); 65 66 for(int k=0; k<num_ep; k++) 67 { 68 for(int j=0; j<count; j+=BUFFER_SIZE) 69 { 70 if(my_rank == 0) 71 { 72 #pragma omp critical (write_to_buffer) 73 { 74 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 75 #pragma omp flush 76 } 77 } 78 79 MPI_Barrier_local(comm); 80 81 if(my_rank == k) 82 { 83 #pragma omp critical (read_from_buffer) 84 { 85 #pragma omp flush 86 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 87 } 88 } 89 MPI_Barrier_local(comm); 90 } 91 } 92 } 93 94 int MPI_Scatter_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 95 { 96 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 97 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 98 99 float *buffer = comm.my_buffer->buf_float; 100 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 101 float *recv_buf = static_cast<float*>(recvbuf); 102 103 for(int k=0; k<num_ep; k++) 104 { 105 for(int j=0; j<count; j+=BUFFER_SIZE) 106 { 107 if(my_rank == 0) 108 { 109 #pragma omp critical (write_to_buffer) 110 { 111 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 112 #pragma omp flush 113 } 114 } 115 116 MPI_Barrier_local(comm); 117 118 if(my_rank == k) 119 { 120 #pragma omp critical (read_from_buffer) 121 { 122 #pragma omp flush 123 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 124 } 125 } 126 MPI_Barrier_local(comm); 127 } 128 } 129 } 130 131 int MPI_Scatter_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 132 { 133 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 134 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 135 136 double *buffer = comm.my_buffer->buf_double; 137 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 138 double *recv_buf = static_cast<double*>(recvbuf); 139 140 for(int k=0; k<num_ep; k++) 141 { 142 for(int j=0; j<count; j+=BUFFER_SIZE) 143 { 144 if(my_rank == 0) 145 { 146 #pragma omp critical (write_to_buffer) 147 { 148 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 149 #pragma omp flush 150 } 151 } 152 153 MPI_Barrier_local(comm); 154 155 if(my_rank == k) 156 { 157 #pragma omp critical (read_from_buffer) 158 { 159 #pragma omp flush 160 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 161 } 162 } 163 MPI_Barrier_local(comm); 164 } 165 } 166 } 167 168 int MPI_Scatter_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 169 { 170 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 171 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 172 173 long *buffer = comm.my_buffer->buf_long; 174 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 175 long *recv_buf = static_cast<long*>(recvbuf); 176 177 for(int k=0; k<num_ep; k++) 178 { 179 for(int j=0; j<count; j+=BUFFER_SIZE) 180 { 181 if(my_rank == 0) 182 { 183 #pragma omp critical (write_to_buffer) 184 { 185 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 186 #pragma omp flush 187 } 188 } 189 190 MPI_Barrier_local(comm); 191 192 if(my_rank == k) 193 { 194 #pragma omp critical (read_from_buffer) 195 { 196 #pragma omp flush 197 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 198 } 199 } 200 MPI_Barrier_local(comm); 201 } 202 } 203 } 204 205 206 int MPI_Scatter_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 207 { 208 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 209 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 210 211 unsigned long *buffer = comm.my_buffer->buf_ulong; 212 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 213 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 214 215 for(int k=0; k<num_ep; k++) 216 { 217 for(int j=0; j<count; j+=BUFFER_SIZE) 218 { 219 if(my_rank == 0) 220 { 221 #pragma omp critical (write_to_buffer) 222 { 223 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 224 #pragma omp flush 225 } 226 } 227 228 MPI_Barrier_local(comm); 229 230 if(my_rank == k) 231 { 232 #pragma omp critical (read_from_buffer) 233 { 234 #pragma omp flush 235 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 236 } 237 } 238 MPI_Barrier_local(comm); 239 } 240 } 241 } 242 243 244 int MPI_Scatter_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 245 { 246 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 247 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 248 249 char *buffer = comm.my_buffer->buf_char; 250 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 251 char *recv_buf = static_cast<char*>(recvbuf); 252 253 for(int k=0; k<num_ep; k++) 254 { 255 for(int j=0; j<count; j+=BUFFER_SIZE) 256 { 257 if(my_rank == 0) 258 { 259 #pragma omp critical (write_to_buffer) 260 { 261 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 262 #pragma omp flush 263 } 264 } 265 266 MPI_Barrier_local(comm); 267 268 if(my_rank == k) 269 { 270 #pragma omp critical (read_from_buffer) 271 { 272 #pragma omp flush 273 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 274 } 275 } 276 MPI_Barrier_local(comm); 277 } 278 } 279 } 280 281 282 283 41 284 42 285 int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) … … 44 287 if(!comm.is_ep) 45 288 { 46 return ::MPI_Scatter(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), root, to_mpi_comm(comm.mpi_comm)); 47 } 48 49 assert(sendcount == recvcount); 50 51 int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 52 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 53 int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 54 int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 55 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 56 int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 289 ::MPI_Scatter(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 290 root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 291 return 0; 292 } 293 294 if(!comm.mpi_comm) return 0; 295 296 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 57 297 58 298 int root_mpi_rank = comm.rank_map->at(root).second; 59 299 int root_ep_loc = comm.rank_map->at(root).first; 60 300 61 bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 62 bool is_root = ep_rank == root; 301 int ep_rank, ep_rank_loc, mpi_rank; 302 int ep_size, num_ep, mpi_size; 303 304 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 305 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 306 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 307 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 308 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 309 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 310 63 311 64 312 MPI_Datatype datatype = sendtype; … … 66 314 67 315 ::MPI_Aint datasize, lb; 68 ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 69 70 void *tmp_sendbuf; 71 if(is_root) tmp_sendbuf = new void*[ep_size * count * datasize]; 72 73 // reorder tmp_sendbuf 74 std::vector<int>local_ranks(num_ep); 75 std::vector<int>ranks(ep_size); 76 77 if(mpi_rank == root_mpi_rank) MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), root_ep_loc, comm); 78 else MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), 0, comm); 79 80 81 std::vector<int> recvcounts(mpi_size, 0); 82 std::vector<int> displs(mpi_size, 0); 83 84 85 if(is_master) 86 { 87 for(int i=0; i<ep_size; i++) 88 { 89 recvcounts[comm.rank_map->at(i).second]++; 90 } 91 316 317 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 318 319 320 void *master_sendbuf; 321 void *local_recvbuf; 322 323 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) 324 { 325 if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count*ep_size]; 326 327 innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count*ep_size, datatype, comm); 328 } 329 330 331 332 if(ep_rank_loc == 0) 333 { 334 int mpi_sendcnt = count*num_ep; 335 int mpi_scatterv_sendcnt[mpi_size]; 336 int displs[mpi_size]; 337 338 local_recvbuf = new void*[datasize*mpi_sendcnt]; 339 340 ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 341 342 displs[0] = 0; 92 343 for(int i=1; i<mpi_size; i++) 93 displs[i] = displs[i-1] + recvcounts[i-1]; 94 95 ::MPI_Gatherv(local_ranks.data(), num_ep, MPI_INT, ranks.data(), recvcounts.data(), displs.data(), MPI_INT, root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 96 } 97 98 99 100 // if(is_root) printf("\nranks = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", ranks[0], ranks[1], ranks[2], ranks[3], ranks[4], ranks[5], ranks[6], ranks[7], 101 // ranks[8], ranks[9], ranks[10], ranks[11], ranks[12], ranks[13], ranks[14], ranks[15]); 102 103 if(is_root) 104 for(int i=0; i<ep_size; i++) 105 { 106 memcpy(tmp_sendbuf + i*datasize*count, sendbuf + ranks[i]*datasize*count, count*datasize); 107 } 108 109 // if(is_root) printf("\ntmp_sendbuf = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[2], static_cast<int*>(tmp_sendbuf)[4], static_cast<int*>(tmp_sendbuf)[6], 110 // static_cast<int*>(tmp_sendbuf)[8], static_cast<int*>(tmp_sendbuf)[10], static_cast<int*>(tmp_sendbuf)[12], static_cast<int*>(tmp_sendbuf)[14], 111 // static_cast<int*>(tmp_sendbuf)[16], static_cast<int*>(tmp_sendbuf)[18], static_cast<int*>(tmp_sendbuf)[20], static_cast<int*>(tmp_sendbuf)[22], 112 // static_cast<int*>(tmp_sendbuf)[24], static_cast<int*>(tmp_sendbuf)[26], static_cast<int*>(tmp_sendbuf)[28], static_cast<int*>(tmp_sendbuf)[30] ); 113 114 115 // MPI_Scatterv from root to masters 116 void* local_recvbuf; 117 if(is_master) local_recvbuf = new void*[datasize * num_ep * count]; 118 119 120 if(is_master) 121 { 122 int local_sendcount = num_ep * count; 123 ::MPI_Gather(&local_sendcount, 1, to_mpi_type(MPI_INT), recvcounts.data(), 1, to_mpi_type(MPI_INT), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 124 125 if(is_root) for(int i=1; i<mpi_size; i++) displs[i] = displs[i-1] + recvcounts[i-1]; 126 127 ::MPI_Scatterv(tmp_sendbuf, recvcounts.data(), displs.data(), to_mpi_type(sendtype), local_recvbuf, num_ep*count, to_mpi_type(recvtype), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 128 129 // printf("local_recvbuf = %d %d %d %d\n", static_cast<int*>(local_recvbuf)[0], static_cast<int*>(local_recvbuf)[1], static_cast<int*>(local_recvbuf)[2], static_cast<int*>(local_recvbuf)[3]); 130 // static_cast<int*>(local_recvbuf)[4], static_cast<int*>(local_recvbuf)[5], static_cast<int*>(local_recvbuf)[6], static_cast<int*>(local_recvbuf)[7]); 131 } 132 133 if(mpi_rank == root_mpi_rank) MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, root_ep_loc, comm); 134 else MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, 0, comm); 135 136 if(is_root) delete[] tmp_sendbuf; 137 if(is_master) delete[] local_recvbuf; 138 } 139 344 displs[i] = displs[i-1] + mpi_scatterv_sendcnt[i-1]; 345 346 347 if(root_ep_loc!=0) 348 { 349 ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 350 local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 351 } 352 else 353 { 354 ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 355 local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 356 } 357 } 358 359 MPI_Scatter_local2(local_recvbuf, count, datatype, recvbuf, comm); 360 361 if(ep_rank_loc == 0) 362 { 363 if(datatype == MPI_INT) 364 { 365 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf); 366 delete[] static_cast<int*>(local_recvbuf); 367 } 368 else if(datatype == MPI_FLOAT) 369 { 370 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf); 371 delete[] static_cast<float*>(local_recvbuf); 372 } 373 else if(datatype == MPI_DOUBLE) 374 { 375 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf); 376 delete[] static_cast<double*>(local_recvbuf); 377 } 378 else if(datatype == MPI_LONG) 379 { 380 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf); 381 delete[] static_cast<long*>(local_recvbuf); 382 } 383 else if(datatype == MPI_UNSIGNED_LONG) 384 { 385 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf); 386 delete[] static_cast<unsigned long*>(local_recvbuf); 387 } 388 else //if(datatype == MPI_DOUBLE) 389 { 390 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf); 391 delete[] static_cast<char*>(local_recvbuf); 392 } 393 } 394 395 } 140 396 }
Note: See TracChangeset
for help on using the changeset viewer.