source: XIOS/dev/dev_olga/src/extern/blitz/include/blitz/array/where.h @ 1022

Last change on this file since 1022 was 1022, checked in by mhnguyen, 7 years ago
File size: 14.9 KB
Line 
1// -*- C++ -*-
2/***************************************************************************
3 * blitz/array/where.h  where(X,Y,Z) operator for array expressions
4 *
5 * $Id$
6 *
7 * Copyright (C) 1997-2011 Todd Veldhuizen <tveldhui@acm.org>
8 *
9 * This file is a part of Blitz.
10 *
11 * Blitz is free software: you can redistribute it and/or modify
12 * it under the terms of the GNU Lesser General Public License
13 * as published by the Free Software Foundation, either version 3
14 * of the License, or (at your option) any later version.
15 *
16 * Blitz is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 * GNU Lesser General Public License for more details.
20 *
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with Blitz.  If not, see <http://www.gnu.org/licenses/>.
23 *
24 * Suggestions:          blitz-devel@lists.sourceforge.net
25 * Bugs:                 blitz-support@lists.sourceforge.net   
26 *
27 * For more information, please see the Blitz++ Home Page:
28 *    https://sourceforge.net/projects/blitz/
29 *
30 ****************************************************************************/
31#ifndef BZ_ARRAYWHERE_H
32#define BZ_ARRAYWHERE_H
33
34#include <blitz/blitz.h>
35#include <blitz/promote.h>
36#include <blitz/prettyprint.h>
37#include <blitz/bounds.h>
38#include <blitz/meta/metaprog.h>
39#include <blitz/tinyvec2.h>
40#include <blitz/array/domain.h>
41#include <blitz/array/asexpr.h>
42
43BZ_NAMESPACE(blitz)
44
45template<typename P_expr1, typename P_expr2, typename P_expr3>
46class _bz_ArrayWhere {
47
48public:
49    typedef P_expr1 T_expr1;
50    typedef P_expr2 T_expr2;
51    typedef P_expr3 T_expr3;
52    typedef _bz_typename T_expr2::T_numtype T_numtype2;
53    typedef _bz_typename T_expr3::T_numtype T_numtype3;
54    typedef BZ_PROMOTE(T_numtype2, T_numtype3) T_numtype;
55    typedef T_expr1 T_ctorArg1;
56    typedef T_expr2 T_ctorArg2;
57    typedef T_expr3 T_ctorArg3;
58  typedef _bz_ArrayWhere<_bz_typename P_expr1::T_range_result,
59                         _bz_typename P_expr2::T_range_result,
60                         _bz_typename P_expr3::T_range_result> T_range_result;
61
62  // select return type
63  typedef typename unwrapET<typename T_expr1::T_result>::T_unwrapped T_unwrapped1;
64  typedef typename unwrapET<typename T_expr2::T_result>::T_unwrapped T_unwrapped2;
65  typedef typename unwrapET<typename T_expr3::T_result>::T_unwrapped T_unwrapped3;
66  typedef typename selectET2<typename T_expr1::T_typeprop, 
67                             typename T_expr2::T_typeprop, 
68                             T_numtype, 
69                             char>::T_selected T_intermediary;
70
71  typedef typename selectET2<T_intermediary,
72                             typename T_expr3::T_typeprop, 
73                             T_numtype, 
74                             _bz_ArrayWhere<typename asExpr<T_unwrapped1>::T_expr, 
75                                            typename asExpr<T_unwrapped2>::T_expr, 
76                                            typename asExpr<T_unwrapped3>::T_expr
77                                            > >::T_selected T_typeprop;
78  typedef typename unwrapET<T_typeprop>::T_unwrapped T_result;
79  typedef T_numtype T_optype;
80
81    static const int 
82        numArrayOperands = P_expr1::numArrayOperands
83                         + P_expr2::numArrayOperands
84                         + P_expr3::numArrayOperands,
85        numTVOperands = T_expr1::numTVOperands +
86      T_expr2::numTVOperands +
87      T_expr3::numTVOperands,
88        numTMOperands = T_expr1::numTMOperands +
89      T_expr2::numTMOperands +
90      T_expr3::numTMOperands,
91        numIndexPlaceholders = P_expr1::numIndexPlaceholders
92                             + P_expr2::numIndexPlaceholders
93                             + P_expr3::numIndexPlaceholders,
94      minWidth = BZ_MIN(BZ_MIN(T_expr1::minWidth, T_expr2::minWidth),
95                        T_expr3::minWidth),
96      maxWidth = BZ_MAX(BZ_MAX(T_expr1::maxWidth, T_expr2::maxWidth), 
97                        T_expr3::maxWidth),
98      rank_ = BZ_MAX(BZ_MAX(T_expr1::rank_, T_expr2::rank_),
99                     T_expr3::rank_);
100
101  template<int N> struct tvresult {
102    typedef _bz_ArrayWhere<
103      typename T_expr1::template tvresult<N>::Type,
104      typename T_expr2::template tvresult<N>::Type,
105      typename T_expr3::template tvresult<N>::Type> Type; 
106  };
107
108    _bz_ArrayWhere(const _bz_ArrayWhere<T_expr1,T_expr2,T_expr3>& a)
109      : iter1_(a.iter1_), iter2_(a.iter2_), iter3_(a.iter3_)
110    { }
111
112    template<typename T1, typename T2, typename T3>
113    _bz_ArrayWhere(BZ_ETPARM(T1) a, BZ_ETPARM(T2) b, BZ_ETPARM(T3) c)
114      : iter1_(a), iter2_(b), iter3_(c)
115    { }
116
117    T_numtype operator*() const
118    { return (*iter1_) ? (*iter2_) : (*iter3_); }
119
120  /* Functions for reading. Because they must depend on the result
121   * type, they utilize a helper class.
122   */
123
124  // For numtypes, apply operator
125  template<typename T> struct readHelper {
126    static T_result fastRead(const T_expr1& iter1, const T_expr2& iter2, 
127                             const T_expr3& iter3, diffType i) {
128      return iter1.fastRead(i) ? iter2.fastRead(i) : iter3.fastRead(i); }
129    static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, 
130                            const T_expr3& iter3, int i) {
131      return iter1[i] ? iter2[i] : iter3[i]; }
132    static T_result first_value(const T_expr1& iter1, const T_expr2& iter2,
133                                const T_expr3& iter3)  {
134      return iter1.first_value() ? 
135        iter2.first_value() : iter3.first_value(); }
136    static T_result shift(const T_expr1& iter1, const T_expr2& iter2,
137                          const T_expr3& iter3, int offset, int dim) {
138      return iter1.shift(offset, dim) ? iter2.shift(offset, dim) : 
139        iter3.shift(offset, dim); }
140    static T_result shift(const T_expr1& iter1, const T_expr2& iter2,
141                          const T_expr3& iter3, int offset1, int dim1,
142                          int offset2, int dim2) {
143      return iter1.shift(offset1, dim1, offset2, dim2) ? 
144        iter2.shift(offset1, dim1, offset2, dim2) : 
145        iter3.shift(offset1, dim1, offset2, dim2); }
146    template<int N_rank>
147#ifdef BZ_ARRAY_EXPR_PASS_INDEX_BY_VALUE
148      static T_result indexop(const T_expr1& iter1, const T_expr2& iter2,
149                              const T_expr3& iter3, 
150                              const TinyVector<int, N_rank> i) {
151#else
152      static T_result indexop(const T_expr1& iter1, const T_expr2& iter2,
153                              const T_expr3& iter3, 
154                              const TinyVector<int, N_rank>& i) {
155#endif
156        return iter1(i) ? iter2(i) : iter3(i); }
157      };
158   
159    // For ET types, bypass operator and create expression
160    template<typename T> struct readHelper<ETBase<T> > {
161    static T_result fastRead(const T_expr1& iter1, const T_expr2& iter2, 
162                             const T_expr3& iter3, diffType i) {
163      return T_result(iter1.fastRead(i), iter2.fastRead(i), iter3.fastRead(i)); }
164    static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, 
165                            const T_expr3& iter3, int i) {
166      return T_result(iter1[i], iter2[i], iter3[i]); };
167    static T_result first_value(const T_expr1& iter1, const T_expr2& iter2,
168                                const T_expr3& iter3)  {
169      return T_result(iter1.first_value(), iter2.first_value(),
170                      iter3.first_value()); }
171    static T_result shift(const T_expr1& iter1, const T_expr2& iter2,
172                          const T_expr3& iter3, int offset, int dim) {
173      return T_result(iter1.shift(offset, dim), iter2.shift(offset, dim),
174                      iter3.shift(offset, dim)); }
175    static T_result shift(const T_expr1& iter1, const T_expr2& iter2,
176                          const T_expr3& iter3, int offset1, int dim1,
177                          int offset2, int dim2) {
178      return T_result(iter1.shift(offset1, dim1, offset2, dim2),
179                      iter2.shift(offset1, dim1, offset2, dim2), 
180                      iter3.shift(offset1, dim1, offset2, dim2)); }
181      template<int N_rank>
182#ifdef BZ_ARRAY_EXPR_PASS_INDEX_BY_VALUE
183      static T_result indexop(const T_expr1& iter1, const T_expr2& iter2,
184                              const T_expr3& iter3, 
185                              const TinyVector<int, N_rank> i) {
186#else
187      static T_result indexop(const T_expr1& iter1, const T_expr2& iter2,
188                              const T_expr3& iter3, 
189                              const TinyVector<int, N_rank>& i) {
190#endif
191        return T_result(iter1(i), iter2(i), iter3(i) ); }
192      };
193
194    T_result fastRead(diffType i) const { 
195      return readHelper<T_typeprop>::fastRead(iter1_, iter2_, iter3_, i); }
196
197      template<int N>
198      typename tvresult<N>::Type fastRead_tv(diffType i) const
199      { return typename tvresult<N>::Type(iter1_.template fastRead_tv<N>(i),
200                                          iter2_.template fastRead_tv<N>(i),
201                                          iter3_.template fastRead_tv<N>(i)); }
202
203    T_result operator[](int i) const { 
204      return readHelper<T_typeprop>::indexop(iter1_, iter2_, iter3_, i); }
205
206    template<int N_rank>
207#ifdef BZ_ARRAY_EXPR_PASS_INDEX_BY_VALUE
208    T_result operator()(const TinyVector<int, N_rank> i) const {
209#else
210      T_result operator()(const TinyVector<int, N_rank>& i) const {
211#endif
212        return readHelper<T_typeprop>::indexop(iter1_, iter2_, iter3_, i); }
213   
214      T_result first_value() const { 
215        return readHelper<T_typeprop>::first_value(iter1_, iter2_, iter3_); }
216
217    T_result shift(int offset, int dim) const {
218      return readHelper<T_typeprop>::shift(iter1_, iter2_, iter3_, 
219                                           offset, dim); }
220
221    T_result shift(int offset1, int dim1,int offset2, int dim2) const {
222      return readHelper<T_typeprop>::shift(iter1_, iter2_, iter3_,
223                                           offset1, dim1, offset2, dim2); }
224
225      // ****** end reading
226
227  bool isVectorAligned(diffType offset) const 
228  { return iter1_.isVectorAligned(offset) && 
229      iter2_.isVectorAligned(offset) &&
230      iter3_.isVectorAligned(offset); }
231
232    T_range_result operator()(const RectDomain<rank_>& d) const
233  { return T_range_result(iter1_(d), iter2_(d), iter3_(d)); }
234
235    int ascending(const int rank) const
236    {
237        return bounds::compute_ascending(rank, bounds::compute_ascending(
238          rank, iter1_.ascending(rank), iter2_.ascending(rank)),
239          iter3_.ascending(rank));
240    }
241
242    int ordering(const int rank) const
243    {
244        return bounds::compute_ordering(rank, bounds::compute_ordering(
245          rank, iter1_.ordering(rank), iter2_.ordering(rank)),
246          iter3_.ordering(rank));
247    }
248
249    int lbound(const int rank) const
250    {
251        return bounds::compute_lbound(rank, bounds::compute_lbound(
252          rank, iter1_.lbound(rank), iter2_.lbound(rank)), 
253          iter3_.lbound(rank));
254    }
255   
256    int ubound(const int rank) const
257    {
258        return bounds::compute_ubound(rank, bounds::compute_ubound(
259          rank, iter1_.ubound(rank), iter2_.ubound(rank)), 
260          iter3_.ubound(rank));
261    } 
262
263  // defer calculation to lbound/ubound
264  RectDomain<rank_> domain() const 
265  { 
266    TinyVector<int, rank_> lb, ub;
267    for(int r=0; r<rank_; ++r) {
268      lb[r]=lbound(r); ub[r]=ubound(r); 
269    }
270    return RectDomain<rank_>(lb,ub);
271  }
272
273    void push(int position)
274    {
275        iter1_.push(position);
276        iter2_.push(position);
277        iter3_.push(position);
278    }
279
280    void pop(int position)
281    {
282        iter1_.pop(position);
283        iter2_.pop(position);
284        iter3_.pop(position);
285    }
286
287    void advance()
288    {
289        iter1_.advance();
290        iter2_.advance();
291        iter3_.advance();
292    }
293
294    void advance(int n)
295    {
296        iter1_.advance(n);
297        iter2_.advance(n);
298        iter3_.advance(n);
299    }
300
301    void loadStride(int rank)
302    {
303        iter1_.loadStride(rank);
304        iter2_.loadStride(rank);
305        iter3_.loadStride(rank);
306    }
307
308    bool isUnitStride(int rank) const
309    { 
310        return iter1_.isUnitStride(rank) 
311            && iter2_.isUnitStride(rank) 
312            && iter3_.isUnitStride(rank);
313    }
314
315    bool isUnitStride() const
316    { 
317        return iter1_.isUnitStride() 
318            && iter2_.isUnitStride() 
319            && iter3_.isUnitStride();
320    }
321
322    void advanceUnitStride()
323    {
324        iter1_.advanceUnitStride();
325        iter2_.advanceUnitStride();
326        iter3_.advanceUnitStride();
327    }
328
329    bool canCollapse(int outerLoopRank, int innerLoopRank) const
330    {
331        return iter1_.canCollapse(outerLoopRank, innerLoopRank)
332            && iter2_.canCollapse(outerLoopRank, innerLoopRank)
333            && iter3_.canCollapse(outerLoopRank, innerLoopRank);
334    }
335
336    template<int N_rank>
337    void moveTo(const TinyVector<int,N_rank>& i)
338    {
339        iter1_.moveTo(i);
340        iter2_.moveTo(i);
341        iter3_.moveTo(i);
342    }
343
344  // this is needed for the stencil expression fastRead to work
345  void _bz_offsetData(sizeType i)
346  {
347    iter1_._bz_offsetData(i);
348    iter2_._bz_offsetData(i);
349    iter3_._bz_offsetData(i);
350  }
351
352    diffType suggestStride(int rank) const
353    {
354        diffType stride1 = iter1_.suggestStride(rank);
355        diffType stride2 = iter2_.suggestStride(rank);
356        diffType stride3 = iter3_.suggestStride(rank);
357        return stride1>(stride2=(stride2>stride3?stride2:stride3))?stride1:stride2;
358    }
359
360    bool isStride(int rank, diffType stride) const
361    {
362        return iter1_.isStride(rank,stride) 
363            && iter2_.isStride(rank,stride)
364            && iter3_.isStride(rank,stride);
365    }
366
367    void prettyPrint(BZ_STD_SCOPE(string) &str, 
368        prettyPrintFormat& format) const
369    {
370        str += "where(";
371        iter1_.prettyPrint(str,format);
372        str += ",";
373        iter2_.prettyPrint(str,format);
374        str += ",";
375        iter3_.prettyPrint(str,format);
376        str += ")";
377    }
378
379    template<typename T_shape>
380    bool shapeCheck(const T_shape& shape) const
381    { 
382        int t1 = iter1_.shapeCheck(shape);
383        int t2 = iter2_.shapeCheck(shape);
384        int t3 = iter3_.shapeCheck(shape);
385
386        return t1 && t2 && t3;
387    }
388
389
390  // sliceinfo for expressions
391  template<typename T1, typename T2 = nilArraySection, 
392           class T3 = nilArraySection, typename T4 = nilArraySection, 
393           class T5 = nilArraySection, typename T6 = nilArraySection, 
394           class T7 = nilArraySection, typename T8 = nilArraySection, 
395           class T9 = nilArraySection, typename T10 = nilArraySection, 
396           class T11 = nilArraySection>
397  class SliceInfo {
398  public:
399    typedef typename T_expr1::template SliceInfo<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>::T_slice T_slice1;
400    typedef typename T_expr2::template SliceInfo<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>::T_slice T_slice2;
401    typedef typename T_expr3::template SliceInfo<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>::T_slice T_slice3;
402    typedef _bz_ArrayWhere<T_slice1, T_slice2, T_slice3> T_slice;
403};
404
405    template<typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
406        typename T7, typename T8, typename T9, typename T10, typename T11>
407    typename SliceInfo<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11>::T_slice
408    operator()(T1 r1, T2 r2, T3 r3, T4 r4, T5 r5, T6 r6, T7 r7, T8 r8, T9 r9, T10 r10, T11 r11) const
409    {
410      return typename SliceInfo<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11>::T_slice
411        (iter1_(r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11),
412         iter2_(r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11),
413         iter3_(r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11));
414    }
415
416private:
417    _bz_ArrayWhere() { }
418
419    T_expr1 iter1_;
420    T_expr2 iter2_;
421    T_expr3 iter3_;
422};
423
424template<typename T1, typename T2, typename T3>
425inline
426_bz_ArrayExpr<_bz_ArrayWhere<_bz_typename asExpr<T1>::T_expr,
427    _bz_typename asExpr<T2>::T_expr, _bz_typename asExpr<T3>::T_expr> >
428where(const T1& a, const T2& b, const T3& c)
429{
430    return _bz_ArrayExpr<_bz_ArrayWhere<_bz_typename asExpr<T1>::T_expr,
431       _bz_typename asExpr<T2>::T_expr, 
432       _bz_typename asExpr<T3>::T_expr> >(a,b,c);
433}
434
435BZ_NAMESPACE_END
436
437#endif // BZ_ARRAYWHERE_H
438
Note: See TracBrowser for help on using the repository browser.