16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2024-12-02 10:59:19.296684
20
- // git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607
19
+ // date: 2024-12-02 18:48:50.243676
20
+ // git hash: 846de1f9aefaef76da15ebb5474080d531efaf38
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
42
42
#elif defined(__HIPCC__)
43
43
#define KERNEL_FLOAT_IS_HIP (1 )
44
44
#define KERNEL_FLOAT_DEVICE __attribute__ ((always_inline)) __device__
45
+ #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
45
46
46
47
#ifdef __HIP_DEVICE_COMPILE__
47
- #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
48
48
#define KERNEL_FLOAT_IS_DEVICE (1 )
49
49
#else
50
- #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
51
50
#define KERNEL_FLOAT_IS_HOST (1 )
52
51
#endif
53
52
@@ -1875,14 +1874,16 @@ namespace ops {
1875
1874
template <typename T>
1876
1875
struct min {
1877
1876
KERNEL_FLOAT_INLINE T operator ()(T left, T right) {
1878
- return left < right ? left : right;
1877
+ auto cond = less<T> {}(left, right);
1878
+ return cast<decltype (cond), bool > {}(cond) ? left : right;
1879
1879
}
1880
1880
};
1881
1881
1882
1882
template <typename T>
1883
1883
struct max {
1884
1884
KERNEL_FLOAT_INLINE T operator ()(T left, T right) {
1885
- return left > right ? left : right;
1885
+ auto cond = greater<T> {}(left, right);
1886
+ return cast<decltype (cond), bool > {}(cond) ? left : right;
1886
1887
}
1887
1888
};
1888
1889
@@ -4307,7 +4308,6 @@ struct allow_float_fallback<bfloat16_t> {
4307
4308
};
4308
4309
}; // namespace detail
4309
4310
4310
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4311
4311
#define KERNEL_FLOAT_BF16_UNARY_FUN (NAME, FUN1, FUN2 ) \
4312
4312
namespace ops { \
4313
4313
template <> \
@@ -4328,6 +4328,7 @@ struct allow_float_fallback<bfloat16_t> {
4328
4328
}; \
4329
4329
}
4330
4330
4331
+ #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4331
4332
KERNEL_FLOAT_BF16_UNARY_FUN (sin, ::hsin, ::h2sin)
4332
4333
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
4333
4334
@@ -4348,9 +4349,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
4348
4349
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
4349
4350
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
4350
4351
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
4352
+
4353
+ // For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops.
4354
+ // For CUDA, we can just use the regular bfloat16 functions (see above).
4355
+ #elif KERNEL_FLOAT_IS_HIP
4356
+ KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs (const __hip_bfloat16 a) {
4357
+ __hip_bfloat16 res = a;
4358
+ res.data &= 0x7FFF ;
4359
+ return res;
4360
+ }
4361
+
4362
+ KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg (const __hip_bfloat16 a) {
4363
+ __hip_bfloat16 res = a;
4364
+ res.data ^= 0x8000 ;
4365
+ return res;
4366
+ }
4367
+
4368
+ KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2 (const __hip_bfloat162 a) {
4369
+ return {hip_habs (a.x ), hip_habs (a.y )};
4370
+ }
4371
+
4372
+ KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2 (const __hip_bfloat162 a) {
4373
+ return {hip_hneg (a.x ), hip_hneg (a.y )};
4374
+ }
4375
+
4376
+ KERNEL_FLOAT_BF16_UNARY_FUN (abs, hip_habs, hip_habs2)
4377
+ KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
4351
4378
#endif
4352
4379
4353
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4354
4380
#define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 ) \
4355
4381
namespace ops { \
4356
4382
template <> \
@@ -4380,6 +4406,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
4380
4406
}; \
4381
4407
}
4382
4408
4409
+ #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4383
4410
KERNEL_FLOAT_BF16_BINARY_FUN (add, __hadd, __hadd2)
4384
4411
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
4385
4412
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)
0 commit comments