Skip to content

Commit 003ce36

Browse files
committed
Add apply_fallback_impl struct
1 parent 014e32f commit 003ce36

File tree

4 files changed

+37
-26
lines changed

4 files changed

+37
-26
lines changed

include/kernel_float/apply.h

+30-8
Original file line numberDiff line numberDiff line change
@@ -157,31 +157,53 @@ using default_policy = KERNEL_FLOAT_POLICY;
157157

158158
namespace detail {
159159

160+
//
160161
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
161-
struct apply_base_impl {
162+
struct apply_fallback_impl {
162163
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
163-
#pragma unroll
164-
for (size_t i = 0; i < N; i++) {
165-
output[i] = fun(args[i]...);
166-
}
164+
static_assert(N > 0, "operation not implemented");
167165
}
168166
};
169167

168+
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
169+
struct apply_base_impl: apply_fallback_impl<Policy, F, N, Output, Args...> {};
170+
170171
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
171172
struct apply_impl: apply_base_impl<Policy, F, N, Output, Args...> {};
172173

174+
// `fast_policy` falls back to `accurate_policy`
173175
template<typename F, size_t N, typename Output, typename... Args>
174-
struct apply_base_impl<fast_policy, F, N, Output, Args...>:
176+
struct apply_fallback_impl<fast_policy, F, N, Output, Args...>:
175177
apply_impl<accurate_policy, F, N, Output, Args...> {};
176178

179+
// `approx_policy` falls back to `fast_policy`
177180
template<typename F, size_t N, typename Output, typename... Args>
178-
struct apply_base_impl<approx_policy, F, N, Output, Args...>:
181+
struct apply_fallback_impl<approx_policy, F, N, Output, Args...>:
179182
apply_impl<fast_policy, F, N, Output, Args...> {};
180183

184+
// `approx_level_policy` falls back to `approx_policy`
181185
template<int Level, typename F, size_t N, typename Output, typename... Args>
182-
struct apply_base_impl<approx_level_policy<Level>, F, N, Output, Args...>:
186+
struct apply_fallback_impl<approx_level_policy<Level>, F, N, Output, Args...>:
183187
apply_impl<approx_policy, F, N, Output, Args...> {};
184188

189+
template<typename F, typename Output, typename... Args>
190+
struct invoke_impl {
191+
KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) {
192+
return fun(args...);
193+
}
194+
};
195+
196+
// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
197+
template<typename F, size_t N, typename Output, typename... Args>
198+
struct apply_impl<accurate_policy, F, N, Output, Args...> {
199+
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
200+
#pragma unroll
201+
for (size_t i = 0; i < N; i++) {
202+
output[i] = invoke_impl<F, Output, Args...>::call(fun, args[i]...);
203+
}
204+
}
205+
};
206+
185207
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
186208
struct map_impl {
187209
static constexpr size_t packet_size = preferred_vector_size<Output>::value;

include/kernel_float/binops.h

+2-10
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ struct multiply<bool> {
291291

292292
namespace detail {
293293
template<typename Policy, typename T, size_t N>
294-
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
294+
struct apply_base_impl<Policy, ops::divide<T>, N, T, T, T> {
295295
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
296296
T rhs_rcp[N];
297297

@@ -301,10 +301,6 @@ struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
301301
}
302302
};
303303

304-
template<typename T, size_t N>
305-
struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
306-
apply_base_impl<accurate_policy, ops::divide<T>, N, T, T, T> {};
307-
308304
#if KERNEL_FLOAT_IS_DEVICE
309305
template<>
310306
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
@@ -319,7 +315,7 @@ struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
319315
namespace detail {
320316
// Override `pow` using `log2` and `exp2`
321317
template<typename Policy, typename T, size_t N>
322-
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
318+
struct apply_base_impl<Policy, ops::pow<T>, N, T, T, T> {
323319
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
324320
T lhs_log[N];
325321
T result_log[N];
@@ -330,10 +326,6 @@ struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
330326
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
331327
}
332328
};
333-
334-
template<typename T, size_t N>
335-
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
336-
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
337329
} // namespace detail
338330

339331
template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>

include/kernel_float/fp16.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ namespace kernel_float {
2222
using half_t = ::__half;
2323
using half2_t = ::__half2;
2424

25-
using __half = void;
26-
using __half2 = void;
27-
2825
template<>
2926
struct preferred_vector_size<half_t> {
3027
static constexpr size_t value = 2;
@@ -50,7 +47,7 @@ template<>
5047
struct allow_float_fallback<half_t> {
5148
static constexpr bool value = true;
5249
};
53-
}; // namespace detail
50+
} // namespace detail
5451

5552
#if KERNEL_FLOAT_IS_DEVICE
5653
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \

include/kernel_float/triops.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ struct fma {
9898
} // namespace ops
9999

100100
namespace detail {
101-
template<typename Policy, typename T, size_t N>
102-
struct apply_impl<Policy, ops::fma<T>, N, T, T, T, T> {
101+
template<typename T, size_t N>
102+
struct apply_impl<accurate_policy, ops::fma<T>, N, T, T, T, T> {
103103
KERNEL_FLOAT_INLINE
104104
static void call(ops::fma<T>, T* output, const T* a, const T* b, const T* c) {
105105
T temp[N];
106-
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
107-
apply_impl<Policy, ops::add<T>, N, T, T, T>::call({}, output, temp, c);
106+
apply_impl<accurate_policy, ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
107+
apply_impl<accurate_policy, ops::add<T>, N, T, T, T>::call({}, output, temp, c);
108108
}
109109
};
110110
} // namespace detail

0 commit comments

Comments
 (0)