Skip to content

Commit 6b4a313

Browse files
committed
Add autocast for bf16 and bf16xN
1 parent 3bec545 commit 6b4a313

File tree

4 files changed

+35
-5
lines changed

4 files changed

+35
-5
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -820,12 +820,20 @@ fn equate_ty<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll Typ
820820
},
821821
)
822822
}
823-
TypeKind::Vector if cx.element_type(llvm_ty) == cx.type_i1() => {
823+
TypeKind::Vector => {
824+
let llvm_element_ty = cx.element_type(llvm_ty);
824825
let element_count = cx.vector_length(llvm_ty) as u64;
825-
let int_width = element_count.next_power_of_two().max(8);
826826

827-
rust_ty == cx.type_ix(int_width)
827+
if llvm_element_ty == cx.type_bf16() {
828+
rust_ty == cx.type_vector(cx.type_i16(), element_count)
829+
} else if llvm_element_ty == cx.type_i1() {
830+
let int_width = element_count.next_power_of_two().max(8);
831+
rust_ty == cx.type_ix(int_width)
832+
} else {
833+
false
834+
}
828835
}
836+
TypeKind::BFloat => rust_ty == cx.type_i16(),
829837
_ => false,
830838
}
831839
}
@@ -890,7 +898,7 @@ fn autocast<'ll>(
890898
bx.bitcast(val, dest_ty)
891899
}
892900
}
893-
_ => unreachable!(),
901+
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
894902
}
895903
}
896904

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,9 @@ unsafe extern "C" {
920920
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
921921
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
922922

923+
// Operations on non-IEEE real types
924+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
925+
923926
// Operations on function types
924927
pub(crate) fn LLVMFunctionType<'a>(
925928
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
183183
)
184184
}
185185
}
186+
187+
pub(crate) fn type_bf16(&self) -> &'ll Type {
188+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
189+
}
186190
}
187191

188192
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {

tests/codegen-llvm/inject-autocast.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
55
#![crate_type = "lib"]
66

7-
use std::simd::i64x2;
7+
use std::simd::{f32x4, i16x8, i64x2};
88

99
#[repr(C, packed)]
1010
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
@@ -72,8 +72,23 @@ pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
7272
foo(a, 1)
7373
}
7474

75+
// CHECK-LABEL: @bf16_vector_autocast
76+
#[no_mangle]
77+
pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
78+
extern "unadjusted" {
79+
#[link_name = "llvm.x86.vcvtneps2bf16128"]
80+
fn foo(a: f32x4) -> i16x8;
81+
}
82+
83+
// CHECK: [[A:%[0-9]+]] = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> {{.*}})
84+
// CHECK: bitcast <8 x bfloat> [[A]] to <8 x i16>
85+
foo(a)
86+
}
87+
7588
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
7689

7790
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
7891

7992
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)
93+
94+
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)

0 commit comments

Comments
 (0)