source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_bcast.cpp @ 1289

Last change on this file since 1289 was 1289, checked in by yushan, 7 years ago

EP update part 2

File size: 8.9 KB
Line 
1/*!
2   \file ep_bcast.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Bcast
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13
14using namespace std;
15
16
17namespace ep_lib
18{
19
20  int MPI_Bcast_local(void *buffer, int count, MPI_Datatype datatype, int local_root, MPI_Comm comm)
21  {
22    assert(valid_type(datatype));
23
24    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
25
26    ::MPI_Aint datasize, lb;
27    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
28   
29
30    if(ep_rank_loc == local_root)
31    {
32      //comm.ep_comm_ptr->comm_list->collective_buffer[local_root] = buffer;
33      comm.my_buffer->void_buffer[local_root] = buffer;
34    }
35
36//    #pragma omp flush
37    MPI_Barrier_local(comm);
38//    #pragma omp flush
39
40    if(ep_rank_loc != local_root)
41    {
42      #pragma omp critical (_bcast)     
43      memcpy(buffer, comm.my_buffer->void_buffer[local_root], datasize * count);
44      //memcpy(buffer, comm.ep_comm_ptr->comm_list->collective_buffer[local_root], datasize * count);
45    }
46
47    MPI_Barrier_local(comm);
48  }
49
50  int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm)
51  {
52
53    if(!comm.is_ep)
54    {
55      #pragma omp single nowait
56      ::MPI_Bcast(buffer, count, static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
57      return 0;
58    }
59
60
61    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
62    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
63    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
64
65    int root_mpi_rank = comm.rank_map->at(root).second;
66    int root_ep_rank_loc = comm.rank_map->at(root).first;
67
68    // printf("root_mpi_rank = %d\n", root_mpi_rank);   
69
70    if((ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root)
71    {
72      ::MPI_Bcast(buffer, count, to_mpi_type(datatype), root_mpi_rank, to_mpi_comm(comm.mpi_comm));
73    }
74
75    if(mpi_rank == root_mpi_rank) MPI_Bcast_local(buffer, count, datatype, root_ep_rank_loc, comm);
76    else                          MPI_Bcast_local(buffer, count, datatype, 0, comm);
77
78    return 0;
79  }
80
81
82
83
84
85
86  int MPI_Bcast_local2(void *buffer, int count, MPI_Datatype datatype, MPI_Comm comm)
87  {
88    if(datatype == MPI_INT)
89    {
90      return MPI_Bcast_local_int(buffer, count, comm);
91    }
92    else if(datatype == MPI_FLOAT)
93    {
94      return MPI_Bcast_local_float(buffer, count, comm);
95    }
96    else if(datatype == MPI_DOUBLE)
97    {
98      return MPI_Bcast_local_double(buffer, count, comm);
99    }
100    else if(datatype == MPI_CHAR)
101    {
102      return MPI_Bcast_local_char(buffer, count, comm);
103    }
104    else if(datatype == MPI_LONG)
105    {
106      return MPI_Bcast_local_long(buffer, count, comm);
107    }
108    else if(datatype == MPI_UNSIGNED_LONG)
109    {
110      return MPI_Bcast_local_char(buffer, count, comm);
111    }
112    else
113    {
114      printf("MPI_Bcast Datatype not supported!\n");
115      exit(0);
116    }
117  }
118
119  int MPI_Bcast_local_int(void *buf, int count, MPI_Comm comm)
120  {
121    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
122    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
123
124    int *buffer = comm.my_buffer->buf_int;
125    int *tmp = static_cast<int*>(buf);
126
127    for(int j=0; j<count; j+=BUFFER_SIZE)
128    {
129      if(my_rank == 0)
130      {
131        #pragma omp critical (write_to_buffer)
132        {
133          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
134        }
135        #pragma omp flush
136      }
137
138      MPI_Barrier_local(comm);
139
140
141
142      if(my_rank != 0)
143      {
144        #pragma omp flush
145        #pragma omp critical (read_from_buffer)
146        {
147          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
148        }
149      }
150
151      MPI_Barrier_local(comm);
152    }
153  }
154
155  int MPI_Bcast_local_float(void *buf, int count, MPI_Comm comm)
156  {
157    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
158    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
159
160    float *buffer = comm.my_buffer->buf_float;
161    float *tmp = static_cast<float*>(buf);
162
163    for(int j=0; j<count; j+=BUFFER_SIZE)
164    {
165      if(my_rank == 0)
166      {
167        #pragma omp critical (write_to_buffer)
168        {
169          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
170        }
171        #pragma omp flush
172      }
173
174      MPI_Barrier_local(comm);
175
176
177      if(my_rank != 0)
178      {
179        #pragma omp flush
180        #pragma omp critical (read_from_buffer)
181        {
182          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
183        }
184      }
185
186      MPI_Barrier_local(comm);
187    }
188  }
189
190  int MPI_Bcast_local_double(void *buf, int count, MPI_Comm comm)
191  {
192    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
193    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
194
195    double *buffer = comm.my_buffer->buf_double;
196    double *tmp = static_cast<double*>(buf);
197
198    for(int j=0; j<count; j+=BUFFER_SIZE)
199    {
200      if(my_rank == 0)
201      {
202        #pragma omp critical (write_to_buffer)
203        {
204          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
205        }
206        #pragma omp flush
207      }
208
209      MPI_Barrier_local(comm);
210
211
212      if(my_rank != 0)
213      {
214        #pragma omp flush
215        #pragma omp critical (read_from_buffer)
216        {
217          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
218        }
219      }
220
221      MPI_Barrier_local(comm);
222    }
223  }
224
225
226  int MPI_Bcast_local_char(void *buf, int count, MPI_Comm comm)
227  {
228    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
229    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
230
231    char *buffer = comm.my_buffer->buf_char;
232    char *tmp = static_cast<char*>(buf);
233
234    for(int j=0; j<count; j+=BUFFER_SIZE)
235    {
236      if(my_rank == 0)
237      {
238        #pragma omp critical (write_to_buffer)
239        {
240          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
241        }
242        #pragma omp flush
243      }
244
245      MPI_Barrier_local(comm);
246
247
248      if(my_rank != 0)
249      {
250        #pragma omp flush
251        #pragma omp critical (read_from_buffer)
252        {
253          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
254        }
255      }
256
257      MPI_Barrier_local(comm);
258    }
259  }
260
261  int MPI_Bcast_local_long(void *buf, int count, MPI_Comm comm)
262  {
263    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
264    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
265
266    long *buffer = comm.my_buffer->buf_long;
267    long *tmp = static_cast<long*>(buf);
268
269    for(int j=0; j<count; j+=BUFFER_SIZE)
270    {
271      if(my_rank == 0)
272      {
273        #pragma omp critical (write_to_buffer)
274        {
275          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
276        }
277        #pragma omp flush
278      }
279
280      MPI_Barrier_local(comm);
281
282
283      if(my_rank != 0)
284      {
285        #pragma omp flush
286        #pragma omp critical (read_from_buffer)
287        {
288          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
289        }
290      }
291
292      MPI_Barrier_local(comm);
293    }
294  }
295
296  int MPI_Bcast_local_ulong(void *buf, int count, MPI_Comm comm)
297  {
298    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
299    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
300
301    unsigned long *buffer = comm.my_buffer->buf_ulong;
302    unsigned long *tmp = static_cast<unsigned long*>(buf);
303
304    for(int j=0; j<count; j+=BUFFER_SIZE)
305    {
306      if(my_rank == 0)
307      {
308        #pragma omp critical (write_to_buffer)
309        {
310          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
311        }
312        #pragma omp flush
313      }
314
315      MPI_Barrier_local(comm);
316
317
318      if(my_rank != 0)
319      {
320        #pragma omp flush
321        #pragma omp critical (read_from_buffer)
322        {
323          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
324        }
325      }
326
327      MPI_Barrier_local(comm);
328    }
329  }
330
331
332  int MPI_Bcast2(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm)
333  {
334
335    if(!comm.is_ep)
336    {
337      ::MPI_Bcast(buffer, count, static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
338      return 0;
339    }
340
341
342    int ep_rank, ep_rank_loc, mpi_rank;
343    int ep_size, num_ep, mpi_size;
344
345    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
346    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
347    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
348    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
349    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
350    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
351
352
353
354    int root_mpi_rank = comm.rank_map->at(root).second;
355    int root_ep_rank_loc = comm.rank_map->at(root).first;
356
357
358    // if root is not master thread, send first to master
359    if(root_ep_rank_loc != 0 && mpi_rank == root_mpi_rank)
360    {
361      innode_memcpy(root_ep_rank_loc, buffer, 0, buffer, count, datatype, comm);
362    }
363
364
365    if(ep_rank_loc==0)
366    {
367      ::MPI_Bcast(buffer, count, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
368    }
369
370    MPI_Bcast_local2(buffer, count, datatype, comm);
371
372    return 0;
373  }
374
375
376}
Note: See TracBrowser for help on using the repository browser.