@@ -157,31 +157,53 @@ using default_policy = KERNEL_FLOAT_POLICY;
157
157
158
158
namespace detail {
159
159
160
+ //
160
161
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
161
- struct apply_base_impl {
162
+ struct apply_fallback_impl {
162
163
KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
163
- #pragma unroll
164
- for (size_t i = 0 ; i < N; i++) {
165
- output[i] = fun (args[i]...);
166
- }
164
+ static_assert (N > 0 , " operation not implemented" );
167
165
}
168
166
};
169
167
168
+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
169
+ struct apply_base_impl : apply_fallback_impl<Policy, F, N, Output, Args...> {};
170
+
170
171
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
171
172
struct apply_impl : apply_base_impl<Policy, F, N, Output, Args...> {};
172
173
174
+ // `fast_policy` falls back to `accurate_policy`
173
175
template <typename F, size_t N, typename Output, typename ... Args>
174
- struct apply_base_impl <fast_policy, F, N, Output, Args...>:
176
+ struct apply_fallback_impl <fast_policy, F, N, Output, Args...>:
175
177
apply_impl<accurate_policy, F, N, Output, Args...> {};
176
178
179
+ // `approx_policy` falls back to `fast_policy`
177
180
template <typename F, size_t N, typename Output, typename ... Args>
178
- struct apply_base_impl <approx_policy, F, N, Output, Args...>:
181
+ struct apply_fallback_impl <approx_policy, F, N, Output, Args...>:
179
182
apply_impl<fast_policy, F, N, Output, Args...> {};
180
183
184
+ // `approx_level_policy` falls back to `approx_policy`
181
185
template <int Level, typename F, size_t N, typename Output, typename ... Args>
182
- struct apply_base_impl <approx_level_policy<Level>, F, N, Output, Args...>:
186
+ struct apply_fallback_impl <approx_level_policy<Level>, F, N, Output, Args...>:
183
187
apply_impl<approx_policy, F, N, Output, Args...> {};
184
188
189
+ template <typename F, typename Output, typename ... Args>
190
+ struct invoke_impl {
191
+ KERNEL_FLOAT_INLINE static Output call (F fun, Args... args) {
192
+ return fun (args...);
193
+ }
194
+ };
195
+
196
+ // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
197
+ template <typename F, size_t N, typename Output, typename ... Args>
198
+ struct apply_impl <accurate_policy, F, N, Output, Args...> {
199
+ KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
200
+ #pragma unroll
201
+ for (size_t i = 0 ; i < N; i++) {
202
+ output[i] = invoke_impl<F, Output, Args...>::call (fun, args[i]...);
203
+ }
204
+ }
205
+ };
206
+
185
207
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
186
208
struct map_impl {
187
209
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
0 commit comments