source: XIOS/trunk/extern/src_ep/ep_scatter.cpp @ 1034

Last change on this file since 1034 was 1034, checked in by yushan, 4 years ago

adding src_ep into extern folder

File size: 11.7 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scatter
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16
17  int MPI_Scatter_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Scatter_local_int(sendbuf, count, recvbuf, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Scatter_local_float(sendbuf, count, recvbuf, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Scatter_local_double(sendbuf, count, recvbuf, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Scatter_local_long(sendbuf, count, recvbuf, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Scatter_local_ulong(sendbuf, count, recvbuf, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Scatter_local_char(sendbuf, count, recvbuf, comm);
48    }
49    else
50    {
51      printf("MPI_Scatter Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Scatter_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    for(int k=0; k<num_ep; k++)
67    {
68      for(int j=0; j<count; j+=BUFFER_SIZE)
69      {
70        if(my_rank == 0)
71        {
72          #pragma omp critical (write_to_buffer)
73          {
74            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
75            #pragma omp flush
76          }
77        }
78
79        MPI_Barrier_local(comm);
80
81        if(my_rank == k)
82        {
83          #pragma omp critical (read_from_buffer)
84          {
85            #pragma omp flush
86            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
87          }
88        }
89        MPI_Barrier_local(comm);
90      }
91    }
92  }
93
94  int MPI_Scatter_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
95  {
96    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
97    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
98
99    float *buffer = comm.my_buffer->buf_float;
100    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
101    float *recv_buf = static_cast<float*>(recvbuf);
102
103    for(int k=0; k<num_ep; k++)
104    {
105      for(int j=0; j<count; j+=BUFFER_SIZE)
106      {
107        if(my_rank == 0)
108        {
109          #pragma omp critical (write_to_buffer)
110          {
111            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
112            #pragma omp flush
113          }
114        }
115
116        MPI_Barrier_local(comm);
117
118        if(my_rank == k)
119        {
120          #pragma omp critical (read_from_buffer)
121          {
122            #pragma omp flush
123            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
124          }
125        }
126        MPI_Barrier_local(comm);
127      }
128    }
129  }
130
131  int MPI_Scatter_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
132  {
133    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
134    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
135
136    double *buffer = comm.my_buffer->buf_double;
137    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
138    double *recv_buf = static_cast<double*>(recvbuf);
139
140    for(int k=0; k<num_ep; k++)
141    {
142      for(int j=0; j<count; j+=BUFFER_SIZE)
143      {
144        if(my_rank == 0)
145        {
146          #pragma omp critical (write_to_buffer)
147          {
148            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
149            #pragma omp flush
150          }
151        }
152
153        MPI_Barrier_local(comm);
154
155        if(my_rank == k)
156        {
157          #pragma omp critical (read_from_buffer)
158          {
159            #pragma omp flush
160            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
161          }
162        }
163        MPI_Barrier_local(comm);
164      }
165    }
166  }
167
168  int MPI_Scatter_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
169  {
170    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
171    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
172
173    long *buffer = comm.my_buffer->buf_long;
174    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
175    long *recv_buf = static_cast<long*>(recvbuf);
176
177    for(int k=0; k<num_ep; k++)
178    {
179      for(int j=0; j<count; j+=BUFFER_SIZE)
180      {
181        if(my_rank == 0)
182        {
183          #pragma omp critical (write_to_buffer)
184          {
185            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
186            #pragma omp flush
187          }
188        }
189
190        MPI_Barrier_local(comm);
191
192        if(my_rank == k)
193        {
194          #pragma omp critical (read_from_buffer)
195          {
196            #pragma omp flush
197            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
198          }
199        }
200        MPI_Barrier_local(comm);
201      }
202    }
203  }
204
205
206  int MPI_Scatter_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
207  {
208    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
209    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
210
211    unsigned long *buffer = comm.my_buffer->buf_ulong;
212    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
213    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
214
215    for(int k=0; k<num_ep; k++)
216    {
217      for(int j=0; j<count; j+=BUFFER_SIZE)
218      {
219        if(my_rank == 0)
220        {
221          #pragma omp critical (write_to_buffer)
222          {
223            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
224            #pragma omp flush
225          }
226        }
227
228        MPI_Barrier_local(comm);
229
230        if(my_rank == k)
231        {
232          #pragma omp critical (read_from_buffer)
233          {
234            #pragma omp flush
235            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
236          }
237        }
238        MPI_Barrier_local(comm);
239      }
240    }
241  }
242
243
244  int MPI_Scatter_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
245  {
246    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
247    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
248
249    char *buffer = comm.my_buffer->buf_char;
250    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
251    char *recv_buf = static_cast<char*>(recvbuf);
252
253    for(int k=0; k<num_ep; k++)
254    {
255      for(int j=0; j<count; j+=BUFFER_SIZE)
256      {
257        if(my_rank == 0)
258        {
259          #pragma omp critical (write_to_buffer)
260          {
261            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
262            #pragma omp flush
263          }
264        }
265
266        MPI_Barrier_local(comm);
267
268        if(my_rank == k)
269        {
270          #pragma omp critical (read_from_buffer)
271          {
272            #pragma omp flush
273            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
274          }
275        }
276        MPI_Barrier_local(comm);
277      }
278    }
279  }
280
281
282
283
284
285  int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
286  {
287    if(!comm.is_ep)
288    {
289      #ifdef _serialized
290      #pragma omp critical (_mpi_call)
291      #endif // _serialized
292      ::MPI_Scatter(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
293                    root, static_cast< ::MPI_Comm>(comm.mpi_comm));
294      return 0;
295    }
296
297    if(!comm.mpi_comm) return 0;
298
299    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
300
301    int root_mpi_rank = comm.rank_map->at(root).second;
302    int root_ep_loc = comm.rank_map->at(root).first;
303
304    int ep_rank, ep_rank_loc, mpi_rank;
305    int ep_size, num_ep, mpi_size;
306
307    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
308    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
309    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
310    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
311    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
312    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
313
314
315    MPI_Datatype datatype = sendtype;
316    int count = sendcount;
317
318    ::MPI_Aint datasize, lb;
319    #ifdef _serialized
320    #pragma omp critical (_mpi_call)
321    #endif // _serialized
322    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
323
324
325    void *master_sendbuf;
326    void *local_recvbuf;
327
328    if(root_ep_loc!=0 && mpi_rank == root_mpi_rank)
329    {
330      if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count*ep_size];
331
332      innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count*ep_size, datatype, comm);
333    }
334
335
336
337    if(ep_rank_loc == 0)
338    {
339      int mpi_sendcnt = count*num_ep;
340      int mpi_scatterv_sendcnt[mpi_size];
341      int displs[mpi_size];
342
343      local_recvbuf = new void*[datasize*mpi_sendcnt];
344
345      #ifdef _serialized
346      #pragma omp critical (_mpi_call)
347      #endif // _serialized
348      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_scatterv_sendcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
349
350      displs[0] = 0;
351      for(int i=1; i<mpi_size; i++)
352        displs[i] = displs[i-1] + mpi_scatterv_sendcnt[i-1];
353
354
355      if(root_ep_loc!=0)
356      {
357        #ifdef _serialized
358        #pragma omp critical (_mpi_call)
359        #endif // _serialized
360        ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype),
361                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
362      }
363      else
364      {
365        #ifdef _serialized
366        #pragma omp critical (_mpi_call)
367        #endif // _serialized
368        ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype),
369                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
370      }
371    }
372
373    MPI_Scatter_local(local_recvbuf, count, datatype, recvbuf, comm);
374
375    if(ep_rank_loc == 0)
376    {
377      if(datatype == MPI_INT)
378      {
379        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf);
380        delete[] static_cast<int*>(local_recvbuf);
381      }
382      else if(datatype == MPI_FLOAT)
383      {
384        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf);
385        delete[] static_cast<float*>(local_recvbuf);
386      }
387      else if(datatype == MPI_DOUBLE)
388      {
389        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf);
390        delete[] static_cast<double*>(local_recvbuf);
391      }
392      else if(datatype == MPI_LONG)
393      {
394        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf);
395        delete[] static_cast<long*>(local_recvbuf);
396      }
397      else if(datatype == MPI_UNSIGNED_LONG)
398      {
399        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf);
400        delete[] static_cast<unsigned long*>(local_recvbuf);
401      }
402      else //if(datatype == MPI_DOUBLE)
403      {
404        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf);
405        delete[] static_cast<char*>(local_recvbuf);
406      }
407    }
408
409  }
410}
Note: See TracBrowser for help on using the repository browser.