source: XIOS/dev/dev_ym/XIOS_COUPLING/src/filter/spatial_transform_filter.cpp @ 1794

Last change on this file since 1794 was 1794, checked in by ymipsl, 2 years ago
  • add some comment about grid map/array/indexes
  • Add some "_" to suffix data members of the class

YM

File size: 14.4 KB
Line 
1#include "spatial_transform_filter.hpp"
2#include "grid_transformation.hpp"
3#include "context.hpp"
4#include "context_client.hpp"
5#include "timer.hpp"
6
7namespace xios
8{
9  CSpatialTransformFilter::CSpatialTransformFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, double outputValue, size_t inputSlotsCount)
10    : CFilter(gc, inputSlotsCount, engine), outputDefaultValue(outputValue)
11  { /* Nothing to do */ }
12
13  std::pair<std::shared_ptr<CSpatialTransformFilter>, std::shared_ptr<CSpatialTransformFilter> >
14  CSpatialTransformFilter::buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid, bool hasMissingValue, double missingValue)
15  {
16    if (!srcGrid || !destGrid)
17      ERROR("std::pair<std::shared_ptr<CSpatialTransformFilter>, std::shared_ptr<CSpatialTransformFilter> >"
18            "buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)",
19            "Impossible to build the filter graph if either the source or the destination grid are null.");
20
21    std::shared_ptr<CSpatialTransformFilter> firstFilter, lastFilter;
22    // Note that this loop goes from the last transformation to the first transformation
23    do
24    {
25      CGridTransformation* gridTransformation = destGrid->getTransformations();
26      CSpatialTransformFilterEngine* engine = CSpatialTransformFilterEngine::get(destGrid->getTransformations());
27      const std::vector<StdString>& auxInputs = gridTransformation->getAuxInputs();
28      size_t inputCount = 1 + (auxInputs.empty() ? 0 : auxInputs.size());
29      double defaultValue  = (hasMissingValue) ? std::numeric_limits<double>::quiet_NaN() : 0.0;
30
31
32      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
33      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
34
35      bool isSpatialTemporal=false ;
36      for (it=algoList.begin();it!=algoList.end();++it)  if (it->second.first == TRANS_TEMPORAL_SPLITTING) isSpatialTemporal=true ;
37
38      std::shared_ptr<CSpatialTransformFilter> filter ;
39      if( isSpatialTemporal) filter = std::shared_ptr<CSpatialTransformFilter>(new CSpatialTemporalFilter(gc, engine, gridTransformation, defaultValue, inputCount));
40      else filter = std::shared_ptr<CSpatialTransformFilter>(new CSpatialTransformFilter(gc, engine, defaultValue, inputCount));
41
42     
43      if (!lastFilter)
44        lastFilter = filter;
45      else
46        filter->connectOutput(firstFilter, 0);
47
48      firstFilter = filter;
49      for (size_t idx = 0; idx < auxInputs.size(); ++idx)
50      {
51        CField* fieldAuxInput = CField::get(auxInputs[idx]);
52        fieldAuxInput->buildFilterGraph(gc, false);
53        fieldAuxInput->getInstantDataFilter()->connectOutput(firstFilter,idx+1);
54      }
55
56      destGrid = gridTransformation->getGridSource();
57    }
58    while (destGrid != srcGrid);
59
60    return std::make_pair(firstFilter, lastFilter);
61  }
62
63  void CSpatialTransformFilter::onInputReady(std::vector<CDataPacketPtr> data)
64  {
65    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
66    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
67    if (outputPacket)
68      onOutputReady(outputPacket);
69  }
70
71  CSpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount)
72    : CSpatialTransformFilter(gc, engine, outputValue, inputSlotsCount), record(0)
73  {
74      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
75      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
76
77      int pos ;
78      for (it=algoList.begin();it!=algoList.end();++it) 
79        if (it->second.first == TRANS_TEMPORAL_SPLITTING)
80        {
81          pos=it->first ;
82          if (pos < algoList.size()-1)
83            ERROR("SpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount))",
84                  "temporal splitting operation must be the last of whole transformation on same grid") ;
85        }
86         
87      CGrid* grid=gridTransformation->getGridDestination() ;
88
89      CAxis* axis = grid->getAxis(gridTransformation->getElementPositionInGridDst2AxisPosition().find(pos)->second) ;
90
91      nrecords = axis->index.numElements() ;
92  }
93
94
95  void CSpatialTemporalFilter::onInputReady(std::vector<CDataPacketPtr> data)
96  {
97    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
98    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
99
100    if (outputPacket)
101    {
102      size_t nelements=outputPacket->data.numElements() ;
103      if (!tmpData.numElements())
104      {
105        tmpData.resize(nelements);
106        tmpData=outputDefaultValue ;
107      }
108
109      nelements/=nrecords ;
110      size_t offset=nelements*record ;
111      for(size_t i=0;i<nelements;++i)  tmpData(i+offset) = outputPacket->data(i) ;
112   
113      record ++ ;
114      if (record==nrecords)
115      {
116        record=0 ;
117        CDataPacketPtr packet = CDataPacketPtr(new CDataPacket);
118        packet->date = data[0]->date;
119        packet->timestamp = data[0]->timestamp;
120        packet->status = data[0]->status;
121        packet->data.resize(tmpData.numElements());
122        packet->data = tmpData;
123        onOutputReady(packet);
124        tmpData.resize(0) ;
125      }
126    }
127  }
128
129
130  CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)
131    : gridTransformation(gridTransformation)
132  {
133    if (!gridTransformation)
134      ERROR("CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)",
135            "Impossible to construct a spatial transform filter engine without a valid grid transformation.");
136  }
137
138  std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> > CSpatialTransformFilterEngine::engines;
139
140  CSpatialTransformFilterEngine* CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)
141  {
142    if (!gridTransformation)
143      ERROR("CSpatialTransformFilterEngine& CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)",
144            "Impossible to get the requested engine, the grid transformation is invalid.");
145
146    std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines.find(gridTransformation);
147    if (it == engines.end())
148    {
149      std::shared_ptr<CSpatialTransformFilterEngine> engine(new CSpatialTransformFilterEngine(gridTransformation));
150      it = engines.insert(std::make_pair(gridTransformation, engine)).first;
151    }
152
153    return it->second.get();
154  }
155
156  CDataPacketPtr CSpatialTransformFilterEngine::apply(std::vector<CDataPacketPtr> data)
157  {
158    /* Nothing to do */
159  }
160
161  CDataPacketPtr CSpatialTransformFilterEngine::applyFilter(std::vector<CDataPacketPtr> data, double defaultValue)
162  {
163    CDataPacketPtr packet(new CDataPacket);
164    packet->date = data[0]->date;
165    packet->timestamp = data[0]->timestamp;
166    packet->status = data[0]->status;
167
168    if (packet->status == CDataPacket::NO_ERROR)
169    {
170      if (1 < data.size())  // Dynamical transformations
171      {
172        std::vector<CArray<double,1>* > dataAuxInputs(data.size()-1);
173        for (size_t idx = 0; idx < dataAuxInputs.size(); ++idx) dataAuxInputs[idx] = &(data[idx+1]->data);
174        gridTransformation->computeAll(dataAuxInputs, packet->timestamp);
175      }
176      packet->data.resize(gridTransformation->getGridDestination()->storeIndex_client_.numElements());
177      if (0 != packet->data.numElements())
178        (packet->data)(0) = defaultValue;
179      apply(data[0]->data, packet->data);
180    }
181
182    return packet;
183  }
184
185  void CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)
186  {
187    CTimer::get("CSpatialTransformFilterEngine::apply").resume(); 
188   
189    CContext* context = CContext::getCurrent();
190    int rank;
191    MPI_Comm_rank (context->intraComm_, &rank);
192
193    // Get default value for output data
194    bool ignoreMissingValue = false; 
195    double defaultValue = std::numeric_limits<double>::quiet_NaN();
196    if (0 != dataDest.numElements()) ignoreMissingValue = NumTraits<double>::isNan(dataDest(0));
197
198    const std::list<CGridTransformation::SendingIndexGridSourceMap>& listLocalIndexSend = gridTransformation->getLocalIndexToSendFromGridSource();
199    const std::list<CGridTransformation::RecvIndexGridDestinationMap>& listLocalIndexToReceive = gridTransformation->getLocalIndexToReceiveOnGridDest();
200    const std::list<size_t>& listNbLocalIndexToReceive = gridTransformation->getNbLocalIndexToReceiveOnGridDest();
201    const std::vector<CGenericAlgorithmTransformation*>& listAlgos = gridTransformation->getAlgos();
202
203    CArray<double,1> dataCurrentDest(dataSrc.copy());
204
205    std::list<CGridTransformation::SendingIndexGridSourceMap>::const_iterator itListSend  = listLocalIndexSend.begin(),
206                                                                              iteListSend = listLocalIndexSend.end();
207    std::list<CGridTransformation::RecvIndexGridDestinationMap>::const_iterator itListRecv = listLocalIndexToReceive.begin();
208    std::list<size_t>::const_iterator itNbListRecv = listNbLocalIndexToReceive.begin();
209    std::vector<CGenericAlgorithmTransformation*>::const_iterator itAlgo = listAlgos.begin();
210
211    for (; itListSend != iteListSend; ++itListSend, ++itListRecv, ++itNbListRecv, ++itAlgo)
212    {
213      CArray<double,1> dataCurrentSrc(dataCurrentDest);
214      const CGridTransformation::SendingIndexGridSourceMap& localIndexToSend = *itListSend;
215
216      // Sending data from field sources to do transformations
217      std::map<int, CArray<int,1> >::const_iterator itbSend = localIndexToSend.begin(), itSend,
218                                                    iteSend = localIndexToSend.end();
219      int idxSendBuff = 0;
220      std::vector<double*> sendBuff(localIndexToSend.size());
221      double* sendBuffRank;
222      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
223      {
224        int destRank = itSend->first;
225        if (0 != itSend->second.numElements())
226        {
227          if (rank != itSend->first)
228            sendBuff[idxSendBuff] = new double[itSend->second.numElements()];
229          else
230            sendBuffRank = new double[itSend->second.numElements()];
231        }
232      }
233
234      idxSendBuff = 0;
235      std::vector<MPI_Request> sendRecvRequest;
236      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
237      {
238        int destRank = itSend->first;
239        const CArray<int,1>& localIndex_p = itSend->second;
240        int countSize = localIndex_p.numElements();
241        if (destRank != rank)
242        {
243          for (int idx = 0; idx < countSize; ++idx)
244          {
245            sendBuff[idxSendBuff][idx] = dataCurrentSrc(localIndex_p(idx));
246          }
247          sendRecvRequest.push_back(MPI_Request());
248          MPI_Isend(sendBuff[idxSendBuff], countSize, MPI_DOUBLE, destRank, 12, context->intraComm_, &sendRecvRequest.back());
249        }
250        else
251        {
252          for (int idx = 0; idx < countSize; ++idx)
253          {
254            sendBuffRank[idx] = dataCurrentSrc(localIndex_p(idx));
255          }
256        }
257      }
258
259      // Receiving data on destination fields
260      const CGridTransformation::RecvIndexGridDestinationMap& localIndexToReceive = *itListRecv;
261      CGridTransformation::RecvIndexGridDestinationMap::const_iterator itbRecv = localIndexToReceive.begin(), itRecv,
262                                                                       iteRecv = localIndexToReceive.end();
263      int recvBuffSize = 0;
264      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
265      {
266        if (itRecv->first != rank )
267          recvBuffSize += itRecv->second.size();
268      }
269      //(recvBuffSize < itRecv->second.size()) ? itRecv->second.size() : recvBuffSize;
270      double* recvBuff;
271
272      if (0 != recvBuffSize) recvBuff = new double[recvBuffSize];
273      int currentBuff = 0;
274      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
275      {
276        int srcRank = itRecv->first;
277        if (srcRank != rank)
278        {
279          int countSize = itRecv->second.size();
280          sendRecvRequest.push_back(MPI_Request());
281          MPI_Irecv(recvBuff + currentBuff, countSize, MPI_DOUBLE, srcRank, 12, context->intraComm_, &sendRecvRequest.back());
282          currentBuff += countSize;
283        }
284      }
285      std::vector<MPI_Status> status(sendRecvRequest.size());
286      MPI_Waitall(sendRecvRequest.size(), &sendRecvRequest[0], &status[0]);
287
288      dataCurrentDest.resize(*itNbListRecv);
289      dataCurrentDest = 0.0;
290
291      std::vector<bool> localInitFlag(dataCurrentDest.numElements(), true);
292      currentBuff = 0;
293      bool firstPass=true; 
294      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
295      {
296        const std::vector<std::pair<int,double> >& localIndex_p = itRecv->second;
297        int srcRank = itRecv->first;
298        if (srcRank != rank)
299        {
300          int countSize = itRecv->second.size();
301          (*itAlgo)->apply(localIndex_p,
302                           recvBuff+currentBuff,
303                           dataCurrentDest,
304                           localInitFlag,
305                           ignoreMissingValue,firstPass);
306          currentBuff += countSize;
307        }
308        else
309        {
310          (*itAlgo)->apply(localIndex_p,
311                           sendBuffRank,
312                           dataCurrentDest,
313                           localInitFlag,
314                           ignoreMissingValue,firstPass);
315        }
316
317        firstPass=false ;
318      }
319
320      (*itAlgo)->updateData(dataCurrentDest);
321
322      idxSendBuff = 0;
323      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
324      {
325        if (0 != itSend->second.numElements())
326        {
327          if (rank != itSend->first)
328            delete [] sendBuff[idxSendBuff];
329          else
330            delete [] sendBuffRank;
331        }
332      }
333      if (0 != recvBuffSize) delete [] recvBuff;
334    }
335    if (dataCurrentDest.numElements() != dataDest.numElements())
336    ERROR("CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)",
337          "Incoherent between the received size and expected size. " << std::endl
338          << "Expected size: " << dataDest.numElements() << std::endl
339          << "Received size: " << dataCurrentDest.numElements());
340
341    dataDest = dataCurrentDest;
342
343    CTimer::get("CSpatialTransformFilterEngine::apply").suspend() ;
344  }
345} // namespace xios
Note: See TracBrowser for help on using the repository browser.