/*! \file ep_scan.cpp \since 2 may 2016 \brief Definitions of MPI collective function: MPI_Scan */ #include "ep_lib.hpp" #include #include "ep_declaration.hpp" #include "ep_mpi.hpp" using namespace std; namespace ep_lib { template T max_op(T a, T b) { return max(a,b); } template T min_op(T a, T b) { return min(a,b); } template void reduce_max(const T * buffer, T* recvbuf, int count) { transform(buffer, buffer+count, recvbuf, recvbuf, max_op); } template void reduce_min(const T * buffer, T* recvbuf, int count) { transform(buffer, buffer+count, recvbuf, recvbuf, min_op); } template void reduce_sum(const T * buffer, T* recvbuf, int count) { transform(buffer, buffer+count, recvbuf, recvbuf, std::plus()); } int MPI_Scan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) { valid_op(op); int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first; int num_ep = comm->ep_comm_ptr->size_rank_info[1].second; int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first; ::MPI_Aint datasize, lb; ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); if(ep_rank_loc == 0 && mpi_rank != 0) { if(op == MPI_SUM) { if(datatype == MPI_INT) { assert(datasize == sizeof(int)); reduce_sum(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_FLOAT) { assert( datasize == sizeof(float)); reduce_sum(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_DOUBLE ) { assert( datasize == sizeof(double)); reduce_sum(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_CHAR) { assert( datasize == sizeof(char)); reduce_sum(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_LONG) { assert( datasize == sizeof(long)); reduce_sum(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_UNSIGNED_LONG) { assert(datasize == sizeof(unsigned long)); reduce_sum(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else printf("datatype Error\n"); } else if(op == MPI_MAX) { if(datatype == MPI_INT) { assert( datasize == sizeof(int)); reduce_max(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_FLOAT ) { assert( datasize == sizeof(float)); reduce_max(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_DOUBLE ) { assert( datasize == sizeof(double)); reduce_max(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_CHAR ) { assert(datasize == sizeof(char)); reduce_max(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_LONG) { assert( datasize == sizeof(long)); reduce_max(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_UNSIGNED_LONG) { assert( datasize == sizeof(unsigned long)); reduce_max(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else printf("datatype Error\n"); } else //(op == MPI_MIN) { if(datatype == MPI_INT ) { assert (datasize == sizeof(int)); reduce_min(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_FLOAT ) { assert( datasize == sizeof(float)); reduce_min(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_DOUBLE ) { assert( datasize == sizeof(double)); reduce_min(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_CHAR ) { assert( datasize == sizeof(char)); reduce_min(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_LONG ) { assert( datasize == sizeof(long)); reduce_min(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else if(datatype == MPI_UNSIGNED_LONG ) { assert( datasize == sizeof(unsigned long)); reduce_min(static_cast(const_cast(sendbuf)), static_cast(recvbuf), count); } else printf("datatype Error\n"); } comm->my_buffer->void_buffer[0] = recvbuf; } else { comm->my_buffer->void_buffer[ep_rank_loc] = const_cast(sendbuf); memcpy(recvbuf, sendbuf, datasize*count); } MPI_Barrier_local(comm); memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count); if(op == MPI_SUM) { if(datatype == MPI_INT ) { assert (datasize == sizeof(int)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_FLOAT ) { assert(datasize == sizeof(float)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_DOUBLE ) { assert(datasize == sizeof(double)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_CHAR ) { assert(datasize == sizeof(char)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_LONG ) { assert(datasize == sizeof(long)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_UNSIGNED_LONG ) { assert(datasize == sizeof(unsigned long)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else printf("datatype Error\n"); } else if(op == MPI_MAX) { if(datatype == MPI_INT) { assert(datasize == sizeof(int)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_FLOAT ) { assert(datasize == sizeof(float)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_DOUBLE ) { assert(datasize == sizeof(double)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_CHAR ) { assert(datasize == sizeof(char)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_LONG ) { assert(datasize == sizeof(long)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_UNSIGNED_LONG ) { assert(datasize == sizeof(unsigned long)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else printf("datatype Error\n"); } else //if(op == MPI_MIN) { if(datatype == MPI_INT ) { assert(datasize == sizeof(int)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_FLOAT ) { assert(datasize == sizeof(float)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_DOUBLE ) { assert(datasize == sizeof(double)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_CHAR ) { assert(datasize == sizeof(char)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_LONG ) { assert(datasize == sizeof(long)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else if(datatype == MPI_UNSIGNED_LONG ) { assert(datasize == sizeof(unsigned long)); for(int i=1; i(static_cast(comm->my_buffer->void_buffer[i]), static_cast(recvbuf), count); } else printf("datatype Error\n"); } MPI_Barrier_local(comm); } int MPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) { if(!comm->is_ep) return ::MPI_Scan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm)); if(comm->is_intercomm) return MPI_Scan_intercomm(sendbuf, recvbuf, count, datatype, op, comm); valid_type(datatype); int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first; int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first; int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first; int ep_size = comm->ep_comm_ptr->size_rank_info[0].second; int num_ep = comm->ep_comm_ptr->size_rank_info[1].second; int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second; ::MPI_Aint datasize, lb; ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); void* tmp_sendbuf; tmp_sendbuf = new void*[datasize * count]; int my_src = 0; int my_dst = ep_rank; std::vector my_map(mpi_size, 0); for(int i=0; iep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++; for(int i=0; impi_comm)); //printf(" ID=%d : %d %d \n", ep_rank, static_cast(tmp_recvbuf)[0], static_cast(tmp_recvbuf)[1]); MPI_Scan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm); // printf(" ID=%d : after local tmp_sendbuf = %d %d ; tmp_recvbuf = %d %d \n", ep_rank, static_cast(tmp_sendbuf)[0], static_cast(tmp_sendbuf)[1], static_cast(tmp_recvbuf)[0], static_cast(tmp_recvbuf)[1]); if(ep_rank != my_src) { MPI_Request request[2]; MPI_Status status[2]; MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src, comm, &request[0]); MPI_Irecv(recvbuf, count, datatype, my_dst, ep_rank, comm, &request[1]); MPI_Waitall(2, request, status); } else memcpy(recvbuf, tmp_recvbuf, datasize*count); delete[] tmp_sendbuf; delete[] tmp_recvbuf; } int MPI_Scan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) { printf("MPI_Scan_intercomm not yet implemented\n"); MPI_Abort(comm, 0); } }