source: XIOS/dev/branch_openmp/extern/ep_dev/ep_split.cpp @ 1504

Last change on this file since 1504 was 1504, checked in by yushan, 6 years ago

MPI_split can deal with discontinuous ranking within a process

File size: 8.4 KB
Line 
1#include "ep_lib.hpp"
2#include <mpi.h>
3#include "ep_declaration.hpp"
4#include "ep_mpi.hpp"
5
6using namespace std;
7
8namespace ep_lib
9{
10
11  void vec_simplify(std::vector<int> *inout_vector)
12  {
13    std::vector<int> out_vec;
14    int found=false;
15    for(std::vector<int>::iterator it_in = inout_vector->begin() ; it_in != inout_vector->end(); ++it_in)
16    {
17      for(std::vector<int>::iterator it = out_vec.begin() ; it != out_vec.end(); ++it)
18      {
19        if(*it_in == *it)
20        {
21          found=true;
22          break;
23        }
24        else found=false;
25      }
26      if(found == false)
27      {
28        out_vec.push_back(*it_in);
29      }
30    }
31    inout_vector->swap(out_vec);
32  }
33 
34  void vec_simplify(std::vector<int> *in_vector, std::vector<int> *out_vector)
35  {
36    int found=false;
37    for(std::vector<int>::iterator it_in = in_vector->begin() ; it_in != in_vector->end(); ++it_in)
38    {
39      for(std::vector<int>::iterator it = out_vector->begin() ; it != out_vector->end(); ++it)
40      {
41        if(*it_in == *it)
42        {
43          found=true;
44          break;
45        }
46        else found=false;
47      }
48      if(found == false)
49      {
50        out_vector->push_back(*it_in);
51      }
52    }
53  }
54
55
56
57  int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm *newcomm)
58  {
59    int ep_rank, ep_rank_loc, mpi_rank;
60    int ep_size, num_ep, mpi_size;
61
62    ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
63    ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
64    mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
65    ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
66    num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
67    mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
68
69    int num_color = 0;
70
71    int color_index;
72
73    vector<int> matched_number;
74    vector<int> matched_number_loc;
75   
76    vector<int> all_color(ep_size);
77    vector<int> all_color_loc(num_ep);
78
79    MPI_Allgather(&color, 1, MPI_INT, all_color.data(), 1, MPI_INT, comm);
80    MPI_Allgather_local(&color, 1, MPI_INT, all_color_loc.data(), comm);
81
82    list<int> color_list(all_color.begin(), all_color.end());
83    list<int> color_list_loc(all_color_loc.begin(), all_color_loc.end());
84
85    vector<int> all_color_simplified;
86    vec_simplify(&all_color, &all_color_simplified);
87    int number_of_color;
88    for(int i=0; i<all_color_simplified.size(); i++)
89    {
90      if(color == all_color_simplified[i])
91      {
92        number_of_color = i;
93        break;
94      }
95    }
96
97    matched_number.resize(all_color_simplified.size(), 0);
98    matched_number_loc.resize(all_color_simplified.size(), 0);
99
100
101    while(!color_list.empty())
102    {
103      int target_color = color_list.front();
104      for(list<int>::iterator it = color_list.begin(); it != color_list.end(); ++it)
105      {
106        if(*it == target_color)
107        {
108          matched_number[num_color]++;
109        }
110      }
111      for(list<int>::iterator it = color_list_loc.begin(); it != color_list_loc.end(); ++it)
112      {
113        if(*it == target_color)
114        {
115          matched_number_loc[num_color]++;
116        }
117      }
118      color_list.remove(target_color);
119      color_list_loc.remove(target_color);
120      num_color++;
121    }
122       
123
124    vector<int> all_key(ep_size);
125    vector<int> all_key_loc(num_ep);
126   
127    vector<int> colored_key[num_color];
128    vector<int> colored_key_loc[num_color];
129   
130
131    MPI_Allgather(&key, 1, MPI_INT, all_key.data(),1, MPI_INT, comm);
132    MPI_Allgather_local(&key, 1, MPI_INT, all_key_loc.data(), comm);
133
134    for(int i=0; i<num_ep; i++)
135    {
136      for(int j = 0; j<num_color; j++)
137      {
138        if(all_color_loc[i] == all_color_simplified[j])
139        {
140          colored_key_loc[j].push_back(all_key_loc[i]);
141        }
142      }
143    }
144   
145    for(int i=0; i<ep_size; i++)
146    {
147      for(int j = 0; j<num_color; j++)
148      {
149        if(all_color[i] == all_color_simplified[j])
150        {
151          colored_key[j].push_back(all_key[i]);
152        }
153      }
154    }
155   
156    for(int i=0; i<num_color; i++)
157    {
158      std::sort(colored_key[i].begin(), colored_key[i].end());
159      std::sort(colored_key_loc[i].begin(), colored_key_loc[i].end());
160    }
161
162    int new_ep_rank;
163   
164    for(int i=0; i<colored_key[number_of_color].size(); i++)
165    {
166      if(key == colored_key[number_of_color][i])
167      {
168        new_ep_rank = i;
169        break;
170      }
171    }
172   
173    int new_ep_rank_loc;
174   
175    for(int i=0; i<colored_key_loc[number_of_color].size(); i++)
176    {
177      if(key == colored_key_loc[number_of_color][i])
178      {
179        new_ep_rank_loc = i;
180        break;
181      }
182    }
183   
184
185    ::MPI_Comm **split_mpi_comm;
186    split_mpi_comm = new ::MPI_Comm* [num_color];
187    for(int ii=0; ii<num_color; ii++)
188      split_mpi_comm[ii] = new ::MPI_Comm;
189
190    for(int j=0; j<num_color; j++)
191    {
192      if(ep_rank_loc == 0)
193      {
194        int master_color = 1;
195        if(matched_number_loc[j] == 0) master_color = MPI_UNDEFINED;
196
197        ::MPI_Comm_split(to_mpi_comm(comm->mpi_comm), master_color, mpi_rank, split_mpi_comm[j]);
198       
199        comm->ep_comm_ptr->comm_list[0]->mpi_bridge = split_mpi_comm[j];
200      }
201     
202      MPI_Barrier_local(comm);
203     
204      int num_new_ep = 0;
205
206      if(new_ep_rank_loc == 0 && color == all_color_simplified[j])
207      {
208        num_new_ep = matched_number_loc[j];
209        MPI_Info info;
210        MPI_Comm *ep_comm;
211
212        MPI_Comm_create_endpoints(comm->ep_comm_ptr->comm_list[0]->mpi_bridge, num_new_ep, info, ep_comm);
213
214        comm->ep_comm_ptr->comm_list[0]->mem_bridge = ep_comm;
215       
216        (*ep_comm)->ep_rank_map->clear();
217       
218        memcheck("in MPI_Split ep_rank="<< ep_rank <<" : *ep_comm = "<< *ep_comm);
219      }
220     
221      MPI_Barrier_local(comm);
222     
223      if(color == all_color_simplified[j])
224      {
225        *newcomm = comm->ep_comm_ptr->comm_list[0]->mem_bridge[new_ep_rank_loc];
226        memcheck("in MPI_Split ep_rank="<< ep_rank <<" : *newcomm = "<< *newcomm);
227
228        (*newcomm)->ep_comm_ptr->comm_label = color;
229       
230        (*newcomm)->ep_comm_ptr->size_rank_info[0].first = new_ep_rank;
231        (*newcomm)->ep_comm_ptr->size_rank_info[1].first = new_ep_rank_loc;
232       
233        int my_triple[3];
234        vector<int> my_triple_vector;
235        vector<int> my_triple_vector_recv;
236        my_triple[0] = new_ep_rank;
237        my_triple[1] = new_ep_rank_loc;
238        my_triple[2] = (*newcomm)->ep_comm_ptr->size_rank_info[2].first; // new_mpi_rank
239       
240        int new_ep_size = (*newcomm)->ep_comm_ptr->size_rank_info[0].second;
241        int new_num_ep  = (*newcomm)->ep_comm_ptr->size_rank_info[1].second;
242       
243        int new_mpi_size = (*newcomm)->ep_comm_ptr->size_rank_info[2].second;
244       
245        if(new_ep_rank_loc == 0) my_triple_vector.resize(3*new_ep_size);
246        if(new_ep_rank_loc == 0) my_triple_vector_recv.resize(3*new_ep_size);
247       
248        MPI_Gather_local(my_triple, 3, MPI_INT, my_triple_vector.data(), 0, *newcomm);
249       
250        if(new_ep_rank_loc == 0)
251        {
252          int *recvcounts = new int[new_mpi_size];
253          int *displs = new int[new_mpi_size];
254          int new_num_epx3 = new_num_ep * 3;
255          ::MPI_Allgather(&new_num_epx3, 1, to_mpi_type(MPI_INT), recvcounts, 1, to_mpi_type(MPI_INT), to_mpi_comm((*newcomm)->mpi_comm));
256          displs[0]=0;
257          for(int i=1; i<new_mpi_size; i++)
258            displs[i] = displs[i-1] + recvcounts[i-1];
259             
260          ::MPI_Allgatherv(my_triple_vector.data(), 3*new_num_ep, to_mpi_type(MPI_INT), my_triple_vector_recv.data(), recvcounts, displs, to_mpi_type(MPI_INT), to_mpi_comm((*newcomm)->mpi_comm));
261         
262          for(int i=0; i<new_ep_size; i++)
263          {
264            (*newcomm)->ep_comm_ptr->comm_list[0]->ep_rank_map->insert(std::pair< int, std::pair<int,int> >(my_triple_vector_recv[3*i], my_triple_vector_recv[3*i+1], my_triple_vector_recv[3*i+2]));
265          }
266         
267          (*newcomm)->ep_rank_map = (*newcomm)->ep_comm_ptr->comm_list[0]->ep_rank_map;
268         
269          delete recvcounts;
270          delete displs;
271        } 
272      }
273    }
274   
275    /*for(int i=0; i<ep_size; i++)
276    {
277      MPI_Barrier(comm);
278      MPI_Barrier(comm);
279      if(ep_rank==i)
280      {
281        printf("ep_rank_map for endpoint %d = \n", ep_rank);
282        for(std::map<int, std::pair<int, int> > :: iterator it = (*newcomm)->ep_rank_map->begin(); it != (*newcomm)->ep_rank_map->end(); it++)
283        {
284          printf("\t\t\t %d %d %d\n", it->first, it->second.first, it->second.second);
285        }
286        printf("\n");
287      }
288      MPI_Barrier(comm);
289      MPI_Barrier(comm);
290    }*/   
291   
292    return 0;
293  }
294
295}
Note: See TracBrowser for help on using the repository browser.