source: XIOS/dev/branch_openmp/extern/ep_dev/ep_reduce_scatter.cpp @ 1527

Last change on this file since 1527 was 1527, checked in by yushan, 6 years ago

save dev

File size: 2.6 KB
Line 
1/*!
2   \file ep_reduce.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Reduce, MPI_Allreduce
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11#include "ep_mpi.hpp"
12
13using namespace std;
14
15
16namespace ep_lib {
17
18
19  int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
20  {
21    if(!comm->is_ep) return ::MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
22    if(comm->is_intercomm) return MPI_Reduce_scatter_intercomm(sendbuf, recvbuf, recvcounts, datatype, op, comm);
23
24
25    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
26    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
27    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
28    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
29    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
30    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
31
32
33    ::MPI_Aint datasize, lb;
34
35    ::MPI_Type_get_extent(*(static_cast< ::MPI_Datatype*>(datatype)), &lb, &datasize);
36
37    bool is_master = ep_rank_loc==0;
38
39    void* local_recvbuf;
40
41    int count = accumulate(recvcounts, recvcounts+ep_size, 0);
42    if(is_master)
43    {
44      local_recvbuf = new void*[datasize * count]; 
45    }
46
47    MPI_Reduce_local(sendbuf, local_recvbuf, count, datatype, op, 0, comm);
48
49   
50    if(is_master)
51    { 
52      ::MPI_Allreduce(MPI_IN_PLACE, local_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
53    }
54
55
56    // master have reduced data
57    // local scatterv
58
59    std::vector<int> local_recvcounts(num_ep, 0);
60    std::vector<int>local_displs(num_ep, 0);
61
62    int my_recvcount = recvcounts[ep_rank];
63    MPI_Gather_local(&my_recvcount, 1, MPI_INT, local_recvcounts.data(), 0, comm);
64    MPI_Bcast_local(local_recvcounts.data(), num_ep, MPI_INT, 0, comm);
65
66    int my_displs = std::accumulate(recvcounts, recvcounts+ep_rank, 0);
67    MPI_Gather_local(&my_displs, 1, MPI_INT, local_displs.data(), 0, comm);
68    MPI_Bcast_local(local_displs.data(), num_ep, MPI_INT, 0, comm);
69
70   
71
72    MPI_Scatterv_local(local_recvbuf, local_recvcounts.data(), local_displs.data(), datatype, recvbuf, recvcounts[ep_rank], datatype, 0, comm);
73
74    if(is_master)
75    {
76      delete[] local_recvbuf;
77    }
78
79  }
80
81  int MPI_Reduce_scatter_intercomm(const void *sendbuf, void *recvbuf, const int recvcounts[], MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
82  {
83    printf("MPI_Reduce_scatter_intercomm not yet implemented\n");
84    MPI_Abort(comm, 0);
85  }
86
87 
88}
89
Note: See TracBrowser for help on using the repository browser.