source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_lib.cpp @ 1220

Last change on this file since 1220 was 1220, checked in by yushan, 4 years ago

test_remap_omp tested on ADA except two fields

File size: 10.7 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4#include <iostream>
5#include <fstream>
6
7using namespace std;
8
9std::list< ep_lib::MPI_Request* > * EP_PendingRequests = 0;
10#pragma omp threadprivate(EP_PendingRequests)
11
12namespace ep_lib
13{ 
14
15  int tag_combine(int real_tag, int src, int dest)
16  {
17    int a = real_tag << 16;
18    int b = src << 8;
19    int c = dest;
20
21    return a+b+c;
22  }
23
24  int get_ep_rank(MPI_Comm comm, int ep_rank_loc, int mpi_rank)
25  {
26    for(int i=0; i<comm.rank_map->size(); i++)
27    {
28      if(   ( comm.rank_map->at(i).first  == ep_rank_loc )
29         && ( comm.rank_map->at(i).second == mpi_rank ) )
30      {
31        return i;
32      }
33    }
34    printf("rank not find\n");
35  }
36 
37  int get_ep_rank_intercomm(MPI_Comm comm, int ep_rank_loc, int mpi_rank)
38  {
39    // intercomm
40    int inter_rank;
41    for(int i=0; i<comm.ep_comm_ptr->intercomm->intercomm_rank_map->size(); i++)
42    {
43      if(   ( comm.ep_comm_ptr->intercomm->intercomm_rank_map->at(i).first  == ep_rank_loc )
44         && ( comm.ep_comm_ptr->intercomm->intercomm_rank_map->at(i).second == mpi_rank ) )
45      {
46        inter_rank =  i;
47        break;
48      }
49    }
50
51    for(int i=0; i<comm.ep_comm_ptr->intercomm->remote_rank_map->size(); i++)
52    {
53      if(  comm.ep_comm_ptr->intercomm->remote_rank_map->at(i).first  == inter_rank  )
54      {
55        //printf("get_ep_rank for intercomm, ep_rank_loc = %d, mpi_rank = %d => ep_src = %d\n", ep_rank_loc, mpi_rank, i);
56        return i;
57      }
58    }
59
60    printf("rank not find\n");
61   
62  }
63
64
65  int innode_memcpy(int sender, const void* sendbuf, int receiver, void* recvbuf, int count, MPI_Datatype datatype, MPI_Comm comm)
66  {
67    int ep_rank, ep_rank_loc, mpi_rank;
68    int ep_size, num_ep, mpi_size;
69
70    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
71    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
72    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
73    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
74    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
75    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
76
77
78
79    if(datatype == MPI_INT)
80    {
81
82      int* send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
83      int* recv_buf = static_cast<int*>(recvbuf);
84      int* buffer = comm.my_buffer->buf_int;
85
86      for(int j=0; j<count; j+=BUFFER_SIZE)
87      {
88        if(ep_rank_loc == sender)
89        {
90          #pragma omp critical (write_to_buffer)
91          {
92            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
93          }
94          #pragma omp flush
95        }
96
97        MPI_Barrier_local(comm);
98
99
100        if(ep_rank_loc == receiver)
101        {
102          #pragma omp flush
103          #pragma omp critical (read_from_buffer)
104          {
105            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
106          }
107        }
108
109        MPI_Barrier_local(comm);
110      }
111    }
112    else if(datatype == MPI_FLOAT)
113    {
114
115      float* send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
116      float* recv_buf = static_cast<float*>(recvbuf);
117      float* buffer = comm.my_buffer->buf_float;
118
119      for(int j=0; j<count; j+=BUFFER_SIZE)
120      {
121        if(ep_rank_loc == sender)
122        {
123          #pragma omp critical (write_to_buffer)
124          {
125            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
126          }
127          #pragma omp flush
128        }
129
130        MPI_Barrier_local(comm);
131
132
133        if(ep_rank_loc == receiver)
134        {
135          #pragma omp flush
136          #pragma omp critical (read_from_buffer)
137          {
138            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
139          }
140        }
141
142        MPI_Barrier_local(comm);
143      }
144    }
145    else if(datatype == MPI_DOUBLE)
146    {
147
148
149      double* send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
150      double* recv_buf = static_cast<double*>(recvbuf);
151      double* buffer = comm.my_buffer->buf_double;
152
153      for(int j=0; j<count; j+=BUFFER_SIZE)
154      {
155        if(ep_rank_loc == sender)
156        {
157          #pragma omp critical (write_to_buffer)
158          {
159            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
160          }
161          #pragma omp flush
162        }
163
164        MPI_Barrier_local(comm);
165
166
167        if(ep_rank_loc == receiver)
168        {
169          #pragma omp flush
170          #pragma omp critical (read_from_buffer)
171          {
172            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
173          }
174        }
175
176        MPI_Barrier_local(comm);
177      }
178    }
179    else if(datatype == MPI_LONG)
180    {
181      long* send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
182      long* recv_buf = static_cast<long*>(recvbuf);
183      long* buffer = comm.my_buffer->buf_long;
184
185      for(int j=0; j<count; j+=BUFFER_SIZE)
186      {
187        if(ep_rank_loc == sender)
188        {
189          #pragma omp critical (write_to_buffer)
190          {
191            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
192          }
193          #pragma omp flush
194        }
195
196        MPI_Barrier_local(comm);
197
198
199        if(ep_rank_loc == receiver)
200        {
201          #pragma omp flush
202          #pragma omp critical (read_from_buffer)
203          {
204            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
205          }
206        }
207
208        MPI_Barrier_local(comm);
209      }
210    }
211    else if(datatype == MPI_UNSIGNED_LONG)
212    {
213      unsigned long* send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
214      unsigned long* recv_buf = static_cast<unsigned long*>(recvbuf);
215      unsigned long* buffer = comm.my_buffer->buf_ulong;
216
217      for(int j=0; j<count; j+=BUFFER_SIZE)
218      {
219        if(ep_rank_loc == sender)
220        {
221          #pragma omp critical (write_to_buffer)
222          {
223            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
224          }
225          #pragma omp flush
226        }
227
228        MPI_Barrier_local(comm);
229
230
231        if(ep_rank_loc == receiver)
232        {
233          #pragma omp flush
234          #pragma omp critical (read_from_buffer)
235          {
236            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
237          }
238        }
239
240        MPI_Barrier_local(comm);
241      }
242    }
243    else if(datatype == MPI_CHAR)
244    {
245      char* send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
246      char* recv_buf = static_cast<char*>(recvbuf);
247      char* buffer = comm.my_buffer->buf_char;
248
249      for(int j=0; j<count; j+=BUFFER_SIZE)
250      {
251        if(ep_rank_loc == sender)
252        {
253          #pragma omp critical (write_to_buffer)
254          {
255            copy(send_buf+j, send_buf+j+min(BUFFER_SIZE, count-j), buffer);
256          }
257          #pragma omp flush
258        }
259
260        MPI_Barrier_local(comm);
261
262
263        if(ep_rank_loc == receiver)
264        {
265          #pragma omp flush
266          #pragma omp critical (read_from_buffer)
267          {
268            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);
269          }
270        }
271
272        MPI_Barrier_local(comm);
273      }
274    }
275    else
276    {
277      printf("datatype not supported!!\n");
278      exit(1);
279    }
280    return 0;
281  }
282
283
284  int MPI_Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
285  {
286/*
287    ::MPI_Aint datasize, char_size, lb;
288
289    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
290    ::MPI_Type_get_extent(MPI_CHAR, &lb, &char_size);
291
292    *count = status->char_count / ( datasize/ char_size);
293
294    //printf("MPI_Get_count, status_count  = %d\n", *count);
295    return 0;
296*/
297    ::MPI_Status *mpi_status = static_cast< ::MPI_Status* >(status->mpi_status);
298    ::MPI_Datatype mpi_datatype = static_cast< ::MPI_Datatype >(datatype);
299
300    ::MPI_Get_count(mpi_status, mpi_datatype, count);
301  }
302
303  double MPI_Wtime()
304  {
305    return ::MPI_Wtime();
306
307  }
308
309  void check_sum_send(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, int type)
310  {
311    int src_rank;
312    int int_count;
313    ::MPI_Aint datasize, intsize, charsize, lb;
314   
315    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
316    ::MPI_Type_get_extent(MPI_CHAR_STD, &lb, &intsize);
317
318    int_count = count * datasize / intsize ;
319
320    char *buffer = static_cast< char* >(const_cast< void*> (buf));
321   
322    unsigned long sum = 0;
323    for(int i = 0; i<int_count; i++)
324      sum += *(buffer+i); 
325
326
327    MPI_Comm_rank(comm, &src_rank);
328   
329    ofstream myfile;
330    myfile.open ("send_log.txt", ios::app);
331    if (myfile.is_open())
332    {
333      myfile << "type = " << type << " src = "<< src_rank<< " dest = "<< dest <<" tag = "<< tag << "  count = "<< count << " sum = "<< sum << "\n";
334      myfile.close(); 
335    }
336    else printf("Unable to open file\n");
337
338  }
339
340
341  void check_sum_recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm, int type)
342  {
343    int dest_rank;
344    int int_count;
345    ::MPI_Aint datasize, intsize, charsize, lb;
346   
347    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
348    ::MPI_Type_get_extent(MPI_CHAR_STD, &lb, &intsize);
349
350    int_count = count * datasize / intsize ;
351
352    char *buffer = static_cast< char* >(buf);
353   
354    unsigned long sum = 0;
355    for(int i = 0; i<int_count; i++)
356      sum += *(buffer+i); 
357
358
359    MPI_Comm_rank(comm, &dest_rank);
360   
361    ofstream myfile;
362    myfile.open ("recv_log.txt", ios::app);
363    if (myfile.is_open())
364    {
365      myfile << "type = " << type << " src = "<< src << " dest = "<< dest_rank <<" tag = "<< tag << "  count = "<< count << " sum = "<< sum << "\n";
366      myfile.close(); 
367    }
368    else printf("Unable to open file\n");
369
370  }
371
372  int test_sendrecv(MPI_Comm comm)
373  {
374    int myRank;
375    MPI_Comm_rank(comm, &myRank);
376    bool amClient = false;
377    bool amServer = false;
378    if(myRank<=3) amClient = true;
379    else amServer = true;
380
381    if(amServer)
382    {
383      int send_buf[4];
384      MPI_Request send_request[8];
385      MPI_Status send_status[8];
386
387     
388     
389      for(int j=0; j<4; j++)  // 4 buffers
390      {
391        for(int i=0; i<2; i++)
392        {
393          send_buf[j] = (myRank+1)*100 + j;
394          MPI_Isend(&send_buf[j], 1, MPI_INT, i*2, 9999, comm, &send_request[i*4+j]);
395        }
396      }
397     
398
399      MPI_Waitall(8, send_request, send_status);
400    }
401
402
403    if(amClient&&myRank%2==0) // Clients leaders
404    {
405      int recv_buf[8];
406      MPI_Request recv_request[8];
407      MPI_Status recv_status[8];
408
409      for(int i=0; i<2; i++)  // 2 servers
410      {
411        for(int j=0; j<4; j++)
412        {
413          MPI_Irecv(&recv_buf[i*4+j], 1, MPI_INT, i+4, 9999, comm, &recv_request[i*4+j]);
414        }
415      }
416
417      MPI_Waitall(8, recv_request, recv_status);
418      printf("============ client %d, recv_buf = %d, %d, %d, %d, %d, %d, %d, %d ================\n", 
419              myRank, recv_buf[0], recv_buf[1], recv_buf[2], recv_buf[3], recv_buf[4], recv_buf[5], recv_buf[6], recv_buf[7]);
420    }
421
422    MPI_Barrier(comm);
423
424  }
425
426}
427
428
429
430
431
432
433
Note: See TracBrowser for help on using the repository browser.