1 | // -*- C++ -*- |
---|
2 | /*************************************************************************** |
---|
3 | * blitz/array/where.h where(X,Y,Z) operator for array expressions |
---|
4 | * |
---|
5 | * $Id$ |
---|
6 | * |
---|
7 | * Copyright (C) 1997-2011 Todd Veldhuizen <tveldhui@acm.org> |
---|
8 | * |
---|
9 | * This file is a part of Blitz. |
---|
10 | * |
---|
11 | * Blitz is free software: you can redistribute it and/or modify |
---|
12 | * it under the terms of the GNU Lesser General Public License |
---|
13 | * as published by the Free Software Foundation, either version 3 |
---|
14 | * of the License, or (at your option) any later version. |
---|
15 | * |
---|
16 | * Blitz is distributed in the hope that it will be useful, |
---|
17 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
18 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
19 | * GNU Lesser General Public License for more details. |
---|
20 | * |
---|
21 | * You should have received a copy of the GNU Lesser General Public |
---|
22 | * License along with Blitz. If not, see <http://www.gnu.org/licenses/>. |
---|
23 | * |
---|
24 | * Suggestions: blitz-devel@lists.sourceforge.net |
---|
25 | * Bugs: blitz-support@lists.sourceforge.net |
---|
26 | * |
---|
27 | * For more information, please see the Blitz++ Home Page: |
---|
28 | * https://sourceforge.net/projects/blitz/ |
---|
29 | * |
---|
30 | ****************************************************************************/ |
---|
31 | #ifndef BZ_ARRAYWHERE_H |
---|
32 | #define BZ_ARRAYWHERE_H |
---|
33 | |
---|
34 | #include <blitz/blitz.h> |
---|
35 | #include <blitz/promote.h> |
---|
36 | #include <blitz/prettyprint.h> |
---|
37 | #include <blitz/bounds.h> |
---|
38 | #include <blitz/meta/metaprog.h> |
---|
39 | #include <blitz/tinyvec2.h> |
---|
40 | #include <blitz/array/domain.h> |
---|
41 | #include <blitz/array/asexpr.h> |
---|
42 | |
---|
43 | BZ_NAMESPACE(blitz) |
---|
44 | |
---|
45 | template<typename P_expr1, typename P_expr2, typename P_expr3> |
---|
46 | class _bz_ArrayWhere { |
---|
47 | |
---|
48 | public: |
---|
49 | typedef P_expr1 T_expr1; |
---|
50 | typedef P_expr2 T_expr2; |
---|
51 | typedef P_expr3 T_expr3; |
---|
52 | typedef _bz_typename T_expr2::T_numtype T_numtype2; |
---|
53 | typedef _bz_typename T_expr3::T_numtype T_numtype3; |
---|
54 | typedef BZ_PROMOTE(T_numtype2, T_numtype3) T_numtype; |
---|
55 | typedef T_expr1 T_ctorArg1; |
---|
56 | typedef T_expr2 T_ctorArg2; |
---|
57 | typedef T_expr3 T_ctorArg3; |
---|
58 | typedef _bz_ArrayWhere<_bz_typename P_expr1::T_range_result, |
---|
59 | _bz_typename P_expr2::T_range_result, |
---|
60 | _bz_typename P_expr3::T_range_result> T_range_result; |
---|
61 | |
---|
62 | // select return type |
---|
63 | typedef typename unwrapET<typename T_expr1::T_result>::T_unwrapped T_unwrapped1; |
---|
64 | typedef typename unwrapET<typename T_expr2::T_result>::T_unwrapped T_unwrapped2; |
---|
65 | typedef typename unwrapET<typename T_expr3::T_result>::T_unwrapped T_unwrapped3; |
---|
66 | typedef typename selectET2<typename T_expr1::T_typeprop, |
---|
67 | typename T_expr2::T_typeprop, |
---|
68 | T_numtype, |
---|
69 | char>::T_selected T_intermediary; |
---|
70 | |
---|
71 | typedef typename selectET2<T_intermediary, |
---|
72 | typename T_expr3::T_typeprop, |
---|
73 | T_numtype, |
---|
74 | _bz_ArrayWhere<typename asExpr<T_unwrapped1>::T_expr, |
---|
75 | typename asExpr<T_unwrapped2>::T_expr, |
---|
76 | typename asExpr<T_unwrapped3>::T_expr |
---|
77 | > >::T_selected T_typeprop; |
---|
78 | typedef typename unwrapET<T_typeprop>::T_unwrapped T_result; |
---|
79 | typedef T_numtype T_optype; |
---|
80 | |
---|
81 | static const int |
---|
82 | numArrayOperands = P_expr1::numArrayOperands |
---|
83 | + P_expr2::numArrayOperands |
---|
84 | + P_expr3::numArrayOperands, |
---|
85 | numTVOperands = T_expr1::numTVOperands + |
---|
86 | T_expr2::numTVOperands + |
---|
87 | T_expr3::numTVOperands, |
---|
88 | numTMOperands = T_expr1::numTMOperands + |
---|
89 | T_expr2::numTMOperands + |
---|
90 | T_expr3::numTMOperands, |
---|
91 | numIndexPlaceholders = P_expr1::numIndexPlaceholders |
---|
92 | + P_expr2::numIndexPlaceholders |
---|
93 | + P_expr3::numIndexPlaceholders, |
---|
94 | minWidth = BZ_MIN(BZ_MIN(T_expr1::minWidth, T_expr2::minWidth), |
---|
95 | T_expr3::minWidth), |
---|
96 | maxWidth = BZ_MAX(BZ_MAX(T_expr1::maxWidth, T_expr2::maxWidth), |
---|
97 | T_expr3::maxWidth), |
---|
98 | rank_ = BZ_MAX(BZ_MAX(T_expr1::rank_, T_expr2::rank_), |
---|
99 | T_expr3::rank_); |
---|
100 | |
---|
101 | template<int N> struct tvresult { |
---|
102 | typedef _bz_ArrayWhere< |
---|
103 | typename T_expr1::template tvresult<N>::Type, |
---|
104 | typename T_expr2::template tvresult<N>::Type, |
---|
105 | typename T_expr3::template tvresult<N>::Type> Type; |
---|
106 | }; |
---|
107 | |
---|
108 | _bz_ArrayWhere(const _bz_ArrayWhere<T_expr1,T_expr2,T_expr3>& a) |
---|
109 | : iter1_(a.iter1_), iter2_(a.iter2_), iter3_(a.iter3_) |
---|
110 | { } |
---|
111 | |
---|
112 | template<typename T1, typename T2, typename T3> |
---|
113 | _bz_ArrayWhere(BZ_ETPARM(T1) a, BZ_ETPARM(T2) b, BZ_ETPARM(T3) c) |
---|
114 | : iter1_(a), iter2_(b), iter3_(c) |
---|
115 | { } |
---|
116 | |
---|
117 | T_numtype operator*() const |
---|
118 | { return (*iter1_) ? (*iter2_) : (*iter3_); } |
---|
119 | |
---|
120 | /* Functions for reading. Because they must depend on the result |
---|
121 | * type, they utilize a helper class. |
---|
122 | */ |
---|
123 | |
---|
124 | // For numtypes, apply operator |
---|
125 | template<typename T> struct readHelper { |
---|
126 | static T_result fastRead(const T_expr1& iter1, const T_expr2& iter2, |
---|
127 | const T_expr3& iter3, diffType i) { |
---|
128 | return iter1.fastRead(i) ? iter2.fastRead(i) : iter3.fastRead(i); } |
---|
129 | static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, |
---|
130 | const T_expr3& iter3, int i) { |
---|
131 | return iter1[i] ? iter2[i] : iter3[i]; } |
---|
132 | static T_result first_value(const T_expr1& iter1, const T_expr2& iter2, |
---|
133 | const T_expr3& iter3) { |
---|
134 | return iter1.first_value() ? |
---|
135 | iter2.first_value() : iter3.first_value(); } |
---|
136 | static T_result shift(const T_expr1& iter1, const T_expr2& iter2, |
---|
137 | const T_expr3& iter3, int offset, int dim) { |
---|
138 | return iter1.shift(offset, dim) ? iter2.shift(offset, dim) : |
---|
139 | iter3.shift(offset, dim); } |
---|
140 | static T_result shift(const T_expr1& iter1, const T_expr2& iter2, |
---|
141 | const T_expr3& iter3, int offset1, int dim1, |
---|
142 | int offset2, int dim2) { |
---|
143 | return iter1.shift(offset1, dim1, offset2, dim2) ? |
---|
144 | iter2.shift(offset1, dim1, offset2, dim2) : |
---|
145 | iter3.shift(offset1, dim1, offset2, dim2); } |
---|
146 | template<int N_rank> |
---|
147 | #ifdef BZ_ARRAY_EXPR_PASS_INDEX_BY_VALUE |
---|
148 | static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, |
---|
149 | const T_expr3& iter3, |
---|
150 | const TinyVector<int, N_rank> i) { |
---|
151 | #else |
---|
152 | static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, |
---|
153 | const T_expr3& iter3, |
---|
154 | const TinyVector<int, N_rank>& i) { |
---|
155 | #endif |
---|
156 | return iter1(i) ? iter2(i) : iter3(i); } |
---|
157 | }; |
---|
158 | |
---|
159 | // For ET types, bypass operator and create expression |
---|
160 | template<typename T> struct readHelper<ETBase<T> > { |
---|
161 | static T_result fastRead(const T_expr1& iter1, const T_expr2& iter2, |
---|
162 | const T_expr3& iter3, diffType i) { |
---|
163 | return T_result(iter1.fastRead(i), iter2.fastRead(i), iter3.fastRead(i)); } |
---|
164 | static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, |
---|
165 | const T_expr3& iter3, int i) { |
---|
166 | return T_result(iter1[i], iter2[i], iter3[i]); }; |
---|
167 | static T_result first_value(const T_expr1& iter1, const T_expr2& iter2, |
---|
168 | const T_expr3& iter3) { |
---|
169 | return T_result(iter1.first_value(), iter2.first_value(), |
---|
170 | iter3.first_value()); } |
---|
171 | static T_result shift(const T_expr1& iter1, const T_expr2& iter2, |
---|
172 | const T_expr3& iter3, int offset, int dim) { |
---|
173 | return T_result(iter1.shift(offset, dim), iter2.shift(offset, dim), |
---|
174 | iter3.shift(offset, dim)); } |
---|
175 | static T_result shift(const T_expr1& iter1, const T_expr2& iter2, |
---|
176 | const T_expr3& iter3, int offset1, int dim1, |
---|
177 | int offset2, int dim2) { |
---|
178 | return T_result(iter1.shift(offset1, dim1, offset2, dim2), |
---|
179 | iter2.shift(offset1, dim1, offset2, dim2), |
---|
180 | iter3.shift(offset1, dim1, offset2, dim2)); } |
---|
181 | template<int N_rank> |
---|
182 | #ifdef BZ_ARRAY_EXPR_PASS_INDEX_BY_VALUE |
---|
183 | static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, |
---|
184 | const T_expr3& iter3, |
---|
185 | const TinyVector<int, N_rank> i) { |
---|
186 | #else |
---|
187 | static T_result indexop(const T_expr1& iter1, const T_expr2& iter2, |
---|
188 | const T_expr3& iter3, |
---|
189 | const TinyVector<int, N_rank>& i) { |
---|
190 | #endif |
---|
191 | return T_result(iter1(i), iter2(i), iter3(i) ); } |
---|
192 | }; |
---|
193 | |
---|
194 | T_result fastRead(diffType i) const { |
---|
195 | return readHelper<T_typeprop>::fastRead(iter1_, iter2_, iter3_, i); } |
---|
196 | |
---|
197 | template<int N> |
---|
198 | typename tvresult<N>::Type fastRead_tv(diffType i) const |
---|
199 | { return typename tvresult<N>::Type(iter1_.template fastRead_tv<N>(i), |
---|
200 | iter2_.template fastRead_tv<N>(i), |
---|
201 | iter3_.template fastRead_tv<N>(i)); } |
---|
202 | |
---|
203 | T_result operator[](int i) const { |
---|
204 | return readHelper<T_typeprop>::indexop(iter1_, iter2_, iter3_, i); } |
---|
205 | |
---|
206 | template<int N_rank> |
---|
207 | #ifdef BZ_ARRAY_EXPR_PASS_INDEX_BY_VALUE |
---|
208 | T_result operator()(const TinyVector<int, N_rank> i) const { |
---|
209 | #else |
---|
210 | T_result operator()(const TinyVector<int, N_rank>& i) const { |
---|
211 | #endif |
---|
212 | return readHelper<T_typeprop>::indexop(iter1_, iter2_, iter3_, i); } |
---|
213 | |
---|
214 | T_result first_value() const { |
---|
215 | return readHelper<T_typeprop>::first_value(iter1_, iter2_, iter3_); } |
---|
216 | |
---|
217 | T_result shift(int offset, int dim) const { |
---|
218 | return readHelper<T_typeprop>::shift(iter1_, iter2_, iter3_, |
---|
219 | offset, dim); } |
---|
220 | |
---|
221 | T_result shift(int offset1, int dim1,int offset2, int dim2) const { |
---|
222 | return readHelper<T_typeprop>::shift(iter1_, iter2_, iter3_, |
---|
223 | offset1, dim1, offset2, dim2); } |
---|
224 | |
---|
225 | // ****** end reading |
---|
226 | |
---|
227 | bool isVectorAligned(diffType offset) const |
---|
228 | { return iter1_.isVectorAligned(offset) && |
---|
229 | iter2_.isVectorAligned(offset) && |
---|
230 | iter3_.isVectorAligned(offset); } |
---|
231 | |
---|
232 | T_range_result operator()(const RectDomain<rank_>& d) const |
---|
233 | { return T_range_result(iter1_(d), iter2_(d), iter3_(d)); } |
---|
234 | |
---|
235 | int ascending(const int rank) const |
---|
236 | { |
---|
237 | return bounds::compute_ascending(rank, bounds::compute_ascending( |
---|
238 | rank, iter1_.ascending(rank), iter2_.ascending(rank)), |
---|
239 | iter3_.ascending(rank)); |
---|
240 | } |
---|
241 | |
---|
242 | int ordering(const int rank) const |
---|
243 | { |
---|
244 | return bounds::compute_ordering(rank, bounds::compute_ordering( |
---|
245 | rank, iter1_.ordering(rank), iter2_.ordering(rank)), |
---|
246 | iter3_.ordering(rank)); |
---|
247 | } |
---|
248 | |
---|
249 | int lbound(const int rank) const |
---|
250 | { |
---|
251 | return bounds::compute_lbound(rank, bounds::compute_lbound( |
---|
252 | rank, iter1_.lbound(rank), iter2_.lbound(rank)), |
---|
253 | iter3_.lbound(rank)); |
---|
254 | } |
---|
255 | |
---|
256 | int ubound(const int rank) const |
---|
257 | { |
---|
258 | return bounds::compute_ubound(rank, bounds::compute_ubound( |
---|
259 | rank, iter1_.ubound(rank), iter2_.ubound(rank)), |
---|
260 | iter3_.ubound(rank)); |
---|
261 | } |
---|
262 | |
---|
263 | // defer calculation to lbound/ubound |
---|
264 | RectDomain<rank_> domain() const |
---|
265 | { |
---|
266 | TinyVector<int, rank_> lb, ub; |
---|
267 | for(int r=0; r<rank_; ++r) { |
---|
268 | lb[r]=lbound(r); ub[r]=ubound(r); |
---|
269 | } |
---|
270 | return RectDomain<rank_>(lb,ub); |
---|
271 | } |
---|
272 | |
---|
273 | void push(int position) |
---|
274 | { |
---|
275 | iter1_.push(position); |
---|
276 | iter2_.push(position); |
---|
277 | iter3_.push(position); |
---|
278 | } |
---|
279 | |
---|
280 | void pop(int position) |
---|
281 | { |
---|
282 | iter1_.pop(position); |
---|
283 | iter2_.pop(position); |
---|
284 | iter3_.pop(position); |
---|
285 | } |
---|
286 | |
---|
287 | void advance() |
---|
288 | { |
---|
289 | iter1_.advance(); |
---|
290 | iter2_.advance(); |
---|
291 | iter3_.advance(); |
---|
292 | } |
---|
293 | |
---|
294 | void advance(int n) |
---|
295 | { |
---|
296 | iter1_.advance(n); |
---|
297 | iter2_.advance(n); |
---|
298 | iter3_.advance(n); |
---|
299 | } |
---|
300 | |
---|
301 | void loadStride(int rank) |
---|
302 | { |
---|
303 | iter1_.loadStride(rank); |
---|
304 | iter2_.loadStride(rank); |
---|
305 | iter3_.loadStride(rank); |
---|
306 | } |
---|
307 | |
---|
308 | bool isUnitStride(int rank) const |
---|
309 | { |
---|
310 | return iter1_.isUnitStride(rank) |
---|
311 | && iter2_.isUnitStride(rank) |
---|
312 | && iter3_.isUnitStride(rank); |
---|
313 | } |
---|
314 | |
---|
315 | bool isUnitStride() const |
---|
316 | { |
---|
317 | return iter1_.isUnitStride() |
---|
318 | && iter2_.isUnitStride() |
---|
319 | && iter3_.isUnitStride(); |
---|
320 | } |
---|
321 | |
---|
322 | void advanceUnitStride() |
---|
323 | { |
---|
324 | iter1_.advanceUnitStride(); |
---|
325 | iter2_.advanceUnitStride(); |
---|
326 | iter3_.advanceUnitStride(); |
---|
327 | } |
---|
328 | |
---|
329 | bool canCollapse(int outerLoopRank, int innerLoopRank) const |
---|
330 | { |
---|
331 | return iter1_.canCollapse(outerLoopRank, innerLoopRank) |
---|
332 | && iter2_.canCollapse(outerLoopRank, innerLoopRank) |
---|
333 | && iter3_.canCollapse(outerLoopRank, innerLoopRank); |
---|
334 | } |
---|
335 | |
---|
336 | template<int N_rank> |
---|
337 | void moveTo(const TinyVector<int,N_rank>& i) |
---|
338 | { |
---|
339 | iter1_.moveTo(i); |
---|
340 | iter2_.moveTo(i); |
---|
341 | iter3_.moveTo(i); |
---|
342 | } |
---|
343 | |
---|
344 | // this is needed for the stencil expression fastRead to work |
---|
345 | void _bz_offsetData(sizeType i) |
---|
346 | { |
---|
347 | iter1_._bz_offsetData(i); |
---|
348 | iter2_._bz_offsetData(i); |
---|
349 | iter3_._bz_offsetData(i); |
---|
350 | } |
---|
351 | |
---|
352 | diffType suggestStride(int rank) const |
---|
353 | { |
---|
354 | diffType stride1 = iter1_.suggestStride(rank); |
---|
355 | diffType stride2 = iter2_.suggestStride(rank); |
---|
356 | diffType stride3 = iter3_.suggestStride(rank); |
---|
357 | return stride1>(stride2=(stride2>stride3?stride2:stride3))?stride1:stride2; |
---|
358 | } |
---|
359 | |
---|
360 | bool isStride(int rank, diffType stride) const |
---|
361 | { |
---|
362 | return iter1_.isStride(rank,stride) |
---|
363 | && iter2_.isStride(rank,stride) |
---|
364 | && iter3_.isStride(rank,stride); |
---|
365 | } |
---|
366 | |
---|
367 | void prettyPrint(BZ_STD_SCOPE(string) &str, |
---|
368 | prettyPrintFormat& format) const |
---|
369 | { |
---|
370 | str += "where("; |
---|
371 | iter1_.prettyPrint(str,format); |
---|
372 | str += ","; |
---|
373 | iter2_.prettyPrint(str,format); |
---|
374 | str += ","; |
---|
375 | iter3_.prettyPrint(str,format); |
---|
376 | str += ")"; |
---|
377 | } |
---|
378 | |
---|
379 | template<typename T_shape> |
---|
380 | bool shapeCheck(const T_shape& shape) const |
---|
381 | { |
---|
382 | int t1 = iter1_.shapeCheck(shape); |
---|
383 | int t2 = iter2_.shapeCheck(shape); |
---|
384 | int t3 = iter3_.shapeCheck(shape); |
---|
385 | |
---|
386 | return t1 && t2 && t3; |
---|
387 | } |
---|
388 | |
---|
389 | |
---|
390 | // sliceinfo for expressions |
---|
391 | template<typename T1, typename T2 = nilArraySection, |
---|
392 | class T3 = nilArraySection, typename T4 = nilArraySection, |
---|
393 | class T5 = nilArraySection, typename T6 = nilArraySection, |
---|
394 | class T7 = nilArraySection, typename T8 = nilArraySection, |
---|
395 | class T9 = nilArraySection, typename T10 = nilArraySection, |
---|
396 | class T11 = nilArraySection> |
---|
397 | class SliceInfo { |
---|
398 | public: |
---|
399 | typedef typename T_expr1::template SliceInfo<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>::T_slice T_slice1; |
---|
400 | typedef typename T_expr2::template SliceInfo<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>::T_slice T_slice2; |
---|
401 | typedef typename T_expr3::template SliceInfo<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>::T_slice T_slice3; |
---|
402 | typedef _bz_ArrayWhere<T_slice1, T_slice2, T_slice3> T_slice; |
---|
403 | }; |
---|
404 | |
---|
405 | template<typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, |
---|
406 | typename T7, typename T8, typename T9, typename T10, typename T11> |
---|
407 | typename SliceInfo<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11>::T_slice |
---|
408 | operator()(T1 r1, T2 r2, T3 r3, T4 r4, T5 r5, T6 r6, T7 r7, T8 r8, T9 r9, T10 r10, T11 r11) const |
---|
409 | { |
---|
410 | return typename SliceInfo<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11>::T_slice |
---|
411 | (iter1_(r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11), |
---|
412 | iter2_(r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11), |
---|
413 | iter3_(r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11)); |
---|
414 | } |
---|
415 | |
---|
416 | private: |
---|
417 | _bz_ArrayWhere() { } |
---|
418 | |
---|
419 | T_expr1 iter1_; |
---|
420 | T_expr2 iter2_; |
---|
421 | T_expr3 iter3_; |
---|
422 | }; |
---|
423 | |
---|
424 | template<typename T1, typename T2, typename T3> |
---|
425 | inline |
---|
426 | _bz_ArrayExpr<_bz_ArrayWhere<_bz_typename asExpr<T1>::T_expr, |
---|
427 | _bz_typename asExpr<T2>::T_expr, _bz_typename asExpr<T3>::T_expr> > |
---|
428 | where(const T1& a, const T2& b, const T3& c) |
---|
429 | { |
---|
430 | return _bz_ArrayExpr<_bz_ArrayWhere<_bz_typename asExpr<T1>::T_expr, |
---|
431 | _bz_typename asExpr<T2>::T_expr, |
---|
432 | _bz_typename asExpr<T3>::T_expr> >(a,b,c); |
---|
433 | } |
---|
434 | |
---|
435 | BZ_NAMESPACE_END |
---|
436 | |
---|
437 | #endif // BZ_ARRAYWHERE_H |
---|
438 | |
---|