source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_gatherv.cpp @ 1209

Last change on this file since 1209 was 1209, checked in by yushan, 7 years ago

bug corrected. happened when certain threads send 0 elements in the allgatherv call

File size: 26.4 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
49    }
50    else
51    {
52      printf("MPI_Gatherv Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      assert(count == recvcounts[0]);
69      copy(send_buf, send_buf+count, recv_buf + displs[0]);
70    }
71
72    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
73    {
74      for(int k=1; k<num_ep; k++)
75      {
76        if(my_rank == k)
77        {
78          #pragma omp critical (write_to_buffer)
79          {
80            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
81            #pragma omp flush
82          }
83        }
84
85        MPI_Barrier_local(comm);
86
87        if(my_rank == 0)
88        {
89          #pragma omp flush
90          #pragma omp critical (read_from_buffer)
91          {
92            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
93          }
94        }
95
96        MPI_Barrier_local(comm);
97      }
98    }
99  }
100
101  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
102  {
103    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
104    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
105
106    float *buffer = comm.my_buffer->buf_float;
107    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
108    float *recv_buf = static_cast<float*>(recvbuf);
109
110    if(my_rank == 0)
111    {
112      assert(count == recvcounts[0]);
113      copy(send_buf, send_buf+count, recv_buf + displs[0]);
114    }
115
116    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
117    {
118      for(int k=1; k<num_ep; k++)
119      {
120        if(my_rank == k)
121        {
122          #pragma omp critical (write_to_buffer)
123          {
124            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
125            #pragma omp flush
126          }
127        }
128
129        MPI_Barrier_local(comm);
130
131        if(my_rank == 0)
132        {
133          #pragma omp flush
134          #pragma omp critical (read_from_buffer)
135          {
136            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
137          }
138        }
139
140        MPI_Barrier_local(comm);
141      }
142    }
143  }
144
145  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
146  {
147    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
148    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
149
150    double *buffer = comm.my_buffer->buf_double;
151    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
152    double *recv_buf = static_cast<double*>(recvbuf);
153
154    if(my_rank == 0)
155    {
156      assert(count == recvcounts[0]);
157      copy(send_buf, send_buf+count, recv_buf + displs[0]);
158    }
159
160    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
161    {
162      for(int k=1; k<num_ep; k++)
163      {
164        if(my_rank == k)
165        {
166          #pragma omp critical (write_to_buffer)
167          {
168            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
169            #pragma omp flush
170          }
171        }
172
173        MPI_Barrier_local(comm);
174
175        if(my_rank == 0)
176        {
177          #pragma omp flush
178          #pragma omp critical (read_from_buffer)
179          {
180            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
181          }
182        }
183
184        MPI_Barrier_local(comm);
185      }
186    }
187  }
188
189  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
190  {
191    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
192    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
193
194    long *buffer = comm.my_buffer->buf_long;
195    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
196    long *recv_buf = static_cast<long*>(recvbuf);
197
198    if(my_rank == 0)
199    {
200      assert(count == recvcounts[0]);
201      copy(send_buf, send_buf+count, recv_buf + displs[0]);
202    }
203
204    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
205    {
206      for(int k=1; k<num_ep; k++)
207      {
208        if(my_rank == k)
209        {
210          #pragma omp critical (write_to_buffer)
211          {
212            if(count!=0)copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
213            #pragma omp flush
214          }
215        }
216
217        MPI_Barrier_local(comm);
218
219        if(my_rank == 0)
220        {
221          #pragma omp flush
222          #pragma omp critical (read_from_buffer)
223          {
224            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
225          }
226        }
227
228        MPI_Barrier_local(comm);
229      }
230    }
231  }
232
233  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
234  {
235    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
236    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
237
238    unsigned long *buffer = comm.my_buffer->buf_ulong;
239    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
240    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
241
242    if(my_rank == 0)
243    {
244      assert(count == recvcounts[0]);
245      copy(send_buf, send_buf+count, recv_buf + displs[0]);
246    }
247
248    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
249    {
250      for(int k=1; k<num_ep; k++)
251      {
252        if(my_rank == k)
253        {
254          #pragma omp critical (write_to_buffer)
255          {
256            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
257            #pragma omp flush
258          }
259        }
260
261        MPI_Barrier_local(comm);
262
263        if(my_rank == 0)
264        {
265          #pragma omp flush
266          #pragma omp critical (read_from_buffer)
267          {
268            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
269          }
270        }
271
272        MPI_Barrier_local(comm);
273      }
274    }
275  }
276
277  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
278  {
279    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
280    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
281
282    char *buffer = comm.my_buffer->buf_char;
283    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
284    char *recv_buf = static_cast<char*>(recvbuf);
285
286    if(my_rank == 0)
287    {
288      assert(count == recvcounts[0]);
289      copy(send_buf, send_buf+count, recv_buf + displs[0]);
290    }
291
292    for(int j=0; count!=0? j<count: j<count+1; j+=BUFFER_SIZE)
293    {
294      for(int k=1; k<num_ep; k++)
295      {
296        if(my_rank == k)
297        {
298          #pragma omp critical (write_to_buffer)
299          {
300            if(count!=0) copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
301            #pragma omp flush
302          }
303        }
304
305        MPI_Barrier_local(comm);
306
307        if(my_rank == 0)
308        {
309          #pragma omp flush
310          #pragma omp critical (read_from_buffer)
311          {
312            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
313          }
314        }
315
316        MPI_Barrier_local(comm);
317      }
318    }
319  }
320
321
322  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
323                  MPI_Datatype recvtype, int root, MPI_Comm comm)
324  {
325 
326    if(!comm.is_ep && comm.mpi_comm)
327    {
328      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
329                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
330      return 0;
331    }
332
333    if(!comm.mpi_comm) return 0;
334
335    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
336
337    MPI_Datatype datatype = sendtype;
338    int count = sendcount;
339
340    int ep_rank, ep_rank_loc, mpi_rank;
341    int ep_size, num_ep, mpi_size;
342
343    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
344    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
345    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
346    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
347    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
348    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
349   
350   
351   
352    if(ep_size == mpi_size) 
353      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
354                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
355
356    if(ep_rank != root)
357    {
358      recvcounts = new int[ep_size];
359      displs = new int[ep_size];
360    }
361   
362    MPI_Bcast(const_cast< int* >(displs),     ep_size, MPI_INT, root, comm);
363    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
364                             
365
366    int recv_plus_displs[ep_size];
367    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
368
369    for(int j=0; j<mpi_size; j++)
370    {
371      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||
372         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 
373      { 
374        Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n");
375        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
376      }
377
378      for(int i=1; i<num_ep-1; i++)
379      {
380        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 
381           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])
382        {
383          Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n");
384          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
385        }
386      }
387    }
388
389
390    int root_mpi_rank = comm.rank_map->at(root).second;
391    int root_ep_loc = comm.rank_map->at(root).first;
392
393
394    ::MPI_Aint datasize, lb;
395
396    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
397
398    void *local_gather_recvbuf;
399    int buffer_size;
400    void *master_recvbuf;
401
402    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 
403    {
404      master_recvbuf = new void*[sizeof(recvbuf)];
405      assert(root_ep_loc == 0);
406    }
407
408    if(ep_rank_loc==0)
409    {
410      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
411
412      local_gather_recvbuf = new void*[datasize*buffer_size];
413    }
414
415    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
416
417    //MPI_Gather
418    if(ep_rank_loc == 0)
419    {
420      int *mpi_recvcnt= new int[mpi_size];
421      int *mpi_displs= new int[mpi_size];
422
423      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
424      int buff_end = buffer_size;
425
426      int mpi_sendcnt = buff_end - buff_start;
427
428
429      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
430      ::MPI_Gather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
431
432      if(root_ep_loc == 0)
433      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
434                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
435      }
436      else  // gatherv to master_recvbuf
437      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt,
438                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
439      }
440
441      delete[] mpi_recvcnt;
442      delete[] mpi_displs;
443    }
444
445    int global_min_displs = *std::min_element(displs, displs+ep_size);
446    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
447
448
449    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
450    {
451      innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
452      if(ep_rank_loc == 0) delete[] master_recvbuf;
453    }
454
455
456
457    if(ep_rank_loc==0)
458    {
459      if(datatype == MPI_INT)
460      {
461        delete[] static_cast<int*>(local_gather_recvbuf);
462      }
463      else if(datatype == MPI_FLOAT)
464      {
465        delete[] static_cast<float*>(local_gather_recvbuf);
466      }
467      else if(datatype == MPI_DOUBLE)
468      {
469        delete[] static_cast<double*>(local_gather_recvbuf);
470      }
471      else if(datatype == MPI_LONG)
472      {
473        delete[] static_cast<long*>(local_gather_recvbuf);
474      }
475      else if(datatype == MPI_UNSIGNED_LONG)
476      {
477        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
478      }
479      else // if(datatype == MPI_CHAR)
480      {
481        delete[] static_cast<char*>(local_gather_recvbuf);
482      }
483    }
484    else
485    {
486      delete[] recvcounts;
487      delete[] displs;
488    }
489    return 0;
490  }
491
492
493
494  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
495                  MPI_Datatype recvtype, MPI_Comm comm)
496  {
497
498    if(!comm.is_ep && comm.mpi_comm)
499    {
500      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
501                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
502      return 0;
503    }
504
505    if(!comm.mpi_comm) return 0;
506
507    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
508
509
510    MPI_Datatype datatype = sendtype;
511    int count = sendcount;
512
513    int ep_rank, ep_rank_loc, mpi_rank;
514    int ep_size, num_ep, mpi_size;
515
516    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
517    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
518    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
519    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
520    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
521    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
522   
523
524    int recv_plus_displs[ep_size];
525    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
526
527    for(int j=0; j<mpi_size; j++)
528    {
529      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||
530         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 
531      { 
532        //printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size);
533        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
534      }
535
536      for(int i=1; i<num_ep-1; i++)
537      {
538        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 
539           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])
540        {
541          //printf("proc %d/%d Call special implementation of mpi_allgatherv.\n", ep_rank, ep_size);
542          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
543        }
544      }
545    }
546
547    ::MPI_Aint datasize, lb;
548
549    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
550
551    void *local_gather_recvbuf;
552    int buffer_size;
553
554    if(ep_rank_loc==0)
555    {
556      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
557
558      local_gather_recvbuf = new void*[datasize*buffer_size];
559    }
560
561    // local gather to master
562    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
563
564    //MPI_Gather
565    if(ep_rank_loc == 0)
566    {
567      int *mpi_recvcnt= new int[mpi_size];
568      int *mpi_displs= new int[mpi_size];
569
570      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
571      int buff_end = buffer_size;
572
573      int mpi_sendcnt = buff_end - buff_start;
574
575
576      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
577      ::MPI_Allgather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
578
579
580      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
581                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
582
583      delete[] mpi_recvcnt;
584      delete[] mpi_displs;
585    }
586
587    int global_min_displs = *std::min_element(displs, displs+ep_size);
588    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
589
590    MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
591
592    if(ep_rank_loc==0)
593    {
594      if(datatype == MPI_INT)
595      {
596        delete[] static_cast<int*>(local_gather_recvbuf);
597      }
598      else if(datatype == MPI_FLOAT)
599      {
600        delete[] static_cast<float*>(local_gather_recvbuf);
601      }
602      else if(datatype == MPI_DOUBLE)
603      {
604        delete[] static_cast<double*>(local_gather_recvbuf);
605      }
606      else if(datatype == MPI_LONG)
607      {
608        delete[] static_cast<long*>(local_gather_recvbuf);
609      }
610      else if(datatype == MPI_UNSIGNED_LONG)
611      {
612        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
613      }
614      else // if(datatype == MPI_CHAR)
615      {
616        delete[] static_cast<char*>(local_gather_recvbuf);
617      }
618    }
619  }
620
621  int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
622                          MPI_Datatype recvtype, int root, MPI_Comm comm)
623  {
624    int ep_rank, ep_rank_loc, mpi_rank;
625    int ep_size, num_ep, mpi_size;
626
627    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
628    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
629    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
630    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
631    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
632    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
633
634    int root_mpi_rank = comm.rank_map->at(root).second;
635    int root_ep_loc = comm.rank_map->at(root).first;
636
637    ::MPI_Aint datasize, lb;
638    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);
639
640    void *local_gather_recvbuf;
641    int buffer_size;
642
643    int *local_displs = new int[num_ep];
644    int *local_rvcnts = new int[num_ep];
645    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];
646    local_displs[0] = 0;
647    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];
648
649    if(ep_rank_loc==0)
650    {
651      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];
652      local_gather_recvbuf = new void*[datasize*buffer_size];
653    }
654
655    // local gather to master
656    MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master
657
658    int **mpi_recvcnts = new int*[num_ep];
659    int **mpi_displs   = new int*[num_ep];
660    for(int i=0; i<num_ep; i++) 
661    {
662      mpi_recvcnts[i] = new int[mpi_size];
663      mpi_displs[i]   = new int[mpi_size];
664      for(int j=0; j<mpi_size; j++)
665      {
666        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];
667        mpi_displs[i][j]   = displs[j*num_ep + i];
668      }
669    } 
670
671    void *master_recvbuf;
672    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
673
674    if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop
675      for(int i=0; i<num_ep; i++)
676      {
677        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],
678                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
679      }
680    if(ep_rank_loc == 0 && root_ep_loc != 0)
681      for(int i=0; i<num_ep; i++)
682      {
683        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i],
684                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
685      }
686
687
688    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
689    {
690      for(int i=0; i<ep_size; i++)
691        innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm);
692
693      if(ep_rank_loc == 0) delete[] master_recvbuf;
694    }
695
696   
697    delete[] local_displs;
698    delete[] local_rvcnts;
699    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 
700                                  delete[] mpi_displs[i]; }
701    delete[] mpi_recvcnts;
702    delete[] mpi_displs;
703    if(ep_rank_loc==0)
704    {
705      if(sendtype == MPI_INT)
706      {
707        delete[] static_cast<int*>(local_gather_recvbuf);
708      }
709      else if(sendtype == MPI_FLOAT)
710      {
711        delete[] static_cast<float*>(local_gather_recvbuf);
712      }
713      else if(sendtype == MPI_DOUBLE)
714      {
715        delete[] static_cast<double*>(local_gather_recvbuf);
716      }
717      else if(sendtype == MPI_LONG)
718      {
719        delete[] static_cast<long*>(local_gather_recvbuf);
720      }
721      else if(sendtype == MPI_UNSIGNED_LONG)
722      {
723        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
724      }
725      else // if(sendtype == MPI_CHAR)
726      {
727        delete[] static_cast<char*>(local_gather_recvbuf);
728      }
729    }
730  }
731
732  int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
733                             MPI_Datatype recvtype, MPI_Comm comm)
734  {
735    int ep_rank, ep_rank_loc, mpi_rank;
736    int ep_size, num_ep, mpi_size;
737
738    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
739    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
740    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
741    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
742    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
743    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
744
745
746    ::MPI_Aint datasize, lb;
747    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);
748
749    void *local_gather_recvbuf;
750    int buffer_size;
751
752    int *local_displs = new int[num_ep];
753    int *local_rvcnts = new int[num_ep];
754    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];
755    local_displs[0] = 0;
756    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];
757
758    if(ep_rank_loc==0)
759    {
760      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];
761      local_gather_recvbuf = new void*[datasize*buffer_size];
762    }
763
764    // local gather to master
765    MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master
766
767    int **mpi_recvcnts = new int*[num_ep];
768    int **mpi_displs   = new int*[num_ep];
769    for(int i=0; i<num_ep; i++) 
770    {
771      mpi_recvcnts[i] = new int[mpi_size];
772      mpi_displs[i]   = new int[mpi_size];
773      for(int j=0; j<mpi_size; j++)
774      {
775        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];
776        mpi_displs[i][j]   = displs[j*num_ep + i];
777      }
778    } 
779
780    if(ep_rank_loc == 0) // master in MPI_Allgatherv loop
781    for(int i=0; i<num_ep; i++)
782    {
783      ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],
784                  static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
785    }
786
787    for(int i=0; i<ep_size; i++)
788      MPI_Bcast_local(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm);
789
790   
791    delete[] local_displs;
792    delete[] local_rvcnts;
793    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 
794                                  delete[] mpi_displs[i]; }
795    delete[] mpi_recvcnts;
796    delete[] mpi_displs;
797    if(ep_rank_loc==0)
798    {
799      if(sendtype == MPI_INT)
800      {
801        delete[] static_cast<int*>(local_gather_recvbuf);
802      }
803      else if(sendtype == MPI_FLOAT)
804      {
805        delete[] static_cast<float*>(local_gather_recvbuf);
806      }
807      else if(sendtype == MPI_DOUBLE)
808      {
809        delete[] static_cast<double*>(local_gather_recvbuf);
810      }
811      else if(sendtype == MPI_LONG)
812      {
813        delete[] static_cast<long*>(local_gather_recvbuf);
814      }
815      else if(sendtype == MPI_UNSIGNED_LONG)
816      {
817        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
818      }
819      else // if(sendtype == MPI_CHAR)
820      {
821        delete[] static_cast<char*>(local_gather_recvbuf);
822      }
823    }
824  }
825
826
827}
Note: See TracBrowser for help on using the repository browser.