Ignore:
Timestamp:
10/04/17 17:02:13 (7 years ago)
Author:
yushan
Message:

EP update part 2

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_alltoall.cpp

    r1287 r1289  
    11#include "ep_lib.hpp" 
    22#include <mpi.h> 
    3 #include "ep_mpi.hpp" 
    4  
     3//#include "ep_declaration.hpp" 
    54 
    65namespace ep_lib 
    76{ 
    87 
     8 
     9 
    910  int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm) 
    1011  { 
    11     if(!comm.is_ep) 
    12     { 
    13       return ::MPI_Alltoall(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm)); 
    14     } 
     12    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype)); 
     13    ::MPI_Aint typesize, llb; 
     14    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &llb, &typesize); 
     15     
     16    int ep_size; 
     17    MPI_Comm_size(comm, &ep_size); 
     18     
    1519 
    16  
    17     assert(valid_type(sendtype) && valid_type(recvtype)); 
    18     assert(sendcount == recvcount); 
    19  
    20     ::MPI_Aint datasize, llb; 
    21     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &llb, &datasize); 
    22  
    23     int count = sendcount; 
    24      
    25     int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    26     int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    27     int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    28     int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    29     int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    30     int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    31  
    32     void* tmp_recvbuf; 
    33     if(ep_rank == 0) tmp_recvbuf = new void*[count * ep_size * ep_size * datasize]; 
    34  
    35     MPI_Gather(sendbuf, count*ep_size, sendtype, tmp_recvbuf, count*ep_size, recvtype, 0, comm); 
    36  
    37      
    38      
    39     // reorder tmp_buf 
    40     void* tmp_sendbuf; 
    41     if(ep_rank == 0) tmp_sendbuf = new void*[count * ep_size * ep_size * datasize]; 
    42  
    43     if(ep_rank == 0) 
    4420    for(int i=0; i<ep_size; i++) 
    4521    { 
    46       for(int j=0; j<ep_size; j++) 
    47       { 
    48         //printf("tmp_recv[%d] = tmp_send[%d]\n", i*ep_size*count + j*count, j*ep_size*count + i*count); 
    49  
    50         memcpy(tmp_sendbuf + j*ep_size*count*datasize + i*count*datasize, tmp_recvbuf + i*ep_size*count*datasize + j*count*datasize, count*datasize); 
    51       } 
     22      ep_lib::MPI_Gather(sendbuf+i*sendcount*typesize, sendcount, sendtype, recvbuf, recvcount, recvtype, i, comm); 
    5223    } 
    53  
    54     MPI_Scatter(tmp_sendbuf, ep_size*count, sendtype, recvbuf, ep_size*recvcount, recvtype, 0, comm); 
    55  
    56     if(ep_rank == 0) 
    57     { 
    58       delete[] tmp_recvbuf; 
    59       delete[] tmp_sendbuf; 
    60     } 
     24     
    6125 
    6226    return 0; 
Note: See TracChangeset for help on using the changeset viewer.