source: XIOS/dev/dev_ym/XIOS_COUPLING/src/node/scalar.cpp @ 2267

Last change on this file since 2267 was 2267, checked in by ymipsl, 2 years ago

tracking memory leak
Elements, views, and connectors are now managed with shared pointer.
YM

File size: 20.3 KB
Line 
1#include "scalar.hpp"
2
3#include "attribute_template.hpp"
4#include "object_template.hpp"
5#include "group_template.hpp"
6#include "object_factory.hpp"
7#include "xios_spl.hpp"
8#include "type.hpp"
9#include "context.hpp"
10
11#include <algorithm>
12#include <regex>
13
14namespace xios
15{
16
17  /// ////////////////////// Définitions ////////////////////// ///
18
19  CScalar::CScalar(void)
20     : CObjectTemplate<CScalar>()
21     , CScalarAttributes()
22     , relFiles()
23  { /* Ne rien faire de plus */ }
24
25  CScalar::CScalar(const StdString & id)
26     : CObjectTemplate<CScalar>(id)
27     , CScalarAttributes()
28     , relFiles()
29  { /* Ne rien faire de plus */ }
30
31  CScalar::~CScalar(void)
32  { /* Ne rien faire de plus */ }
33
34  std::map<StdString, ETranformationType> CScalar::transformationMapList_ = std::map<StdString, ETranformationType>();
35  bool CScalar::dummyTransformationMapList_ = CScalar::initializeTransformationMap(CScalar::transformationMapList_);
36  bool CScalar::initializeTransformationMap(std::map<StdString, ETranformationType>& m)
37  {
38    m["reduce_axis"]   = TRANS_REDUCE_AXIS_TO_SCALAR;
39    m["extract_axis"]  = TRANS_EXTRACT_AXIS_TO_SCALAR;
40    m["reduce_domain"] = TRANS_REDUCE_DOMAIN_TO_SCALAR;
41    m["reduce_scalar"] = TRANS_REDUCE_SCALAR_TO_SCALAR;
42    return true;
43  }
44
45  StdString CScalar::GetName(void)   { return (StdString("scalar")); }
46  StdString CScalar::GetDefName(void){ return (CScalar::GetName()); }
47  ENodeType CScalar::GetType(void)   { return (eScalar); }
48
49  CScalar* CScalar::createScalar()
50  {
51    CScalar* scalar = CScalarGroup::get("scalar_definition")->createChild();
52    return scalar;
53  }
54
55  CScalar* CScalar::get(const string& id, bool noError)
56  {
57    const regex r("::");
58    smatch m;
59    if (regex_search(id, m, r))
60    {
61      if (m.size()!=1) ERROR("CScalar* CScalar::get(string& id)", <<" id = "<<id<< "  -> bad format id, separator :: append more than one time");
62      string fieldId=m.prefix() ;
63      if (fieldId.empty()) ERROR("CScalar* CScalar::get(string& id)", <<" id = "<<id<< "  -> bad format id, field name is empty");
64      string suffix=m.suffix() ;
65      if (!CField::has(fieldId)) 
66          if (noError)  return nullptr ;
67          else ERROR("CScalar* CScalar::get(const string& id, bool noError)", <<" id = "<<id<< "  -> field Id : < "<<fieldId<<" > doesn't exist");
68      CField* field=CField::get(fieldId) ;
69      return field->getAssociatedScalar(suffix, noError) ;
70    }
71    else
72    {
73       if (noError) if(!CObjectFactory::HasObject<CScalar>(id)) return nullptr ;
74       return CObjectFactory::GetObject<CScalar>(id).get();
75     }
76   }
77
78   bool CScalar::has(const string& id)
79   {
80     if (CScalar::get(id,true)==nullptr) return false ;
81     else return true ;
82   }
83   
84   CField* CScalar::getFieldFromId(const string& id)
85   {
86     const regex r("::");
87     smatch m;
88     if (regex_search(id, m, r))
89     {
90        if (m.size()!=1) ERROR("CField* CScalar::getFieldFromId(const string& id)", <<" id = "<<id<< "  -> bad format id, separator :: append more than one time");
91        string fieldId=m.prefix() ;
92        if (fieldId.empty()) ERROR("CField* CScalar::getFieldFromId(const string& id)", <<" id = "<<id<< "  -> bad format id, field name is empty");
93        string suffix=m.suffix() ;
94        CField* field=CField::get(fieldId) ;
95        return field ;
96     }
97     else return nullptr;
98   }
99
100  bool CScalar::IsWritten(const StdString & filename) const
101  {
102    return (this->relFiles.find(filename) != this->relFiles.end());
103  }
104
105  void CScalar::addRelFile(const StdString& filename)
106  {
107      this->relFiles.insert(filename);
108  }
109
110  void CScalar::checkAttributes(void)
111  {
112    if (checkAttributes_done_) return ;
113    checkAttributes_done_ = true ; 
114   
115    if (n.isEmpty()) n=1 ;
116    if (mask.isEmpty()) mask=true ;
117
118    initializeLocalElement() ;
119    addFullView() ;
120    addWorkflowView() ;
121    addModelView() ;
122  }
123
124  /*!
125    Compare two scalar objects.
126    They are equal if only if they have identical attributes as well as their values.
127    Moreover, they must have the same transformations.
128  \param [in] scalar Compared scalar
129  \return result of the comparison
130  */
131  bool CScalar::isEqual(CScalar* obj)
132  {
133    vector<StdString> excludedAttr;
134    excludedAttr.push_back("scalar_ref");
135    bool objEqual = SuperClass::isEqual(obj, excludedAttr);
136    if (!objEqual) return objEqual;
137
138    TransMapTypes thisTrans = this->getAllTransformations();
139    TransMapTypes objTrans  = obj->getAllTransformations();
140
141    TransMapTypes::const_iterator it, itb, ite;
142    std::vector<ETranformationType> thisTransType, objTransType;
143    for (it = thisTrans.begin(); it != thisTrans.end(); ++it)
144      thisTransType.push_back(it->first);
145    for (it = objTrans.begin(); it != objTrans.end(); ++it)
146      objTransType.push_back(it->first);
147
148    if (thisTransType.size() != objTransType.size()) return false;
149    for (int idx = 0; idx < thisTransType.size(); ++idx)
150      objEqual &= (thisTransType[idx] == objTransType[idx]);
151
152    return objEqual;
153  }
154
155  CTransformation<CScalar>* CScalar::addTransformation(ETranformationType transType, const StdString& id)
156  {
157    transformationMap_.push_back(std::make_pair(transType, CTransformation<CScalar>::createTransformation(transType,id)));
158    return transformationMap_.back().second;
159  }
160
161  bool CScalar::hasTransformation()
162  {
163    return (!transformationMap_.empty());
164  }
165
166  void CScalar::setTransformations(const TransMapTypes& scalarTrans)
167  {
168    transformationMap_ = scalarTrans;
169  }
170
171  CScalar::TransMapTypes CScalar::getAllTransformations(void)
172  {
173    return transformationMap_;
174  }
175
176  void CScalar::duplicateTransformation(CScalar* src)
177  {
178    if (src->hasTransformation())
179    {
180      this->setTransformations(src->getAllTransformations());
181    }
182  }
183
184  /*!
185   * Go through the hierarchy to find the scalar from which the transformations must be inherited
186   */
187  void CScalar::solveInheritanceTransformation_old()
188  {
189    if (hasTransformation() || !hasDirectScalarReference())
190      return;
191
192    CScalar* scalar = this;
193    std::vector<CScalar*> refScalar;
194    while (!scalar->hasTransformation() && scalar->hasDirectScalarReference())
195    {
196      refScalar.push_back(scalar);
197      scalar = scalar->getDirectScalarReference();
198    }
199
200    if (scalar->hasTransformation())
201      for (size_t i = 0; i < refScalar.size(); ++i)
202        refScalar[i]->setTransformations(scalar->getAllTransformations());
203  }
204 
205  void CScalar::solveInheritanceTransformation()
206  TRY
207  {
208    if (solveInheritanceTransformation_done_) return;
209    else solveInheritanceTransformation_done_=true ;
210
211    CScalar* scalar = this;
212    CScalar* Lastscalar ;
213    std::list<CScalar*> refScalars;
214    bool out=false ;
215    vector<StdString> excludedAttr;
216    excludedAttr.push_back("scalar_ref");
217   
218    refScalars.push_front(scalar) ;
219    while (scalar->hasDirectScalarReference() && !out)
220    {
221      CScalar* lastScalar=scalar ;
222      scalar = scalar->getDirectScalarReference();
223      scalar->solveRefInheritance() ;
224      if (!scalar->SuperClass::isEqual(lastScalar,excludedAttr)) out=true ;
225      refScalars.push_front(scalar) ;
226    }
227
228    CTransformationPaths::TPath path ;
229    auto& pathList = std::get<2>(path) ;
230    std::get<0>(path) = EElement::SCALAR ;
231    std::get<1>(path) = refScalars.front()->getId() ;
232    for (auto& scalar : refScalars)
233    {
234      CScalar::TransMapTypes transformations = scalar->getAllTransformations();
235      for(auto& transformation : transformations) pathList.push_back({transformation.second->getTransformationType(), 
236                                                                      transformation.second->getId()}) ;
237    }
238    transformationPaths_.addPath(path) ;
239
240  }
241  CATCH_DUMP_ATTR
242
243  bool CScalar::activateFieldWorkflow(CGarbageCollector& gc)
244  TRY
245  {
246    if (!scalar_ref.isEmpty())
247    {
248      CField* field=getFieldFromId(scalar_ref) ;
249      if (field!=nullptr)
250      {
251        bool ret = field->buildWorkflowGraph(gc) ;
252        if (!ret) return false ; // cannot build workflow graph at this state
253      }
254      else 
255      {
256        CScalar* scalar = get(scalar_ref) ;
257        bool ret = scalar->activateFieldWorkflow(gc) ;
258        if (!ret) return false ; // cannot build workflow graph at this state
259        scalar_ref=scalar->getId() ; // replace domain_ref by solved reference
260      }
261    }
262    activateFieldWorkflow_done_=true ;
263    return true ;
264  }
265  CATCH_DUMP_ATTR
266
267
268  /* obsolete, to remove after reimplementing coupling */
269  void CScalar::sendScalarToCouplerOut(CContextClient* client, const string& fieldId, int posInGrid)
270  {
271    if (sendScalarToCouplerOut_done_.count(client)!=0) return ;
272    else sendScalarToCouplerOut_done_.insert(client) ;
273
274    string scalarId = getCouplingAlias(fieldId, posInGrid) ;
275
276    this->sendAllAttributesToServer(client, scalarId);
277  } 
278
279  string CScalar::getCouplingAlias(const string& fieldId, int posInGrid)
280  {
281    return "_scalar["+std::to_string(posInGrid)+"]_of_"+fieldId ;
282  }
283
284  void CScalar::makeAliasForCoupling(const string& fieldId, int posInGrid)
285  {
286    const string scalarId = getCouplingAlias(fieldId, posInGrid) ;
287    this->createAlias(scalarId) ;
288  }
289
290  void CScalar::setContextClient(CContextClient* contextClient)
291  TRY
292  {
293    if (clientsSet.find(contextClient)==clientsSet.end())
294    {
295      clients.push_back(contextClient) ;
296      clientsSet.insert(contextClient);
297    }
298  }
299  CATCH_DUMP_ATTR
300  /*!
301    Parse children nodes of a scalar in xml file.
302    \param node child node to process
303  */
304  void CScalar::parse(xml::CXMLNode & node)
305  {
306    SuperClass::parse(node);
307
308    if (node.goToChildElement())
309    {
310      StdString nodeElementName;
311      do
312      {
313        StdString nodeId("");
314        if (node.getAttributes().end() != node.getAttributes().find("id"))
315        { nodeId = node.getAttributes()["id"]; }
316
317        nodeElementName = node.getElementName();
318        std::map<StdString, ETranformationType>::const_iterator ite = transformationMapList_.end(), it;
319        it = transformationMapList_.find(nodeElementName);
320        if (ite != it)
321        {
322          transformationMap_.push_back(std::make_pair(it->second, CTransformation<CScalar>::createTransformation(it->second,
323                                                                                                                 nodeId,
324                                                                                                                 &node)));
325        }
326        else
327        {
328          ERROR("void CScalar::parse(xml::CXMLNode & node)",
329                << "The transformation " << nodeElementName << " has not been supported yet.");
330        }
331      } while (node.goToNextElement()) ;
332      node.goToParentElement();
333    }
334  }
335
336   //////////////////////////////////////////////////////////////////////////////////////
337   //  this part is related to distribution, element definition, views and connectors  //
338   //////////////////////////////////////////////////////////////////////////////////////
339
340   void CScalar::initializeLocalElement(void)
341   {
342      // after checkAttribute index of size n
343      int rank = CContext::getCurrent()->getIntraCommRank() ;
344     
345     
346      CArray<size_t,1> index(n) ;
347      if (n==1) index(0)=0 ;
348      localElement_ = make_shared<CLocalElement>(rank, 1, index) ;
349   }
350
351   void CScalar::addFullView(void)
352   {
353      CArray<int,1> index(n) ;
354      if (n==1) index(0)=0 ;
355      localElement_ -> addView(CElementView::FULL, index) ;
356   }
357
358   void CScalar::addWorkflowView(void)
359   {
360      CArray<int,1> index ;
361      if (mask && n==1)
362      {
363        index.resize(1) ;
364        index(0)=0 ;
365      }
366      else index.resize(0) ;
367      localElement_ -> addView(CElementView::WORKFLOW, index) ;
368   }
369
370   void CScalar::addModelView(void)
371   {
372     CArray<int,1> index(1) ;
373     for(int i=0; i<1 ; i++) index(0)=0 ;
374     localElement_->addView(CElementView::MODEL, index) ;
375   }
376
377   void CScalar::computeModelToWorkflowConnector(void)
378   { 
379     shared_ptr<CLocalView> srcView=getLocalView(CElementView::MODEL) ;
380     shared_ptr<CLocalView> dstView=getLocalView(CElementView::WORKFLOW) ;
381     modelToWorkflowConnector_ = make_shared<CLocalConnector>(srcView, dstView); 
382     modelToWorkflowConnector_->computeConnector() ;
383   }
384
385
386  void CScalar::computeRemoteElement(CContextClient* client, EDistributionType type)
387  {
388    CContext* context = CContext::getCurrent();
389    map<int, CArray<size_t,1>> globalIndex ;
390    size_t nglo=1 ;
391
392    if (type==EDistributionType::ROOT) // Bands distribution to send to file server
393    {
394      for (auto& rankServer : client->getRanksServerLeader())
395      {
396        auto& globalInd =  globalIndex[rankServer] ;
397        if (rankServer==0) 
398        {
399          globalInd.resize(1) ;
400          globalInd(0)=0 ;
401        }
402      }
403    }
404    else if (type==EDistributionType::NONE) // domain is not distributed ie all servers get the same local domain
405    {
406      int nbServer = client->serverSize;
407      CArray<size_t,1> indGlo(nglo) ;
408      for(size_t i=0;i<nglo;i++) indGlo(i) = i ;
409      for (auto& rankServer : client->getRanksServerLeader()) globalIndex[rankServer].reference(indGlo.copy()) ; 
410    }
411    remoteElement_[client] = make_shared<CDistributedElement>(nglo, globalIndex) ;
412    remoteElement_[client]->addFullView() ;
413  }
414 
415  void CScalar::distributeToServer(CContextClient* client, std::map<int, CArray<size_t,1>>& globalIndex, 
416                                   shared_ptr<CScattererConnector> &scattererConnector, const string& scalarId)
417  {
418    string serverScalarId = scalarId.empty() ? this->getId() : scalarId ;
419    CContext* context = CContext::getCurrent();
420
421    this->sendAllAttributesToServer(client, serverScalarId)  ;
422
423    auto scatteredElement = make_shared<CDistributedElement>(1,globalIndex) ;
424    scatteredElement->addFullView() ;
425    scattererConnector = make_shared<CScattererConnector>(localElement_->getView(CElementView::FULL), scatteredElement->getView(CElementView::FULL), 
426                                                          context->getIntraComm(), client->getRemoteSize()) ;
427    scattererConnector->computeConnector() ;
428   
429    // phase 0
430    // send remote element to construct the full view on server, ie without hole
431    CEventClient event0(getType(), EVENT_ID_SCALAR_DISTRIBUTION);
432    CMessage message0 ;
433    message0<<serverScalarId<<0 ; 
434    remoteElement_[client]->sendToServer(client,event0,message0) ; 
435   
436    // phase 1
437    // send the full view of element to construct the connector which connect distributed data coming from client to the full local view
438    CEventClient event1(getType(), EVENT_ID_SCALAR_DISTRIBUTION);
439    CMessage message1 ;
440    message1<<serverScalarId<<1<<localElement_->getView(CElementView::FULL)->getGlobalSize() ; 
441    scattererConnector->transfer(localElement_->getView(CElementView::FULL)->getGlobalIndex(),client,event1,message1) ;
442
443    sendDistributedAttributes(client, scattererConnector, scalarId) ;
444 
445    // phase 2 send the mask : data index + mask2D
446    CArray<bool,1> maskIn(localElement_->getView(CElementView::WORKFLOW)->getSize());
447    CArray<bool,1> maskOut ;
448    auto workflowToFull = make_shared<CLocalConnector>(localElement_->getView(CElementView::WORKFLOW), localElement_->getView(CElementView::FULL)) ;
449    workflowToFull->computeConnector() ;
450    maskIn=true ;
451    workflowToFull->transfer(maskIn,maskOut,false) ;
452
453    // phase 3 : prepare grid scatterer connector to send data from client to server
454    map<int,CArray<size_t,1>> workflowGlobalIndex ;
455    map<int,CArray<bool,1>> maskOut2 ; 
456    scattererConnector->transfer(maskOut, maskOut2) ;
457    scatteredElement->addView(CElementView::WORKFLOW, maskOut2) ;
458    scatteredElement->getView(CElementView::WORKFLOW)->getGlobalIndexView(workflowGlobalIndex) ;
459    // create new workflow view for scattered element
460    auto clientToServerElement = make_shared<CDistributedElement>(scatteredElement->getGlobalSize(), workflowGlobalIndex) ;
461    clientToServerElement->addFullView() ;
462    CEventClient event2(getType(), EVENT_ID_SCALAR_DISTRIBUTION);
463    CMessage message2 ;
464    message2<<serverScalarId<<2 ; 
465    clientToServerElement->sendToServer(client, event2, message2) ; 
466    clientToServerConnector_[client] = make_shared<CScattererConnector>(localElement_->getView(CElementView::WORKFLOW), clientToServerElement->getView(CElementView::FULL),
467                                                                        context->getIntraComm(), client->getRemoteSize()) ;
468    clientToServerConnector_[client]->computeConnector() ;
469
470    clientFromServerConnector_[client] = make_shared<CGathererConnector>(clientToServerElement->getView(CElementView::FULL), localElement_->getView(CElementView::WORKFLOW));
471    clientFromServerConnector_[client]->computeConnector() ;
472
473  }
474 
475  void CScalar::recvScalarDistribution(CEventServer& event)
476  TRY
477  {
478    string scalarId;
479    int phasis ;
480    for (auto& subEvent : event.subEvents) (*subEvent.buffer) >> scalarId >> phasis ;
481    get(scalarId)->receivedScalarDistribution(event, phasis);
482  }
483  CATCH
484 
485  void CScalar::receivedScalarDistribution(CEventServer& event, int phasis)
486  TRY
487  {
488    CContext* context = CContext::getCurrent();
489    if (phasis==0) // receive the remote element to construct the full view
490    {
491      localElement_ = make_shared<CLocalElement>(context->getIntraCommRank(),event) ;
492      localElement_->addFullView() ;
493      // construct the local dimension and indexes
494      auto& globalIndex=localElement_->getGlobalIndex() ;
495      n=globalIndex.numElements() ;
496      // no distribution for scalar => nk ==1 or maybe 0 ?
497    }
498    else if (phasis==1) // receive the sent view from client to construct the full distributed full view on server
499    {
500      CContext* context = CContext::getCurrent();
501      shared_ptr<CDistributedElement> elementFrom = make_shared<CDistributedElement>(event) ;
502      elementFrom->addFullView() ;
503      gathererConnector_ = make_shared<CGathererConnector>(elementFrom->getView(CElementView::FULL), localElement_->getView(CElementView::FULL)) ;
504      gathererConnector_->computeConnector() ; 
505    }
506    else if (phasis==2)
507    {
508//      delete gathererConnector_ ;
509      elementFrom_ = make_shared<CDistributedElement>(event) ;
510      elementFrom_->addFullView() ;
511//      gathererConnector_ =  make_shared<CGathererConnector>(elementFrom_->getView(CElementView::FULL), localElement_->getView(CElementView::FULL)) ;
512//      gathererConnector_ -> computeConnector() ;
513    }
514  }
515  CATCH
516
517  void CScalar::setServerMask(CArray<bool,1>& serverMask, CContextClient* client)
518  TRY
519  {
520    CContext* context = CContext::getCurrent();
521    localElement_->addView(CElementView::WORKFLOW, serverMask) ;
522    if (serverMask.numElements()==1) mask = serverMask(0) ;
523 
524    serverFromClientConnector_ = make_shared<CGathererConnector>(elementFrom_->getView(CElementView::FULL), localElement_->getView(CElementView::WORKFLOW)) ;
525    serverFromClientConnector_->computeConnector() ;
526     
527    serverToClientConnector_ = make_shared<CScattererConnector>(localElement_->getView(CElementView::WORKFLOW), elementFrom_->getView(CElementView::FULL),
528                                                                context->getIntraComm(), client->getRemoteSize()) ;
529    serverToClientConnector_->computeConnector() ;
530  }
531  CATCH_DUMP_ATTR
532
533  void CScalar::sendDistributedAttributes(CContextClient* client, shared_ptr<CScattererConnector> scattererConnector, const string& scalarId)
534  {
535    string serverScalarId = scalarId.empty() ? this->getId() : scalarId ;
536    CContext* context = CContext::getCurrent();
537
538    // nothing for now
539  }
540
541  void CScalar::recvDistributedAttributes(CEventServer& event)
542  TRY
543  {
544    string scalarId;
545    string type ;
546    for (auto& subEvent : event.subEvents) (*subEvent.buffer) >> scalarId >> type ;
547    get(scalarId)->recvDistributedAttributes(event, type);
548  }
549  CATCH
550
551  void CScalar::recvDistributedAttributes(CEventServer& event, const string& type)
552  TRY
553  {
554    // nothing for now
555  }
556  CATCH 
557
558  bool CScalar::dispatchEvent(CEventServer& event)
559  TRY
560  {
561     if (SuperClass::dispatchEvent(event)) return true;
562     else
563     {
564       switch(event.type)
565       {
566          case EVENT_ID_SCALAR_DISTRIBUTION:
567            recvScalarDistribution(event);
568            return true;
569            break;
570          case EVENT_ID_SEND_DISTRIBUTED_ATTRIBUTE:
571            recvDistributedAttributes(event);
572            return true;
573            break;
574          default :
575            ERROR("bool CScalar::dispatchEvent(CEventServer& event)",
576                   << "Unknown Event");
577          return false;
578        }
579     }
580  }
581  CATCH
582
583
584  // Definition of some macros
585  DEFINE_REF_FUNC(Scalar,scalar)
586
587} // namespace xios
Note: See TracBrowser for help on using the repository browser.