source: XIOS/dev/dev_ym/XIOS_COUPLING/src/node/scalar.cpp @ 2022

Last change on this file since 2022 was 2022, checked in by ymipsl, 3 years ago

Reimplement coupling in the new infrastructure.
Tested for 2-way coupling toy model.

YM

File size: 17.0 KB
Line 
1#include "scalar.hpp"
2
3#include "attribute_template.hpp"
4#include "object_template.hpp"
5#include "group_template.hpp"
6#include "object_factory.hpp"
7#include "xios_spl.hpp"
8#include "type.hpp"
9#include "context.hpp"
10
11namespace xios
12{
13
14  /// ////////////////////// Définitions ////////////////////// ///
15
16  CScalar::CScalar(void)
17     : CObjectTemplate<CScalar>()
18     , CScalarAttributes()
19     , relFiles()
20  { /* Ne rien faire de plus */ }
21
22  CScalar::CScalar(const StdString & id)
23     : CObjectTemplate<CScalar>(id)
24     , CScalarAttributes()
25     , relFiles()
26  { /* Ne rien faire de plus */ }
27
28  CScalar::~CScalar(void)
29  { /* Ne rien faire de plus */ }
30
31  std::map<StdString, ETranformationType> CScalar::transformationMapList_ = std::map<StdString, ETranformationType>();
32  bool CScalar::dummyTransformationMapList_ = CScalar::initializeTransformationMap(CScalar::transformationMapList_);
33  bool CScalar::initializeTransformationMap(std::map<StdString, ETranformationType>& m)
34  {
35    m["reduce_axis"]   = TRANS_REDUCE_AXIS_TO_SCALAR;
36    m["extract_axis"]  = TRANS_EXTRACT_AXIS_TO_SCALAR;
37    m["reduce_domain"] = TRANS_REDUCE_DOMAIN_TO_SCALAR;
38    m["reduce_scalar"] = TRANS_REDUCE_SCALAR_TO_SCALAR;
39  }
40
41  StdString CScalar::GetName(void)   { return (StdString("scalar")); }
42  StdString CScalar::GetDefName(void){ return (CScalar::GetName()); }
43  ENodeType CScalar::GetType(void)   { return (eScalar); }
44
45  CScalar* CScalar::createScalar()
46  {
47    CScalar* scalar = CScalarGroup::get("scalar_definition")->createChild();
48    return scalar;
49  }
50
51  bool CScalar::IsWritten(const StdString & filename) const
52  {
53    return (this->relFiles.find(filename) != this->relFiles.end());
54  }
55
56  void CScalar::addRelFile(const StdString& filename)
57  {
58      this->relFiles.insert(filename);
59  }
60
61  void CScalar::checkAttributes(void)
62  {
63    if (checkAttributes_done_) return ;
64    checkAttributes_done_ = true ; 
65
66    if (mask.isEmpty()) mask=true ;
67
68    initializeLocalElement() ;
69    addFullView() ;
70    addWorkflowView() ;
71    addModelView() ;
72  }
73
74  /*!
75    Compare two scalar objects.
76    They are equal if only if they have identical attributes as well as their values.
77    Moreover, they must have the same transformations.
78  \param [in] scalar Compared scalar
79  \return result of the comparison
80  */
81  bool CScalar::isEqual(CScalar* obj)
82  {
83    vector<StdString> excludedAttr;
84    excludedAttr.push_back("scalar_ref");
85    bool objEqual = SuperClass::isEqual(obj, excludedAttr);
86    if (!objEqual) return objEqual;
87
88    TransMapTypes thisTrans = this->getAllTransformations();
89    TransMapTypes objTrans  = obj->getAllTransformations();
90
91    TransMapTypes::const_iterator it, itb, ite;
92    std::vector<ETranformationType> thisTransType, objTransType;
93    for (it = thisTrans.begin(); it != thisTrans.end(); ++it)
94      thisTransType.push_back(it->first);
95    for (it = objTrans.begin(); it != objTrans.end(); ++it)
96      objTransType.push_back(it->first);
97
98    if (thisTransType.size() != objTransType.size()) return false;
99    for (int idx = 0; idx < thisTransType.size(); ++idx)
100      objEqual &= (thisTransType[idx] == objTransType[idx]);
101
102    return objEqual;
103  }
104
105  CTransformation<CScalar>* CScalar::addTransformation(ETranformationType transType, const StdString& id)
106  {
107    transformationMap_.push_back(std::make_pair(transType, CTransformation<CScalar>::createTransformation(transType,id)));
108    return transformationMap_.back().second;
109  }
110
111  bool CScalar::hasTransformation()
112  {
113    return (!transformationMap_.empty());
114  }
115
116  void CScalar::setTransformations(const TransMapTypes& scalarTrans)
117  {
118    transformationMap_ = scalarTrans;
119  }
120
121  CScalar::TransMapTypes CScalar::getAllTransformations(void)
122  {
123    return transformationMap_;
124  }
125
126  void CScalar::duplicateTransformation(CScalar* src)
127  {
128    if (src->hasTransformation())
129    {
130      this->setTransformations(src->getAllTransformations());
131    }
132  }
133
134  /*!
135   * Go through the hierarchy to find the scalar from which the transformations must be inherited
136   */
137  void CScalar::solveInheritanceTransformation_old()
138  {
139    if (hasTransformation() || !hasDirectScalarReference())
140      return;
141
142    CScalar* scalar = this;
143    std::vector<CScalar*> refScalar;
144    while (!scalar->hasTransformation() && scalar->hasDirectScalarReference())
145    {
146      refScalar.push_back(scalar);
147      scalar = scalar->getDirectScalarReference();
148    }
149
150    if (scalar->hasTransformation())
151      for (size_t i = 0; i < refScalar.size(); ++i)
152        refScalar[i]->setTransformations(scalar->getAllTransformations());
153  }
154 
155  void CScalar::solveInheritanceTransformation()
156  TRY
157  {
158    if (solveInheritanceTransformation_done_) return;
159    else solveInheritanceTransformation_done_=true ;
160
161    CScalar* scalar = this;
162    CScalar* Lastscalar ;
163    std::list<CScalar*> refScalars;
164    bool out=false ;
165    vector<StdString> excludedAttr;
166    excludedAttr.push_back("scalar_ref");
167   
168    refScalars.push_front(scalar) ;
169    while (scalar->hasDirectScalarReference() && !out)
170    {
171      CScalar* lastScalar=scalar ;
172      scalar = scalar->getDirectScalarReference();
173      scalar->solveRefInheritance() ;
174      if (!scalar->SuperClass::isEqual(lastScalar,excludedAttr)) out=true ;
175      refScalars.push_front(scalar) ;
176    }
177
178    CTransformationPaths::TPath path ;
179    auto& pathList = std::get<2>(path) ;
180    std::get<0>(path) = EElement::SCALAR ;
181    std::get<1>(path) = refScalars.front()->getId() ;
182    for (auto& scalar : refScalars)
183    {
184      CScalar::TransMapTypes transformations = scalar->getAllTransformations();
185      for(auto& transformation : transformations) pathList.push_back({transformation.second->getTransformationType(), 
186                                                                      transformation.second->getId()}) ;
187    }
188    transformationPaths_.addPath(path) ;
189
190  }
191  CATCH_DUMP_ATTR
192
193  /* obsolete, to remove after reimplementing coupling */
194  void CScalar::sendScalarToCouplerOut(CContextClient* client, const string& fieldId, int posInGrid)
195  {
196    if (sendScalarToCouplerOut_done_.count(client)!=0) return ;
197    else sendScalarToCouplerOut_done_.insert(client) ;
198
199    string scalarId = getCouplingAlias(fieldId, posInGrid) ;
200
201    this->sendAllAttributesToServer(client, scalarId);
202  } 
203
204  string CScalar::getCouplingAlias(const string& fieldId, int posInGrid)
205  {
206    return "_scalar["+std::to_string(posInGrid)+"]_of_"+fieldId ;
207  }
208
209  void CScalar::makeAliasForCoupling(const string& fieldId, int posInGrid)
210  {
211    const string scalarId = getCouplingAlias(fieldId, posInGrid) ;
212    this->createAlias(scalarId) ;
213  }
214
215  void CScalar::setContextClient(CContextClient* contextClient)
216  TRY
217  {
218    if (clientsSet.find(contextClient)==clientsSet.end())
219    {
220      clients.push_back(contextClient) ;
221      clientsSet.insert(contextClient);
222    }
223  }
224  CATCH_DUMP_ATTR
225  /*!
226    Parse children nodes of a scalar in xml file.
227    \param node child node to process
228  */
229  void CScalar::parse(xml::CXMLNode & node)
230  {
231    SuperClass::parse(node);
232
233    if (node.goToChildElement())
234    {
235      StdString nodeElementName;
236      do
237      {
238        StdString nodeId("");
239        if (node.getAttributes().end() != node.getAttributes().find("id"))
240        { nodeId = node.getAttributes()["id"]; }
241
242        nodeElementName = node.getElementName();
243        std::map<StdString, ETranformationType>::const_iterator ite = transformationMapList_.end(), it;
244        it = transformationMapList_.find(nodeElementName);
245        if (ite != it)
246        {
247          transformationMap_.push_back(std::make_pair(it->second, CTransformation<CScalar>::createTransformation(it->second,
248                                                                                                                 nodeId,
249                                                                                                                 &node)));
250        }
251        else
252        {
253          ERROR("void CScalar::parse(xml::CXMLNode & node)",
254                << "The transformation " << nodeElementName << " has not been supported yet.");
255        }
256      } while (node.goToNextElement()) ;
257      node.goToParentElement();
258    }
259  }
260
261   //////////////////////////////////////////////////////////////////////////////////////
262   //  this part is related to distribution, element definition, views and connectors  //
263   //////////////////////////////////////////////////////////////////////////////////////
264
265   void CScalar::initializeLocalElement(void)
266   {
267      // after checkAttribute index of size n
268      int rank = CContext::getCurrent()->getIntraCommRank() ;
269     
270      CArray<size_t,1> ind(1) ;
271      ind(0)=0 ;
272      localElement_ = new CLocalElement(rank, 1, ind) ;
273   }
274
275   void CScalar::addFullView(void)
276   {
277      CArray<int,1> index(1) ;
278      for(int i=0; i<1 ; i++) index(0)=0 ;
279      localElement_ -> addView(CElementView::FULL, index) ;
280   }
281
282   void CScalar::addWorkflowView(void)
283   {
284      CArray<int,1> index ;
285      if (mask) 
286      {
287        index.resize(1) ;
288        index(0)=0 ;
289      }
290      else index.resize(0) ;
291      localElement_ -> addView(CElementView::WORKFLOW, index) ;
292   }
293
294   void CScalar::addModelView(void)
295   {
296     CArray<int,1> index(1) ;
297     for(int i=0; i<1 ; i++) index(0)=0 ;
298     localElement_->addView(CElementView::MODEL, index) ;
299   }
300
301   void CScalar::computeModelToWorkflowConnector(void)
302   { 
303     CLocalView* srcView=getLocalView(CElementView::MODEL) ;
304     CLocalView* dstView=getLocalView(CElementView::WORKFLOW) ;
305     modelToWorkflowConnector_ = new CLocalConnector(srcView, dstView); 
306     modelToWorkflowConnector_->computeConnector() ;
307   }
308
309
310  void CScalar::computeRemoteElement(CContextClient* client, EDistributionType type)
311  {
312    CContext* context = CContext::getCurrent();
313    map<int, CArray<size_t,1>> globalIndex ;
314
315    int nbServer = client->serverSize;
316    size_t nglo=1 ;
317    CArray<size_t,1> indGlo(nglo) ;
318    for(size_t i=0;i<nglo;i++) indGlo(i) = i ;
319    for (auto& rankServer : client->getRanksServerLeader()) globalIndex[rankServer].reference(indGlo.copy()) ; 
320
321    remoteElement_[client] = new CDistributedElement(nglo, globalIndex) ;
322    remoteElement_[client]->addFullView() ;
323  }
324 
325  void CScalar::distributeToServer(CContextClient* client, std::map<int, CArray<size_t,1>>& globalIndex, 
326                                   CScattererConnector* &scattererConnector, const string& scalarId)
327  {
328    string serverScalarId = scalarId.empty() ? this->getId() : scalarId ;
329    CContext* context = CContext::getCurrent();
330
331    this->sendAllAttributesToServer(client, serverScalarId)  ;
332
333    CDistributedElement scatteredElement(1,globalIndex) ;
334    scatteredElement.addFullView() ;
335    scattererConnector = new CScattererConnector(localElement_->getView(CElementView::FULL), scatteredElement.getView(CElementView::FULL), 
336                                                 context->getIntraComm(), client->getRemoteSize()) ;
337    scattererConnector->computeConnector() ;
338   
339    // phase 0
340    // send remote element to construct the full view on server, ie without hole
341    CEventClient event0(getType(), EVENT_ID_SCALAR_DISTRIBUTION);
342    CMessage message0 ;
343    message0<<serverScalarId<<0 ; 
344    remoteElement_[client]->sendToServer(client,event0,message0) ; 
345   
346    // phase 1
347    // send the full view of element to construct the connector which connect distributed data coming from client to the full local view
348    CEventClient event1(getType(), EVENT_ID_SCALAR_DISTRIBUTION);
349    CMessage message1 ;
350    message1<<serverScalarId<<1<<localElement_->getView(CElementView::FULL)->getGlobalSize() ; 
351    scattererConnector->transfer(localElement_->getView(CElementView::FULL)->getGlobalIndex(),client,event1,message1) ;
352
353    sendDistributedAttributes(client, *scattererConnector, scalarId) ;
354 
355    // phase 2 send the mask : data index + mask2D
356    CArray<bool,1> maskIn(localElement_->getView(CElementView::WORKFLOW)->getSize());
357    CArray<bool,1> maskOut ;
358    CLocalConnector workflowToFull(localElement_->getView(CElementView::WORKFLOW), localElement_->getView(CElementView::FULL)) ;
359    workflowToFull.computeConnector() ;
360    maskIn=true ;
361    workflowToFull.transfer(maskIn,maskOut,false) ;
362
363    // phase 3 : prepare grid scatterer connector to send data from client to server
364    map<int,CArray<size_t,1>> workflowGlobalIndex ;
365    map<int,CArray<bool,1>> maskOut2 ; 
366    scattererConnector->transfer(maskOut, maskOut2) ;
367    scatteredElement.addView(CElementView::WORKFLOW, maskOut2) ;
368    scatteredElement.getView(CElementView::WORKFLOW)->getGlobalIndexView(workflowGlobalIndex) ;
369    // create new workflow view for scattered element
370    CDistributedElement clientToServerElement(scatteredElement.getGlobalSize(), workflowGlobalIndex) ;
371    clientToServerElement.addFullView() ;
372    CEventClient event2(getType(), EVENT_ID_SCALAR_DISTRIBUTION);
373    CMessage message2 ;
374    message2<<serverScalarId<<2 ; 
375    clientToServerElement.sendToServer(client, event2, message2) ; 
376    clientToServerConnector_[client] = new CScattererConnector(localElement_->getView(CElementView::WORKFLOW), clientToServerElement.getView(CElementView::FULL),
377                                                               context->getIntraComm(), client->getRemoteSize()) ;
378    clientToServerConnector_[client]->computeConnector() ;
379
380    clientFromServerConnector_[client] = new CGathererConnector(clientToServerElement.getView(CElementView::FULL), localElement_->getView(CElementView::WORKFLOW));
381    clientFromServerConnector_[client]->computeConnector() ;
382
383  }
384 
385  void CScalar::recvScalarDistribution(CEventServer& event)
386  TRY
387  {
388    string scalarId;
389    int phasis ;
390    for (auto& subEvent : event.subEvents) (*subEvent.buffer) >> scalarId >> phasis ;
391    get(scalarId)->receivedScalarDistribution(event, phasis);
392  }
393  CATCH
394 
395  void CScalar::receivedScalarDistribution(CEventServer& event, int phasis)
396  TRY
397  {
398    CContext* context = CContext::getCurrent();
399    if (phasis==0) // receive the remote element to construct the full view
400    {
401      localElement_ = new  CLocalElement(context->getIntraCommRank(),event) ;
402      localElement_->addFullView() ;
403      // construct the local dimension and indexes
404      auto& globalIndex=localElement_->getGlobalIndex() ;
405      int nk=globalIndex.numElements() ;
406      // no distribution for scalar => nk ==1 or maybe 0 ?
407    }
408    else if (phasis==1) // receive the sent view from client to construct the full distributed full view on server
409    {
410      CContext* context = CContext::getCurrent();
411      CDistributedElement* elementFrom = new  CDistributedElement(event) ;
412      elementFrom->addFullView() ;
413      gathererConnector_ = new CGathererConnector(elementFrom->getView(CElementView::FULL), localElement_->getView(CElementView::FULL)) ;
414      gathererConnector_->computeConnector() ; 
415    }
416    else if (phasis==2)
417    {
418//      delete gathererConnector_ ;
419      elementFrom_ = new  CDistributedElement(event) ;
420      elementFrom_->addFullView() ;
421//      gathererConnector_ =  new CGathererConnector(elementFrom_->getView(CElementView::FULL), localElement_->getView(CElementView::FULL)) ;
422//      gathererConnector_ -> computeConnector() ;
423    }
424  }
425  CATCH
426
427  void CScalar::setServerMask(CArray<bool,1>& serverMask, CContextClient* client)
428  TRY
429  {
430    CContext* context = CContext::getCurrent();
431    localElement_->addView(CElementView::WORKFLOW, serverMask) ;
432    mask = serverMask(0) ;
433 
434    serverFromClientConnector_ = new CGathererConnector(elementFrom_->getView(CElementView::FULL), localElement_->getView(CElementView::WORKFLOW)) ;
435    serverFromClientConnector_->computeConnector() ;
436     
437    serverToClientConnector_ = new CScattererConnector(localElement_->getView(CElementView::WORKFLOW), elementFrom_->getView(CElementView::FULL),
438                                                         context->getIntraComm(), client->getRemoteSize()) ;
439    serverToClientConnector_->computeConnector() ;
440  }
441  CATCH_DUMP_ATTR
442
443  void CScalar::sendDistributedAttributes(CContextClient* client, CScattererConnector& scattererConnector, const string& scalarId)
444  {
445    string serverScalarId = scalarId.empty() ? this->getId() : scalarId ;
446    CContext* context = CContext::getCurrent();
447
448    // nothing for now
449  }
450
451  void CScalar::recvDistributedAttributes(CEventServer& event)
452  TRY
453  {
454    string scalarId;
455    string type ;
456    for (auto& subEvent : event.subEvents) (*subEvent.buffer) >> scalarId >> type ;
457    get(scalarId)->recvDistributedAttributes(event, type);
458  }
459  CATCH
460
461  void CScalar::recvDistributedAttributes(CEventServer& event, const string& type)
462  TRY
463  {
464    // nothing for now
465  }
466  CATCH 
467
468  bool CScalar::dispatchEvent(CEventServer& event)
469  TRY
470  {
471     if (SuperClass::dispatchEvent(event)) return true;
472     else
473     {
474       switch(event.type)
475       {
476          case EVENT_ID_SCALAR_DISTRIBUTION:
477            recvScalarDistribution(event);
478            return true;
479            break;
480          case EVENT_ID_SEND_DISTRIBUTED_ATTRIBUTE:
481            recvDistributedAttributes(event);
482            return true;
483            break;
484          default :
485            ERROR("bool CScalar::dispatchEvent(CEventServer& event)",
486                   << "Unknown Event");
487          return false;
488        }
489     }
490  }
491  CATCH
492
493
494  // Definition of some macros
495  DEFINE_REF_FUNC(Scalar,scalar)
496
497} // namespace xios
Note: See TracBrowser for help on using the repository browser.