source: XIOS/dev/dev_trunk_omp/src/filter/spatial_transform_filter.cpp @ 1730

Last change on this file since 1730 was 1689, checked in by yushan, 5 years ago

dev for graph. up to date with trunk at r1684

File size: 16.5 KB
RevLine 
[1601]1#include "mpi.hpp"
[644]2#include "spatial_transform_filter.hpp"
3#include "grid_transformation.hpp"
4#include "context.hpp"
5#include "context_client.hpp"
[1412]6#include "timer.hpp"
[1646]7#ifdef _usingEP
[1601]8using namespace ep_lib;
[1646]9#endif
[1668]10#include "workflow_graph.hpp"
[1680]11#include "file.hpp"
[644]12namespace xios
13{
[1689]14  CSpatialTransformFilter::CSpatialTransformFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, double outputValue, size_t inputSlotsCount)
[1679]15    : CFilter(gc, inputSlotsCount, engine), outputDefaultValue(outputValue)
[644]16  { /* Nothing to do */ }
17
[1542]18  std::pair<std::shared_ptr<CSpatialTransformFilter>, std::shared_ptr<CSpatialTransformFilter> >
[1679]19  CSpatialTransformFilter::buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid, bool hasMissingValue, double missingValue)
[644]20  {
21    if (!srcGrid || !destGrid)
[1542]22      ERROR("std::pair<std::shared_ptr<CSpatialTransformFilter>, std::shared_ptr<CSpatialTransformFilter> >"
[644]23            "buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)",
24            "Impossible to build the filter graph if either the source or the destination grid are null.");
25
[1542]26    std::shared_ptr<CSpatialTransformFilter> firstFilter, lastFilter;
[790]27    // Note that this loop goes from the last transformation to the first transformation
28    do
29    {
30      CGridTransformation* gridTransformation = destGrid->getTransformations();
31      CSpatialTransformFilterEngine* engine = CSpatialTransformFilterEngine::get(destGrid->getTransformations());
[827]32      const std::vector<StdString>& auxInputs = gridTransformation->getAuxInputs();
33      size_t inputCount = 1 + (auxInputs.empty() ? 0 : auxInputs.size());
[1158]34      double defaultValue  = (hasMissingValue) ? std::numeric_limits<double>::quiet_NaN() : 0.0;
[644]35
[1275]36
37      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
38      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
39
40      bool isSpatialTemporal=false ;
41      for (it=algoList.begin();it!=algoList.end();++it)  if (it->second.first == TRANS_TEMPORAL_SPLITTING) isSpatialTemporal=true ;
42
[1542]43      std::shared_ptr<CSpatialTransformFilter> filter ;
[1689]44      if( isSpatialTemporal) filter = std::shared_ptr<CSpatialTransformFilter>(new CSpatialTemporalFilter(gc, engine, gridTransformation, defaultValue, inputCount));
45      else filter = std::shared_ptr<CSpatialTransformFilter>(new CSpatialTransformFilter(gc, engine, defaultValue, inputCount));
[1275]46
47     
[790]48      if (!lastFilter)
49        lastFilter = filter;
50      else
51        filter->connectOutput(firstFilter, 0);
52
53      firstFilter = filter;
[827]54      for (size_t idx = 0; idx < auxInputs.size(); ++idx)
55      {
56        CField* fieldAuxInput = CField::get(auxInputs[idx]);
57        fieldAuxInput->buildFilterGraph(gc, false);
58        fieldAuxInput->getInstantDataFilter()->connectOutput(firstFilter,idx+1);
59      }
60
[790]61      destGrid = gridTransformation->getGridSource();
62    }
63    while (destGrid != srcGrid);
64
65    return std::make_pair(firstFilter, lastFilter);
[644]66  }
67
[873]68  void CSpatialTransformFilter::onInputReady(std::vector<CDataPacketPtr> data)
69  {
70    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
[1685]71    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue, this->tag, this->start_graph, this->end_graph, this->field);
[873]72    if (outputPacket)
[1021]73      onOutputReady(outputPacket);
[873]74  }
75
[1689]76  CSpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount)
[1679]77    : CSpatialTransformFilter(gc, engine, outputValue, inputSlotsCount), record(0)
[1275]78  {
79      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
80      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
81
82      int pos ;
83      for (it=algoList.begin();it!=algoList.end();++it) 
84        if (it->second.first == TRANS_TEMPORAL_SPLITTING)
85        {
86          pos=it->first ;
87          if (pos < algoList.size()-1)
88            ERROR("SpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount))",
89                  "temporal splitting operation must be the last of whole transformation on same grid") ;
90        }
91         
92      CGrid* grid=gridTransformation->getGridDestination() ;
93
94      CAxis* axis = grid->getAxis(gridTransformation->getElementPositionInGridDst2AxisPosition().find(pos)->second) ;
95
96      nrecords = axis->index.numElements() ;
97  }
98
99
100  void CSpatialTemporalFilter::onInputReady(std::vector<CDataPacketPtr> data)
101  {
102    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
[1685]103    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue, this->tag, this->start_graph, this->end_graph, this->field);
[1275]104
105    if (outputPacket)
106    {
107      size_t nelements=outputPacket->data.numElements() ;
108      if (!tmpData.numElements())
109      {
110        tmpData.resize(nelements);
111        tmpData=outputDefaultValue ;
112      }
113
114      nelements/=nrecords ;
115      size_t offset=nelements*record ;
116      for(size_t i=0;i<nelements;++i)  tmpData(i+offset) = outputPacket->data(i) ;
117   
118      record ++ ;
119      if (record==nrecords)
120      {
121        record=0 ;
122        CDataPacketPtr packet = CDataPacketPtr(new CDataPacket);
123        packet->date = data[0]->date;
124        packet->timestamp = data[0]->timestamp;
125        packet->status = data[0]->status;
126        packet->data.resize(tmpData.numElements());
127        packet->data = tmpData;
[1679]128        packet->field = this->field;
[1275]129        onOutputReady(packet);
130        tmpData.resize(0) ;
131      }
132    }
133  }
134
135
[644]136  CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)
137    : gridTransformation(gridTransformation)
138  {
139    if (!gridTransformation)
140      ERROR("CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)",
141            "Impossible to construct a spatial transform filter engine without a valid grid transformation.");
142  }
143
[1601]144  std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> > *CSpatialTransformFilterEngine::engines_ptr = 0;
[644]145
146  CSpatialTransformFilterEngine* CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)
147  {
148    if (!gridTransformation)
149      ERROR("CSpatialTransformFilterEngine& CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)",
150            "Impossible to get the requested engine, the grid transformation is invalid.");
[1601]151    if(engines_ptr == NULL) engines_ptr = new std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> >;
[644]152
[1601]153    std::map<CGridTransformation*, std::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines_ptr->find(gridTransformation);
154    if (it == engines_ptr->end())
[644]155    {
[1542]156      std::shared_ptr<CSpatialTransformFilterEngine> engine(new CSpatialTransformFilterEngine(gridTransformation));
[1601]157      it = engines_ptr->insert(std::make_pair(gridTransformation, engine)).first;
[644]158    }
159
160    return it->second.get();
161  }
162
163  CDataPacketPtr CSpatialTransformFilterEngine::apply(std::vector<CDataPacketPtr> data)
164  {
[873]165    /* Nothing to do */
166  }
167
[1685]168  bool CSpatialTransformFilterEngine::buildGraph(std::vector<CDataPacketPtr> data, int tag, Time start_graph, Time end_graph, CField *field)
[873]169  {
[1679]170    bool building_graph = tag ? data[0]->timestamp >= start_graph && data[0]->timestamp <= end_graph : false;
171    if(building_graph)
[1677]172    {
173      this->filterID = InvalidableObject::filterIdGenerator++;
[1680]174      int edgeID = InvalidableObject::edgeIdGenerator++;   
[1677]175
[1680]176      CWorkflowGraph::allocNodeEdge();
[1677]177
[1680]178      CWorkflowGraph::addNode(this->filterID, "Spatial Transform Filter", 4, 1, 1, data[0]);
[1681]179      (*CWorkflowGraph::mapFilters_ptr_with_info)[this->filterID].distance = data[0]->distance+1;
180      (*CWorkflowGraph::mapFilters_ptr_with_info)[this->filterID].attributes = field->record4graphXiosAttributes();
181      if(field->file) (*CWorkflowGraph::mapFilters_ptr_with_info)[this->filterID].attributes += "</br>file attributes : </br>" +field->file->record4graphXiosAttributes();
[1677]182
183
[1679]184      if(CWorkflowGraph::build_begin)
185      {
186        CWorkflowGraph::addEdge(edgeID, this->filterID, data[0]);
187
188        (*CWorkflowGraph::mapFilters_ptr_with_info)[data[0]->src_filterID].filter_filled = 0;
189      }
190      else CWorkflowGraph::build_begin = true;
[1677]191    }
192
[1681]193    return building_graph;
194  }
[1677]195
[1685]196  CDataPacketPtr CSpatialTransformFilterEngine::applyFilter(std::vector<CDataPacketPtr> data, double defaultValue, int tag, Time start_graph, Time end_graph, CField *field)
[1681]197  {
198   
[1685]199    bool BG = buildGraph(data, tag, start_graph, end_graph, field);
[1681]200
[644]201    CDataPacketPtr packet(new CDataPacket);
202    packet->date = data[0]->date;
203    packet->timestamp = data[0]->timestamp;
204    packet->status = data[0]->status;
205
206    if (packet->status == CDataPacket::NO_ERROR)
207    {
[827]208      if (1 < data.size())  // Dynamical transformations
209      {
210        std::vector<CArray<double,1>* > dataAuxInputs(data.size()-1);
211        for (size_t idx = 0; idx < dataAuxInputs.size(); ++idx) dataAuxInputs[idx] = &(data[idx+1]->data);
[832]212        gridTransformation->computeAll(dataAuxInputs, packet->timestamp);
[827]213      }
[644]214      packet->data.resize(gridTransformation->getGridDestination()->storeIndex_client.numElements());
[1158]215      if (0 != packet->data.numElements())
216        (packet->data)(0) = defaultValue;
[1681]217      if(BG) apply(data[0]->data, packet->data, this->filterID);
[1680]218      else apply(data[0]->data, packet->data);
[644]219    }
220
[1681]221    if(BG) packet->src_filterID=this->filterID;
222    if(BG) packet->distance=data[0]->distance+1;
[1679]223    packet->field = field;
[1677]224
[644]225    return packet;
226  }
227
[1679]228  void CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest, int filterID)
[644]229  {
[1414]230    CTimer::get("CSpatialTransformFilterEngine::apply").resume(); 
231   
[644]232    CContextClient* client = CContext::getCurrent()->client;
[1646]233    int rank;
234    MPI_Comm_rank (client->intraComm, &rank);
[644]235
[873]236    // Get default value for output data
[1158]237    bool ignoreMissingValue = false; 
238    double defaultValue = std::numeric_limits<double>::quiet_NaN();
[1474]239    if (0 != dataDest.numElements()) ignoreMissingValue = NumTraits<double>::isNan(dataDest(0));
[873]240
[841]241    const std::list<CGridTransformation::SendingIndexGridSourceMap>& listLocalIndexSend = gridTransformation->getLocalIndexToSendFromGridSource();
242    const std::list<CGridTransformation::RecvIndexGridDestinationMap>& listLocalIndexToReceive = gridTransformation->getLocalIndexToReceiveOnGridDest();
243    const std::list<size_t>& listNbLocalIndexToReceive = gridTransformation->getNbLocalIndexToReceiveOnGridDest();
[888]244    const std::vector<CGenericAlgorithmTransformation*>& listAlgos = gridTransformation->getAlgos();
[644]245
[841]246    CArray<double,1> dataCurrentDest(dataSrc.copy());
[644]247
[841]248    std::list<CGridTransformation::SendingIndexGridSourceMap>::const_iterator itListSend  = listLocalIndexSend.begin(),
249                                                                              iteListSend = listLocalIndexSend.end();
250    std::list<CGridTransformation::RecvIndexGridDestinationMap>::const_iterator itListRecv = listLocalIndexToReceive.begin();
251    std::list<size_t>::const_iterator itNbListRecv = listNbLocalIndexToReceive.begin();
[888]252    std::vector<CGenericAlgorithmTransformation*>::const_iterator itAlgo = listAlgos.begin();
[841]253
[1646]254    for (; itListSend != iteListSend; ++itListSend, ++itListRecv, ++itNbListRecv, ++itAlgo)
[709]255    {
[841]256      CArray<double,1> dataCurrentSrc(dataCurrentDest);
257      const CGridTransformation::SendingIndexGridSourceMap& localIndexToSend = *itListSend;
[709]258
[841]259      // Sending data from field sources to do transformations
260      std::map<int, CArray<int,1> >::const_iterator itbSend = localIndexToSend.begin(), itSend,
261                                                    iteSend = localIndexToSend.end();
262      int idxSendBuff = 0;
263      std::vector<double*> sendBuff(localIndexToSend.size());
[1646]264      double* sendBuffRank;
[841]265      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
[644]266      {
[1646]267        int destRank = itSend->first;
[841]268        if (0 != itSend->second.numElements())
[1646]269        {
270          if (rank != itSend->first)
271            sendBuff[idxSendBuff] = new double[itSend->second.numElements()];
272          else
273            sendBuffRank = new double[itSend->second.numElements()];
274        }
[644]275      }
276
[841]277      idxSendBuff = 0;
[1601]278      std::vector<MPI_Request> sendRecvRequest(localIndexToSend.size() + itListRecv->size());
279      int position = 0;
[841]280      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
[644]281      {
[841]282        int destRank = itSend->first;
283        const CArray<int,1>& localIndex_p = itSend->second;
284        int countSize = localIndex_p.numElements();
[1646]285        if (destRank != rank)
[644]286        {
[1646]287          for (int idx = 0; idx < countSize; ++idx)
288          {
289            sendBuff[idxSendBuff][idx] = dataCurrentSrc(localIndex_p(idx));
290          } 
291          MPI_Isend(sendBuff[idxSendBuff], countSize, MPI_DOUBLE, destRank, 12, client->intraComm, &sendRecvRequest[position++]);
292         
[644]293        }
[1646]294        else
295        {
296          for (int idx = 0; idx < countSize; ++idx)
297          {
298            sendBuffRank[idx] = dataCurrentSrc(localIndex_p(idx));
299          }
300        }
[644]301      }
302
[841]303      // Receiving data on destination fields
304      const CGridTransformation::RecvIndexGridDestinationMap& localIndexToReceive = *itListRecv;
305      CGridTransformation::RecvIndexGridDestinationMap::const_iterator itbRecv = localIndexToReceive.begin(), itRecv,
306                                                                       iteRecv = localIndexToReceive.end();
307      int recvBuffSize = 0;
[1646]308      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
309      {
310        if (itRecv->first != rank )
311          recvBuffSize += itRecv->second.size();
312      }
313      //(recvBuffSize < itRecv->second.size()) ? itRecv->second.size() : recvBuffSize;
[841]314      double* recvBuff;
[1646]315
[841]316      if (0 != recvBuffSize) recvBuff = new double[recvBuffSize];
317      int currentBuff = 0;
318      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
319      {
320        int srcRank = itRecv->first;
[1646]321        if (srcRank != rank)
322        {
323          int countSize = itRecv->second.size();
324          MPI_Irecv(recvBuff + currentBuff, countSize, MPI_DOUBLE, srcRank, 12, client->intraComm, &sendRecvRequest[position++]);
325          currentBuff += countSize;
326        }
[841]327      }
328      std::vector<MPI_Status> status(sendRecvRequest.size());
[1646]329      MPI_Waitall(position, &sendRecvRequest[0], &status[0]);
[709]330
[1646]331
332
[841]333      dataCurrentDest.resize(*itNbListRecv);
[1646]334      dataCurrentDest = 0.0;
[873]335
[1158]336      std::vector<bool> localInitFlag(dataCurrentDest.numElements(), true);
[841]337      currentBuff = 0;
[1260]338      bool firstPass=true; 
[841]339      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
340      {
[842]341        const std::vector<std::pair<int,double> >& localIndex_p = itRecv->second;
[1646]342        int srcRank = itRecv->first;
[1679]343
[1681]344        if(filterID >=0) // building_graph
[1680]345        {
346           (*CWorkflowGraph::mapFilters_ptr_with_info)[filterID].filter_name = (*itAlgo)->getName();
347        } 
[1646]348        if (srcRank != rank)
349        {
350          int countSize = itRecv->second.size();
351          (*itAlgo)->apply(localIndex_p,
352                           recvBuff+currentBuff,
353                           dataCurrentDest,
354                           localInitFlag,
355                           ignoreMissingValue,firstPass);
356          currentBuff += countSize;
357        }
358        else
359        {
360          (*itAlgo)->apply(localIndex_p,
361                           sendBuffRank,
362                           dataCurrentDest,
363                           localInitFlag,
364                           ignoreMissingValue,firstPass);
365        }
[888]366
[1260]367        firstPass=false ;
[841]368      }
369
[979]370      (*itAlgo)->updateData(dataCurrentDest);
371
[841]372      idxSendBuff = 0;
373      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
374      {
375        if (0 != itSend->second.numElements())
[1646]376        {
377          if (rank != itSend->first)
378            delete [] sendBuff[idxSendBuff];
379          else
380            delete [] sendBuffRank;
381        }
[841]382      }
383      if (0 != recvBuffSize) delete [] recvBuff;
[709]384    }
[841]385    if (dataCurrentDest.numElements() != dataDest.numElements())
386    ERROR("CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)",
[1021]387          "Incoherent between the received size and expected size. " << std::endl
388          << "Expected size: " << dataDest.numElements() << std::endl
389          << "Received size: " << dataCurrentDest.numElements());
[841]390
391    dataDest = dataCurrentDest;
[1412]392
[1414]393    CTimer::get("CSpatialTransformFilterEngine::apply").suspend() ;
[644]394  }
395} // namespace xios
Note: See TracBrowser for help on using the repository browser.