source: XIOS/trunk/extern/src_ep/ep_scatterv.cpp @ 1034

Last change on this file since 1034 was 1034, checked in by yushan, 7 years ago

adding src_ep into extern folder

File size: 12.9 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scatterv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16
17  int MPI_Scatterv_local(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Scatterv_local_int(sendbuf, sendcounts, displs, recvbuf, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Scatterv_local_float(sendbuf, sendcounts, displs, recvbuf, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Scatterv_local_double(sendbuf, sendcounts, displs, recvbuf, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Scatterv_local_long(sendbuf, sendcounts, displs, recvbuf, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Scatterv_local_ulong(sendbuf, sendcounts, displs, recvbuf, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Scatterv_local_char(sendbuf, sendcounts, displs, recvbuf, comm);
48    }
49    else
50    {
51      printf("MPI_scatterv Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Scatterv_local_int(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    for(int k=0; k<num_ep; k++)
67    {
68      int count = sendcounts[k];
69      for(int j=0; j<count; j+=BUFFER_SIZE)
70      {
71        if(my_rank == 0)
72        {
73          #pragma omp critical (write_to_buffer)
74          {
75            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
76            #pragma omp flush
77          }
78        }
79
80        MPI_Barrier_local(comm);
81
82        if(my_rank == k)
83        {
84          #pragma omp critical (read_from_buffer)
85          {
86            #pragma omp flush
87            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
88          }
89        }
90        MPI_Barrier_local(comm);
91      }
92    }
93  }
94
95  int MPI_Scatterv_local_float(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
96  {
97    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
98    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
99
100
101    float *buffer = comm.my_buffer->buf_float;
102    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
103    float *recv_buf = static_cast<float*>(recvbuf);
104
105    for(int k=0; k<num_ep; k++)
106    {
107      int count = sendcounts[k];
108      for(int j=0; j<count; j+=BUFFER_SIZE)
109      {
110        if(my_rank == 0)
111        {
112          #pragma omp critical (write_to_buffer)
113          {
114            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
115            #pragma omp flush
116          }
117        }
118
119        MPI_Barrier_local(comm);
120
121        if(my_rank == k)
122        {
123          #pragma omp critical (read_from_buffer)
124          {
125            #pragma omp flush
126            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
127          }
128        }
129        MPI_Barrier_local(comm);
130      }
131    }
132  }
133
134  int MPI_Scatterv_local_double(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
135  {
136    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
137    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
138
139
140    double *buffer = comm.my_buffer->buf_double;
141    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
142    double *recv_buf = static_cast<double*>(recvbuf);
143
144    for(int k=0; k<num_ep; k++)
145    {
146      int count = sendcounts[k];
147      for(int j=0; j<count; j+=BUFFER_SIZE)
148      {
149        if(my_rank == 0)
150        {
151          #pragma omp critical (write_to_buffer)
152          {
153            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
154            #pragma omp flush
155          }
156        }
157
158        MPI_Barrier_local(comm);
159
160        if(my_rank == k)
161        {
162          #pragma omp critical (read_from_buffer)
163          {
164            #pragma omp flush
165            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
166          }
167        }
168        MPI_Barrier_local(comm);
169      }
170    }
171  }
172
173  int MPI_Scatterv_local_long(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
174  {
175    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
176    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
177
178
179    long *buffer = comm.my_buffer->buf_long;
180    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
181    long *recv_buf = static_cast<long*>(recvbuf);
182
183    for(int k=0; k<num_ep; k++)
184    {
185      int count = sendcounts[k];
186      for(int j=0; j<count; j+=BUFFER_SIZE)
187      {
188        if(my_rank == 0)
189        {
190          #pragma omp critical (write_to_buffer)
191          {
192            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
193            #pragma omp flush
194          }
195        }
196
197        MPI_Barrier_local(comm);
198
199        if(my_rank == k)
200        {
201          #pragma omp critical (read_from_buffer)
202          {
203            #pragma omp flush
204            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
205          }
206        }
207        MPI_Barrier_local(comm);
208      }
209    }
210  }
211
212
213  int MPI_Scatterv_local_ulong(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
214  {
215    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
216    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
217
218
219    unsigned long *buffer = comm.my_buffer->buf_ulong;
220    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
221    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
222
223    for(int k=0; k<num_ep; k++)
224    {
225      int count = sendcounts[k];
226      for(int j=0; j<count; j+=BUFFER_SIZE)
227      {
228        if(my_rank == 0)
229        {
230          #pragma omp critical (write_to_buffer)
231          {
232            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
233            #pragma omp flush
234          }
235        }
236
237        MPI_Barrier_local(comm);
238
239        if(my_rank == k)
240        {
241          #pragma omp critical (read_from_buffer)
242          {
243            #pragma omp flush
244            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
245          }
246        }
247        MPI_Barrier_local(comm);
248      }
249    }
250  }
251
252
253  int MPI_Scatterv_local_char(const void *sendbuf, const int sendcounts[], const int displs[], void *recvbuf, MPI_Comm comm)
254  {
255    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
256    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
257
258
259    char *buffer = comm.my_buffer->buf_char;
260    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
261    char *recv_buf = static_cast<char*>(recvbuf);
262
263    for(int k=0; k<num_ep; k++)
264    {
265      int count = sendcounts[k];
266      for(int j=0; j<count; j+=BUFFER_SIZE)
267      {
268        if(my_rank == 0)
269        {
270          #pragma omp critical (write_to_buffer)
271          {
272            copy(send_buf+displs[k]+j, send_buf+displs[k]+j+min(BUFFER_SIZE, count-j), buffer);
273            #pragma omp flush
274          }
275        }
276
277        MPI_Barrier_local(comm);
278
279        if(my_rank == k)
280        {
281          #pragma omp critical (read_from_buffer)
282          {
283            #pragma omp flush
284            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
285          }
286        }
287        MPI_Barrier_local(comm);
288      }
289    }
290  }
291
292
293  int MPI_Scatterv(const void *sendbuf, const int sendcounts[], const int displs[], MPI_Datatype sendtype, void *recvbuf, int recvcount,
294                   MPI_Datatype recvtype, int root, MPI_Comm comm)
295  {
296    if(!comm.is_ep)
297    {
298      #ifdef _serialized
299      #pragma omp critical (_mpi_call)
300      #endif // _serialized
301      ::MPI_Scatterv(sendbuf, sendcounts, displs, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount,
302                     static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
303      return 0;
304    }
305    if(!comm.mpi_comm) return 0;
306
307    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
308
309    MPI_Datatype datatype = sendtype;
310
311    int root_mpi_rank = comm.rank_map->at(root).second;
312    int root_ep_loc = comm.rank_map->at(root).first;
313
314    int ep_rank, ep_rank_loc, mpi_rank;
315    int ep_size, num_ep, mpi_size;
316
317    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
318    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
319    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
320    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
321    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
322    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
323   
324    MPI_Bcast(const_cast<int*>(sendcounts), ep_size, MPI_INT, root, comm);
325    MPI_Bcast(const_cast<int*>(displs), ep_size, MPI_INT, root, comm);
326
327
328    int count = recvcount;
329
330    ::MPI_Aint datasize, lb;
331    #ifdef _serialized
332    #pragma omp critical (_mpi_call)
333    #endif // _serialized
334    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
335
336    assert(accumulate(sendcounts, sendcounts+ep_size-1, 0) == displs[ep_size-1]); // Only for contunuous gather.
337
338
339    void *master_sendbuf;
340    void *local_recvbuf;
341
342    if(root_ep_loc!=0 && mpi_rank == root_mpi_rank)
343    {
344      int count_sum = accumulate(sendcounts, sendcounts+ep_size, 0);
345      if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count_sum];
346
347      innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count_sum, datatype, comm);
348    }
349
350
351
352    if(ep_rank_loc == 0)
353    {
354      int mpi_sendcnt = accumulate(sendcounts+ep_rank, sendcounts+ep_rank+num_ep, 0);
355      int mpi_scatterv_sendcnt[mpi_size];
356      int mpi_displs[mpi_size];
357
358      local_recvbuf = new void*[datasize*mpi_sendcnt];
359
360      #ifdef _serialized
361      #pragma omp critical (_mpi_call)
362      #endif // _serialized
363      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_scatterv_sendcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
364
365      mpi_displs[0] = displs[0];
366      for(int i=1; i<mpi_size; i++)
367        mpi_displs[i] = mpi_displs[i-1] + mpi_scatterv_sendcnt[i-1];
368
369
370      if(root_ep_loc!=0)
371      {
372        #ifdef _serialized
373        #pragma omp critical (_mpi_call)
374        #endif // _serialized
375        ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype),
376                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
377      }
378      else
379      {
380        #ifdef _serialized
381        #pragma omp critical (_mpi_call)
382        #endif // _serialized
383        ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, mpi_displs, static_cast< ::MPI_Datatype>(datatype),
384                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
385      }
386    }
387
388    int local_displs[num_ep];
389    local_displs[0] = 0;
390    for(int i=1; i<num_ep; i++)
391    {
392      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
393    }
394
395    MPI_Scatterv_local(local_recvbuf, sendcounts+ep_rank-ep_rank_loc, local_displs, datatype, recvbuf, comm);
396
397    if(ep_rank_loc == 0)
398    {
399      if(datatype == MPI_INT)
400      {
401        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf);
402        delete[] static_cast<int*>(local_recvbuf);
403      }
404      else if(datatype == MPI_FLOAT)
405      {
406        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf);
407        delete[] static_cast<float*>(local_recvbuf);
408      }
409      else  if(datatype == MPI_DOUBLE)
410      {
411        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf);
412        delete[] static_cast<double*>(local_recvbuf);
413      }
414      else  if(datatype == MPI_LONG)
415      {
416        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf);
417        delete[] static_cast<long*>(local_recvbuf);
418      }
419      else  if(datatype == MPI_UNSIGNED_LONG)
420      {
421        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf);
422        delete[] static_cast<unsigned long*>(local_recvbuf);
423      }
424      else // if(datatype == MPI_DOUBLE)
425      {
426        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf);
427        delete[] static_cast<char*>(local_recvbuf);
428      }
429    }
430
431  }
432}
Note: See TracBrowser for help on using the repository browser.