source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_alltoall.cpp @ 1339

Last change on this file since 1339 was 1339, checked in by yushan, 6 years ago

dev_omp OK

File size: 1.9 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_mpi.hpp"
4
5
6namespace ep_lib
7{
8
9  int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
10  {
11    if(!comm.is_ep)
12    {
13      return ::MPI_Alltoall(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm));
14    }
15
16
17    assert(valid_type(sendtype) && valid_type(recvtype));
18    assert(sendcount == recvcount);
19
20    ::MPI_Aint datasize, llb;
21    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &llb, &datasize);
22
23    int count = sendcount;
24   
25    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
26    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
27    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
28    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
29    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
30    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
31
32    void* tmp_recvbuf;
33    if(ep_rank == 0) tmp_recvbuf = new void*[count * ep_size * ep_size * datasize];
34
35    MPI_Gather(sendbuf, count*ep_size, sendtype, tmp_recvbuf, count*ep_size, recvtype, 0, comm);
36
37   
38   
39    // reorder tmp_buf
40    void* tmp_sendbuf;
41    if(ep_rank == 0) tmp_sendbuf = new void*[count * ep_size * ep_size * datasize];
42
43    if(ep_rank == 0)
44    for(int i=0; i<ep_size; i++)
45    {
46      for(int j=0; j<ep_size; j++)
47      {
48        //printf("tmp_recv[%d] = tmp_send[%d]\n", i*ep_size*count + j*count, j*ep_size*count + i*count);
49
50        memcpy(tmp_sendbuf + j*ep_size*count*datasize + i*count*datasize, tmp_recvbuf + i*ep_size*count*datasize + j*count*datasize, count*datasize);
51      }
52    }
53
54    MPI_Scatter(tmp_sendbuf, ep_size*count, sendtype, recvbuf, ep_size*recvcount, recvtype, 0, comm);
55
56    if(ep_rank == 0)
57    {
58      delete[] tmp_recvbuf;
59      delete[] tmp_sendbuf;
60    }
61   
62    MPI_Barrier(comm);
63
64    return 0;
65  }
66
67}
68
69
Note: See TracBrowser for help on using the repository browser.