source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_bcast.cpp @ 1134

Last change on this file since 1134 was 1134, checked in by yushan, 7 years ago

branch merged with trunk r1130

File size: 7.0 KB
Line 
1/*!
2   \file ep_bcast.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Bcast
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12using namespace std;
13
14
15namespace ep_lib
16{
17  int MPI_Bcast_local(void *buffer, int count, MPI_Datatype datatype, MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      return MPI_Bcast_local_int(buffer, count, comm);
22    }
23    else if(datatype == MPI_FLOAT)
24    {
25      return MPI_Bcast_local_float(buffer, count, comm);
26    }
27    else if(datatype == MPI_DOUBLE)
28    {
29      return MPI_Bcast_local_double(buffer, count, comm);
30    }
31    else if(datatype == MPI_CHAR)
32    {
33      return MPI_Bcast_local_char(buffer, count, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      return MPI_Bcast_local_long(buffer, count, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      return MPI_Bcast_local_char(buffer, count, comm);
42    }
43    else
44    {
45      printf("MPI_Bcast Datatype not supported!\n");
46      exit(0);
47    }
48  }
49
50  int MPI_Bcast_local_int(void *buf, int count, MPI_Comm comm)
51  {
52    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
53    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
54
55    int *buffer = comm.my_buffer->buf_int;
56    int *tmp = static_cast<int*>(buf);
57
58    for(int j=0; j<count; j+=BUFFER_SIZE)
59    {
60      if(my_rank == 0)
61      {
62        #pragma omp critical (write_to_buffer)
63        {
64          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
65        }
66        #pragma omp flush
67      }
68
69      MPI_Barrier_local(comm);
70
71
72
73      if(my_rank != 0)
74      {
75        #pragma omp flush
76        #pragma omp critical (read_from_buffer)
77        {
78          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
79        }
80      }
81
82      MPI_Barrier_local(comm);
83    }
84  }
85
86  int MPI_Bcast_local_float(void *buf, int count, MPI_Comm comm)
87  {
88    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
89    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
90
91    float *buffer = comm.my_buffer->buf_float;
92    float *tmp = static_cast<float*>(buf);
93
94    for(int j=0; j<count; j+=BUFFER_SIZE)
95    {
96      if(my_rank == 0)
97      {
98        #pragma omp critical (write_to_buffer)
99        {
100          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
101        }
102        #pragma omp flush
103      }
104
105      MPI_Barrier_local(comm);
106
107
108      if(my_rank != 0)
109      {
110        #pragma omp flush
111        #pragma omp critical (read_from_buffer)
112        {
113          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
114        }
115      }
116
117      MPI_Barrier_local(comm);
118    }
119  }
120
121  int MPI_Bcast_local_double(void *buf, int count, MPI_Comm comm)
122  {
123    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
124    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
125
126    double *buffer = comm.my_buffer->buf_double;
127    double *tmp = static_cast<double*>(buf);
128
129    for(int j=0; j<count; j+=BUFFER_SIZE)
130    {
131      if(my_rank == 0)
132      {
133        #pragma omp critical (write_to_buffer)
134        {
135          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
136        }
137        #pragma omp flush
138      }
139
140      MPI_Barrier_local(comm);
141
142
143      if(my_rank != 0)
144      {
145        #pragma omp flush
146        #pragma omp critical (read_from_buffer)
147        {
148          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
149        }
150      }
151
152      MPI_Barrier_local(comm);
153    }
154  }
155
156
157  int MPI_Bcast_local_char(void *buf, int count, MPI_Comm comm)
158  {
159    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
160    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
161
162    char *buffer = comm.my_buffer->buf_char;
163    char *tmp = static_cast<char*>(buf);
164
165    for(int j=0; j<count; j+=BUFFER_SIZE)
166    {
167      if(my_rank == 0)
168      {
169        #pragma omp critical (write_to_buffer)
170        {
171          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
172        }
173        #pragma omp flush
174      }
175
176      MPI_Barrier_local(comm);
177
178
179      if(my_rank != 0)
180      {
181        #pragma omp flush
182        #pragma omp critical (read_from_buffer)
183        {
184          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
185        }
186      }
187
188      MPI_Barrier_local(comm);
189    }
190  }
191
192  int MPI_Bcast_local_long(void *buf, int count, MPI_Comm comm)
193  {
194    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
195    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
196
197    long *buffer = comm.my_buffer->buf_long;
198    long *tmp = static_cast<long*>(buf);
199
200    for(int j=0; j<count; j+=BUFFER_SIZE)
201    {
202      if(my_rank == 0)
203      {
204        #pragma omp critical (write_to_buffer)
205        {
206          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
207        }
208        #pragma omp flush
209      }
210
211      MPI_Barrier_local(comm);
212
213
214      if(my_rank != 0)
215      {
216        #pragma omp flush
217        #pragma omp critical (read_from_buffer)
218        {
219          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
220        }
221      }
222
223      MPI_Barrier_local(comm);
224    }
225  }
226
227  int MPI_Bcast_local_ulong(void *buf, int count, MPI_Comm comm)
228  {
229    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
230    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
231
232    unsigned long *buffer = comm.my_buffer->buf_ulong;
233    unsigned long *tmp = static_cast<unsigned long*>(buf);
234
235    for(int j=0; j<count; j+=BUFFER_SIZE)
236    {
237      if(my_rank == 0)
238      {
239        #pragma omp critical (write_to_buffer)
240        {
241          copy(tmp+j, tmp+j+min(BUFFER_SIZE, count-j), buffer);
242        }
243        #pragma omp flush
244      }
245
246      MPI_Barrier_local(comm);
247
248
249      if(my_rank != 0)
250      {
251        #pragma omp flush
252        #pragma omp critical (read_from_buffer)
253        {
254          copy(buffer, buffer+min(BUFFER_SIZE, count-j), tmp+j);
255        }
256      }
257
258      MPI_Barrier_local(comm);
259    }
260  }
261
262
263  int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm)
264  {
265
266    if(!comm.is_ep)
267    {
268      ::MPI_Bcast(buffer, count, static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
269      return 0;
270    }
271
272
273    int ep_rank, ep_rank_loc, mpi_rank;
274    int ep_size, num_ep, mpi_size;
275
276    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
277    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
278    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
279    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
280    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
281    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
282
283
284
285    int root_mpi_rank = comm.rank_map->at(root).second;
286    int root_ep_rank_loc = comm.rank_map->at(root).first;
287
288
289    // if root is not master thread, send first to master
290    if(root_ep_rank_loc != 0 && mpi_rank == root_mpi_rank)
291    {
292      innode_memcpy(root_ep_rank_loc, buffer, 0, buffer, count, datatype, comm);
293    }
294
295
296    if(ep_rank_loc==0)
297    {
298      ::MPI_Bcast(buffer, count, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
299    }
300
301    MPI_Bcast_local(buffer, count, datatype, comm);
302
303    return 0;
304  }
305
306
307}
Note: See TracBrowser for help on using the repository browser.