source: XIOS/trunk/src/transformation/axis_algorithm_interpolate.cpp @ 668

Last change on this file since 668 was 668, checked in by mhnguyen, 9 years ago

Implementing some code factoring

+) Replace some slow searching function by faster ones

Test
+) On Curie
+) test_client and test_complete are correct

File size: 7.5 KB
Line 
1/*!
2   \file axis_algorithm_interpolate.cpp
3   \author Ha NGUYEN
4   \since 23 June 2015
5   \date 02 Jul 2015
6
7   \brief Algorithm for interpolation on an axis.
8 */
9#include "axis_algorithm_interpolate.hpp"
10#include <algorithm>
11#include "context.hpp"
12#include "context_client.hpp"
13#include "utils.hpp"
14
15namespace xios {
16
17CAxisAlgorithmInterpolate::CAxisAlgorithmInterpolate(CAxis* axisDestination, CAxis* axisSource, CInterpolateAxis* interpAxis)
18: CAxisAlgorithmTransformation(axisDestination, axisSource)
19{
20  interpAxis->checkValid(axisSource);
21  order_ = interpAxis->order.getValue();
22  if (order_ >= axisSource->n_glo.getValue())
23  {
24    ERROR("CAxisAlgorithmInterpolate::CAxisAlgorithmInterpolate(CAxis* axisDestination, CAxis* axisSource, CInterpolateAxis* interpAxis)",
25           << "Order of interpolation is greater than global size of axis source"
26           << "Size of axis source " <<axisSource->getId() << " is " << axisSource->n_glo.getValue()  << std::endl
27           << "Order of interpolation is " << order_ );
28  }
29
30  computeIndexSourceMapping();
31}
32
33/*!
34  Compute the index mapping between axis on grid source and one on grid destination
35*/
36void CAxisAlgorithmInterpolate::computeIndexSourceMapping()
37{
38  CArray<double,1>& axisValue = axisSrc_->value;
39  CArray<bool,1>& axisMask = axisSrc_->mask;
40
41  CContext* context = CContext::getCurrent();
42  CContextClient* client=context->client;
43  int nbClient = client->clientSize;
44
45  int srcSize  = axisSrc_->n_glo.getValue();
46  int numValue = axisValue.numElements();
47
48  std::vector<double> recvBuff(srcSize);
49  std::vector<int> indexVec(srcSize);
50
51  retrieveAllAxisValue(recvBuff, indexVec);
52  XIOSAlgorithms::sortWithIndex<double, CVectorStorage>(recvBuff, indexVec);
53  computeInterpolantPoint(recvBuff, indexVec);
54}
55
56/*!
57  Compute the interpolant points
58  Assume that we have all value of axis source, with these values, need to calculate weight (coeff) of Lagrange polynomial
59  \param [in] axisValue all value of axis source
60  \param [in] indexVec permutation index of axisValue
61*/
62void CAxisAlgorithmInterpolate::computeInterpolantPoint(const std::vector<double>& axisValue, const std::vector<int>& indexVec)
63{
64  std::vector<double>::const_iterator itb = axisValue.begin(), ite = axisValue.end();
65  std::vector<double>::const_iterator itLowerBound, itUpperBound, it;
66  std::vector<int>::const_iterator itbVec = indexVec.begin(), itVec;
67  const double sfmax = NumTraits<double>::sfmax();
68
69  int ibegin = axisDest_->begin.getValue();
70  CArray<double,1>& axisDestValue = axisDest_->value;
71  int numValue = axisDestValue.numElements();
72  std::map<int, std::vector<std::pair<int,double> > > interpolatingIndexValues;
73
74  for (int idx = 0; idx < numValue; ++idx)
75  {
76    double destValue = axisDestValue(idx);
77    itLowerBound = std::lower_bound(itb, ite, destValue);
78    itUpperBound = std::upper_bound(itb, ite, destValue);
79    if ((ite != itUpperBound) && (sfmax == *itUpperBound)) itUpperBound = ite;
80
81
82    // If the value is not in the range, that means we'll do extra-polation
83    if (ite == itLowerBound) // extra-polation
84    {
85      itLowerBound = itb;
86      itUpperBound = itb + order_+1;
87    }
88    else if (ite == itUpperBound) // extra-polation
89    {
90      itLowerBound = itUpperBound - order_-1;
91    }
92    else
93    {
94      if (itb != itLowerBound) --itLowerBound;
95      if (ite != itUpperBound) ++itUpperBound;
96      int order = (order_ + 1) - 2;
97      bool down = true;
98      for (int k = 0; k < order; ++k)
99      {
100        if ((itb != itLowerBound) && down)
101        {
102          --itLowerBound;
103          down = false;
104          continue;
105        }
106        if ((ite != itUpperBound) && (sfmax != *itUpperBound))
107        {
108          ++itUpperBound;
109          down = true;
110        }
111      }
112    }
113
114    for (it = itLowerBound; it != itUpperBound; ++it)
115    {
116      int index = std::distance(itb, it);
117      interpolatingIndexValues[idx+ibegin].push_back(make_pair(indexVec[index],*it));
118    }
119  }
120  computeWeightedValueAndMapping(interpolatingIndexValues);
121}
122
123/*!
124  Compute weight (coeff) of Lagrange's polynomial
125  \param [in] interpolatingIndexValues the necessary axis value to calculate the coeffs
126*/
127void CAxisAlgorithmInterpolate::computeWeightedValueAndMapping(const std::map<int, std::vector<std::pair<int,double> > >& interpolatingIndexValues)
128{
129  std::map<int, std::vector<int> >& transMap = this->transformationMapping_;
130  std::map<int, std::vector<double> >& transWeight = this->transformationWeight_;
131  std::map<int, std::vector<std::pair<int,double> > >::const_iterator itb = interpolatingIndexValues.begin(), it,
132                                                                      ite = interpolatingIndexValues.end();
133  int ibegin = axisDest_->begin.getValue();
134  for (it = itb; it != ite; ++it)
135  {
136    int globalIndexDest = it->first;
137    double localValue = axisDest_->value(globalIndexDest - ibegin);
138    const std::vector<std::pair<int,double> >& interpVal = it->second;
139    int interpSize = interpVal.size();
140    for (int idx = 0; idx < interpSize; ++idx)
141    {
142      int index = interpVal[idx].first;
143      double weight = 1.0;
144
145      for (int k = 0; k < interpSize; ++k)
146      {
147        if (k == idx) continue;
148        weight *= (localValue - interpVal[k].second);
149        weight /= (interpVal[idx].second - interpVal[k].second);
150      }
151      transMap[globalIndexDest].push_back(index);
152      transWeight[globalIndexDest].push_back(weight);
153    }
154  }
155}
156
157/*!
158  Each client retrieves all values of an axis
159  \param [in/out] recvBuff buffer for receiving values (already allocated)
160  \param [in/out] indexVec mapping between values and global index of axis
161*/
162void CAxisAlgorithmInterpolate::retrieveAllAxisValue(std::vector<double>& recvBuff, std::vector<int>& indexVec)
163{
164  CArray<double,1>& axisValue = axisSrc_->value;
165  CArray<bool,1>& axisMask = axisSrc_->mask;
166
167  CContext* context = CContext::getCurrent();
168  CContextClient* client=context->client;
169  int nbClient = client->clientSize;
170
171  int srcSize  = axisSrc_->n_glo.getValue();
172  int numValue = axisValue.numElements();
173
174  if (srcSize == numValue)  // Only one client or axis not distributed
175  {
176    for (int idx = 0; idx < srcSize; ++idx)
177    {
178      if (axisMask(idx))
179      {
180        recvBuff[idx] = axisValue(idx);
181        indexVec[idx] = idx;
182      }
183      else recvBuff[idx] = NumTraits<double>::sfmax();
184    }
185
186  }
187  else // Axis distributed
188  {
189    double* sendValueBuff = new double [numValue];
190    int* sendIndexBuff = new int [numValue];
191    int* recvIndexBuff = new int [srcSize];
192
193    int ibegin = axisSrc_->begin.getValue();
194    for (int idx = 0; idx < numValue; ++idx)
195    {
196      if (axisMask(idx))
197      {
198        sendValueBuff[idx] = axisValue(idx);
199        sendIndexBuff[idx] = idx + ibegin;
200      }
201      else
202      {
203        sendValueBuff[idx] = NumTraits<double>::sfmax();
204        sendIndexBuff[idx] = -1;
205      }
206    }
207
208    int* recvCount=new int[nbClient];
209    MPI_Allgather(&numValue,1,MPI_INT,recvCount,1,MPI_INT,client->intraComm);
210
211    int* displ=new int[nbClient];
212    displ[0]=0 ;
213    for(int n=1;n<nbClient;n++) displ[n]=displ[n-1]+recvCount[n-1];
214
215    // Each client have enough global info of axis
216    MPI_Allgatherv(sendIndexBuff,numValue,MPI_INT,recvIndexBuff,recvCount,displ,MPI_INT,client->intraComm);
217    MPI_Allgatherv(sendValueBuff,numValue,MPI_DOUBLE,&(recvBuff[0]),recvCount,displ,MPI_DOUBLE,client->intraComm);
218
219    for (int idx = 0; idx < srcSize; ++idx)
220    {
221      indexVec[idx] = recvIndexBuff[idx];
222    }
223
224    delete [] displ;
225    delete [] recvCount;
226    delete [] recvIndexBuff;
227    delete [] sendIndexBuff;
228    delete [] sendValueBuff;
229  }
230}
231
232}
Note: See TracBrowser for help on using the repository browser.