source: XIOS/trunk/src/transformation/domain_algorithm_interpolate.cpp @ 709

Last change on this file since 709 was 709, checked in by mhnguyen, 9 years ago

Correcting a bug in interpolation domain

+) Replace shared send buffer by seperate buffer for each proc
+) Remove some redundant codes

Test
+) On Curie
+) test_client, test_complete and test_remap pass

File size: 13.3 KB
Line 
1/*!
2   \file domain_algorithm_interpolate_from_file.cpp
3   \author Ha NGUYEN
4   \since 09 Jul 2015
5   \date 15 Sep 2015
6
7   \brief Algorithm for interpolation on a domain.
8 */
9#include "domain_algorithm_interpolate.hpp"
10#include <boost/unordered_map.hpp>
11#include "context.hpp"
12#include "context_client.hpp"
13#include "distribution_client.hpp"
14#include "client_server_mapping_distributed.hpp"
15#include "netcdf.hpp"
16#include "mapper.hpp"
17
18namespace xios {
19
20CDomainAlgorithmInterpolate::CDomainAlgorithmInterpolate(CDomain* domainDestination, CDomain* domainSource, CInterpolateDomain* interpDomain)
21: CDomainAlgorithmTransformation(domainDestination, domainSource), interpDomain_(interpDomain)
22{
23  interpDomain_->checkValid(domainSource);
24  computeIndexSourceMapping();
25}
26
27/*!
28  Compute remap with integrated remap calculation module
29*/
30void CDomainAlgorithmInterpolate::computeRemap()
31{
32  using namespace sphereRemap;
33
34  CContext* context = CContext::getCurrent();
35  CContextClient* client=context->client;
36  int clientRank = client->clientRank;
37  int i, j, k, idx;
38  std::vector<double> srcPole(3,0), dstPole(3,0);
39        int orderInterp = interpDomain_->order.getValue();
40
41  int constNVertex = 4; // Value by default number of vertex for rectangular domain
42  int nVertexSrc, nVertexDest;
43  nVertexSrc = nVertexDest = constNVertex;
44
45  // First of all, try to retrieve the boundary values of domain source and domain destination
46  int localDomainSrcSize = domainSrc_->i_index.numElements();
47  int niSrc = domainSrc_->ni.getValue(), njSrc = domainSrc_->nj.getValue();
48  bool hasBoundSrc = domainSrc_->hasBounds;
49  if (hasBoundSrc) nVertexSrc = domainSrc_->nvertex.getValue();
50  CArray<double,2> boundsLonSrc(nVertexSrc,localDomainSrcSize);
51  CArray<double,2> boundsLatSrc(nVertexSrc,localDomainSrcSize);
52
53  if (CDomain::type_attr::rectilinear == domainSrc_->type) srcPole[2] = 1;
54  if (hasBoundSrc)  // Suppose that domain source is curvilinear or unstructured
55  {
56    if (!domainSrc_->bounds_lon_2d.isEmpty())
57    {
58      for (j = 0; j < njSrc; ++j)
59        for (i = 0; i < niSrc; ++i)
60        {
61          k=j*niSrc+i;
62          for(int n=0;n<nVertexSrc;++n)
63          {
64            boundsLonSrc(n,k) = domainSrc_->bounds_lon_2d(n,i,j);
65            boundsLatSrc(n,k) = domainSrc_->bounds_lat_2d(n,i,j);
66          }
67        }
68    }
69    else
70    {
71      boundsLonSrc = domainSrc_->bounds_lon_1d;
72      boundsLatSrc = domainSrc_->bounds_lat_1d;
73    }
74  }
75  else // if domain source is rectilinear, not do anything now
76  {
77    nVertexSrc = constNVertex;
78    domainSrc_->fillInRectilinearBoundLonLat(boundsLonSrc, boundsLatSrc);
79  }
80
81  int localDomainDestSize = domainDest_->i_index.numElements();
82  int niDest = domainDest_->ni.getValue(), njDest = domainDest_->nj.getValue();
83  bool hasBoundDest = domainDest_->hasBounds;
84  if (hasBoundDest) nVertexDest = domainDest_->nvertex.getValue();
85  CArray<double,2> boundsLonDest(nVertexDest,localDomainDestSize);
86  CArray<double,2> boundsLatDest(nVertexDest,localDomainDestSize);
87
88  if (CDomain::type_attr::rectilinear == domainDest_->type) dstPole[2] = 1;
89  if (hasBoundDest)
90  {
91    if (!domainDest_->bounds_lon_2d.isEmpty())
92    {
93      for (j = 0; j < njDest; ++j)
94        for (i = 0; i < niDest; ++i)
95        {
96          k=j*niDest+i;
97          for(int n=0;n<nVertexDest;++n)
98          {
99            boundsLonDest(n,k) = domainDest_->bounds_lon_2d(n,i,j);
100            boundsLatDest(n,k) = domainDest_->bounds_lat_2d(n,i,j);
101          }
102        }
103    }
104    else
105    {
106      boundsLonDest = domainDest_->bounds_lon_1d;
107      boundsLatDest = domainDest_->bounds_lat_1d;
108    }
109  }
110  else
111  {
112    // Ok, fill in boundary values for rectangular domain
113    domainDest_->fillInRectilinearBoundLonLat(boundsLonDest, boundsLatDest);
114    nVertexDest = constNVertex;
115  }
116
117
118
119  // Ok, now use mapper to calculate
120  int nSrcLocal = domainSrc_->i_index.numElements();
121  int nDstLocal = domainDest_->i_index.numElements();
122  long int * globalSrc = new long int [nSrcLocal];
123  long int * globalDst = new long int [nDstLocal];
124
125  long int globalIndex;
126  int i_ind, j_ind;
127  for (int idx = 0; idx < nSrcLocal; ++idx)
128  {
129    i_ind=domainSrc_->i_index(idx) ;
130    j_ind=domainSrc_->j_index(idx) ;
131
132    globalIndex = i_ind + j_ind * domainSrc_->ni_glo;
133    globalSrc[idx] = globalIndex;
134  }
135
136  for (int idx = 0; idx < nDstLocal; ++idx)
137  {
138    i_ind=domainDest_->i_index(idx) ;
139    j_ind=domainDest_->j_index(idx) ;
140
141    globalIndex = i_ind + j_ind * domainDest_->ni_glo;
142    globalDst[idx] = globalIndex;
143  }
144
145
146  // Calculate weight index
147  Mapper mapper(client->intraComm);
148  mapper.setVerbosity(PROGRESS) ;
149  mapper.setSourceMesh(boundsLonSrc.dataFirst(), boundsLatSrc.dataFirst(), nVertexSrc, nSrcLocal, &srcPole[0], globalSrc);
150  mapper.setTargetMesh(boundsLonDest.dataFirst(), boundsLatDest.dataFirst(), nVertexDest, nDstLocal, &dstPole[0], globalDst);
151  std::vector<double> timings = mapper.computeWeights(orderInterp);
152
153  std::map<int,std::vector<std::pair<int,double> > > interpMapValue;
154  for (int idx = 0;  idx < mapper.nWeights; ++idx)
155  {
156    interpMapValue[mapper.targetWeightId[idx]].push_back(make_pair(mapper.sourceWeightId[idx],mapper.remapMatrix[idx]));
157  }
158  exchangeRemapInfo(interpMapValue);
159
160  delete [] globalSrc;
161  delete [] globalDst;
162}
163
164/*!
165  Compute the index mapping between domain on grid source and one on grid destination
166*/
167void CDomainAlgorithmInterpolate::computeIndexSourceMapping()
168{
169  if (!interpDomain_->file.isEmpty())
170    readRemapInfo();
171  else
172    computeRemap();
173}
174
175void CDomainAlgorithmInterpolate::readRemapInfo()
176{
177  CContext* context = CContext::getCurrent();
178  CContextClient* client=context->client;
179  int clientRank = client->clientRank;
180
181  std::string filename = interpDomain_->file.getValue();
182  std::map<int,std::vector<std::pair<int,double> > > interpMapValue;
183  readInterpolationInfo(filename, interpMapValue);
184
185  exchangeRemapInfo(interpMapValue);
186}
187
188
189/*!
190  Read remap information from file then distribute it among clients
191*/
192void CDomainAlgorithmInterpolate::exchangeRemapInfo(const std::map<int,std::vector<std::pair<int,double> > >& interpMapValue)
193{
194  CContext* context = CContext::getCurrent();
195  CContextClient* client=context->client;
196  int clientRank = client->clientRank;
197
198  boost::unordered_map<size_t,int> globalIndexOfDomainDest;
199  int ni = domainDest_->ni.getValue();
200  int nj = domainDest_->nj.getValue();
201  int ni_glo = domainDest_->ni_glo.getValue();
202  size_t globalIndex;
203  int nIndexSize = domainDest_->i_index.numElements(), i_ind, j_ind;
204  for (int idx = 0; idx < nIndexSize; ++idx)
205  {
206    i_ind=domainDest_->i_index(idx) ;
207    j_ind=domainDest_->j_index(idx) ;
208
209    globalIndex = i_ind + j_ind * ni_glo;
210    globalIndexOfDomainDest[globalIndex] = clientRank;
211  }
212
213  CClientServerMappingDistributed domainIndexClientClientMapping(globalIndexOfDomainDest,
214                                                                 client->intraComm,
215                                                                 true);
216  CArray<size_t,1> globalIndexInterp(interpMapValue.size());
217  std::map<int,std::vector<std::pair<int,double> > >::const_iterator itb = interpMapValue.begin(), it,
218                                                                     ite = interpMapValue.end();
219  size_t globalIndexCount = 0;
220  for (it = itb; it != ite; ++it)
221  {
222    globalIndexInterp(globalIndexCount) = it->first;
223    ++globalIndexCount;
224  }
225
226  domainIndexClientClientMapping.computeServerIndexMapping(globalIndexInterp);
227  const std::map<int, std::vector<size_t> >& globalIndexInterpSendToClient = domainIndexClientClientMapping.getGlobalIndexOnServer();
228
229  //Inform each client number of index they will receive
230  int nbClient = client->clientSize;
231  int* sendBuff = new int[nbClient];
232  int* recvBuff = new int[nbClient];
233  for (int i = 0; i < nbClient; ++i)
234  {
235    sendBuff[i] = 0;
236    recvBuff[i] = 0;
237  }
238  int sendBuffSize = 0;
239  std::map<int, std::vector<size_t> >::const_iterator itbMap = globalIndexInterpSendToClient.begin(), itMap,
240                                                      iteMap = globalIndexInterpSendToClient.end();
241  for (itMap = itbMap; itMap != iteMap; ++itMap)
242  {
243    const std::vector<size_t>& tmp = itMap->second;
244    int sizeIndex = 0, mapSize = (itMap->second).size();
245    for (int idx = 0; idx < mapSize; ++idx)
246    {
247      sizeIndex += interpMapValue.at((itMap->second)[idx]).size();
248    }
249    sendBuff[itMap->first] = sizeIndex;
250    sendBuffSize += sizeIndex;
251  }
252
253
254  MPI_Allreduce(sendBuff, recvBuff, nbClient, MPI_INT, MPI_SUM, client->intraComm);
255
256  int* sendIndexDestBuff = new int [sendBuffSize];
257  int* sendIndexSrcBuff  = new int [sendBuffSize];
258  double* sendWeightBuff = new double [sendBuffSize];
259
260  std::vector<MPI_Request> sendRequest;
261
262  int sendOffSet = 0, l = 0;
263  for (itMap = itbMap; itMap != iteMap; ++itMap)
264  {
265    const std::vector<size_t>& indexToSend = itMap->second;
266    int mapSize = indexToSend.size();
267    int k = 0;
268    for (int idx = 0; idx < mapSize; ++idx)
269    {
270      const std::vector<std::pair<int,double> >& interpMap = interpMapValue.at(indexToSend[idx]);
271      for (int i = 0; i < interpMap.size(); ++i)
272      {
273        sendIndexDestBuff[l] = indexToSend[idx];
274        sendIndexSrcBuff[l]  = interpMap[i].first;
275        sendWeightBuff[l]    = interpMap[i].second;
276        ++k;
277        ++l;
278      }
279    }
280
281    sendRequest.push_back(MPI_Request());
282    MPI_Isend(sendIndexDestBuff + sendOffSet,
283             k,
284             MPI_INT,
285             itMap->first,
286             7,
287             client->intraComm,
288             &sendRequest.back());
289    sendRequest.push_back(MPI_Request());
290    MPI_Isend(sendIndexSrcBuff + sendOffSet,
291             k,
292             MPI_INT,
293             itMap->first,
294             8,
295             client->intraComm,
296             &sendRequest.back());
297    sendRequest.push_back(MPI_Request());
298    MPI_Isend(sendWeightBuff + sendOffSet,
299             k,
300             MPI_DOUBLE,
301             itMap->first,
302             9,
303             client->intraComm,
304             &sendRequest.back());
305    sendOffSet += k;
306  }
307
308  int recvBuffSize = recvBuff[clientRank];
309  int* recvIndexDestBuff = new int [recvBuffSize];
310  int* recvIndexSrcBuff  = new int [recvBuffSize];
311  double* recvWeightBuff = new double [recvBuffSize];
312  int receivedSize = 0;
313  int clientSrcRank;
314  while (receivedSize < recvBuffSize)
315  {
316    MPI_Status recvStatus;
317    MPI_Recv((recvIndexDestBuff + receivedSize),
318             recvBuffSize,
319             MPI_INT,
320             MPI_ANY_SOURCE,
321             7,
322             client->intraComm,
323             &recvStatus);
324
325    int countBuff = 0;
326    MPI_Get_count(&recvStatus, MPI_INT, &countBuff);
327    clientSrcRank = recvStatus.MPI_SOURCE;
328
329    MPI_Recv((recvIndexSrcBuff + receivedSize),
330             recvBuffSize,
331             MPI_INT,
332             clientSrcRank,
333             8,
334             client->intraComm,
335             &recvStatus);
336
337    MPI_Recv((recvWeightBuff + receivedSize),
338             recvBuffSize,
339             MPI_DOUBLE,
340             clientSrcRank,
341             9,
342             client->intraComm,
343             &recvStatus);
344
345    for (int idx = 0; idx < countBuff; ++idx)
346    {
347      transformationMapping_[*(recvIndexDestBuff + receivedSize + idx)].push_back(*(recvIndexSrcBuff + receivedSize + idx));
348      transformationWeight_[*(recvIndexDestBuff + receivedSize + idx)].push_back(*(recvWeightBuff + receivedSize + idx));
349    }
350    receivedSize += countBuff;
351  }
352
353  std::vector<MPI_Status> requestStatus(sendRequest.size());
354  MPI_Waitall(sendRequest.size(), &sendRequest[0], MPI_STATUS_IGNORE);
355
356  delete [] sendIndexDestBuff;
357  delete [] sendIndexSrcBuff;
358  delete [] sendWeightBuff;
359  delete [] recvIndexDestBuff;
360  delete [] recvIndexSrcBuff;
361  delete [] recvWeightBuff;
362  delete [] sendBuff;
363  delete [] recvBuff;
364}
365
366/*!
367  Read interpolation information from a file
368  \param [in] filename interpolation file
369  \param [in/out] interpMapValue Mapping between (global) index of domain on grid destination and
370         corresponding global index of domain and associated weight value on grid source
371*/
372void CDomainAlgorithmInterpolate::readInterpolationInfo(std::string& filename,
373                                                        std::map<int,std::vector<std::pair<int,double> > >& interpMapValue)
374{
375  int ncid ;
376  int weightDimId ;
377  size_t nbWeightGlo ;
378
379  CContext* context = CContext::getCurrent();
380  CContextClient* client=context->client;
381  int clientRank = client->clientRank;
382  int clientSize = client->clientSize;
383
384  nc_open(filename.c_str(),NC_NOWRITE, &ncid) ;
385  nc_inq_dimid(ncid,"n_weight",&weightDimId) ;
386  nc_inq_dimlen(ncid,weightDimId,&nbWeightGlo) ;
387
388  size_t nbWeight ;
389  size_t start ;
390  size_t div = nbWeightGlo/clientSize ;
391  size_t mod = nbWeightGlo%clientSize ;
392  if (clientRank < mod)
393  {
394    nbWeight=div+1 ;
395    start=clientRank*(div+1) ;
396  }
397  else
398  {
399    nbWeight=div ;
400    start= mod * (div+1) + (clientRank-mod) * div ;
401  }
402
403  double* weight=new double[nbWeight] ;
404  int weightId ;
405  nc_inq_varid (ncid, "weight", &weightId) ;
406  nc_get_vara_double(ncid, weightId, &start, &nbWeight, weight) ;
407
408  long* srcIndex=new long[nbWeight] ;
409  int srcIndexId ;
410  nc_inq_varid (ncid, "src_idx", &srcIndexId) ;
411  nc_get_vara_long(ncid, srcIndexId, &start, &nbWeight, srcIndex) ;
412
413  long* dstIndex=new long[nbWeight] ;
414  int dstIndexId ;
415  nc_inq_varid (ncid, "dst_idx", &dstIndexId) ;
416  nc_get_vara_long(ncid, dstIndexId, &start, &nbWeight, dstIndex) ;
417
418  for(size_t ind=0; ind<nbWeight;++ind)
419    interpMapValue[dstIndex[ind]-1].push_back(make_pair(srcIndex[ind]-1,weight[ind]));
420}
421
422}
Note: See TracBrowser for help on using the repository browser.