source: XIOS/trunk/src/filter/spatial_transform_filter.cpp @ 832

Last change on this file since 832 was 832, checked in by mhnguyen, 8 years ago

Weight computation of dynamic transformation is done only one for each time stamp

+) Each weight computation of dynamic transformation attached to timestamp
+) Remove some redundant codes

Test
+) On Curie
+) All tests pass

File size: 7.5 KB
Line 
1#include "spatial_transform_filter.hpp"
2#include "grid_transformation.hpp"
3#include "context.hpp"
4#include "context_client.hpp"
5
6namespace xios
7{
8  CSpatialTransformFilter::CSpatialTransformFilter(CGarbageCollector& gc, CSpatialTransformFilterEngine* engine, size_t inputSlotsCount)
9    : CFilter(gc, inputSlotsCount, engine)
10  { /* Nothing to do */ }
11
12  std::pair<boost::shared_ptr<CSpatialTransformFilter>, boost::shared_ptr<CSpatialTransformFilter> >
13  CSpatialTransformFilter::buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)
14  {
15    if (!srcGrid || !destGrid)
16      ERROR("std::pair<boost::shared_ptr<CSpatialTransformFilter>, boost::shared_ptr<CSpatialTransformFilter> >"
17            "buildFilterGraph(CGarbageCollector& gc, CGrid* srcGrid, CGrid* destGrid)",
18            "Impossible to build the filter graph if either the source or the destination grid are null.");
19
20    boost::shared_ptr<CSpatialTransformFilter> firstFilter, lastFilter;
21    // Note that this loop goes from the last transformation to the first transformation
22    do
23    {
24      CGridTransformation* gridTransformation = destGrid->getTransformations();
25      CSpatialTransformFilterEngine* engine = CSpatialTransformFilterEngine::get(destGrid->getTransformations());
26      const std::vector<StdString>& auxInputs = gridTransformation->getAuxInputs();
27      size_t inputCount = 1 + (auxInputs.empty() ? 0 : auxInputs.size());
28      boost::shared_ptr<CSpatialTransformFilter> filter(new CSpatialTransformFilter(gc, engine, inputCount));
29
30      if (!lastFilter)
31        lastFilter = filter;
32      else
33        filter->connectOutput(firstFilter, 0);
34
35      firstFilter = filter;
36      for (size_t idx = 0; idx < auxInputs.size(); ++idx)
37      {
38        CField* fieldAuxInput = CField::get(auxInputs[idx]);
39        fieldAuxInput->buildFilterGraph(gc, false);
40        fieldAuxInput->getInstantDataFilter()->connectOutput(firstFilter,idx+1);
41      }
42
43      destGrid = gridTransformation->getGridSource();
44    }
45    while (destGrid != srcGrid);
46
47    return std::make_pair(firstFilter, lastFilter);
48  }
49
50  CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)
51    : gridTransformation(gridTransformation)
52  {
53    if (!gridTransformation)
54      ERROR("CSpatialTransformFilterEngine::CSpatialTransformFilterEngine(CGridTransformation* gridTransformation)",
55            "Impossible to construct a spatial transform filter engine without a valid grid transformation.");
56  }
57
58  std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> > CSpatialTransformFilterEngine::engines;
59
60  CSpatialTransformFilterEngine* CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)
61  {
62    if (!gridTransformation)
63      ERROR("CSpatialTransformFilterEngine& CSpatialTransformFilterEngine::get(CGridTransformation* gridTransformation)",
64            "Impossible to get the requested engine, the grid transformation is invalid.");
65
66    std::map<CGridTransformation*, boost::shared_ptr<CSpatialTransformFilterEngine> >::iterator it = engines.find(gridTransformation);
67    if (it == engines.end())
68    {
69      boost::shared_ptr<CSpatialTransformFilterEngine> engine(new CSpatialTransformFilterEngine(gridTransformation));
70      it = engines.insert(std::make_pair(gridTransformation, engine)).first;
71    }
72
73    return it->second.get();
74  }
75
76  CDataPacketPtr CSpatialTransformFilterEngine::apply(std::vector<CDataPacketPtr> data)
77  {
78    CDataPacketPtr packet(new CDataPacket);
79    packet->date = data[0]->date;
80    packet->timestamp = data[0]->timestamp;
81    packet->status = data[0]->status;
82
83    if (packet->status == CDataPacket::NO_ERROR)
84    {
85      if (1 < data.size())  // Dynamical transformations
86      {
87        std::vector<CArray<double,1>* > dataAuxInputs(data.size()-1);
88        for (size_t idx = 0; idx < dataAuxInputs.size(); ++idx) dataAuxInputs[idx] = &(data[idx+1]->data);
89        gridTransformation->computeAll(dataAuxInputs, packet->timestamp);
90      }
91      packet->data.resize(gridTransformation->getGridDestination()->storeIndex_client.numElements());
92      apply(data[0]->data, packet->data);
93    }
94
95    return packet;
96  }
97
98  void CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)
99  {
100    CContextClient* client = CContext::getCurrent()->client;
101
102    const std::map<int, CArray<int,1> >& localIndexToSend = gridTransformation->getLocalIndexToSendFromGridSource();
103    const std::map<int, std::vector<std::vector<std::pair<int,double> > > >& localIndexToReceive = gridTransformation->getLocalIndexToReceiveOnGridDest();
104
105    dataDest = 0.0;
106
107    // Sending data from field sources to do transformations
108    std::map<int, CArray<int,1> >::const_iterator itbSend = localIndexToSend.begin(), itSend,
109                                                  iteSend = localIndexToSend.end();
110    int idxSendBuff = 0;
111    std::vector<double*> sendBuff(localIndexToSend.size());
112    for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
113    {
114      if (0 != itSend->second.numElements())
115        sendBuff[idxSendBuff] = new double[itSend->second.numElements()];
116    }
117
118    idxSendBuff = 0;
119    std::vector<MPI_Request> sendRequest;
120    for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
121    {
122      int destRank = itSend->first;
123      const CArray<int,1>& localIndex_p = itSend->second;
124      int countSize = localIndex_p.numElements();
125      for (int idx = 0; idx < countSize; ++idx)
126      {
127        sendBuff[idxSendBuff][idx] = dataSrc(localIndex_p(idx));
128      }
129      sendRequest.push_back(MPI_Request());
130      MPI_Isend(sendBuff[idxSendBuff], countSize, MPI_DOUBLE, destRank, 12, client->intraComm, &sendRequest.back());
131    }
132
133    // Receiving data on destination fields
134    std::map<int,std::vector<std::vector<std::pair<int,double> > > >::const_iterator itbRecv = localIndexToReceive.begin(), itRecv,
135                                                                                     iteRecv = localIndexToReceive.end();
136    int recvBuffSize = 0;
137    for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv) recvBuffSize = (recvBuffSize < itRecv->second.size())
138                                                                     ? itRecv->second.size() : recvBuffSize;
139    double* recvBuff;
140    if (0 != recvBuffSize) recvBuff = new double[recvBuffSize];
141    for (itRecv = itbRecv; itRecv != iteRecv; ++itRecv)
142    {
143      MPI_Status status;
144      int srcRank = itRecv->first;
145      int countSize = itRecv->second.size();
146      MPI_Recv(recvBuff, recvBuffSize, MPI_DOUBLE, srcRank, 12, client->intraComm, &status);
147      int countBuff = 0;
148      MPI_Get_count(&status, MPI_DOUBLE, &countBuff);
149      if (countBuff != countSize)
150        ERROR("CSpatialTransformFilterEngine::apply(const CArray<double, 1>& dataSrc, CArray<double,1>& dataDest)",
151              "Incoherent between the received size and expected size");
152      for (int idx = 0; idx < countSize; ++idx)
153      {
154        const std::vector<std::pair<int,double> >& localIndex_p = itRecv->second[idx];
155        int numIndex = localIndex_p.size();
156        for (int i = 0; i < numIndex; ++i)
157        {
158          dataDest(localIndex_p[i].first) += recvBuff[idx] * localIndex_p[i].second;
159        }
160      }
161    }
162
163
164    if (!sendRequest.empty()) MPI_Waitall(sendRequest.size(), &sendRequest[0], MPI_STATUSES_IGNORE);
165    idxSendBuff = 0;
166    for (itSend = itbSend; itSend != iteSend; ++itSend, ++idxSendBuff)
167    {
168      if (0 != itSend->second.numElements())
169        delete [] sendBuff[idxSendBuff];
170    }
171    if (0 != recvBuffSize) delete [] recvBuff;
172  }
173} // namespace xios
Note: See TracBrowser for help on using the repository browser.