source: XIOS/dev/branch_yushan/extern/src_ep_dev/ep_scatterv.cpp @ 1053

Last change on this file since 1053 was 1053, checked in by yushan, 6 years ago

ep_lib namespace specified when netcdf involved

File size: 12.5 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scatterv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16
17  int MPI_Scatterv_local(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Scatterv_local_int(sendbuf, sendcounts, displs, recvbuf, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Scatterv_local_float(sendbuf, sendcounts, displs, recvbuf, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Scatterv_local_double(sendbuf, sendcounts, displs, recvbuf, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Scatterv_local_long(sendbuf, sendcounts, displs, recvbuf, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Scatterv_local_ulong(sendbuf, sendcounts, displs, recvbuf, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Scatterv_local_char(sendbuf, sendcounts, displs, recvbuf, comm);
48    }
49    else
50    {
51      printf("MPI_scatterv Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Scatterv_local_int(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    for(int k=0; k<num_ep; k++)
67    {
68      int count = sendcounts[k];
69      for(int j=0; j<count; j+=BUFFER_SIZE)
70      {
71        if(my_rank == 0)
72        {
73          #pragma omp critical (write_to_buffer)
74          {
75            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
76            #pragma omp flush
77          }
78        }
79
80        MPI_Barrier_local(comm);
81
82        if(my_rank == k)
83        {
84          #pragma omp critical (read_from_buffer)
85          {
86            #pragma omp flush
87            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
88          }
89        }
90        MPI_Barrier_local(comm);
91      }
92    }
93  }
94
95  int MPI_Scatterv_local_float(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
96  {
97    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
98    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
99
100
101    float *buffer = comm.my_buffer->buf_float;
102    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
103    float *recv_buf = static_cast<float*>(recvbuf);
104
105    for(int k=0; k<num_ep; k++)
106    {
107      int count = sendcounts[k];
108      for(int j=0; j<count; j+=BUFFER_SIZE)
109      {
110        if(my_rank == 0)
111        {
112          #pragma omp critical (write_to_buffer)
113          {
114            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
115            #pragma omp flush
116          }
117        }
118
119        MPI_Barrier_local(comm);
120
121        if(my_rank == k)
122        {
123          #pragma omp critical (read_from_buffer)
124          {
125            #pragma omp flush
126            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
127          }
128        }
129        MPI_Barrier_local(comm);
130      }
131    }
132  }
133
134  int MPI_Scatterv_local_double(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
135  {
136    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
137    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
138
139
140    double *buffer = comm.my_buffer->buf_double;
141    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
142    double *recv_buf = static_cast<double*>(recvbuf);
143
144    for(int k=0; k<num_ep; k++)
145    {
146      int count = sendcounts[k];
147      for(int j=0; j<count; j+=BUFFER_SIZE)
148      {
149        if(my_rank == 0)
150        {
151          #pragma omp critical (write_to_buffer)
152          {
153            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
154            #pragma omp flush
155          }
156        }
157
158        MPI_Barrier_local(comm);
159
160        if(my_rank == k)
161        {
162          #pragma omp critical (read_from_buffer)
163          {
164            #pragma omp flush
165            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
166          }
167        }
168        MPI_Barrier_local(comm);
169      }
170    }
171  }
172
173  int MPI_Scatterv_local_long(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
174  {
175    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
176    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
177
178
179    long *buffer = comm.my_buffer->buf_long;
180    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
181    long *recv_buf = static_cast<long*>(recvbuf);
182
183    for(int k=0; k<num_ep; k++)
184    {
185      int count = sendcounts[k];
186      for(int j=0; j<count; j+=BUFFER_SIZE)
187      {
188        if(my_rank == 0)
189        {
190          #pragma omp critical (write_to_buffer)
191          {
192            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
193            #pragma omp flush
194          }
195        }
196
197        MPI_Barrier_local(comm);
198
199        if(my_rank == k)
200        {
201          #pragma omp critical (read_from_buffer)
202          {
203            #pragma omp flush
204            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
205          }
206        }
207        MPI_Barrier_local(comm);
208      }
209    }
210  }
211
212
213  int MPI_Scatterv_local_ulong(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
214  {
215    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
216    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
217
218
219    unsigned long *buffer = comm.my_buffer->buf_ulong;
220    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
221    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
222
223    for(int k=0; k<num_ep; k++)
224    {
225      int count = sendcounts[k];
226      for(int j=0; j<count; j+=BUFFER_SIZE)
227      {
228        if(my_rank == 0)
229        {
230          #pragma omp critical (write_to_buffer)
231          {
232            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
233            #pragma omp flush
234          }
235        }
236
237        MPI_Barrier_local(comm);
238
239        if(my_rank == k)
240        {
241          #pragma omp critical (read_from_buffer)
242          {
243            #pragma omp flush
244            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
245          }
246        }
247        MPI_Barrier_local(comm);
248      }
249    }
250  }
251
252
253  int MPI_Scatterv_local_char(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
254  {
255    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
256    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
257
258
259    char *buffer = comm.my_buffer->buf_char;
260    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
261    char *recv_buf = static_cast<char*>(recvbuf);
262
263    for(int k=0; k<num_ep; k++)
264    {
265      int count = sendcounts[k];
266      for(int j=0; j<count; j+=BUFFER_SIZE)
267      {
268        if(my_rank == 0)
269        {
270          #pragma omp critical (write_to_buffer)
271          {
272            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
273            #pragma omp flush
274          }
275        }
276
277        MPI_Barrier_local(comm);
278
279        if(my_rank == k)
280        {
281          #pragma omp critical (read_from_buffer)
282          {
283            #pragma omp flush
284            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
285          }
286        }
287        MPI_Barrier_local(comm);
288      }
289    }
290  }
291
292
293  int MPI_Scatterv(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype sendtype, void *recvbuf, int recvcount,
294                   MPI_Datatype recvtype, int root, MPI_Comm comm)
295  {
296    if(!comm.is_ep)
297    {
298      ::MPI_Scatterv(sendbuf, sendcounts, displs, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount,
299                     static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
300      return 0;
301    }
302    if(!comm.mpi_comm) return 0;
303
304    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
305
306    MPI_Datatype datatype = sendtype;
307
308    int root_mpi_rank = comm.rank_map->at(root).second;
309    int root_ep_loc = comm.rank_map->at(root).first;
310
311    int ep_rank, ep_rank_loc, mpi_rank;
312    int ep_size, num_ep, mpi_size;
313
314    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
315    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
316    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
317    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
318    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
319    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
320   
321    MPI_Bcast(const_cast<int*>(sendcounts), ep_size, MPI_INT, root, comm);
322    MPI_Bcast(const_cast<int*>(displs), ep_size, MPI_INT, root, comm);
323
324
325    int count = recvcount;
326
327    ::MPI_Aint datasize, lb;
328
329    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
330
331    assert(accumulate(sendcounts, sendcounts+ep_size-1, 0) == displs[ep_size-1]); // Only for contunuous gather.
332
333
334    void *master_sendbuf;
335    void *local_recvbuf;
336
337    if(root_ep_loc!=0 && mpi_rank == root_mpi_rank)
338    {
339      int count_sum = accumulate(sendcounts, sendcounts+ep_size, 0);
340      if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count_sum];
341
342      innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count_sum, datatype, comm);
343    }
344
345
346
347    if(ep_rank_loc == 0)
348    {
349      int mpi_sendcnt = accumulate(sendcounts+ep_rank, sendcounts+ep_rank+num_ep, 0);
350      int mpi_scatterv_sendcnt[mpi_size];
351      int mpi_displs[mpi_size];
352
353      local_recvbuf = new void*[datasize*mpi_sendcnt];
354
355      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_scatterv_sendcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
356
357      mpi_displs[0] = displs[0];
358      for(int i=1; i<mpi_size; i++)
359        mpi_displs[i] = mpi_displs[i-1] + mpi_scatterv_sendcnt[i-1];
360
361
362      if(root_ep_loc!=0)
363      {
364        ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype),
365                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
366      }
367      else
368      {
369        ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype),
370                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
371      }
372    }
373
374    int local_displs[num_ep];
375    local_displs[0] = 0;
376    for(int i=1; i<num_ep; i++)
377    {
378      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
379    }
380
381    MPI_Scatterv_local(local_recvbuf, sendcounts+ep_rank-ep_rank_loc, local_displs, datatype, recvbuf, comm);
382
383    if(ep_rank_loc == 0)
384    {
385      if(datatype == MPI_INT)
386      {
387        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf);
388        delete[] static_cast<int*>(local_recvbuf);
389      }
390      else if(datatype == MPI_FLOAT)
391      {
392        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf);
393        delete[] static_cast<float*>(local_recvbuf);
394      }
395      else  if(datatype == MPI_DOUBLE)
396      {
397        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf);
398        delete[] static_cast<double*>(local_recvbuf);
399      }
400      else  if(datatype == MPI_LONG)
401      {
402        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf);
403        delete[] static_cast<long*>(local_recvbuf);
404      }
405      else  if(datatype == MPI_UNSIGNED_LONG)
406      {
407        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf);
408        delete[] static_cast<unsigned long*>(local_recvbuf);
409      }
410      else // if(datatype == MPI_DOUBLE)
411      {
412        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf);
413        delete[] static_cast<char*>(local_recvbuf);
414      }
415    }
416
417  }
418}
Note: See TracBrowser for help on using the repository browser.