#include "ep_lib.hpp" #include #include "ep_declaration.hpp" #include "ep_mpi.hpp" namespace ep_lib { int MPI_Iprobe_mpi(int src, int tag, MPI_Comm comm, int *flag, MPI_Status *status) { ::MPI_Status mpi_status; ::MPI_Iprobe(src<0? MPI_ANY_SOURCE : src, tag<0? MPI_ANY_TAG: tag, to_mpi_comm(comm->mpi_comm), flag, &mpi_status); status->mpi_status = new ::MPI_Status(mpi_status); status->ep_src = src; status->ep_tag = tag; } int MPI_Improbe_mpi(int src, int tag, MPI_Comm comm, int *flag, MPI_Message *message, MPI_Status *status) { ::MPI_Status mpi_status; ::MPI_Message mpi_message; #ifdef _openmpi #pragma omp critical (_mpi_call) { ::MPI_Iprobe(src<0? MPI_ANY_SOURCE : src, tag<0? MPI_ANY_TAG: tag, to_mpi_comm(comm->mpi_comm), flag, &mpi_status); if(*flag) { ::MPI_Mprobe(src<0? MPI_ANY_SOURCE : src, tag<0? MPI_ANY_TAG: tag, to_mpi_comm(comm->mpi_comm), &mpi_message, &mpi_status); } } #elif _intelmpi ::MPI_Improbe(src<0? MPI_ANY_SOURCE : src, tag<0? MPI_ANY_TAG: tag, to_mpi_comm(comm->mpi_comm), flag, &mpi_message, &mpi_status); #endif status->mpi_status = new ::MPI_Status(mpi_status); status->ep_src = src; status->ep_tag = tag; (*message)->mpi_message = &message; (*message)->ep_src = src; (*message)->ep_tag = tag; } int MPI_Iprobe(int src, int tag, MPI_Comm comm, int *flag, MPI_Status *status) { if(!comm->is_ep) { Debug("MPI_Iprobe with MPI\n"); return MPI_Iprobe_mpi(src, tag, comm, flag, status); } if(comm->is_intercomm) { if(src>=0) src = comm->inter_rank_map->at(src); } return MPI_Iprobe_endpoint(src, tag, comm, flag, status); } int MPI_Iprobe_endpoint(int src, int tag, MPI_Comm comm, int *flag, MPI_Status *status) { Debug("MPI_Iprobe with EP\n"); *flag = false; Message_Check(comm); #pragma omp flush #pragma omp critical (_query) for(Message_list::iterator it = comm->ep_comm_ptr->message_queue->begin(); it!= comm->ep_comm_ptr->message_queue->end(); ++it) { bool src_matched = src<0? true: (*it)->ep_src == src; bool tag_matched = tag<0? true: (*it)->ep_tag == tag; if(src_matched && tag_matched) { Debug("find message\n"); status->mpi_status = new ::MPI_Status(*static_cast< ::MPI_Status*>((*it)->mpi_status)); status->ep_src = (*it)->ep_src; status->ep_tag = (*it)->ep_tag; if(comm->is_intercomm) { for(INTER_RANK_MAP::iterator iter = comm->inter_rank_map->begin(); iter != comm->inter_rank_map->end(); iter++) { if(iter->second == (*it)->ep_src) status->ep_src=iter->first; } } *flag = true; break; } } } int MPI_Improbe(int src, int tag, MPI_Comm comm, int *flag, MPI_Message *message, MPI_Status *status) { if(!comm->is_ep) { Debug("MPI_Iprobe with MPI\n"); return MPI_Improbe_mpi(src, tag, comm, flag, message, status); } if(comm->is_intercomm) { src = comm->inter_rank_map->at(src); *message = new ep_message; printf("============= new *message = %p\n", *message); } return MPI_Improbe_endpoint(src, tag, comm, flag, message, status); } int MPI_Improbe_endpoint(int src, int tag, MPI_Comm comm, int *flag, MPI_Message *message, MPI_Status *status) { int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first; int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first; *flag = false; Message_Check(comm); #pragma omp flush #pragma omp critical (_query) if(! comm->ep_comm_ptr->message_queue->empty()) { for(Message_list::iterator it = comm->ep_comm_ptr->message_queue->begin(); it!= comm->ep_comm_ptr->message_queue->end(); ++it) { bool src_matched = src<0? true: (*it)->ep_src == src; bool tag_matched = tag<0? true: (*it)->ep_tag == tag; if(src_matched && tag_matched) { *flag = true; status->mpi_status = new ::MPI_Status(*static_cast< ::MPI_Status*>((*it)->mpi_status)); memcheck("new "<< status->mpi_status << " : in ep_lib::MPI_Improbe, status->mpi_status = new ::MPI_Status"); status->ep_src = (*it)->ep_src; status->ep_tag = (*it)->ep_tag; (*message)->mpi_message = new ::MPI_Message(*static_cast< ::MPI_Message*>((*it)->mpi_message)); memcheck("new "<< (*message)->mpi_message <<" : in ep_lib::MPI_Improbe, (*message)->mpi_message = new ::MPI_Message"); (*message)->ep_src = (*it)->ep_src; (*message)->ep_tag = (*it)->ep_tag; #pragma omp critical (_query2) { memcheck("delete "<< (*it)->mpi_message <<" : in ep_lib::Message_Check, delete (*it)->mpi_message"); memcheck("delete "<< (*it)->mpi_status <<" : in ep_lib::Message_Check, delete (*it)->mpi_status"); memcheck("delete "<< (*it) <<" : in ep_lib::Message_Check, delete (*it)"); delete (*it)->mpi_message; delete (*it)->mpi_status; delete *it; comm->ep_comm_ptr->message_queue->erase(it); memcheck("message_queue["<size = "<ep_comm_ptr->message_queue->size()); #pragma omp flush } break; } } } } }