Skip to content

Commit faa6095

Browse files
authored
Demote(B)Float16 pass: only keep enabled for PPC. (#55486)
LLVM should handle this properly now for everything but PPC (where BFoat16 isn't supported anyway).
1 parent 8a19b74 commit faa6095

File tree

3 files changed

+87
-57
lines changed

3 files changed

+87
-57
lines changed

src/llvm-demote-float16.cpp

+19-28
Original file line numberDiff line numberDiff line change
@@ -49,37 +49,28 @@ extern JuliaOJIT *jl_ExecutionEngine;
4949

5050
namespace {
5151

52-
static bool have_fp16(Function &caller, const Triple &TT) {
53-
Attribute FSAttr = caller.getFnAttribute("target-features");
54-
StringRef FS = "";
55-
if (FSAttr.isValid())
56-
FS = FSAttr.getValueAsString();
57-
else if (jl_ExecutionEngine)
58-
FS = jl_ExecutionEngine->getTargetFeatureString();
59-
// else probably called from opt, just do nothing
60-
if (TT.isAArch64()) {
61-
if (FS.find("+fp16fml") != llvm::StringRef::npos || FS.find("+fullfp16") != llvm::StringRef::npos){
62-
return true;
63-
}
64-
} else if (TT.getArch() == Triple::x86_64) {
65-
if (FS.find("+avx512fp16") != llvm::StringRef::npos){
66-
return true;
67-
}
68-
}
69-
if (caller.hasFnAttribute("julia.hasfp16")) {
70-
return true;
71-
}
72-
return false;
52+
static bool have_fp16(Function &F, const Triple &TT) {
53+
// for testing purposes
54+
Attribute Attr = F.getFnAttribute("julia.hasfp16");
55+
if (Attr.isValid())
56+
return Attr.getValueAsBool();
57+
58+
// llvm/llvm-project#97975: on some platforms, `half` uses excessive precision
59+
if (TT.isPPC())
60+
return false;
61+
62+
return true;
7363
}
7464

75-
static bool have_bf16(Function &caller, const Triple &TT) {
76-
if (caller.hasFnAttribute("julia.hasbf16")) {
77-
return true;
78-
}
65+
static bool have_bf16(Function &F, const Triple &TT) {
66+
// for testing purposes
67+
Attribute Attr = F.getFnAttribute("julia.hasbf16");
68+
if (Attr.isValid())
69+
return Attr.getValueAsBool();
7970

80-
// there's no targets that fully support bfloat yet;,
81-
// AVX512BF16 only provides conversion and dot product instructions.
82-
return false;
71+
// https://github.com/llvm/llvm-project/issues/97975#issuecomment-2218770199:
72+
// on current versions of LLVM, bf16 always uses TypeSoftPromoteHalf
73+
return true;
8374
}
8475

8576
static bool demoteFloat16(Function &F)

test/llvmpasses/fastmath.jl

-26
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,3 @@ import Base.FastMath
1616

1717
# CHECK: call fast float @llvm.sqrt.f32(float %"x::Float32")
1818
emit(FastMath.sqrt_fast, Float32)
19-
20-
21-
# Float16 operations should be performed as Float32, unless @fastmath is specified
22-
# TODO: this is not true for platforms that natively support Float16
23-
24-
foo(x::T,y::T) where T = x-y == zero(T)
25-
# CHECK: define {{(swiftcc )?}}i8 @julia_foo_{{[0-9]+}}({{.*}}half %[[X:"x::Float16"]], half %[[Y:"y::Float16"]]) {{.*}}{
26-
# CHECK-DAG: %[[XEXT:[0-9]+]] = fpext half %[[X]] to float
27-
# CHECK-DAG: %[[YEXT:[0-9]+]] = fpext half %[[Y]] to float
28-
# CHECK: %[[DIFF:[0-9]+]] = fsub float %[[XEXT]], %[[YEXT]]
29-
# CHECK: %[[TRUNC:[0-9]+]] = fptrunc float %[[DIFF]] to half
30-
# CHECK: %[[DIFFEXT:[0-9]+]] = fpext half %[[TRUNC]] to float
31-
# CHECK: %[[CMP:[0-9]+]] = fcmp oeq float %[[DIFFEXT]], 0.000000e+00
32-
# CHECK: %[[ZEXT:[0-9]+]] = zext i1 %[[CMP]] to i8
33-
# CHECK: ret i8 %[[ZEXT]]
34-
# CHECK: }
35-
emit(foo, Float16, Float16)
36-
37-
@fastmath foo(x::T,y::T) where T = x-y == zero(T)
38-
# CHECK: define {{(swiftcc )?}}i8 @julia_foo_{{[0-9]+}}({{.*}}half %[[X:"x::Float16"]], half %[[Y:"y::Float16"]]) {{.*}}{
39-
# CHECK: %[[DIFF:[0-9]+]] = fsub fast half %[[X]], %[[Y]]
40-
# CHECK: %[[CMP:[0-9]+]] = fcmp fast oeq half %[[DIFF]], 0xH0000
41-
# CHECK: %[[ZEXT:[0-9]+]] = zext i1 %[[CMP]] to i8
42-
# CHECK: ret i8 %[[ZEXT]]
43-
# CHECK: }
44-
emit(foo, Float16, Float16)

test/llvmpasses/float16.ll

+68-3
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ top:
9999
ret half %13
100100
}
101101

102-
define bfloat @demote_bfloat_test(bfloat %a, bfloat %b) {
102+
define bfloat @demote_bfloat_test(bfloat %a, bfloat %b) #2 {
103103
top:
104104
; CHECK-LABEL: @demote_bfloat_test(
105105
; CHECK-NEXT: top:
@@ -160,5 +160,70 @@ top:
160160
ret bfloat %13
161161
}
162162

163-
attributes #0 = { "target-features"="-avx512fp16" }
164-
attributes #1 = { "target-features"="+avx512fp16" }
163+
define bfloat @native_bfloat_test(bfloat %a, bfloat %b) #3 {
164+
top:
165+
; CHECK-LABEL: @native_bfloat_test(
166+
; CHECK-NEXT: top:
167+
; CHECK-NEXT: %0 = fadd bfloat %a, %b
168+
; CHECK-NEXT: %1 = fadd bfloat %0, %b
169+
; CHECK-NEXT: %2 = fadd bfloat %1, %b
170+
; CHECK-NEXT: %3 = fmul bfloat %2, %b
171+
; CHECK-NEXT: %4 = fdiv bfloat %3, %b
172+
; CHECK-NEXT: %5 = insertelement <2 x bfloat> undef, bfloat %a, i32 0
173+
; CHECK-NEXT: %6 = insertelement <2 x bfloat> %5, bfloat %b, i32 1
174+
; CHECK-NEXT: %7 = insertelement <2 x bfloat> undef, bfloat %b, i32 0
175+
; CHECK-NEXT: %8 = insertelement <2 x bfloat> %7, bfloat %b, i32 1
176+
; CHECK-NEXT: %9 = fadd <2 x bfloat> %6, %8
177+
; CHECK-NEXT: %10 = extractelement <2 x bfloat> %9, i32 0
178+
; CHECK-NEXT: %11 = extractelement <2 x bfloat> %9, i32 1
179+
; CHECK-NEXT: %12 = fadd bfloat %10, %11
180+
; CHECK-NEXT: %13 = fadd bfloat %12, %4
181+
; CHECK-NEXT: ret bfloat %13
182+
;
183+
%0 = fadd bfloat %a, %b
184+
%1 = fadd bfloat %0, %b
185+
%2 = fadd bfloat %1, %b
186+
%3 = fmul bfloat %2, %b
187+
%4 = fdiv bfloat %3, %b
188+
%5 = insertelement <2 x bfloat> undef, bfloat %a, i32 0
189+
%6 = insertelement <2 x bfloat> %5, bfloat %b, i32 1
190+
%7 = insertelement <2 x bfloat> undef, bfloat %b, i32 0
191+
%8 = insertelement <2 x bfloat> %7, bfloat %b, i32 1
192+
%9 = fadd <2 x bfloat> %6, %8
193+
%10 = extractelement <2 x bfloat> %9, i32 0
194+
%11 = extractelement <2 x bfloat> %9, i32 1
195+
%12 = fadd bfloat %10, %11
196+
%13 = fadd bfloat %12, %4
197+
ret bfloat %13
198+
}
199+
200+
define i1 @fast_half_test(half %0, half %1) #0 {
201+
top:
202+
; CHECK-LABEL: @fast_half_test(
203+
; CHECK-NEXT: top:
204+
; CHECK-NEXT: %2 = fsub fast half %0, %1
205+
; CHECK-NEXT: %3 = fcmp fast oeq half %2, 0xH0000
206+
; CHECK-NEXT: ret i1 %3
207+
;
208+
%2 = fsub fast half %0, %1
209+
%3 = fcmp fast oeq half %2, 0xH0000
210+
ret i1 %3
211+
}
212+
213+
define i1 @fast_bfloat_test(bfloat %0, bfloat %1) #2 {
214+
top:
215+
; CHECK-LABEL: @fast_bfloat_test(
216+
; CHECK-NEXT: top:
217+
; CHECK-NEXT: %2 = fsub fast bfloat %0, %1
218+
; CHECK-NEXT: %3 = fcmp fast oeq bfloat %2, 0xR0000
219+
; CHECK-NEXT: ret i1 %3
220+
;
221+
%2 = fsub fast bfloat %0, %1
222+
%3 = fcmp fast oeq bfloat %2, 0xR0000
223+
ret i1 %3
224+
}
225+
226+
attributes #0 = { "julia.hasfp16"="false" }
227+
attributes #1 = { "julia.hasfp16"="true" }
228+
attributes #2 = { "julia.hasbf16"="false" }
229+
attributes #3 = { "julia.hasbf16"="true" }

0 commit comments

Comments
 (0)