1616
1717// ================================================================================
1818// this file has been auto-generated, do not modify its contents!
19- // date: 2024-12-02 10:59:19.296684
20- // git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607
19+ // date: 2024-12-02 18:48:50.243676
20+ // git hash: 846de1f9aefaef76da15ebb5474080d531efaf38
2121// ================================================================================
2222
2323#ifndef KERNEL_FLOAT_MACROS_H
4242#elif defined(__HIPCC__)
4343 #define KERNEL_FLOAT_IS_HIP (1 )
4444 #define KERNEL_FLOAT_DEVICE __attribute__ ((always_inline)) __device__
45+ #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
4546
4647 #ifdef __HIP_DEVICE_COMPILE__
47- #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
4848 #define KERNEL_FLOAT_IS_DEVICE (1 )
4949 #else
50- #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
5150 #define KERNEL_FLOAT_IS_HOST (1 )
5251 #endif
5352
@@ -1875,14 +1874,16 @@ namespace ops {
18751874template <typename T>
18761875struct min {
18771876 KERNEL_FLOAT_INLINE T operator ()(T left, T right) {
1878- return left < right ? left : right;
1877+ auto cond = less<T> {}(left, right);
1878+ return cast<decltype (cond), bool > {}(cond) ? left : right;
18791879 }
18801880};
18811881
18821882template <typename T>
18831883struct max {
18841884 KERNEL_FLOAT_INLINE T operator ()(T left, T right) {
1885- return left > right ? left : right;
1885+ auto cond = greater<T> {}(left, right);
1886+ return cast<decltype (cond), bool > {}(cond) ? left : right;
18861887 }
18871888};
18881889
@@ -4307,7 +4308,6 @@ struct allow_float_fallback<bfloat16_t> {
43074308};
43084309}; // namespace detail
43094310
4310- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43114311#define KERNEL_FLOAT_BF16_UNARY_FUN (NAME, FUN1, FUN2 ) \
43124312 namespace ops { \
43134313 template <> \
@@ -4328,6 +4328,7 @@ struct allow_float_fallback<bfloat16_t> {
43284328 }; \
43294329 }
43304330
4331+ #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43314332KERNEL_FLOAT_BF16_UNARY_FUN (sin, ::hsin, ::h2sin)
43324333KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
43334334
@@ -4348,9 +4349,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
43484349KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
43494350KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
43504351KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
4352+
4353+ // For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops.
4354+ // For CUDA, we can just use the regular bfloat16 functions (see above).
4355+ #elif KERNEL_FLOAT_IS_HIP
4356+ KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs (const __hip_bfloat16 a) {
4357+ __hip_bfloat16 res = a;
4358+ res.data &= 0x7FFF ;
4359+ return res;
4360+ }
4361+
4362+ KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg (const __hip_bfloat16 a) {
4363+ __hip_bfloat16 res = a;
4364+ res.data ^= 0x8000 ;
4365+ return res;
4366+ }
4367+
4368+ KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2 (const __hip_bfloat162 a) {
4369+ return {hip_habs (a.x ), hip_habs (a.y )};
4370+ }
4371+
4372+ KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2 (const __hip_bfloat162 a) {
4373+ return {hip_hneg (a.x ), hip_hneg (a.y )};
4374+ }
4375+
4376+ KERNEL_FLOAT_BF16_UNARY_FUN (abs, hip_habs, hip_habs2)
4377+ KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
43514378#endif
43524379
4353- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43544380#define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 ) \
43554381 namespace ops { \
43564382 template <> \
@@ -4380,6 +4406,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
43804406 }; \
43814407 }
43824408
4409+ #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43834410KERNEL_FLOAT_BF16_BINARY_FUN (add, __hadd, __hadd2)
43844411KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
43854412KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)
0 commit comments