14
14
15
15
/* * \file ein_reduce.h
16
16
* \brief Optional helper for computing Einstein reductions on arrays.
17
- */
17
+ */
18
18
19
19
#ifndef NDARRAY_EIN_REDUCE_H
20
20
#define NDARRAY_EIN_REDUCE_H
@@ -61,22 +61,38 @@ struct ein_op {
61
61
}
62
62
63
63
template <class T , class = enable_if_ein_op<T>>
64
- auto operator +(const T& r) const { return make_ein_op_add (*this , r); }
64
+ auto operator +(const T& r) const {
65
+ return make_ein_op_add (*this , r);
66
+ }
65
67
template <class T , class = enable_if_ein_op<T>>
66
- auto operator -(const T& r) const { return make_ein_op_sub (*this , r); }
68
+ auto operator -(const T& r) const {
69
+ return make_ein_op_sub (*this , r);
70
+ }
67
71
template <class T , class = enable_if_ein_op<T>>
68
- auto operator *(const T& r) const { return make_ein_op_mul (*this , r); }
72
+ auto operator *(const T& r) const {
73
+ return make_ein_op_mul (*this , r);
74
+ }
69
75
template <class T , class = enable_if_ein_op<T>>
70
- auto operator /(const T& r) const { return make_ein_op_div (*this , r); }
76
+ auto operator /(const T& r) const {
77
+ return make_ein_op_div (*this , r);
78
+ }
71
79
72
80
template <class T , class = enable_if_ein_op<T>>
73
- auto operator =(const T& r) const { return make_ein_op_assign (*this , r); }
81
+ auto operator =(const T& r) const {
82
+ return make_ein_op_assign (*this , r);
83
+ }
74
84
template <class T , class = enable_if_ein_op<T>>
75
- auto operator +=(const T& r) const { return make_ein_op_add_assign (*this , r); }
85
+ auto operator +=(const T& r) const {
86
+ return make_ein_op_add_assign (*this , r);
87
+ }
76
88
template <class T , class = enable_if_ein_op<T>>
77
- auto operator -=(const T& r) const { return make_ein_op_sub_assign (*this , r); }
89
+ auto operator -=(const T& r) const {
90
+ return make_ein_op_sub_assign (*this , r);
91
+ }
78
92
template <class T , class = enable_if_ein_op<T>>
79
- auto operator *=(const T& r) const { return make_ein_op_mul_assign (*this , r); }
93
+ auto operator *=(const T& r) const {
94
+ return make_ein_op_mul_assign (*this , r);
95
+ }
80
96
};
81
97
82
98
// A binary operation of two operands.
@@ -94,46 +110,55 @@ struct ein_bin_op {
94
110
const Derived& derived () const { return *static_cast <const Derived*>(this ); }
95
111
96
112
template <class T , class = enable_if_ein_op<T>>
97
- auto operator +(const T& r) const { return make_ein_op_add (derived (), r); }
113
+ auto operator +(const T& r) const {
114
+ return make_ein_op_add (derived (), r);
115
+ }
98
116
template <class T , class = enable_if_ein_op<T>>
99
- auto operator -(const T& r) const { return make_ein_op_sub (derived (), r); }
117
+ auto operator -(const T& r) const {
118
+ return make_ein_op_sub (derived (), r);
119
+ }
100
120
template <class T , class = enable_if_ein_op<T>>
101
- auto operator *(const T& r) const { return make_ein_op_mul (derived (), r); }
121
+ auto operator *(const T& r) const {
122
+ return make_ein_op_mul (derived (), r);
123
+ }
102
124
template <class T , class = enable_if_ein_op<T>>
103
- auto operator /(const T& r) const { return make_ein_op_div (derived (), r); }
125
+ auto operator /(const T& r) const {
126
+ return make_ein_op_div (derived (), r);
127
+ }
104
128
};
105
129
106
- #define NDARRAY_MAKE_EIN_BIN_HELPERS (name, op ) \
107
- template <class OpA , class OpB > auto make_##name(const OpA& a, const OpB& b) { \
108
- name<OpA, OpB> result; \
109
- result.op_a = a; \
110
- result.op_b = b; \
111
- return result; \
112
- }
130
+ #define NDARRAY_MAKE_EIN_BIN_HELPERS (name, op ) \
131
+ template <class OpA , class OpB > \
132
+ auto make_##name(const OpA& a, const OpB& b) { \
133
+ name<OpA, OpB> result; \
134
+ result.op_a = a; \
135
+ result.op_b = b; \
136
+ return result; \
137
+ }
113
138
114
- #define NDARRAY_MAKE_EIN_BIN_OP (name, op, is_assign_ ) \
115
- template <class OpA , class OpB > \
116
- struct name : public ein_bin_op <OpA, OpB, name<OpA, OpB>> { \
117
- using is_assign = is_assign_; \
118
- template <class Idx > \
119
- NDARRAY_INLINE auto operator ()(const Idx& i) const { \
120
- using base = ein_bin_op<OpA, OpB, name>; \
121
- return base::op_a (i) op base::op_b (i); \
122
- } \
123
- }; \
124
- NDARRAY_MAKE_EIN_BIN_HELPERS (name, op)
125
-
126
- #define NDARRAY_MAKE_EIN_BIN_FN (name, fn, is_assign_ ) \
127
- template <class OpA , class OpB > \
128
- struct name : public ein_bin_op <OpA, OpB, name<OpA, OpB>> { \
129
- using is_assign = is_assign_; \
130
- template <class Idx > \
131
- NDARRAY_INLINE auto operator ()(const Idx& i) const { \
132
- using base = ein_bin_op<OpA, OpB, name>; \
133
- return fn (base::op_a (i), base::op_b (i)); \
134
- } \
135
- }; \
136
- NDARRAY_MAKE_EIN_BIN_HELPERS (name, op)
139
+ #define NDARRAY_MAKE_EIN_BIN_OP (name, op, is_assign_ ) \
140
+ template <class OpA , class OpB > \
141
+ struct name : public ein_bin_op <OpA, OpB, name<OpA, OpB>> { \
142
+ using is_assign = is_assign_; \
143
+ template <class Idx > \
144
+ NDARRAY_INLINE auto operator ()(const Idx& i) const { \
145
+ using base = ein_bin_op<OpA, OpB, name>; \
146
+ return base::op_a (i) op base::op_b (i); \
147
+ } \
148
+ }; \
149
+ NDARRAY_MAKE_EIN_BIN_HELPERS (name, op)
150
+
151
+ #define NDARRAY_MAKE_EIN_BIN_FN (name, fn, is_assign_ ) \
152
+ template <class OpA , class OpB > \
153
+ struct name : public ein_bin_op <OpA, OpB, name<OpA, OpB>> { \
154
+ using is_assign = is_assign_; \
155
+ template <class Idx > \
156
+ NDARRAY_INLINE auto operator ()(const Idx& i) const { \
157
+ using base = ein_bin_op<OpA, OpB, name>; \
158
+ return fn (base::op_a (i), base::op_b (i)); \
159
+ } \
160
+ }; \
161
+ NDARRAY_MAKE_EIN_BIN_HELPERS (name, op)
137
162
138
163
// Define the expression types for the operations we support.
139
164
NDARRAY_MAKE_EIN_BIN_OP (ein_op_add, +, std::false_type);
@@ -153,9 +178,13 @@ NDARRAY_MAKE_EIN_BIN_OP(ein_op_mul_assign, *=, std::true_type);
153
178
#undef NDARRAY_MAKE_EIN_BIN_HELPERS
154
179
155
180
template <class OpA , class OpB >
156
- auto min (const OpA& a, const OpB& b) { return make_ein_op_min (a, b); }
181
+ auto min (const OpA& a, const OpB& b) {
182
+ return make_ein_op_min (a, b);
183
+ }
157
184
template <class OpA , class OpB >
158
- auto max (const OpA& a, const OpB& b) { return make_ein_op_max (a, b); }
185
+ auto max (const OpA& a, const OpB& b) {
186
+ return make_ein_op_max (a, b);
187
+ }
159
188
160
189
// Helper to reinterpret a dim/shape with a new stride.
161
190
template <index_t NewStride, index_t Min, index_t Extent, index_t Stride>
@@ -200,14 +229,18 @@ auto reconcile_dim(const std::tuple<Dims...>& dims) {
200
229
201
230
// Get the shape of an ein_reduce operand, or an empty shape if not an array.
202
231
template <class T , class Shape >
203
- const auto & dims_of (const array_ref<T, Shape>& op) { return op.shape ().dims (); }
232
+ const auto & dims_of (const array_ref<T, Shape>& op) {
233
+ return op.shape ().dims ();
234
+ }
204
235
template <class T >
205
- auto dims_of (const T& op) { return std::tuple<>(); }
236
+ auto dims_of (const T& op) {
237
+ return std::tuple<>();
238
+ }
206
239
207
240
// These types are flags that let us overload behavior based on these 3 options.
208
- class is_inferred_shape {};
209
- class is_result_shape {};
210
- class is_operand_shape {};
241
+ class is_inferred_shape {};
242
+ class is_result_shape {};
243
+ class is_operand_shape {};
211
244
212
245
// Get a dim from an operand, depending on the intended use of the shape.
213
246
template <size_t Dim, class Dims , size_t ... Is>
@@ -240,30 +273,37 @@ auto make_ein_reduce_shape(index_sequence<Is...>, const Ops&... ops) {
240
273
return make_shape (gather_dims<Is>(ops...)...);
241
274
}
242
275
243
- } // namespace internal
276
+ } // namespace internal
244
277
245
278
/* * Operand for an Einstein summation, which is an array or other
246
279
* callable object, along with a set of dimension indices.
247
280
* `ein<i, j, ...>(a)` means the dimensions `i, j, ...` of the
248
281
* summation index are used to address `a` during Einstein
249
282
* summation. See `ein_reduce` for more details. */
250
- template <size_t ... Is, class Op ,
251
- class = internal::enable_if_callable<Op, decltype(Is)...>>
252
- auto ein (Op op) { return internal::ein_op<Op, Is...>{op}; }
283
+ template <size_t ... Is, class Op , class = internal::enable_if_callable<Op, decltype(Is)...>>
284
+ auto ein (Op op) {
285
+ return internal::ein_op<Op, Is...>{op};
286
+ }
253
287
template <size_t ... Is, class T , class Shape , class Alloc ,
254
288
class = std::enable_if_t <sizeof ...(Is) == Shape::rank()>>
255
- auto ein (array<T, Shape, Alloc>& op) { return ein<Is...>(op.ref ()); }
289
+ auto ein (array<T, Shape, Alloc>& op) {
290
+ return ein<Is...>(op.ref ());
291
+ }
256
292
template <size_t ... Is, class T , class Shape , class Alloc ,
257
293
class = std::enable_if_t <sizeof ...(Is) == Shape::rank()>>
258
- auto ein (const array<T, Shape, Alloc>& op) { return ein<Is...>(op.ref ()); }
294
+ auto ein (const array<T, Shape, Alloc>& op) {
295
+ return ein<Is...>(op.ref ());
296
+ }
259
297
260
298
/* * Define an Einstein summation operand for a scalar. The scalar
261
299
* is broadcasted as needed during the summation. Because this
262
300
* operand does not provide a shape, the dimensions of the sum
263
301
* must be inferred from other operands. See `ein_reduce` for more
264
302
* details. */
265
303
template <class T >
266
- auto ein (T& scalar) { return ein<>(array_ref<T, shape<>>(&scalar, {})); }
304
+ auto ein (T& scalar) {
305
+ return ein<>(array_ref<T, shape<>>(&scalar, {}));
306
+ }
267
307
268
308
/* * Compute an Einstein reduction. This function allows one to specify
269
309
* many kinds of array transformations and reductions using Einstein
@@ -322,8 +362,7 @@ NDARRAY_UNIQUE auto ein_reduce(const Expr& expr) {
322
362
// first dimension it finds, so we want that to be the result dimension if it
323
363
// is present. If not, this selects one of the operand dimensions, which are
324
364
// given stride 0.
325
- auto reduction_shape = internal::make_ein_reduce_shape (
326
- internal::make_index_sequence<loop_rank>(),
365
+ auto reduction_shape = internal::make_ein_reduce_shape (internal::make_index_sequence<loop_rank>(),
327
366
std::make_tuple (internal::is_result_shape (), expr.op_a ),
328
367
std::make_tuple (internal::is_operand_shape (), expr.op_b ));
329
368
@@ -338,19 +377,16 @@ NDARRAY_UNIQUE auto ein_reduce(const Expr& expr) {
338
377
339
378
/* * Wrapper for `ein_reduce` computing the sum of the operand operand
340
379
* expression via `ein_reduce(result += expr)`. */
341
- template <class Expr , class Result ,
342
- class = internal::enable_if_ein_op<Expr>,
380
+ template <class Expr , class Result , class = internal::enable_if_ein_op<Expr>,
343
381
class = internal::enable_if_ein_op<Result>>
344
382
NDARRAY_UNIQUE auto ein_sum (const Expr& expr, const Result& result) {
345
383
return ein_reduce (result += expr);
346
384
}
347
385
348
386
/* * Infer the shape of the result of `make_ein_reduce`. */
349
- template <size_t ... ResultIs, class Expr ,
350
- class = internal::enable_if_ein_op<Expr>>
387
+ template <size_t ... ResultIs, class Expr , class = internal::enable_if_ein_op<Expr>>
351
388
auto make_ein_reduce_shape (const Expr& expr) {
352
- auto result_shape = internal::make_ein_reduce_shape (
353
- internal::index_sequence<ResultIs...>(),
389
+ auto result_shape = internal::make_ein_reduce_shape (internal::index_sequence<ResultIs...>(),
354
390
std::make_tuple (internal::is_inferred_shape (), expr));
355
391
// TODO: This would really benefit from addressing https://github.com/dsharlet/array/issues/31
356
392
return make_compact (result_shape);
@@ -388,6 +424,6 @@ NDARRAY_UNIQUE auto make_ein_sum(
388
424
return result;
389
425
}
390
426
391
- } // namespace nda
427
+ } // namespace nda
392
428
393
- #endif // NDARRAY_EIN_REDUCE_H
429
+ #endif // NDARRAY_EIN_REDUCE_H
0 commit comments