source: XIOS/dev/branch_openmp/src/filter/spatial_transform_filter.cpp @ 1533

Last change on this file since 1533 was 1482, checked in by yushan, 6 years ago

Branch EP merged with Dev_cmip6 @r1481

File size: 14.4 KB
RevLine 
[1328]1#include "mpi.hpp"
[644]2#include "spatial_transform_filter.hpp"
3#include "grid_transformation.hpp"
4#include "context.hpp"
5#include "context_client.hpp"
[1460]6#include "timer.hpp"
[1328]7using namespace ep_lib;
[644]8
9namespace xios
10{
[873]11  CSpatialTransformFilter::CSpatialTransformFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, double outputValue, size_t inputSlotsCount)
12    : CFilter(gc, inputSlotsCount, engine), outputDefaultValue(outputValue)
[644]13  { /* Nothing to do */ }
14
15  std::pair<boost::shared_ptr<CSpatialTransformFilter>, boost::shared_ptr<CSpatialTransformFilter> >
[1018]16  CSpatialTransformFilter::buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid, bool hasMissingValue, double missingValue)
[644]17  {
18    if (!srcGrid || !destGrid)
19      ERROR("std::pair<boost::shared_ptr<CSpatialTransformFilter>, boost::shared_ptr<CSpatialTransformFilter> >"
20            "buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)",
21            "Impossible to build the filter graph if either the source or the destination grid are null.");
22
[790]23    boost::shared_ptr<CSpatialTransformFilter> firstFilter, lastFilter;
24    // Note that this loop goes from the last transformation to the first transformation
25    do
26    {
27      CGridTransformation* gridTransformation = destGrid->getTransformations();
28      CSpatialTransformFilterEngine* engine = CSpatialTransformFilterEngine::get(destGrid->getTransformations());
[827]29      const std::vector<StdString>& auxInputs = gridTransformation->getAuxInputs();
30      size_t inputCount = 1 + (auxInputs.empty() ? 0 : auxInputs.size());
[1076]31      double defaultValue  = (hasMissingValue) ? std::numeric_limits<double>::quiet_NaN() : 0.0;
[644]32
[1460]33
34      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
35      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
36
37      bool isSpatialTemporal=false ;
38      for (it=algoList.begin();it!=algoList.end();++it)  if (it->second.first == TRANS_TEMPORAL_SPLITTING) isSpatialTemporal=true ;
39
40      boost::shared_ptr<CSpatialTransformFilter> filter ;
41      if( isSpatialTemporal) filter = boost::shared_ptr<CSpatialTransformFilter>(new CSpatialTemporalFilter(gc, engine, gridTransformation, defaultValue, inputCount));
42      else filter = boost::shared_ptr<CSpatialTransformFilter>(new CSpatialTransformFilter(gc, engine, defaultValue, inputCount));
43
44     
[790]45      if (!lastFilter)
46        lastFilter = filter;
47      else
48        filter->connectOutput(firstFilter, 0);
49
50      firstFilter = filter;
[827]51      for (size_t idx = 0; idx < auxInputs.size(); ++idx)
52      {
53        CField* fieldAuxInput = CField::get(auxInputs[idx]);
54        fieldAuxInput->buildFilterGraph(gc, false);
55        fieldAuxInput->getInstantDataFilter()->connectOutput(firstFilter,idx+1);
56      }
57
[790]58      destGrid = gridTransformation->getGridSource();
59    }
60    while (destGrid != srcGrid);
61
62    return std::make_pair(firstFilter, lastFilter);
[644]63  }
64
[873]65  void CSpatialTransformFilter::onInputReady(std::vector<CDataPacketPtr> data)
66  {
67    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
68    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
69    if (outputPacket)
[1006]70      onOutputReady(outputPacket);
[873]71  }
72
[1460]73
74
75
76
77  CSpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount)
78    : CSpatialTransformFilter(gc, engine, outputValue, inputSlotsCount), record(0)
79  {
80      const CGridTransformationSelector::ListAlgoType& algoList = gridTransformation->getAlgoList() ;
81      CGridTransformationSelector::ListAlgoType::const_iterator it  ;
82
83      int pos ;
84      for (it=algoList.begin();it!=algoList.end();++it) 
85        if (it->second.first == TRANS_TEMPORAL_SPLITTING)
86        {
87          pos=it->first ;
88          if (pos < algoList.size()-1)
89            ERROR("SpatialTemporalFilter::CSpatialTemporalFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, CGridTransformation* gridTransformation, double outputValue, size_t inputSlotsCount))",
90                  "temporal splitting operation must be the last of whole transformation on same grid") ;
91        }
92         
93      CGrid* grid=gridTransformation->getGridDestination() ;
94
95      CAxis* axis = grid->getAxis(gridTransformation->getElementPositionInGridDst2AxisPosition().find(pos)->second) ;
96
97      nrecords = axis->index.numElements() ;
98  }
99
100
101  void CSpatialTemporalFilter::onInputReady(std::vector<CDataPacketPtr> data)
102  {
103    CSpatialTransformFilterEngine* spaceFilter = static_cast<CSpatialTransformFilterEngine*>(engine);
104    CDataPacketPtr outputPacket = spaceFilter->applyFilter(data, outputDefaultValue);
105
106    if (outputPacket)
107    {
108      size_t nelements=outputPacket->data.numElements() ;
109      if (!tmpData.numElements())
110      {
111        tmpData.resize(nelements);
112        tmpData=outputDefaultValue ;
113      }
114
115      nelements/=nrecords ;
116      size_t offset=nelements*record ;
117      for(size_t i=0;i<nelements;++i)  tmpData(i+offset) = outputPacket->data(i) ;
118   
119      record ++ ;
120      if (record==nrecords)
121      {
122        record=0 ;
123        CDataPacketPtr packet = CDataPacketPtr(new CDataPacket);
124        packet->date = data[0]->date;
125        packet->timestamp = data[0]->timestamp;
126        packet->status = data[0]->status;
127        packet->data.resize(tmpData.numElements());
128        packet->data = tmpData;
129        onOutputReady(packet);
130        tmpData.resize(0) ;
131      }
132    }
133  }
134
135
[644]136  CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)
137    : gridTransformation(gridTransformation)
138  {
139    if (!gridTransformation)
140      ERROR("CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)",
141            "Impossible to construct a spatial transform filter engine without a valid grid transformation.");
142  }
143
[1328]144  //std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> > CSpatialTransformFilterEngine::engines;
[1134]145  std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> > *CSpatialTransformFilterEngine::engines_ptr = 0;
[644]146
147  CSpatialTransformFilterEngine* CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)
148  {
149    if (!gridTransformation)
150      ERROR("CSpatialTransformFilterEngine& CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)",
151            "Impossible to get the requested engine, the grid transformation is invalid.");
[1328]152   
[1134]153    if(engines_ptr == NULL) engines_ptr = new std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> >;
154
[1328]155    //std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines.find(gridTransformation);
[1134]156    std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines_ptr->find(gridTransformation);
[1328]157    //if (it == engines.end())
[1134]158    if (it == engines_ptr->end())
[644]159    {
160      boost::shared_ptr<CSpatialTransformFilterEngine> engine(new CSpatialTransformFilterEngine(gridTransformation));
[1328]161      //it = engines.insert(std::make_pair(gridTransformation, engine)).first;
[1134]162      it = engines_ptr->insert(std::make_pair(gridTransformation, engine)).first;
[644]163    }
164
165    return it->second.get();
166  }
167
168  CDataPacketPtr CSpatialTransformFilterEngine::apply(std::vector<CDataPacketPtr> data)
169  {
[873]170    /* Nothing to do */
171  }
172
173  CDataPacketPtr CSpatialTransformFilterEngine::applyFilter(std::vector<CDataPacketPtr> data, double defaultValue)
174  {
[644]175    CDataPacketPtr packet(new CDataPacket);
176    packet->date = data[0]->date;
177    packet->timestamp = data[0]->timestamp;
178    packet->status = data[0]->status;
179
180    if (packet->status == CDataPacket::NO_ERROR)
181    {
[827]182      if (1 < data.size())  // Dynamical transformations
183      {
184        std::vector<CArray<double,1>* > dataAuxInputs(data.size()-1);
185        for (size_t idx = 0; idx < dataAuxInputs.size(); ++idx) dataAuxInputs[idx] = &(data[idx+1]->data);
[832]186        gridTransformation->computeAll(dataAuxInputs, packet->timestamp);
[827]187      }
[644]188      packet->data.resize(gridTransformation->getGridDestination()->storeIndex_client.numElements());
[1018]189      if (0 != packet->data.numElements())
190        (packet->data)(0) = defaultValue;
[644]191      apply(data[0]->data, packet->data);
192    }
193
194    return packet;
195  }
196
197  void CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)
198  {
[1460]199    CTimer::get("CSpatialTransformFilterEngine::apply").resume(); 
200   
[644]201    CContextClient* client = CContext::getCurrent()->client;
202
[873]203    // Get default value for output data
[1076]204    bool ignoreMissingValue = false; 
205    double defaultValue = std::numeric_limits<double>::quiet_NaN();
[1482]206    if (0 != dataDest.numElements()) ignoreMissingValue = NumTraits<double>::isNan(dataDest(0));
[1328]207
[841]208    const std::list<CGridTransformation::SendingIndexGridSourceMap>& listLocalIndexSend = gridTransformation->getLocalIndexToSendFromGridSource();
209    const std::list<CGridTransformation::RecvIndexGridDestinationMap>& listLocalIndexToReceive = gridTransformation->getLocalIndexToReceiveOnGridDest();
210    const std::list<size_t>& listNbLocalIndexToReceive = gridTransformation->getNbLocalIndexToReceiveOnGridDest();
[873]211    const std::list<std::vector<bool> >& listLocalIndexMaskOnDest = gridTransformation->getLocalMaskIndexOnGridDest();
[888]212    const std::vector<CGenericAlgorithmTransformation*>& listAlgos = gridTransformation->getAlgos();
[644]213
[841]214    CArray<double,1> dataCurrentDest(dataSrc.copy());
[644]215
[1328]216    std::list<CGridTransformation::SendingIndexGridSourceMap>::const_iterator itListSend  = listLocalIndexSend.begin(),
217                                                                              iteListSend = listLocalIndexSend.end();
[841]218    std::list<CGridTransformation::RecvIndexGridDestinationMap>::const_iterator itListRecv = listLocalIndexToReceive.begin();
219    std::list<size_t>::const_iterator itNbListRecv = listNbLocalIndexToReceive.begin();
[873]220    std::list<std::vector<bool> >::const_iterator itLocalMaskIndexOnDest = listLocalIndexMaskOnDest.begin();
[888]221    std::vector<CGenericAlgorithmTransformation*>::const_iterator itAlgo = listAlgos.begin();
[841]222
[888]223    for (; itListSend != iteListSend; ++itListSend, ++itListRecv, ++itNbListRecv, ++itLocalMaskIndexOnDest, ++itAlgo)
[709]224    {
[841]225      CArray<double,1> dataCurrentSrc(dataCurrentDest);
226      const CGridTransformation::SendingIndexGridSourceMap& localIndexToSend = *itListSend;
[709]227
[841]228      // Sending data from field sources to do transformations
229      std::map<int, CArray<int,1> >::const_iterator itbSend = localIndexToSend.begin(), itSend,
230                                                    iteSend = localIndexToSend.end();
231      int idxSendBuff = 0;
232      std::vector<double*> sendBuff(localIndexToSend.size());
233      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
[644]234      {
[841]235        if (0 != itSend->second.numElements())
236          sendBuff[idxSendBuff] = new double[itSend->second.numElements()];
[644]237      }
238
[841]239      idxSendBuff = 0;
[1328]240      std::vector<MPI_Request> sendRecvRequest(localIndexToSend.size() + itListRecv->size());
[1203]241      int position = 0;
[841]242      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
[644]243      {
[841]244        int destRank = itSend->first;
245        const CArray<int,1>& localIndex_p = itSend->second;
246        int countSize = localIndex_p.numElements();
247        for (int idx = 0; idx < countSize; ++idx)
[644]248        {
[841]249          sendBuff[idxSendBuff][idx] = dataCurrentSrc(localIndex_p(idx));
[644]250        }
[1328]251        MPI_Isend(sendBuff[idxSendBuff], countSize, MPI_DOUBLE, destRank, 12, client->intraComm, &sendRecvRequest[position++]);
[644]252      }
253
[841]254      // Receiving data on destination fields
[1328]255      const CGridTransformation::RecvIndexGridDestinationMap& localIndexToReceive = *itListRecv;
256      CGridTransformation::RecvIndexGridDestinationMap::const_iterator itbRecv = localIndexToReceive.begin(), itRecv,
257                                                                       iteRecv = localIndexToReceive.end();
[841]258      int recvBuffSize = 0;
259      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv) recvBuffSize += itRecv->second.size(); //(recvBuffSize < itRecv->second.size())
260                                                                       //? itRecv->second.size() : recvBuffSize;
261      double* recvBuff;
262      if (0 != recvBuffSize) recvBuff = new double[recvBuffSize];
263      int currentBuff = 0;
264      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
265      {
266        int srcRank = itRecv->first;
267        int countSize = itRecv->second.size();
[1328]268        MPI_Irecv(recvBuff + currentBuff, countSize, MPI_DOUBLE, srcRank, 12, client->intraComm, &sendRecvRequest[position++]);
[841]269        currentBuff += countSize;
270      }
[1328]271      std::vector<MPI_Status> status(sendRecvRequest.size());
[841]272      MPI_Waitall(sendRecvRequest.size(), &sendRecvRequest[0], &status[0]);
[709]273
[841]274      dataCurrentDest.resize(*itNbListRecv);
[873]275      const std::vector<bool>& localMaskDest = *itLocalMaskIndexOnDest;
276      for (int i = 0; i < localMaskDest.size(); ++i)
277        if (localMaskDest[i]) dataCurrentDest(i) = 0.0;
278        else dataCurrentDest(i) = defaultValue;
279
[1042]280      std::vector<bool> localInitFlag(dataCurrentDest.numElements(), true);
[841]281      currentBuff = 0;
[1328]282      bool firstPass=true; 
[841]283      for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
284      {
285        int countSize = itRecv->second.size();
[842]286        const std::vector<std::pair<int,double> >& localIndex_p = itRecv->second;
[888]287        (*itAlgo)->apply(localIndex_p,
288                         recvBuff+currentBuff,
289                         dataCurrentDest,
[918]290                         localInitFlag,
[1328]291                         ignoreMissingValue,firstPass);
[888]292
[841]293        currentBuff += countSize;
[1328]294        firstPass=false ;
[841]295      }
296
[979]297      (*itAlgo)->updateData(dataCurrentDest);
298
[841]299      idxSendBuff = 0;
300      for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
301      {
302        if (0 != itSend->second.numElements())
303          delete [] sendBuff[idxSendBuff];
304      }
305      if (0 != recvBuffSize) delete [] recvBuff;
[709]306    }
[841]307    if (dataCurrentDest.numElements() != dataDest.numElements())
308    ERROR("CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)",
[1003]309          "Incoherent between the received size and expected size. " << std::endl
310          << "Expected size: " << dataDest.numElements() << std::endl
311          << "Received size: " << dataCurrentDest.numElements());
[841]312
313    dataDest = dataCurrentDest;
[1460]314
315    CTimer::get("CSpatialTransformFilterEngine::apply").suspend() ;
[644]316  }
[1328]317} // namespace xios
Note: See TracBrowser for help on using the repository browser.