#include "ep_lib.hpp" #include #include "ep_declaration.hpp" #include "ep_mpi.hpp" using namespace std; extern std::map, MPI_Group* > * tag_group_map; extern std::map > > * tag_comm_map; extern MPI_Group MPI_GROUP_WORLD; namespace ep_lib { int MPI_Intercomm_create(MPI_Comm local_comm, int local_leader, MPI_Comm peer_comm, int remote_leader, int tag, MPI_Comm *newintercomm) { assert(local_comm->is_ep); int ep_rank, ep_rank_loc, mpi_rank; int ep_size, num_ep, mpi_size; ep_rank = local_comm->ep_comm_ptr->size_rank_info[0].first; ep_rank_loc = local_comm->ep_comm_ptr->size_rank_info[1].first; mpi_rank = local_comm->ep_comm_ptr->size_rank_info[2].first; ep_size = local_comm->ep_comm_ptr->size_rank_info[0].second; num_ep = local_comm->ep_comm_ptr->size_rank_info[1].second; mpi_size = local_comm->ep_comm_ptr->size_rank_info[2].second; int world_rank_and_num_ep[2]; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank_and_num_ep[0]); world_rank_and_num_ep[1] = num_ep; int remote_mpi_size; int remote_ep_size; int *local_world_rank_and_num_ep; int *remote_world_rank_and_num_ep; int *summed_world_rank_and_num_ep; bool is_leader = ep_rank==local_leader? true : false; bool is_local_leader = is_leader? true: (ep_rank_loc==0 && mpi_rank!=local_comm->ep_rank_map->at(local_leader).second ? true : false); bool priority; if(is_leader) { int leader_mpi_rank_in_peer; MPI_Comm_rank(peer_comm, &leader_mpi_rank_in_peer); if(leader_mpi_rank_in_peer == remote_leader) { printf("same leader in peer_comm\n"); exit(1); } priority = leader_mpi_rank_in_peermpi_comm)); } if(is_leader) { MPI_Request request; MPI_Status status; if(priority) { MPI_Isend(&mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Irecv(&remote_mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Isend(&ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Irecv(&remote_ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } else { MPI_Irecv(&remote_mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Isend(&mpi_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Irecv(&remote_ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Isend(&ep_size, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } } MPI_Bcast(&remote_mpi_size, 1, MPI_INT, local_leader, local_comm); MPI_Bcast(&remote_ep_size, 1, MPI_INT, local_leader, local_comm); remote_world_rank_and_num_ep = new int[2*remote_mpi_size]; if(is_leader) { MPI_Request request; MPI_Status status; if(priority) { MPI_Isend(local_world_rank_and_num_ep, 2*mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Irecv(remote_world_rank_and_num_ep, 2*remote_mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } else { MPI_Irecv(remote_world_rank_and_num_ep, 2*remote_mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Isend(local_world_rank_and_num_ep, 2*mpi_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } } MPI_Bcast(remote_world_rank_and_num_ep, 2*remote_mpi_size, MPI_INT, local_leader, local_comm); bool is_new_leader = is_local_leader; if(is_local_leader && !priority) { for(int i=0; i, ::MPI_Group * >; tag_group_map->insert(std::make_pair(std::make_pair(tag, priority? 1 : 2), local_group)); } } MPI_Barrier(local_comm); if(is_leader) { MPI_Request request; MPI_Status status; int send_signal=0; int recv_signal; if(priority) { MPI_Isend(&send_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Irecv(&recv_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } else { MPI_Irecv(&recv_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Isend(&send_signal, 1, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } } MPI_Barrier(local_comm); if(is_new_leader) { ::MPI_Group *group1; ::MPI_Group *group2; empty_group = new ::MPI_Group; *empty_group = MPI_GROUP_EMPTY; #pragma omp flush #pragma omp critical (read_from_tag_group_map) { group1 = tag_group_map->find(make_pair(tag, 1)) != tag_group_map->end()? tag_group_map->at(std::make_pair(tag, 1)) : empty_group; group2 = tag_group_map->find(make_pair(tag, 2)) != tag_group_map->end()? tag_group_map->at(std::make_pair(tag, 2)) : empty_group; } #ifdef _showinfo int group1_rank, group1_size; int group2_rank, group2_size; ::MPI_Group_rank(*group1, &group1_rank); ::MPI_Group_size(*group1, &group1_size); ::MPI_Group_rank(*group2, &group2_rank); ::MPI_Group_size(*group2, &group2_size); #endif ::MPI_Group_union(*group1, *group2, &union_group); #pragma omp critical (read_from_tag_group_map) { tag_group_map->erase(make_pair(tag, 1)); tag_group_map->erase(make_pair(tag, 2)); } #ifdef _showinfo int group_rank, group_size; ::MPI_Group_rank(union_group, &group_rank); ::MPI_Group_size(union_group, &group_size); printf("rank = %d : map = %p, group1_rank/size = %d/%d, group2_rank/size = %d/%d, union_rank/size = %d/%d\n", ep_rank, tag_group_map, group1_rank, group1_size, group2_rank, group2_size, group_rank, group_size); #endif } int summed_world_rank_and_num_ep_size=mpi_size; summed_world_rank_and_num_ep = new int[2*(mpi_size+remote_mpi_size)]; if(is_leader) { for(int i=0; impi_comm), union_group, tag, &mpi_comm); MPI_Comm *ep_comm; MPI_Info info; MPI_Comm_create_endpoints(&mpi_comm, new_num_ep, info, ep_comm); #pragma omp critical (write_to_tag_comm_map) { if(tag_comm_map == 0) tag_comm_map = new std::map > >; tag_comm_map->insert(std::make_pair(tag, std::make_pair(ep_comm, std::make_pair(new_num_ep, 0)))); } #pragma omp flush } bool found=false; while(!found) { #pragma omp flush #pragma omp critical (read_from_tag_comm_map) { if(tag_comm_map!=0) { if(tag_comm_map->find(tag) != tag_comm_map->end()) { *newintercomm = tag_comm_map->at(tag).first[new_ep_rank_loc]; tag_comm_map->at(tag).second.second++; if(tag_comm_map->at(tag).second.second == tag_comm_map->at(tag).second.first) { tag_comm_map->erase(tag_comm_map->find(tag)); } found=true; } } } } (*newintercomm)->is_intercomm = true; (*newintercomm)->inter_rank_map = new INTER_RANK_MAP; int rank_info[2]; rank_info[0] = ep_rank; rank_info[1] = (*newintercomm)->ep_comm_ptr->size_rank_info[0].first; #ifdef _showinfo printf("priority = %d, ep_rank = %d, new_ep_rank = %d\n", priority, rank_info[0], rank_info[1]); #endif int *local_rank_info = new int[2*ep_size]; int *remote_rank_info = new int[2*remote_ep_size]; MPI_Allgather(rank_info, 2, MPI_INT, local_rank_info, 2, MPI_INT, local_comm); if(is_leader) { MPI_Request request; MPI_Status status; if(priority) { MPI_Isend(local_rank_info, 2*ep_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Irecv(remote_rank_info, 2*remote_ep_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } else { MPI_Irecv(remote_rank_info, 2*remote_ep_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); MPI_Isend(local_rank_info, 2*ep_size, MPI_INT, remote_leader, tag, peer_comm, &request); MPI_Wait(&request, &status); } } MPI_Bcast(remote_rank_info, 2*remote_ep_size, MPI_INT, local_leader, local_comm); for(int i=0; iinter_rank_map->insert(make_pair(remote_rank_info[2*i], remote_rank_info[2*i+1])); } #ifdef _showinfo if(ep_rank==4 && !priority) { for(std::map :: iterator it=(*newintercomm)->inter_rank_map->begin(); it != (*newintercomm)->inter_rank_map->end(); it++) { printf("inter_rank_map[%d] = %d\n", it->first, it->second); } } #endif (*newintercomm)->ep_comm_ptr->size_rank_info[0] = local_comm->ep_comm_ptr->size_rank_info[0]; if(is_local_leader) { delete[] local_world_rank_and_num_ep; MPI_Group_free(local_group); delete local_group; } if(is_new_leader) { MPI_Group_free(&union_group); delete empty_group; } delete[] remote_world_rank_and_num_ep; delete[] summed_world_rank_and_num_ep; delete[] local_rank_info; delete[] remote_rank_info; } }