source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatterv.cpp @ 1289

Last change on this file since 1289 was 1289, checked in by yushan, 7 years ago

EP update part 2

File size: 12.6 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scatterv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16
17  int MPI_Scatterv_local2(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Scatterv_local_int(sendbuf, sendcounts, displs, recvbuf, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Scatterv_local_float(sendbuf, sendcounts, displs, recvbuf, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Scatterv_local_double(sendbuf, sendcounts, displs, recvbuf, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Scatterv_local_long(sendbuf, sendcounts, displs, recvbuf, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Scatterv_local_ulong(sendbuf, sendcounts, displs, recvbuf, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Scatterv_local_char(sendbuf, sendcounts, displs, recvbuf, comm);
48    }
49    else
50    {
51      printf("MPI_scatterv Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Scatterv_local_int(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    for(int k=0; k<num_ep; k++)
67    {
68      int count = sendcounts[k];
69      for(int j=0; j<count; j+=BUFFER_SIZE)
70      {
71        if(my_rank == 0)
72        {
73          #pragma omp critical (write_to_buffer)
74          {
75            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
76            #pragma omp flush
77          }
78        }
79
80        MPI_Barrier_local(comm);
81
82        if(my_rank == k)
83        {
84          #pragma omp critical (read_from_buffer)
85          {
86            #pragma omp flush
87            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
88          }
89        }
90        MPI_Barrier_local(comm);
91      }
92    }
93  }
94
95  int MPI_Scatterv_local_float(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
96  {
97    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
98    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
99
100
101    float *buffer = comm.my_buffer->buf_float;
102    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
103    float *recv_buf = static_cast<float*>(recvbuf);
104
105    for(int k=0; k<num_ep; k++)
106    {
107      int count = sendcounts[k];
108      for(int j=0; j<count; j+=BUFFER_SIZE)
109      {
110        if(my_rank == 0)
111        {
112          #pragma omp critical (write_to_buffer)
113          {
114            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
115            #pragma omp flush
116          }
117        }
118
119        MPI_Barrier_local(comm);
120
121        if(my_rank == k)
122        {
123          #pragma omp critical (read_from_buffer)
124          {
125            #pragma omp flush
126            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
127          }
128        }
129        MPI_Barrier_local(comm);
130      }
131    }
132  }
133
134  int MPI_Scatterv_local_double(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
135  {
136    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
137    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
138
139
140    double *buffer = comm.my_buffer->buf_double;
141    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
142    double *recv_buf = static_cast<double*>(recvbuf);
143
144    for(int k=0; k<num_ep; k++)
145    {
146      int count = sendcounts[k];
147      for(int j=0; j<count; j+=BUFFER_SIZE)
148      {
149        if(my_rank == 0)
150        {
151          #pragma omp critical (write_to_buffer)
152          {
153            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
154            #pragma omp flush
155          }
156        }
157
158        MPI_Barrier_local(comm);
159
160        if(my_rank == k)
161        {
162          #pragma omp critical (read_from_buffer)
163          {
164            #pragma omp flush
165            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
166          }
167        }
168        MPI_Barrier_local(comm);
169      }
170    }
171  }
172
173  int MPI_Scatterv_local_long(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
174  {
175    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
176    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
177
178
179    long *buffer = comm.my_buffer->buf_long;
180    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
181    long *recv_buf = static_cast<long*>(recvbuf);
182
183    for(int k=0; k<num_ep; k++)
184    {
185      int count = sendcounts[k];
186      for(int j=0; j<count; j+=BUFFER_SIZE)
187      {
188        if(my_rank == 0)
189        {
190          #pragma omp critical (write_to_buffer)
191          {
192            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
193            #pragma omp flush
194          }
195        }
196
197        MPI_Barrier_local(comm);
198
199        if(my_rank == k)
200        {
201          #pragma omp critical (read_from_buffer)
202          {
203            #pragma omp flush
204            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
205          }
206        }
207        MPI_Barrier_local(comm);
208      }
209    }
210  }
211
212
213  int MPI_Scatterv_local_ulong(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
214  {
215    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
216    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
217
218
219    unsigned long *buffer = comm.my_buffer->buf_ulong;
220    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
221    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
222
223    for(int k=0; k<num_ep; k++)
224    {
225      int count = sendcounts[k];
226      for(int j=0; j<count; j+=BUFFER_SIZE)
227      {
228        if(my_rank == 0)
229        {
230          #pragma omp critical (write_to_buffer)
231          {
232            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
233            #pragma omp flush
234          }
235        }
236
237        MPI_Barrier_local(comm);
238
239        if(my_rank == k)
240        {
241          #pragma omp critical (read_from_buffer)
242          {
243            #pragma omp flush
244            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
245          }
246        }
247        MPI_Barrier_local(comm);
248      }
249    }
250  }
251
252
253  int MPI_Scatterv_local_char(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
254  {
255    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
256    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
257
258
259    char *buffer = comm.my_buffer->buf_char;
260    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
261    char *recv_buf = static_cast<char*>(recvbuf);
262
263    for(int k=0; k<num_ep; k++)
264    {
265      int count = sendcounts[k];
266      for(int j=0; j<count; j+=BUFFER_SIZE)
267      {
268        if(my_rank == 0)
269        {
270          #pragma omp critical (write_to_buffer)
271          {
272            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
273            #pragma omp flush
274          }
275        }
276
277        MPI_Barrier_local(comm);
278
279        if(my_rank == k)
280        {
281          #pragma omp critical (read_from_buffer)
282          {
283            #pragma omp flush
284            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
285          }
286        }
287        MPI_Barrier_local(comm);
288      }
289    }
290  }
291
292
293  int MPI_Scatterv(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype sendtype, void *recvbuf, int recvcount,
294                   MPI_Datatype recvtype, int root, MPI_Comm comm)
295  {
296    if(!comm.is_ep)
297    {
298      ::MPI_Scatterv(sendbuf, sendcounts, displs, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount,
299                     static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
300      return 0;
301    }
302    if(!comm.mpi_comm) return 0;
303
304    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
305
306    MPI_Datatype datatype = sendtype;
307
308    int root_mpi_rank = comm.rank_map->at(root).second;
309    int root_ep_loc = comm.rank_map->at(root).first;
310
311    int ep_rank, ep_rank_loc, mpi_rank;
312    int ep_size, num_ep, mpi_size;
313
314    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
315    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
316    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
317    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
318    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
319    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
320
321    if(ep_rank != root)
322    {
323      sendcounts = new int[ep_size];
324      displs = new int[ep_size];
325    }
326   
327    MPI_Bcast(const_cast<int*>(sendcounts), ep_size, MPI_INT, root, comm);
328    MPI_Bcast(const_cast<int*>(displs), ep_size, MPI_INT, root, comm);
329
330
331    int count = recvcount;
332
333    ::MPI_Aint datasize, lb;
334
335    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
336
337    assert(accumulate(sendcounts, sendcounts+ep_size-1, 0) == displs[ep_size-1]); // Only for contunuous gather.
338
339
340    void *master_sendbuf;
341    void *local_recvbuf;
342
343    if(root_ep_loc!=0 && mpi_rank == root_mpi_rank)
344    {
345      int count_sum = accumulate(sendcounts, sendcounts+ep_size, 0);
346      if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count_sum];
347
348      innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count_sum, datatype, comm);
349    }
350
351
352
353    if(ep_rank_loc == 0)
354    {
355      int mpi_sendcnt = accumulate(sendcounts+ep_rank, sendcounts+ep_rank+num_ep, 0);
356      int mpi_scatterv_sendcnt[mpi_size];
357      int mpi_displs[mpi_size];
358
359      local_recvbuf = new void*[datasize*mpi_sendcnt];
360
361      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
362
363      mpi_displs[0] = displs[0];
364      for(int i=1; i<mpi_size; i++)
365        mpi_displs[i] = mpi_displs[i-1] + mpi_scatterv_sendcnt[i-1];
366
367
368      if(root_ep_loc!=0)
369      {
370        ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype),
371                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
372      }
373      else
374      {
375        ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype),
376                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
377      }
378    }
379
380    int local_displs[num_ep];
381    local_displs[0] = 0;
382    for(int i=1; i<num_ep; i++)
383    {
384      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
385    }
386
387    MPI_Scatterv_local2(local_recvbuf, sendcounts+ep_rank-ep_rank_loc, local_displs, datatype, recvbuf, comm);
388
389    if(ep_rank_loc == 0)
390    {
391      if(datatype == MPI_INT)
392      {
393        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf);
394        delete[] static_cast<int*>(local_recvbuf);
395      }
396      else if(datatype == MPI_FLOAT)
397      {
398        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf);
399        delete[] static_cast<float*>(local_recvbuf);
400      }
401      else  if(datatype == MPI_DOUBLE)
402      {
403        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf);
404        delete[] static_cast<double*>(local_recvbuf);
405      }
406      else  if(datatype == MPI_LONG)
407      {
408        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf);
409        delete[] static_cast<long*>(local_recvbuf);
410      }
411      else  if(datatype == MPI_UNSIGNED_LONG)
412      {
413        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf);
414        delete[] static_cast<unsigned long*>(local_recvbuf);
415      }
416      else // if(datatype == MPI_DOUBLE)
417      {
418        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf);
419        delete[] static_cast<char*>(local_recvbuf);
420      }
421    }
422    else
423    {
424      delete[] sendcounts;
425      delete[] displs;
426    }
427
428  }
429}
Note: See TracBrowser for help on using the repository browser.