source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_reduce_scatter.cpp @ 1287

Last change on this file since 1287 was 1287, checked in by yushan, 7 years ago

EP updated

File size: 2.2 KB
Line 
1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15
16namespace ep_lib {
17
18
19  int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
20  {
21    if(!comm.is_ep)
22    {
23      return ::MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
24    }
25
26
27    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
28    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
29    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
30    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
31    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
32    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
33
34
35    ::MPI_Aint datasize, lb;
36
37    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
38
39    bool is_master = ep_rank_loc==0;
40
41    void* local_recvbuf;
42
43    int count = accumulate(recvcounts, recvcounts+ep_size, 0);
44    if(is_master)
45    {
46      local_recvbuf = new void*[datasize * count]; 
47    }
48
49    MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
50
51   
52    if(is_master)
53    { 
54      ::MPI_Allreduce(MPI_IN_PLACE, local_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm.mpi_comm));
55    }
56
57
58    // master have reduced data
59    // local scatterv
60
61    std::vector<int> local_recvcounts(num_ep, 0);
62    std::vector<int>local_displs(num_ep, 0);
63
64    int my_recvcount = recvcounts[ep_rank];
65    MPI_Gather_local(&my_recvcount, 1, MPI_INT, local_recvcounts.data(), 0, comm);
66    MPI_Bcast_local(local_recvcounts.data(), num_ep, MPI_INT, 0, comm);
67
68    int my_displs = std::accumulate(recvcounts, recvcounts+ep_rank, 0);
69    MPI_Gather_local(&my_displs, 1, MPI_INT, local_displs.data(), 0, comm);
70    MPI_Bcast_local(local_displs.data(), num_ep, MPI_INT, 0, comm);
71
72   
73
74    MPI_Scatterv_local(local_recvbuf, local_recvcounts.data(), local_displs.data(), datatype, recvbuf, recvcounts[ep_rank], datatype, 0, comm);
75
76    if(is_master)
77    {
78      delete[] local_recvbuf;
79    }
80
81  }
82
83 
84}
85
Note: See TracBrowser for help on using the repository browser.