source: XIOS/dev/dev_trunk_omp/extern/src_ep_dev/ep_reduce_scatter.cpp @ 1646

Last change on this file since 1646 was 1646, checked in by yushan, 5 years ago

branch merged with trunk @1645. arch file (ep&mpi) added for ADA

File size: 2.6 KB
Line 
1#ifdef _usingEP
2/*!
3   \file ep_reduce.cpp
4   \since 2 may 2016
5
6   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
7 */
8
9#include "ep_lib.hpp"
10#include <mpi.h>
11#include "ep_declaration.hpp"
12#include "ep_mpi.hpp"
13
14using namespace std;
15
16
17namespace ep_lib {
18
19
20  int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
21  {
22    if(!comm->is_ep) return ::MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
23    if(comm->is_intercomm) return MPI_Reduce_scatter_intercomm(sendbuf, recvbuf, recvcounts, datatype, op, comm);
24
25
26    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
27    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
28    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
29    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
30    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
31    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
32
33
34    ::MPI_Aint datasize, lb;
35
36    ::MPI_Type_get_extent(*(static_cast< ::MPI_Datatype*>(datatype)), &lb, &datasize);
37
38    bool is_master = ep_rank_loc==0;
39
40    void* local_recvbuf;
41
42    int count = accumulate(recvcounts, recvcounts+ep_size, 0);
43    if(is_master)
44    {
45      local_recvbuf = new void*[datasize * count]; 
46    }
47
48    MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
49
50   
51    if(is_master)
52    { 
53      ::MPI_Allreduce(MPI_IN_PLACE, local_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
54    }
55
56
57    // master have reduced data
58    // local scatterv
59
60    std::vector<int> local_recvcounts(num_ep, 0);
61    std::vector<int>local_displs(num_ep, 0);
62
63    int my_recvcount = recvcounts[ep_rank];
64    MPI_Gather_local(&my_recvcount, 1, MPI_INT, local_recvcounts.data(), 0, comm);
65    MPI_Bcast_local(local_recvcounts.data(), num_ep, MPI_INT, 0, comm);
66
67    int my_displs = std::accumulate(recvcounts, recvcounts+ep_rank, 0);
68    MPI_Gather_local(&my_displs, 1, MPI_INT, local_displs.data(), 0, comm);
69    MPI_Bcast_local(local_displs.data(), num_ep, MPI_INT, 0, comm);
70
71   
72
73    MPI_Scatterv_local(local_recvbuf, local_recvcounts.data(), local_displs.data(), datatype, recvbuf, recvcounts[ep_rank], datatype, 0, comm);
74
75    if(is_master)
76    {
77      delete[] local_recvbuf;
78    }
79
80  }
81
82  int MPI_Reduce_scatter_intercomm(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
83  {
84    printf("MPI_Reduce_scatter_intercomm not yet implemented\n");
85    MPI_Abort(comm, 0);
86  }
87
88 
89}
90
91#endif
Note: See TracBrowser for help on using the repository browser.