source: XIOS/dev/XIOS_DEV_CMIP6/src/filter/spatial_transform_filter.cpp @ 1275

Last change on this file since 1275 was 1275, checked in by ymipsl, 3 years ago

implement diurnal cycle transformation taken as a grid tranformation : scalar -> axis

YM

File size: 13.7 KB
Line 
1#include "spatial_transform_filter.hpp"
2#include "grid_transformation.hpp"
3#include "context.hpp"
4#include "context_client.hpp"
5
6namespace xios
7{
8  CSpatialTransformFilter::CSpatialTransformFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, double outputValue, size_t inputSlotsCount)
9    : CFilter(gc, inputSlotsCount, engine), outputDefaultValue(outputValue)
10  { /* Nothing to do */ }
11
12  std::pair<boost::shared_ptr<CSpatialTransformFilter>, boost::shared_ptr<CSpatialTransformFilter> >
13  CSpatialTransformFilter::buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid, bool hasMissingValue, double missingValue)
14  {
15    if (!srcGrid || !destGrid)
16      ERROR("std::pair<boost::shared_ptr<CSpatialTransformFilter>, boost::shared_ptr<CSpatialTransformFilter> >"
17            "buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)",
18            "Impossible to build the filter graph if either the source or the destination grid are null.");
19
20    boost::shared_ptr<CSpatialTransformFilter> firstFilter, lastFilter;
21    // Note that this loop goes from the last transformation to the first transformation
22    do
23    {
24      CGridTransformation* gridTransformation = destGrid->getTransformations();
25      CSpatialTransformFilterEngine* engine = CSpatialTransformFilterEngine::get(destGrid->getTransformations());
26      const std::vector<StdString>& auxInputs = gridTransformation->getAuxInputs();
27      size_t inputCount = 1 + (auxInputs.empty() ? 0 : auxInputs.size());
28      double defaultValue  = (hasMissingValue) ? std::numeric_limits<double>::quiet_NaN() : 0.0;
29
30
31      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
32      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
33
34      bool isSpatialTemporal=false ;
35      for (it=algoList.begin();it!=algoList.end();++it)  if (it->second.first == TRANS_TEMPORAL_SPLITTING) isSpatialTemporal=true ;
36
37      boost::shared_ptr<CSpatialTransformFilter> filter ;
38      if( isSpatialTemporal) filter = boost::shared_ptr<CSpatialTransformFilter>(new CSpatialTemporalFilter(gc, engine, gridTransformation, defaultValue, inputCount));
39      else filter = boost::shared_ptr<CSpatialTransformFilter>(new CSpatialTransformFilter(gc, engine, defaultValue, inputCount));
40
41     
42      if (!lastFilter)
43        lastFilter = filter;
44      else
45        filter->connectOutput(firstFilter, 0);
46
47      firstFilter = filter;
48      for (size_t idx = 0; idx < auxInputs.size(); ++idx)
49      {
50        CField* fieldAuxInput = CField::get(auxInputs[idx]);
51        fieldAuxInput->buildFilterGraph(gc, false);
52        fieldAuxInput->getInstantDataFilter()->connectOutput(firstFilter,idx+1);
53      }
54
55      destGrid = gridTransformation->getGridSource();
56    }
57    while (destGrid != srcGrid);
58
59    return std::make_pair(firstFilter, lastFilter);
60  }
61
62  void CSpatialTransformFilter::onInputReady(std::vector<CDataPacketPtr> data)
63  {
64    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
65    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
66    if (outputPacket)
67      onOutputReady(outputPacket);
68  }
69
70
71
72
73
74  CSpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount)
75    : CSpatialTransformFilter(gc, engine, outputValue, inputSlotsCount), record(0)
76  {
77      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
78      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
79
80      int pos ;
81      for (it=algoList.begin();it!=algoList.end();++it) 
82        if (it->second.first == TRANS_TEMPORAL_SPLITTING)
83        {
84          pos=it->first ;
85          if (pos < algoList.size()-1)
86            ERROR("SpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount))",
87                  "temporal splitting operation must be the last of whole transformation on same grid") ;
88        }
89         
90      CGrid* grid=gridTransformation->getGridDestination() ;
91
92      CAxis* axis = grid->getAxis(gridTransformation->getElementPositionInGridDst2AxisPosition().find(pos)->second) ;
93
94      nrecords = axis->index.numElements() ;
95  }
96
97
98  void CSpatialTemporalFilter::onInputReady(std::vector<CDataPacketPtr> data)
99  {
100    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
101    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
102
103    if (outputPacket)
104    {
105      size_t nelements=outputPacket->data.numElements() ;
106      if (!tmpData.numElements())
107      {
108        tmpData.resize(nelements);
109        tmpData=outputDefaultValue ;
110      }
111
112      nelements/=nrecords ;
113      size_t offset=nelements*record ;
114      for(size_t i=0;i<nelements;++i)  tmpData(i+offset) = outputPacket->data(i) ;
115   
116      record ++ ;
117      if (record==nrecords)
118      {
119        record=0 ;
120        CDataPacketPtr packet = CDataPacketPtr(new CDataPacket);
121        packet->date = data[0]->date;
122        packet->timestamp = data[0]->timestamp;
123        packet->status = data[0]->status;
124        packet->data.resize(tmpData.numElements());
125        packet->data = tmpData;
126        onOutputReady(packet);
127        tmpData.resize(0) ;
128      }
129    }
130  }
131
132
133  CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)
134    : gridTransformation(gridTransformation)
135  {
136    if (!gridTransformation)
137      ERROR("CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)",
138            "Impossible to construct a spatial transform filter engine without a valid grid transformation.");
139  }
140
141  std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> > CSpatialTransformFilterEngine::engines;
142
143  CSpatialTransformFilterEngine* CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)
144  {
145    if (!gridTransformation)
146      ERROR("CSpatialTransformFilterEngine& CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)",
147            "Impossible to get the requested engine, the grid transformation is invalid.");
148
149    std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines.find(gridTransformation);
150    if (it == engines.end())
151    {
152      boost::shared_ptr<CSpatialTransformFilterEngine> engine(new CSpatialTransformFilterEngine(gridTransformation));
153      it = engines.insert(std::make_pair(gridTransformation, engine)).first;
154    }
155
156    return it->second.get();
157  }
158
159  CDataPacketPtr CSpatialTransformFilterEngine::apply(std::vector<CDataPacketPtr> data)
160  {
161    /* Nothing to do */
162  }
163
164  CDataPacketPtr CSpatialTransformFilterEngine::applyFilter(std::vector<CDataPacketPtr> data, double defaultValue)
165  {
166    CDataPacketPtr packet(new CDataPacket);
167    packet->date = data[0]->date;
168    packet->timestamp = data[0]->timestamp;
169    packet->status = data[0]->status;
170
171    if (packet->status == CDataPacket::NO_ERROR)
172    {
173      if (1 < data.size())  // Dynamical transformations
174      {
175        std::vector<CArray<double,1>* > dataAuxInputs(data.size()-1);
176        for (size_t idx = 0; idx < dataAuxInputs.size(); ++idx) dataAuxInputs[idx] = &(data[idx+1]->data);
177        gridTransformation->computeAll(dataAuxInputs, packet->timestamp);
178      }
179      packet->data.resize(gridTransformation->getGridDestination()->storeIndex_client.numElements());
180      if (0 != packet->data.numElements())
181        (packet->data)(0) = defaultValue;
182      apply(data[0]->data, packet->data);
183    }
184
185    return packet;
186  }
187
188  void CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)
189  {
190    CContextClient* client = CContext::getCurrent()->client;
191
192    // Get default value for output data
193    bool ignoreMissingValue = false; 
194    double defaultValue = std::numeric_limits<double>::quiet_NaN();
195    if (0 != dataDest.numElements()) ignoreMissingValue = NumTraits<double>::isnan(dataDest(0));
196
197    const std::list<CGridTransformation::SendingIndexGridSourceMap>& listLocalIndexSend = gridTransformation->getLocalIndexToSendFromGridSource();
198    const std::list<CGridTransformation::RecvIndexGridDestinationMap>& listLocalIndexToReceive = gridTransformation->getLocalIndexToReceiveOnGridDest();
199    const std::list<size_t>& listNbLocalIndexToReceive = gridTransformation->getNbLocalIndexToReceiveOnGridDest();
200    const std::list<std::vector<bool> >& listLocalIndexMaskOnDest = gridTransformation->getLocalMaskIndexOnGridDest();
201    const std::vector<CGenericAlgorithmTransformation*>& listAlgos = gridTransformation->getAlgos();
202
203    CArray<double,1> dataCurrentDest(dataSrc.copy());
204
205    std::list<CGridTransformation::SendingIndexGridSourceMap>::const_iterator itListSend  = listLocalIndexSend.begin(),
206                                                                              iteListSend = listLocalIndexSend.end();
207    std::list<CGridTransformation::RecvIndexGridDestinationMap>::const_iterator itListRecv = listLocalIndexToReceive.begin();
208    std::list<size_t>::const_iterator itNbListRecv = listNbLocalIndexToReceive.begin();
209    std::list<std::vector<bool> >::const_iterator itLocalMaskIndexOnDest = listLocalIndexMaskOnDest.begin();
210    std::vector<CGenericAlgorithmTransformation*>::const_iterator itAlgo = listAlgos.begin();
211
212    for (; itListSend != iteListSend; ++itListSend, ++itListRecv, ++itNbListRecv, ++itLocalMaskIndexOnDest, ++itAlgo)
213    {
214      CArray<double,1> dataCurrentSrc(dataCurrentDest);
215      const CGridTransformation::SendingIndexGridSourceMap& localIndexToSend = *itListSend;
216
217      // Sending data from field sources to do transformations
218      std::map<int, CArray<int,1> >::const_iterator itbSend = localIndexToSend.begin(), itSend,
219                                                    iteSend = localIndexToSend.end();
220      int idxSendBuff = 0;
221      std::vector<double*> sendBuff(localIndexToSend.size());
222      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
223      {
224        if (0 != itSend->second.numElements())
225          sendBuff[idxSendBuff] = new double[itSend->second.numElements()];
226      }
227
228      idxSendBuff = 0;
229      std::vector<MPI_Request> sendRecvRequest;
230      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
231      {
232        int destRank = itSend->first;
233        const CArray<int,1>& localIndex_p = itSend->second;
234        int countSize = localIndex_p.numElements();
235        for (int idx = 0; idx < countSize; ++idx)
236        {
237          sendBuff[idxSendBuff][idx] = dataCurrentSrc(localIndex_p(idx));
238        }
239        sendRecvRequest.push_back(MPI_Request());
240        MPI_Isend(sendBuff[idxSendBuff], countSize, MPI_DOUBLE, destRank, 12, client->intraComm, &sendRecvRequest.back());
241      }
242
243      // Receiving data on destination fields
244      const CGridTransformation::RecvIndexGridDestinationMap& localIndexToReceive = *itListRecv;
245      CGridTransformation::RecvIndexGridDestinationMap::const_iterator itbRecv = localIndexToReceive.begin(), itRecv,
246                                                                       iteRecv = localIndexToReceive.end();
247      int recvBuffSize = 0;
248      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv) recvBuffSize += itRecv->second.size(); //(recvBuffSize < itRecv->second.size())
249                                                                       //? itRecv->second.size() : recvBuffSize;
250      double* recvBuff;
251      if (0 != recvBuffSize) recvBuff = new double[recvBuffSize];
252      int currentBuff = 0;
253      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
254      {
255        int srcRank = itRecv->first;
256        int countSize = itRecv->second.size();
257        sendRecvRequest.push_back(MPI_Request());
258        MPI_Irecv(recvBuff + currentBuff, countSize, MPI_DOUBLE, srcRank, 12, client->intraComm, &sendRecvRequest.back());
259        currentBuff += countSize;
260      }
261      std::vector<MPI_Status> status(sendRecvRequest.size());
262      MPI_Waitall(sendRecvRequest.size(), &sendRecvRequest[0], &status[0]);
263
264      dataCurrentDest.resize(*itNbListRecv);
265      const std::vector<bool>& localMaskDest = *itLocalMaskIndexOnDest;
266      for (int i = 0; i < localMaskDest.size(); ++i)
267        if (localMaskDest[i]) dataCurrentDest(i) = 0.0;
268        else dataCurrentDest(i) = defaultValue;
269
270      std::vector<bool> localInitFlag(dataCurrentDest.numElements(), true);
271      currentBuff = 0;
272      bool firstPass=true; 
273      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
274      {
275        int countSize = itRecv->second.size();
276        const std::vector<std::pair<int,double> >& localIndex_p = itRecv->second;
277        (*itAlgo)->apply(localIndex_p,
278                         recvBuff+currentBuff,
279                         dataCurrentDest,
280                         localInitFlag,
281                         ignoreMissingValue,firstPass);
282
283        currentBuff += countSize;
284        firstPass=false ;
285      }
286
287      (*itAlgo)->updateData(dataCurrentDest);
288
289      idxSendBuff = 0;
290      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
291      {
292        if (0 != itSend->second.numElements())
293          delete [] sendBuff[idxSendBuff];
294      }
295      if (0 != recvBuffSize) delete [] recvBuff;
296    }
297    if (dataCurrentDest.numElements() != dataDest.numElements())
298    ERROR("CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)",
299          "Incoherent between the received size and expected size. " << std::endl
300          << "Expected size: " << dataDest.numElements() << std::endl
301          << "Received size: " << dataCurrentDest.numElements());
302
303    dataDest = dataCurrentDest;
304  }
305} // namespace xios
Note: See TracBrowser for help on using the repository browser.