Skip to content

Commit f94bd10

Browse files
committed
Fix several issues related to HIP compilation for bfloat16
1 parent 846de1f commit f94bd10

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

include/kernel_float/bf16.h

+28-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ struct allow_float_fallback<bfloat16_t> {
6060
};
6161
}; // namespace detail
6262

63-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
6463
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
6564
namespace ops { \
6665
template<> \
@@ -81,6 +80,7 @@ struct allow_float_fallback<bfloat16_t> {
8180
}; \
8281
}
8382

83+
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
8484
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
8585
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
8686

@@ -101,9 +101,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
101101
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
102102
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
103103
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
104+
105+
// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops.
106+
// For CUDA, we can just use the regular bfloat16 functions (see above).
107+
#elif KERNEL_FLOAT_IS_HIP
108+
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) {
109+
__hip_bfloat16 res = a;
110+
res.data &= 0x7FFF;
111+
return res;
112+
}
113+
114+
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) {
115+
__hip_bfloat16 res = a;
116+
res.data ^= 0x8000;
117+
return res;
118+
}
119+
120+
KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) {
121+
return {hip_habs(a.x), hip_habs(a.y)};
122+
}
123+
124+
KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) {
125+
return {hip_hneg(a.x), hip_hneg(a.y)};
126+
}
127+
128+
KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2)
129+
KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
104130
#endif
105131

106-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
107132
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
108133
namespace ops { \
109134
template<> \
@@ -133,6 +158,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
133158
}; \
134159
}
135160

161+
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
136162
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
137163
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
138164
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)

include/kernel_float/binops.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,16 @@ namespace ops {
189189
template<typename T>
190190
struct min {
191191
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
192-
return left < right ? left : right;
192+
auto cond = less<T> {}(left, right);
193+
return cast<decltype(cond), bool> {}(cond) ? left : right;
193194
}
194195
};
195196

196197
template<typename T>
197198
struct max {
198199
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
199-
return left > right ? left : right;
200+
auto cond = greater<T> {}(left, right);
201+
return cast<decltype(cond), bool> {}(cond) ? left : right;
200202
}
201203
};
202204

include/kernel_float/macros.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
#elif defined(__HIPCC__)
2121
#define KERNEL_FLOAT_IS_HIP (1)
2222
#define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__
23+
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
2324

2425
#ifdef __HIP_DEVICE_COMPILE__
25-
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
2626
#define KERNEL_FLOAT_IS_DEVICE (1)
2727
#else
28-
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
2928
#define KERNEL_FLOAT_IS_HOST (1)
3029
#endif
3130

single_include/kernel_float.h

+35-8
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-12-02 10:59:19.296684
20-
// git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607
19+
// date: 2024-12-02 18:48:50.243676
20+
// git hash: 846de1f9aefaef76da15ebb5474080d531efaf38
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -42,12 +42,11 @@
4242
#elif defined(__HIPCC__)
4343
#define KERNEL_FLOAT_IS_HIP (1)
4444
#define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__
45+
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
4546

4647
#ifdef __HIP_DEVICE_COMPILE__
47-
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
4848
#define KERNEL_FLOAT_IS_DEVICE (1)
4949
#else
50-
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
5150
#define KERNEL_FLOAT_IS_HOST (1)
5251
#endif
5352

@@ -1875,14 +1874,16 @@ namespace ops {
18751874
template<typename T>
18761875
struct min {
18771876
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
1878-
return left < right ? left : right;
1877+
auto cond = less<T> {}(left, right);
1878+
return cast<decltype(cond), bool> {}(cond) ? left : right;
18791879
}
18801880
};
18811881

18821882
template<typename T>
18831883
struct max {
18841884
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
1885-
return left > right ? left : right;
1885+
auto cond = greater<T> {}(left, right);
1886+
return cast<decltype(cond), bool> {}(cond) ? left : right;
18861887
}
18871888
};
18881889

@@ -4307,7 +4308,6 @@ struct allow_float_fallback<bfloat16_t> {
43074308
};
43084309
}; // namespace detail
43094310

4310-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43114311
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
43124312
namespace ops { \
43134313
template<> \
@@ -4328,6 +4328,7 @@ struct allow_float_fallback<bfloat16_t> {
43284328
}; \
43294329
}
43304330

4331+
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43314332
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
43324333
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
43334334

@@ -4348,9 +4349,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
43484349
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
43494350
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
43504351
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
4352+
4353+
// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops.
4354+
// For CUDA, we can just use the regular bfloat16 functions (see above).
4355+
#elif KERNEL_FLOAT_IS_HIP
4356+
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) {
4357+
__hip_bfloat16 res = a;
4358+
res.data &= 0x7FFF;
4359+
return res;
4360+
}
4361+
4362+
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) {
4363+
__hip_bfloat16 res = a;
4364+
res.data ^= 0x8000;
4365+
return res;
4366+
}
4367+
4368+
KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) {
4369+
return {hip_habs(a.x), hip_habs(a.y)};
4370+
}
4371+
4372+
KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) {
4373+
return {hip_hneg(a.x), hip_hneg(a.y)};
4374+
}
4375+
4376+
KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2)
4377+
KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
43514378
#endif
43524379

4353-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43544380
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
43554381
namespace ops { \
43564382
template<> \
@@ -4380,6 +4406,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
43804406
}; \
43814407
}
43824408

4409+
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
43834410
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
43844411
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
43854412
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)

0 commit comments

Comments
 (0)