source: XIOS/dev/branch_yushan_merged/extern/src_ep_dev/ep_gatherv.cpp @ 1164

Last change on this file since 1164 was 1164, checked in by yushan, 5 years ago

Bug fixed in MPI_(All)Gatherv with displs

File size: 26.5 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17
18  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
19  {
20    if(datatype == MPI_INT)
21    {
22      Debug("datatype is INT\n");
23      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
24    }
25    else if(datatype == MPI_FLOAT)
26    {
27      Debug("datatype is FLOAT\n");
28      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
29    }
30    else if(datatype == MPI_DOUBLE)
31    {
32      Debug("datatype is DOUBLE\n");
33      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
34    }
35    else if(datatype == MPI_LONG)
36    {
37      Debug("datatype is LONG\n");
38      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
39    }
40    else if(datatype == MPI_UNSIGNED_LONG)
41    {
42      Debug("datatype is uLONG\n");
43      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
44    }
45    else if(datatype == MPI_CHAR)
46    {
47      Debug("datatype is CHAR\n");
48      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
49    }
50    else
51    {
52      printf("MPI_Gatherv Datatype not supported!\n");
53      exit(0);
54    }
55  }
56
57  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
58  {
59    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
60    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
61
62    int *buffer = comm.my_buffer->buf_int;
63    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
64    int *recv_buf = static_cast<int*>(recvbuf);
65
66    if(my_rank == 0)
67    {
68      assert(count == recvcounts[0]);
69      copy(send_buf, send_buf+count, recv_buf + displs[0]);
70    }
71
72    for(int j=0; j<count; j+=BUFFER_SIZE)
73    {
74      for(int k=1; k<num_ep; k++)
75      {
76        if(my_rank == k)
77        {
78          #pragma omp critical (write_to_buffer)
79          {
80            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
81            #pragma omp flush
82          }
83        }
84
85        MPI_Barrier_local(comm);
86
87        if(my_rank == 0)
88        {
89          #pragma omp flush
90          #pragma omp critical (read_from_buffer)
91          {
92            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
93          }
94        }
95
96        MPI_Barrier_local(comm);
97      }
98    }
99  }
100
101  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
102  {
103    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
104    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
105
106    float *buffer = comm.my_buffer->buf_float;
107    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
108    float *recv_buf = static_cast<float*>(recvbuf);
109
110    if(my_rank == 0)
111    {
112      assert(count == recvcounts[0]);
113      copy(send_buf, send_buf+count, recv_buf + displs[0]);
114    }
115
116    for(int j=0; j<count; j+=BUFFER_SIZE)
117    {
118      for(int k=1; k<num_ep; k++)
119      {
120        if(my_rank == k)
121        {
122          #pragma omp critical (write_to_buffer)
123          {
124            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
125            #pragma omp flush
126          }
127        }
128
129        MPI_Barrier_local(comm);
130
131        if(my_rank == 0)
132        {
133          #pragma omp flush
134          #pragma omp critical (read_from_buffer)
135          {
136            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
137          }
138        }
139
140        MPI_Barrier_local(comm);
141      }
142    }
143  }
144
145  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
146  {
147    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
148    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
149
150    double *buffer = comm.my_buffer->buf_double;
151    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
152    double *recv_buf = static_cast<double*>(recvbuf);
153
154    if(my_rank == 0)
155    {
156      assert(count == recvcounts[0]);
157      copy(send_buf, send_buf+count, recv_buf + displs[0]);
158    }
159
160    for(int j=0; j<count; j+=BUFFER_SIZE)
161    {
162      for(int k=1; k<num_ep; k++)
163      {
164        if(my_rank == k)
165        {
166          #pragma omp critical (write_to_buffer)
167          {
168            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
169            #pragma omp flush
170          }
171        }
172
173        MPI_Barrier_local(comm);
174
175        if(my_rank == 0)
176        {
177          #pragma omp flush
178          #pragma omp critical (read_from_buffer)
179          {
180            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
181          }
182        }
183
184        MPI_Barrier_local(comm);
185      }
186    }
187  }
188
189  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
190  {
191    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
192    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
193
194    long *buffer = comm.my_buffer->buf_long;
195    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
196    long *recv_buf = static_cast<long*>(recvbuf);
197
198    if(my_rank == 0)
199    {
200      assert(count == recvcounts[0]);
201      copy(send_buf, send_buf+count, recv_buf + displs[0]);
202    }
203
204    for(int j=0; j<count; j+=BUFFER_SIZE)
205    {
206      for(int k=1; k<num_ep; k++)
207      {
208        if(my_rank == k)
209        {
210          #pragma omp critical (write_to_buffer)
211          {
212            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
213            #pragma omp flush
214          }
215        }
216
217        MPI_Barrier_local(comm);
218
219        if(my_rank == 0)
220        {
221          #pragma omp flush
222          #pragma omp critical (read_from_buffer)
223          {
224            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
225          }
226        }
227
228        MPI_Barrier_local(comm);
229      }
230    }
231  }
232
233  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
234  {
235    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
236    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
237
238    unsigned long *buffer = comm.my_buffer->buf_ulong;
239    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
240    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
241
242    if(my_rank == 0)
243    {
244      assert(count == recvcounts[0]);
245      copy(send_buf, send_buf+count, recv_buf + displs[0]);
246    }
247
248    for(int j=0; j<count; j+=BUFFER_SIZE)
249    {
250      for(int k=1; k<num_ep; k++)
251      {
252        if(my_rank == k)
253        {
254          #pragma omp critical (write_to_buffer)
255          {
256            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
257            #pragma omp flush
258          }
259        }
260
261        MPI_Barrier_local(comm);
262
263        if(my_rank == 0)
264        {
265          #pragma omp flush
266          #pragma omp critical (read_from_buffer)
267          {
268            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
269          }
270        }
271
272        MPI_Barrier_local(comm);
273      }
274    }
275  }
276
277  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
278  {
279    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
280    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
281
282    char *buffer = comm.my_buffer->buf_char;
283    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
284    char *recv_buf = static_cast<char*>(recvbuf);
285
286    if(my_rank == 0)
287    {
288      assert(count == recvcounts[0]);
289      copy(send_buf, send_buf+count, recv_buf + displs[0]);
290    }
291
292    for(int j=0; j<count; j+=BUFFER_SIZE)
293    {
294      for(int k=1; k<num_ep; k++)
295      {
296        if(my_rank == k)
297        {
298          #pragma omp critical (write_to_buffer)
299          {
300            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
301            #pragma omp flush
302          }
303        }
304
305        MPI_Barrier_local(comm);
306
307        if(my_rank == 0)
308        {
309          #pragma omp flush
310          #pragma omp critical (read_from_buffer)
311          {
312            copy(buffer, buffer+min(BUFFER_SIZE, recvcounts[k]-j), recv_buf+j+displs[k]);
313          }
314        }
315
316        MPI_Barrier_local(comm);
317      }
318    }
319  }
320
321
322  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
323                  MPI_Datatype recvtype, int root, MPI_Comm comm)
324  {
325 
326    if(!comm.is_ep && comm.mpi_comm)
327    {
328      ::MPI_Gatherv(const_cast<void*>(sendbuf), sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, const_cast<int*>(recvcounts), const_cast<int*>(displs),
329                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
330      return 0;
331    }
332
333    if(!comm.mpi_comm) return 0;
334
335    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
336
337    MPI_Datatype datatype = sendtype;
338    int count = sendcount;
339
340    int ep_rank, ep_rank_loc, mpi_rank;
341    int ep_size, num_ep, mpi_size;
342
343    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
344    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
345    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
346    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
347    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
348    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
349   
350   
351   
352    if(ep_size == mpi_size) 
353      return ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
354                              static_cast< ::MPI_Datatype>(datatype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
355
356    if(ep_rank != root)
357    {
358      recvcounts = new int[ep_size];
359      displs = new int[ep_size];
360    }
361   
362    MPI_Bcast(const_cast< int* >(displs),     ep_size, MPI_INT, root, comm);
363    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
364                             
365
366    int recv_plus_displs[ep_size];
367    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
368
369    for(int j=0; j<mpi_size; j++)
370    {
371      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||
372         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 
373      { 
374        Debug("Call special implementation of mpi_gatherv. 1st condition not OK\n");
375        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
376      }
377
378      for(int i=1; i<num_ep-1; i++)
379      {
380        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 
381           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])
382        {
383          Debug("Call special implementation of mpi_gatherv. 2nd condition not OK\n");
384          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
385        }
386      }
387    }
388
389
390    int root_mpi_rank = comm.rank_map->at(root).second;
391    int root_ep_loc = comm.rank_map->at(root).first;
392
393
394    ::MPI_Aint datasize, lb;
395
396    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
397
398    void *local_gather_recvbuf;
399    int buffer_size;
400    void *master_recvbuf;
401
402    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) 
403    {
404      master_recvbuf = new void*[sizeof(recvbuf)];
405      assert(root_ep_loc == 0);
406    }
407
408    if(ep_rank_loc==0)
409    {
410      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
411
412      local_gather_recvbuf = new void*[datasize*buffer_size];
413    }
414
415    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
416
417    //MPI_Gather
418    if(ep_rank_loc == 0)
419    {
420      int *mpi_recvcnt= new int[mpi_size];
421      int *mpi_displs= new int[mpi_size];
422
423      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
424      int buff_end = buffer_size;
425
426      int mpi_sendcnt = buff_end - buff_start;
427
428
429      ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
430      ::MPI_Gather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
431
432      if(root_ep_loc == 0)
433      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
434                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
435      }
436      else  // gatherv to master_recvbuf
437      {  ::MPI_Gatherv(local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), master_recvbuf, mpi_recvcnt,
438                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
439      }
440
441      delete[] mpi_recvcnt;
442      delete[] mpi_displs;
443    }
444
445    int global_min_displs = *std::min_element(displs, displs+ep_size);
446    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
447
448
449    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
450    {
451      innode_memcpy(0, master_recvbuf+datasize*global_min_displs, root_ep_loc, recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
452      if(ep_rank_loc == 0) delete[] master_recvbuf;
453    }
454
455
456
457    if(ep_rank_loc==0)
458    {
459      if(datatype == MPI_INT)
460      {
461        delete[] static_cast<int*>(local_gather_recvbuf);
462      }
463      else if(datatype == MPI_FLOAT)
464      {
465        delete[] static_cast<float*>(local_gather_recvbuf);
466      }
467      else if(datatype == MPI_DOUBLE)
468      {
469        delete[] static_cast<double*>(local_gather_recvbuf);
470      }
471      else if(datatype == MPI_LONG)
472      {
473        delete[] static_cast<long*>(local_gather_recvbuf);
474      }
475      else if(datatype == MPI_UNSIGNED_LONG)
476      {
477        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
478      }
479      else // if(datatype == MPI_CHAR)
480      {
481        delete[] static_cast<char*>(local_gather_recvbuf);
482      }
483    }
484    else
485    {
486      delete[] recvcounts;
487      delete[] displs;
488    }
489    return 0;
490  }
491
492
493
494  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
495                  MPI_Datatype recvtype, MPI_Comm comm)
496  {
497
498    if(!comm.is_ep && comm.mpi_comm)
499    {
500      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
501                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
502      return 0;
503    }
504
505    if(!comm.mpi_comm) return 0;
506
507    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
508
509
510    MPI_Datatype datatype = sendtype;
511    int count = sendcount;
512
513    int ep_rank, ep_rank_loc, mpi_rank;
514    int ep_size, num_ep, mpi_size;
515
516    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
517    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
518    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
519    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
520    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
521    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
522
523    //printf("size of recvbuf = %lu\n", sizeof(recvbuf));
524    //printf("size of (char*)recvbuf = %lu\n", sizeof((char*)recvbuf));
525   
526    if(ep_size == mpi_size) 
527      return ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(datatype), recvbuf, recvcounts, displs,
528                              static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
529   
530
531    int recv_plus_displs[ep_size];
532    for(int i=0; i<ep_size; i++) recv_plus_displs[i] = recvcounts[i] + displs[i];
533
534    for(int j=0; j<mpi_size; j++)
535    {
536      if(recv_plus_displs[j*num_ep] < displs[j*num_ep+1] ||
537         recv_plus_displs[j*num_ep + num_ep -1] < displs[j*num_ep + num_ep -2]) 
538      { 
539        Debug("Call special implementation of mpi_allgatherv.\n");
540        return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
541      }
542
543      for(int i=1; i<num_ep-1; i++)
544      {
545        if(recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i+1] || 
546           recv_plus_displs[j*num_ep+i] < displs[j*num_ep+i-1])
547        {
548          Debug("Call special implementation of mpi_allgatherv.\n");
549          return MPI_Allgatherv_special(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
550        }
551      }
552    }
553
554    ::MPI_Aint datasize, lb;
555
556    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
557
558    void *local_gather_recvbuf;
559    int buffer_size;
560
561    if(ep_rank_loc==0)
562    {
563      buffer_size = *std::max_element(recv_plus_displs+ep_rank, recv_plus_displs+ep_rank+num_ep);
564
565      local_gather_recvbuf = new void*[datasize*buffer_size];
566    }
567
568    // local gather to master
569    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, displs+ep_rank-ep_rank_loc, comm);
570
571    //MPI_Gather
572    if(ep_rank_loc == 0)
573    {
574      int *mpi_recvcnt= new int[mpi_size];
575      int *mpi_displs= new int[mpi_size];
576
577      int buff_start = *std::min_element(displs+ep_rank, displs+ep_rank+num_ep);;
578      int buff_end = buffer_size;
579
580      int mpi_sendcnt = buff_end - buff_start;
581
582
583      ::MPI_Allgather(&mpi_sendcnt, 1, MPI_INT_STD, mpi_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
584      ::MPI_Allgather(&buff_start,  1, MPI_INT_STD, mpi_displs,  1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
585
586
587      ::MPI_Allgatherv((char*)local_gather_recvbuf + datasize*buff_start, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, mpi_recvcnt,
588                       mpi_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
589
590      delete[] mpi_recvcnt;
591      delete[] mpi_displs;
592    }
593
594    int global_min_displs = *std::min_element(displs, displs+ep_size);
595    int global_recvcnt = *std::max_element(recv_plus_displs, recv_plus_displs+ep_size);
596
597    MPI_Bcast_local(recvbuf+datasize*global_min_displs, global_recvcnt, datatype, comm);
598
599    if(ep_rank_loc==0)
600    {
601      if(datatype == MPI_INT)
602      {
603        delete[] static_cast<int*>(local_gather_recvbuf);
604      }
605      else if(datatype == MPI_FLOAT)
606      {
607        delete[] static_cast<float*>(local_gather_recvbuf);
608      }
609      else if(datatype == MPI_DOUBLE)
610      {
611        delete[] static_cast<double*>(local_gather_recvbuf);
612      }
613      else if(datatype == MPI_LONG)
614      {
615        delete[] static_cast<long*>(local_gather_recvbuf);
616      }
617      else if(datatype == MPI_UNSIGNED_LONG)
618      {
619        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
620      }
621      else // if(datatype == MPI_CHAR)
622      {
623        delete[] static_cast<char*>(local_gather_recvbuf);
624      }
625    }
626  }
627
628  int MPI_Gatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
629                          MPI_Datatype recvtype, int root, MPI_Comm comm)
630  {
631    int ep_rank, ep_rank_loc, mpi_rank;
632    int ep_size, num_ep, mpi_size;
633
634    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
635    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
636    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
637    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
638    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
639    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
640
641    int root_mpi_rank = comm.rank_map->at(root).second;
642    int root_ep_loc = comm.rank_map->at(root).first;
643
644    ::MPI_Aint datasize, lb;
645    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);
646
647    void *local_gather_recvbuf;
648    int buffer_size;
649
650    int *local_displs = new int[num_ep];
651    int *local_rvcnts = new int[num_ep];
652    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];
653    local_displs[0] = 0;
654    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];
655
656    if(ep_rank_loc==0)
657    {
658      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];
659      local_gather_recvbuf = new void*[datasize*buffer_size];
660    }
661
662    // local gather to master
663    MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master
664
665    int **mpi_recvcnts = new int*[num_ep];
666    int **mpi_displs   = new int*[num_ep];
667    for(int i=0; i<num_ep; i++) 
668    {
669      mpi_recvcnts[i] = new int[mpi_size];
670      mpi_displs[i]   = new int[mpi_size];
671      for(int j=0; j<mpi_size; j++)
672      {
673        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];
674        mpi_displs[i][j]   = displs[j*num_ep + i];
675      }
676    } 
677
678    void *master_recvbuf;
679    if(ep_rank_loc == 0 && mpi_rank == root_mpi_rank && root_ep_loc != 0) master_recvbuf = new void*[sizeof(recvbuf)];
680
681    if(ep_rank_loc == 0 && root_ep_loc == 0) // master in MPI_Allgatherv loop
682      for(int i=0; i<num_ep; i++)
683      {
684        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],
685                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
686      }
687    if(ep_rank_loc == 0 && root_ep_loc != 0)
688      for(int i=0; i<num_ep; i++)
689      {
690        ::MPI_Gatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), master_recvbuf, mpi_recvcnts[i], mpi_displs[i],
691                    static_cast< ::MPI_Datatype>(recvtype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
692      }
693
694
695    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
696    {
697      for(int i=0; i<ep_size; i++)
698        innode_memcpy(0, master_recvbuf + datasize*displs[i], root_ep_loc, recvbuf + datasize*displs[i], recvcounts[i], sendtype, comm);
699
700      if(ep_rank_loc == 0) delete[] master_recvbuf;
701    }
702
703   
704    delete[] local_displs;
705    delete[] local_rvcnts;
706    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 
707                                  delete[] mpi_displs[i]; }
708    delete[] mpi_recvcnts;
709    delete[] mpi_displs;
710    if(ep_rank_loc==0)
711    {
712      if(sendtype == MPI_INT)
713      {
714        delete[] static_cast<int*>(local_gather_recvbuf);
715      }
716      else if(sendtype == MPI_FLOAT)
717      {
718        delete[] static_cast<float*>(local_gather_recvbuf);
719      }
720      else if(sendtype == MPI_DOUBLE)
721      {
722        delete[] static_cast<double*>(local_gather_recvbuf);
723      }
724      else if(sendtype == MPI_LONG)
725      {
726        delete[] static_cast<long*>(local_gather_recvbuf);
727      }
728      else if(sendtype == MPI_UNSIGNED_LONG)
729      {
730        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
731      }
732      else // if(sendtype == MPI_CHAR)
733      {
734        delete[] static_cast<char*>(local_gather_recvbuf);
735      }
736    }
737  }
738
739  int MPI_Allgatherv_special(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
740                             MPI_Datatype recvtype, MPI_Comm comm)
741  {
742    int ep_rank, ep_rank_loc, mpi_rank;
743    int ep_size, num_ep, mpi_size;
744
745    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
746    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
747    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
748    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
749    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
750    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
751
752
753    ::MPI_Aint datasize, lb;
754    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &lb, &datasize);
755
756    void *local_gather_recvbuf;
757    int buffer_size;
758
759    int *local_displs = new int[num_ep];
760    int *local_rvcnts = new int[num_ep];
761    for(int i=0; i<num_ep; i++) local_rvcnts[i] = recvcounts[ep_rank-ep_rank_loc + i];
762    local_displs[0] = 0;
763    for(int i=1; i<num_ep; i++) local_displs[i] = local_displs[i-1] + local_rvcnts[i-1];
764
765    if(ep_rank_loc==0)
766    {
767      buffer_size = local_displs[num_ep-1] + recvcounts[ep_rank+num_ep-1];
768      local_gather_recvbuf = new void*[datasize*buffer_size];
769    }
770
771    // local gather to master
772    MPI_Gatherv_local(sendbuf, sendcount, sendtype, local_gather_recvbuf, local_rvcnts, local_displs, comm); // all sendbuf gathered to master
773
774    int **mpi_recvcnts = new int*[num_ep];
775    int **mpi_displs   = new int*[num_ep];
776    for(int i=0; i<num_ep; i++) 
777    {
778      mpi_recvcnts[i] = new int[mpi_size];
779      mpi_displs[i]   = new int[mpi_size];
780      for(int j=0; j<mpi_size; j++)
781      {
782        mpi_recvcnts[i][j] = recvcounts[j*num_ep + i];
783        mpi_displs[i][j]   = displs[j*num_ep + i];
784      }
785    } 
786
787    if(ep_rank_loc == 0) // master in MPI_Allgatherv loop
788    for(int i=0; i<num_ep; i++)
789    {
790      ::MPI_Allgatherv(local_gather_recvbuf + datasize*local_displs[i], recvcounts[ep_rank+i], static_cast< ::MPI_Datatype>(sendtype), recvbuf, mpi_recvcnts[i], mpi_displs[i],
791                  static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
792    }
793
794    for(int i=0; i<ep_size; i++)
795      MPI_Bcast_local(recvbuf + datasize*displs[i], recvcounts[i], recvtype, comm);
796
797   
798    delete[] local_displs;
799    delete[] local_rvcnts;
800    for(int i=0; i<num_ep; i++) { delete[] mpi_recvcnts[i]; 
801                                  delete[] mpi_displs[i]; }
802    delete[] mpi_recvcnts;
803    delete[] mpi_displs;
804    if(ep_rank_loc==0)
805    {
806      if(sendtype == MPI_INT)
807      {
808        delete[] static_cast<int*>(local_gather_recvbuf);
809      }
810      else if(sendtype == MPI_FLOAT)
811      {
812        delete[] static_cast<float*>(local_gather_recvbuf);
813      }
814      else if(sendtype == MPI_DOUBLE)
815      {
816        delete[] static_cast<double*>(local_gather_recvbuf);
817      }
818      else if(sendtype == MPI_LONG)
819      {
820        delete[] static_cast<long*>(local_gather_recvbuf);
821      }
822      else if(sendtype == MPI_UNSIGNED_LONG)
823      {
824        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
825      }
826      else // if(sendtype == MPI_CHAR)
827      {
828        delete[] static_cast<char*>(local_gather_recvbuf);
829      }
830    }
831  }
832
833
834}
Note: See TracBrowser for help on using the repository browser.