source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp @ 1289

Last change on this file since 1289 was 1289, checked in by yushan, 5 years ago

EP update part 2

File size: 11.3 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Scatter
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14namespace ep_lib
15{
16
17  int MPI_Scatter_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Scatter_local_int(sendbuf, count, recvbuf, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Scatter_local_float(sendbuf, count, recvbuf, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Scatter_local_double(sendbuf, count, recvbuf, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Scatter_local_long(sendbuf, count, recvbuf, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Scatter_local_ulong(sendbuf, count, recvbuf, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Scatter_local_char(sendbuf, count, recvbuf, comm);
48    }
49    else
50    {
51      printf("MPI_Scatter Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Scatter_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    for(int k=0; k<num_ep; k++)
67    {
68      for(int j=0; j<count; j+=BUFFER_SIZE)
69      {
70        if(my_rank == 0)
71        {
72          #pragma omp critical (write_to_buffer)
73          {
74            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
75            #pragma omp flush
76          }
77        }
78
79        MPI_Barrier_local(comm);
80
81        if(my_rank == k)
82        {
83          #pragma omp critical (read_from_buffer)
84          {
85            #pragma omp flush
86            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
87          }
88        }
89        MPI_Barrier_local(comm);
90      }
91    }
92  }
93
94  int MPI_Scatter_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
95  {
96    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
97    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
98
99    float *buffer = comm.my_buffer->buf_float;
100    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
101    float *recv_buf = static_cast<float*>(recvbuf);
102
103    for(int k=0; k<num_ep; k++)
104    {
105      for(int j=0; j<count; j+=BUFFER_SIZE)
106      {
107        if(my_rank == 0)
108        {
109          #pragma omp critical (write_to_buffer)
110          {
111            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
112            #pragma omp flush
113          }
114        }
115
116        MPI_Barrier_local(comm);
117
118        if(my_rank == k)
119        {
120          #pragma omp critical (read_from_buffer)
121          {
122            #pragma omp flush
123            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
124          }
125        }
126        MPI_Barrier_local(comm);
127      }
128    }
129  }
130
131  int MPI_Scatter_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
132  {
133    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
134    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
135
136    double *buffer = comm.my_buffer->buf_double;
137    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
138    double *recv_buf = static_cast<double*>(recvbuf);
139
140    for(int k=0; k<num_ep; k++)
141    {
142      for(int j=0; j<count; j+=BUFFER_SIZE)
143      {
144        if(my_rank == 0)
145        {
146          #pragma omp critical (write_to_buffer)
147          {
148            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
149            #pragma omp flush
150          }
151        }
152
153        MPI_Barrier_local(comm);
154
155        if(my_rank == k)
156        {
157          #pragma omp critical (read_from_buffer)
158          {
159            #pragma omp flush
160            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
161          }
162        }
163        MPI_Barrier_local(comm);
164      }
165    }
166  }
167
168  int MPI_Scatter_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
169  {
170    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
171    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
172
173    long *buffer = comm.my_buffer->buf_long;
174    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
175    long *recv_buf = static_cast<long*>(recvbuf);
176
177    for(int k=0; k<num_ep; k++)
178    {
179      for(int j=0; j<count; j+=BUFFER_SIZE)
180      {
181        if(my_rank == 0)
182        {
183          #pragma omp critical (write_to_buffer)
184          {
185            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
186            #pragma omp flush
187          }
188        }
189
190        MPI_Barrier_local(comm);
191
192        if(my_rank == k)
193        {
194          #pragma omp critical (read_from_buffer)
195          {
196            #pragma omp flush
197            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
198          }
199        }
200        MPI_Barrier_local(comm);
201      }
202    }
203  }
204
205
206  int MPI_Scatter_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
207  {
208    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
209    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
210
211    unsigned long *buffer = comm.my_buffer->buf_ulong;
212    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
213    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
214
215    for(int k=0; k<num_ep; k++)
216    {
217      for(int j=0; j<count; j+=BUFFER_SIZE)
218      {
219        if(my_rank == 0)
220        {
221          #pragma omp critical (write_to_buffer)
222          {
223            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
224            #pragma omp flush
225          }
226        }
227
228        MPI_Barrier_local(comm);
229
230        if(my_rank == k)
231        {
232          #pragma omp critical (read_from_buffer)
233          {
234            #pragma omp flush
235            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
236          }
237        }
238        MPI_Barrier_local(comm);
239      }
240    }
241  }
242
243
244  int MPI_Scatter_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)
245  {
246    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
247    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
248
249    char *buffer = comm.my_buffer->buf_char;
250    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
251    char *recv_buf = static_cast<char*>(recvbuf);
252
253    for(int k=0; k<num_ep; k++)
254    {
255      for(int j=0; j<count; j+=BUFFER_SIZE)
256      {
257        if(my_rank == 0)
258        {
259          #pragma omp critical (write_to_buffer)
260          {
261            copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);
262            #pragma omp flush
263          }
264        }
265
266        MPI_Barrier_local(comm);
267
268        if(my_rank == k)
269        {
270          #pragma omp critical (read_from_buffer)
271          {
272            #pragma omp flush
273            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
274          }
275        }
276        MPI_Barrier_local(comm);
277      }
278    }
279  }
280
281
282
283
284
285  int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
286  {
287    if(!comm.is_ep)
288    {
289      ::MPI_Scatter(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype),
290                    root, static_cast< ::MPI_Comm>(comm.mpi_comm));
291      return 0;
292    }
293
294    if(!comm.mpi_comm) return 0;
295
296    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount);
297
298    int root_mpi_rank = comm.rank_map->at(root).second;
299    int root_ep_loc = comm.rank_map->at(root).first;
300
301    int ep_rank, ep_rank_loc, mpi_rank;
302    int ep_size, num_ep, mpi_size;
303
304    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
305    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
306    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
307    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
308    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
309    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
310
311
312    MPI_Datatype datatype = sendtype;
313    int count = sendcount;
314
315    ::MPI_Aint datasize, lb;
316
317    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
318
319
320    void *master_sendbuf;
321    void *local_recvbuf;
322
323    if(root_ep_loc!=0 && mpi_rank == root_mpi_rank)
324    {
325      if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count*ep_size];
326
327      innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count*ep_size, datatype, comm);
328    }
329
330
331
332    if(ep_rank_loc == 0)
333    {
334      int mpi_sendcnt = count*num_ep;
335      int mpi_scatterv_sendcnt[mpi_size];
336      int displs[mpi_size];
337
338      local_recvbuf = new void*[datasize*mpi_sendcnt];
339
340      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
341
342      displs[0] = 0;
343      for(int i=1; i<mpi_size; i++)
344        displs[i] = displs[i-1] + mpi_scatterv_sendcnt[i-1];
345
346
347      if(root_ep_loc!=0)
348      {
349        ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype),
350                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
351      }
352      else
353      {
354        ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype),
355                     local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
356      }
357    }
358
359    MPI_Scatter_local2(local_recvbuf, count, datatype, recvbuf, comm);
360
361    if(ep_rank_loc == 0)
362    {
363      if(datatype == MPI_INT)
364      {
365        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf);
366        delete[] static_cast<int*>(local_recvbuf);
367      }
368      else if(datatype == MPI_FLOAT)
369      {
370        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf);
371        delete[] static_cast<float*>(local_recvbuf);
372      }
373      else if(datatype == MPI_DOUBLE)
374      {
375        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf);
376        delete[] static_cast<double*>(local_recvbuf);
377      }
378      else if(datatype == MPI_LONG)
379      {
380        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf);
381        delete[] static_cast<long*>(local_recvbuf);
382      }
383      else if(datatype == MPI_UNSIGNED_LONG)
384      {
385        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf);
386        delete[] static_cast<unsigned long*>(local_recvbuf);
387      }
388      else //if(datatype == MPI_DOUBLE)
389      {
390        if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf);
391        delete[] static_cast<char*>(local_recvbuf);
392      }
393    }
394
395  }
396}
Note: See TracBrowser for help on using the repository browser.