source: XIOS/dev/dev_trunk_omp/src/filter/spatial_transform_filter.cpp @ 1668

Last change on this file since 1668 was 1668, checked in by yushan, 2 years ago

MARK: branch merged with trunk @1663. static graph OK with EP

File size: 15.8 KB
Line 
1#include "mpi.hpp"
2#include "spatial_transform_filter.hpp"
3#include "grid_transformation.hpp"
4#include "context.hpp"
5#include "context_client.hpp"
6#include "timer.hpp"
7#ifdef _usingEP
8using namespace ep_lib;
9#endif
10#include "workflow_graph.hpp"
11namespace xios
12{
13  CSpatialTransformFilter::CSpatialTransformFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine,
14                                                   double outputValue, size_t inputSlotsCount, bool buildWorkflowGraph /*= false*/)
15    : CFilter(gc, inputSlotsCount, engine, buildWorkflowGraph), outputDefaultValue(outputValue)
16  { /* Nothing to do */ }
17
18  std::pair<std::shared_ptr<CSpatialTransformFilter>, std::shared_ptr<CSpatialTransformFilter> >
19  CSpatialTransformFilter::buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid, bool hasMissingValue, double missingValue,
20                                            bool buildWorkflowGraph)
21  {
22    if (!srcGrid || !destGrid)
23      ERROR("std::pair<std::shared_ptr<CSpatialTransformFilter>, std::shared_ptr<CSpatialTransformFilter> >"
24            "buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)",
25            "Impossible to build the filter graph if either the source or the destination grid are null.");
26
27    std::shared_ptr<CSpatialTransformFilter> firstFilter, lastFilter;
28    // Note that this loop goes from the last transformation to the first transformation
29    do
30    {
31      CGridTransformation* gridTransformation = destGrid->getTransformations();
32      CSpatialTransformFilterEngine* engine = CSpatialTransformFilterEngine::get(destGrid->getTransformations());
33      const std::vector<StdString>& auxInputs = gridTransformation->getAuxInputs();
34      size_t inputCount = 1 + (auxInputs.empty() ? 0 : auxInputs.size());
35      double defaultValue  = (hasMissingValue) ? std::numeric_limits<double>::quiet_NaN() : 0.0;
36
37
38      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
39      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
40
41      bool isSpatialTemporal=false ;
42      for (it=algoList.begin();it!=algoList.end();++it)  if (it->second.first == TRANS_TEMPORAL_SPLITTING) isSpatialTemporal=true ;
43
44      std::shared_ptr<CSpatialTransformFilter> filter ;
45      if( isSpatialTemporal)
46        filter = std::shared_ptr<CSpatialTransformFilter>(new CSpatialTemporalFilter(gc, engine, gridTransformation, defaultValue, inputCount, buildWorkflowGraph));
47      else
48        filter = std::shared_ptr<CSpatialTransformFilter>(new CSpatialTransformFilter(gc, engine, defaultValue, inputCount, buildWorkflowGraph));
49
50     
51      if (!lastFilter)
52        lastFilter = filter;
53      else
54      {
55        filter->connectOutput(firstFilter, 0);
56        if (buildWorkflowGraph)
57        {
58          if(CWorkflowGraph::mapFilters_ptr==0) CWorkflowGraph::mapFilters_ptr = new std::unordered_map <int, StdString>;
59          if(CWorkflowGraph::mapFieldToFilters_ptr==0) CWorkflowGraph::mapFieldToFilters_ptr = new std::unordered_map <StdString, vector <int> >;
60          int filterOut = (std::static_pointer_cast<COutputPin>(filter))->getFilterId();
61          int filterIn = (std::static_pointer_cast<COutputPin>(firstFilter))->getFilterId();
62          // PASS field's id here
63          (*CWorkflowGraph::mapFieldToFilters_ptr)["XXX"].push_back(filterOut);
64          (*CWorkflowGraph::mapFieldToFilters_ptr)["XXX"].push_back(filterIn);
65          (*CWorkflowGraph::mapFilters_ptr)[filterOut] = "Spatial transform filter";
66          (*CWorkflowGraph::mapFilters_ptr)[filterIn] = "Spatial transform filter";
67        }
68      }
69
70      firstFilter = filter;
71      for (size_t idx = 0; idx < auxInputs.size(); ++idx)
72      {
73        CField* fieldAuxInput = CField::get(auxInputs[idx]);
74        fieldAuxInput->buildFilterGraph(gc, false);
75        fieldAuxInput->getInstantDataFilter()->connectOutput(firstFilter,idx+1);
76      }
77
78      destGrid = gridTransformation->getGridSource();
79    }
80    while (destGrid != srcGrid);
81
82    return std::make_pair(firstFilter, lastFilter);
83  }
84
85  void CSpatialTransformFilter::onInputReady(std::vector<CDataPacketPtr> data)
86  {
87    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
88    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
89    if (outputPacket)
90      onOutputReady(outputPacket);
91  }
92
93  CSpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine,
94                                                  CGridTransformation* gridTransformation, double outputValue,
95                                                  size_t inputSlotsCount, bool buildWorkflowGraph)
96    : CSpatialTransformFilter(gc, engine, outputValue, inputSlotsCount, buildWorkflowGraph), record(0)
97  {
98      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
99      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
100
101      int pos ;
102      for (it=algoList.begin();it!=algoList.end();++it) 
103        if (it->second.first == TRANS_TEMPORAL_SPLITTING)
104        {
105          pos=it->first ;
106          if (pos < algoList.size()-1)
107            ERROR("SpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount))",
108                  "temporal splitting operation must be the last of whole transformation on same grid") ;
109        }
110         
111      CGrid* grid=gridTransformation->getGridDestination() ;
112
113      CAxis* axis = grid->getAxis(gridTransformation->getElementPositionInGridDst2AxisPosition().find(pos)->second) ;
114
115      nrecords = axis->index.numElements() ;
116  }
117
118
119  void CSpatialTemporalFilter::onInputReady(std::vector<CDataPacketPtr> data)
120  {
121    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
122    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
123
124    if (outputPacket)
125    {
126      size_t nelements=outputPacket->data.numElements() ;
127      if (!tmpData.numElements())
128      {
129        tmpData.resize(nelements);
130        tmpData=outputDefaultValue ;
131      }
132
133      nelements/=nrecords ;
134      size_t offset=nelements*record ;
135      for(size_t i=0;i<nelements;++i)  tmpData(i+offset) = outputPacket->data(i) ;
136   
137      record ++ ;
138      if (record==nrecords)
139      {
140        record=0 ;
141        CDataPacketPtr packet = CDataPacketPtr(new CDataPacket);
142        packet->date = data[0]->date;
143        packet->timestamp = data[0]->timestamp;
144        packet->status = data[0]->status;
145        packet->data.resize(tmpData.numElements());
146        packet->data = tmpData;
147        onOutputReady(packet);
148        tmpData.resize(0) ;
149      }
150    }
151  }
152
153
154  CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)
155    : gridTransformation(gridTransformation)
156  {
157    if (!gridTransformation)
158      ERROR("CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)",
159            "Impossible to construct a spatial transform filter engine without a valid grid transformation.");
160  }
161
162  std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> > *CSpatialTransformFilterEngine::engines_ptr = 0;
163
164  CSpatialTransformFilterEngine* CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)
165  {
166    if (!gridTransformation)
167      ERROR("CSpatialTransformFilterEngine& CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)",
168            "Impossible to get the requested engine, the grid transformation is invalid.");
169   
170    if(engines_ptr == NULL) engines_ptr = new std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> >;
171
172
173    std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines_ptr->find(gridTransformation);
174    if (it == engines_ptr->end())
175    {
176      std::shared_ptr<CSpatialTransformFilterEngine> engine(new CSpatialTransformFilterEngine(gridTransformation));
177      it = engines_ptr->insert(std::make_pair(gridTransformation, engine)).first;
178    }
179
180    return it->second.get();
181  }
182
183  CDataPacketPtr CSpatialTransformFilterEngine::apply(std::vector<CDataPacketPtr> data)
184  {
185    /* Nothing to do */
186  }
187
188  CDataPacketPtr CSpatialTransformFilterEngine::applyFilter(std::vector<CDataPacketPtr> data, double defaultValue)
189  {
190    CDataPacketPtr packet(new CDataPacket);
191    packet->date = data[0]->date;
192    packet->timestamp = data[0]->timestamp;
193    packet->status = data[0]->status;
194
195    if (packet->status == CDataPacket::NO_ERROR)
196    {
197      if (1 < data.size())  // Dynamical transformations
198      {
199        std::vector<CArray<double,1>* > dataAuxInputs(data.size()-1);
200        for (size_t idx = 0; idx < dataAuxInputs.size(); ++idx) dataAuxInputs[idx] = &(data[idx+1]->data);
201        gridTransformation->computeAll(dataAuxInputs, packet->timestamp);
202      }
203      packet->data.resize(gridTransformation->getGridDestination()->storeIndex_client.numElements());
204      if (0 != packet->data.numElements())
205        (packet->data)(0) = defaultValue;
206      apply(data[0]->data, packet->data);
207    }
208
209    return packet;
210  }
211
212  void CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)
213  {
214    CTimer::get("CSpatialTransformFilterEngine::apply").resume(); 
215   
216    CContextClient* client = CContext::getCurrent()->client;
217    int rank;
218    MPI_Comm_rank (client->intraComm, &rank);
219
220    // Get default value for output data
221    bool ignoreMissingValue = false; 
222    double defaultValue = std::numeric_limits<double>::quiet_NaN();
223    if (0 != dataDest.numElements()) ignoreMissingValue = NumTraits<double>::isNan(dataDest(0));
224
225    const std::list<CGridTransformation::SendingIndexGridSourceMap>& listLocalIndexSend = gridTransformation->getLocalIndexToSendFromGridSource();
226    const std::list<CGridTransformation::RecvIndexGridDestinationMap>& listLocalIndexToReceive = gridTransformation->getLocalIndexToReceiveOnGridDest();
227    const std::list<size_t>& listNbLocalIndexToReceive = gridTransformation->getNbLocalIndexToReceiveOnGridDest();
228    const std::vector<CGenericAlgorithmTransformation*>& listAlgos = gridTransformation->getAlgos();
229
230    CArray<double,1> dataCurrentDest(dataSrc.copy());
231
232    std::list<CGridTransformation::SendingIndexGridSourceMap>::const_iterator itListSend  = listLocalIndexSend.begin(),
233                                                                              iteListSend = listLocalIndexSend.end();
234    std::list<CGridTransformation::RecvIndexGridDestinationMap>::const_iterator itListRecv = listLocalIndexToReceive.begin();
235    std::list<size_t>::const_iterator itNbListRecv = listNbLocalIndexToReceive.begin();
236    std::vector<CGenericAlgorithmTransformation*>::const_iterator itAlgo = listAlgos.begin();
237
238    for (; itListSend != iteListSend; ++itListSend, ++itListRecv, ++itNbListRecv, ++itAlgo)
239    {
240      CArray<double,1> dataCurrentSrc(dataCurrentDest);
241      const CGridTransformation::SendingIndexGridSourceMap& localIndexToSend = *itListSend;
242
243      // Sending data from field sources to do transformations
244      std::map<int, CArray<int,1> >::const_iterator itbSend = localIndexToSend.begin(), itSend,
245                                                    iteSend = localIndexToSend.end();
246      int idxSendBuff = 0;
247      std::vector<double*> sendBuff(localIndexToSend.size());
248      double* sendBuffRank;
249      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
250      {
251        int destRank = itSend->first;
252        if (0 != itSend->second.numElements())
253        {
254          if (rank != itSend->first)
255            sendBuff[idxSendBuff] = new double[itSend->second.numElements()];
256          else
257            sendBuffRank = new double[itSend->second.numElements()];
258        }
259      }
260
261      idxSendBuff = 0;
262      std::vector<MPI_Request> sendRecvRequest(localIndexToSend.size() + itListRecv->size());
263      int position = 0;
264      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
265      {
266        int destRank = itSend->first;
267        const CArray<int,1>& localIndex_p = itSend->second;
268        int countSize = localIndex_p.numElements();
269        if (destRank != rank)
270        {
271          for (int idx = 0; idx < countSize; ++idx)
272          {
273            sendBuff[idxSendBuff][idx] = dataCurrentSrc(localIndex_p(idx));
274          } 
275          MPI_Isend(sendBuff[idxSendBuff], countSize, MPI_DOUBLE, destRank, 12, client->intraComm, &sendRecvRequest[position++]);
276         
277        }
278        else
279        {
280          for (int idx = 0; idx < countSize; ++idx)
281          {
282            sendBuffRank[idx] = dataCurrentSrc(localIndex_p(idx));
283          }
284        }
285      }
286
287      // Receiving data on destination fields
288      const CGridTransformation::RecvIndexGridDestinationMap& localIndexToReceive = *itListRecv;
289      CGridTransformation::RecvIndexGridDestinationMap::const_iterator itbRecv = localIndexToReceive.begin(), itRecv,
290                                                                       iteRecv = localIndexToReceive.end();
291      int recvBuffSize = 0;
292      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
293      {
294        if (itRecv->first != rank )
295          recvBuffSize += itRecv->second.size();
296      }
297      //(recvBuffSize < itRecv->second.size()) ? itRecv->second.size() : recvBuffSize;
298      double* recvBuff;
299
300      if (0 != recvBuffSize) recvBuff = new double[recvBuffSize];
301      int currentBuff = 0;
302      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
303      {
304        int srcRank = itRecv->first;
305        if (srcRank != rank)
306        {
307          int countSize = itRecv->second.size();
308          MPI_Irecv(recvBuff + currentBuff, countSize, MPI_DOUBLE, srcRank, 12, client->intraComm, &sendRecvRequest[position++]);
309          currentBuff += countSize;
310        }
311      }
312      std::vector<MPI_Status> status(sendRecvRequest.size());
313      MPI_Waitall(position, &sendRecvRequest[0], &status[0]);
314
315
316
317      dataCurrentDest.resize(*itNbListRecv);
318      dataCurrentDest = 0.0;
319
320      std::vector<bool> localInitFlag(dataCurrentDest.numElements(), true);
321      currentBuff = 0;
322      bool firstPass=true; 
323      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
324      {
325        const std::vector<std::pair<int,double> >& localIndex_p = itRecv->second;
326        int srcRank = itRecv->first;
327        if (srcRank != rank)
328        {
329          int countSize = itRecv->second.size();
330          (*itAlgo)->apply(localIndex_p,
331                           recvBuff+currentBuff,
332                           dataCurrentDest,
333                           localInitFlag,
334                           ignoreMissingValue,firstPass);
335          currentBuff += countSize;
336        }
337        else
338        {
339          (*itAlgo)->apply(localIndex_p,
340                           sendBuffRank,
341                           dataCurrentDest,
342                           localInitFlag,
343                           ignoreMissingValue,firstPass);
344        }
345
346        firstPass=false ;
347      }
348
349      (*itAlgo)->updateData(dataCurrentDest);
350
351      idxSendBuff = 0;
352      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
353      {
354        if (0 != itSend->second.numElements())
355        {
356          if (rank != itSend->first)
357            delete [] sendBuff[idxSendBuff];
358          else
359            delete [] sendBuffRank;
360        }
361      }
362      if (0 != recvBuffSize) delete [] recvBuff;
363    }
364    if (dataCurrentDest.numElements() != dataDest.numElements())
365    ERROR("CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)",
366          "Incoherent between the received size and expected size. " << std::endl
367          << "Expected size: " << dataDest.numElements() << std::endl
368          << "Received size: " << dataCurrentDest.numElements());
369
370    dataDest = dataCurrentDest;
371
372    CTimer::get("CSpatialTransformFilterEngine::apply").suspend() ;
373  }
374} // namespace xios
Note: See TracBrowser for help on using the repository browser.