Skip to content

Commit a86b38e

Browse files
committed
Add f16 inline ASM support for RISC-V
1 parent 92af831 commit a86b38e

File tree

4 files changed

+146
-29
lines changed

4 files changed

+146
-29
lines changed

compiler/rustc_codegen_llvm/src/asm.rs

+87-24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::slice;
2+
13
use crate::attributes;
24
use crate::builder::Builder;
35
use crate::common::Funclet;
@@ -64,7 +66,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
6466
let mut layout = None;
6567
let ty = if let Some(ref place) = place {
6668
layout = Some(&place.layout);
67-
llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout)
69+
llvm_fixup_output_type(self, reg.reg_class(), &place.layout)
6870
} else if matches!(
6971
reg.reg_class(),
7072
InlineAsmRegClass::X86(
@@ -112,7 +114,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
112114
// so we just use the type of the input.
113115
&in_value.layout
114116
};
115-
let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout);
117+
let ty = llvm_fixup_output_type(self, reg.reg_class(), layout);
116118
output_types.push(ty);
117119
op_idx.insert(idx, constraints.len());
118120
let prefix = if late { "=" } else { "=&" };
@@ -913,6 +915,46 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty
913915
}
914916
}
915917

918+
fn function_target_features<'ll>(builder: &Builder<'_, 'll, '_>) -> impl Iterator<Item = &'ll str> {
919+
let llfn = builder.llfn();
920+
let key = "target-features";
921+
let attr = unsafe {
922+
llvm::LLVMGetStringAttributeAtIndex(
923+
llfn,
924+
llvm::AttributePlace::Function.as_uint(),
925+
key.as_ptr().cast(),
926+
key.len().try_into().unwrap(),
927+
)
928+
};
929+
let Some(attr) = attr else {
930+
return "".split(',');
931+
};
932+
let value = unsafe {
933+
let mut length = 0;
934+
let ptr = llvm::LLVMGetStringAttributeValue(attr, &mut length);
935+
slice::from_raw_parts(ptr.cast(), length.try_into().unwrap())
936+
};
937+
let Ok(value) = std::str::from_utf8(value) else {
938+
return "".split(',');
939+
};
940+
value.split(',')
941+
}
942+
943+
fn is_zfhmin_enabled(builder: &Builder<'_, '_, '_>) -> bool {
944+
let mut zfhmin_enabled = false;
945+
let mut zfh_enabled = false;
946+
for feature in function_target_features(builder) {
947+
match feature {
948+
"+zfhmin" => zfhmin_enabled = true,
949+
"-zfhmin" => zfhmin_enabled = false,
950+
"+zfh" => zfh_enabled = true,
951+
"-zfh" => zfh_enabled = false,
952+
_ => {}
953+
}
954+
}
955+
zfhmin_enabled || zfh_enabled
956+
}
957+
916958
/// Fix up an input value to work around LLVM bugs.
917959
fn llvm_fixup_input<'ll, 'tcx>(
918960
bx: &mut Builder<'_, 'll, 'tcx>,
@@ -1029,6 +1071,15 @@ fn llvm_fixup_input<'ll, 'tcx>(
10291071
_ => value,
10301072
}
10311073
}
1074+
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
1075+
if s.primitive() == Primitive::Float(Float::F16) && !is_zfhmin_enabled(bx) =>
1076+
{
1077+
// Smaller floats are always "NaN-boxed" inside larger floats on RISC-V.
1078+
let value = bx.bitcast(value, bx.type_i16());
1079+
let value = bx.zext(value, bx.type_i32());
1080+
let value = bx.or(value, bx.const_u32(0xFFFF_0000));
1081+
bx.bitcast(value, bx.type_f32())
1082+
}
10321083
_ => value,
10331084
}
10341085
}
@@ -1140,56 +1191,63 @@ fn llvm_fixup_output<'ll, 'tcx>(
11401191
_ => value,
11411192
}
11421193
}
1194+
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
1195+
if s.primitive() == Primitive::Float(Float::F16) && !is_zfhmin_enabled(bx) =>
1196+
{
1197+
let value = bx.bitcast(value, bx.type_i32());
1198+
let value = bx.trunc(value, bx.type_i16());
1199+
bx.bitcast(value, bx.type_f16())
1200+
}
11431201
_ => value,
11441202
}
11451203
}
11461204

11471205
/// Output type to use for llvm_fixup_output.
11481206
fn llvm_fixup_output_type<'ll, 'tcx>(
1149-
cx: &CodegenCx<'ll, 'tcx>,
1207+
bx: &Builder<'_, 'll, 'tcx>,
11501208
reg: InlineAsmRegClass,
11511209
layout: &TyAndLayout<'tcx>,
11521210
) -> &'ll Type {
11531211
match (reg, layout.abi) {
11541212
(InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
11551213
if let Primitive::Int(Integer::I8, _) = s.primitive() {
1156-
cx.type_vector(cx.type_i8(), 8)
1214+
bx.type_vector(bx.type_i8(), 8)
11571215
} else {
1158-
layout.llvm_type(cx)
1216+
layout.llvm_type(bx)
11591217
}
11601218
}
11611219
(InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg_low16), Abi::Scalar(s)) => {
1162-
let elem_ty = llvm_asm_scalar_type(cx, s);
1220+
let elem_ty = llvm_asm_scalar_type(bx, s);
11631221
let count = 16 / layout.size.bytes();
1164-
cx.type_vector(elem_ty, count)
1222+
bx.type_vector(elem_ty, count)
11651223
}
11661224
(
11671225
InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg_low16),
11681226
Abi::Vector { element, count },
11691227
) if layout.size.bytes() == 8 => {
1170-
let elem_ty = llvm_asm_scalar_type(cx, element);
1171-
cx.type_vector(elem_ty, count * 2)
1228+
let elem_ty = llvm_asm_scalar_type(bx, element);
1229+
bx.type_vector(elem_ty, count * 2)
11721230
}
11731231
(InlineAsmRegClass::X86(X86InlineAsmRegClass::reg_abcd), Abi::Scalar(s))
11741232
if s.primitive() == Primitive::Float(Float::F64) =>
11751233
{
1176-
cx.type_i64()
1234+
bx.type_i64()
11771235
}
11781236
(
11791237
InlineAsmRegClass::X86(X86InlineAsmRegClass::xmm_reg | X86InlineAsmRegClass::zmm_reg),
11801238
Abi::Vector { .. },
1181-
) if layout.size.bytes() == 64 => cx.type_vector(cx.type_f64(), 8),
1239+
) if layout.size.bytes() == 64 => bx.type_vector(bx.type_f64(), 8),
11821240
(
11831241
InlineAsmRegClass::X86(
11841242
X86InlineAsmRegClass::xmm_reg
11851243
| X86InlineAsmRegClass::ymm_reg
11861244
| X86InlineAsmRegClass::zmm_reg,
11871245
),
11881246
Abi::Scalar(s),
1189-
) if cx.sess().asm_arch == Some(InlineAsmArch::X86)
1247+
) if bx.sess().asm_arch == Some(InlineAsmArch::X86)
11901248
&& s.primitive() == Primitive::Float(Float::F128) =>
11911249
{
1192-
cx.type_vector(cx.type_i32(), 4)
1250+
bx.type_vector(bx.type_i32(), 4)
11931251
}
11941252
(
11951253
InlineAsmRegClass::X86(
@@ -1198,7 +1256,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
11981256
| X86InlineAsmRegClass::zmm_reg,
11991257
),
12001258
Abi::Scalar(s),
1201-
) if s.primitive() == Primitive::Float(Float::F16) => cx.type_vector(cx.type_i16(), 8),
1259+
) if s.primitive() == Primitive::Float(Float::F16) => bx.type_vector(bx.type_i16(), 8),
12021260
(
12031261
InlineAsmRegClass::X86(
12041262
X86InlineAsmRegClass::xmm_reg
@@ -1207,16 +1265,16 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
12071265
),
12081266
Abi::Vector { element, count: count @ (8 | 16) },
12091267
) if element.primitive() == Primitive::Float(Float::F16) => {
1210-
cx.type_vector(cx.type_i16(), count)
1268+
bx.type_vector(bx.type_i16(), count)
12111269
}
12121270
(
12131271
InlineAsmRegClass::Arm(ArmInlineAsmRegClass::sreg | ArmInlineAsmRegClass::sreg_low16),
12141272
Abi::Scalar(s),
12151273
) => {
12161274
if let Primitive::Int(Integer::I32, _) = s.primitive() {
1217-
cx.type_f32()
1275+
bx.type_f32()
12181276
} else {
1219-
layout.llvm_type(cx)
1277+
layout.llvm_type(bx)
12201278
}
12211279
}
12221280
(
@@ -1228,20 +1286,25 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
12281286
Abi::Scalar(s),
12291287
) => {
12301288
if let Primitive::Int(Integer::I64, _) = s.primitive() {
1231-
cx.type_f64()
1289+
bx.type_f64()
12321290
} else {
1233-
layout.llvm_type(cx)
1291+
layout.llvm_type(bx)
12341292
}
12351293
}
12361294
(InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
12371295
match s.primitive() {
12381296
// MIPS only supports register-length arithmetics.
1239-
Primitive::Int(Integer::I8 | Integer::I16, _) => cx.type_i32(),
1240-
Primitive::Float(Float::F32) => cx.type_i32(),
1241-
Primitive::Float(Float::F64) => cx.type_i64(),
1242-
_ => layout.llvm_type(cx),
1297+
Primitive::Int(Integer::I8 | Integer::I16, _) => bx.type_i32(),
1298+
Primitive::Float(Float::F32) => bx.type_i32(),
1299+
Primitive::Float(Float::F64) => bx.type_i64(),
1300+
_ => layout.llvm_type(bx),
12431301
}
12441302
}
1245-
_ => layout.llvm_type(cx),
1303+
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
1304+
if s.primitive() == Primitive::Float(Float::F16) && !is_zfhmin_enabled(bx) =>
1305+
{
1306+
bx.type_f32()
1307+
}
1308+
_ => layout.llvm_type(bx),
12461309
}
12471310
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+7
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,13 @@ extern "C" {
10001000
Value: *const c_char,
10011001
ValueLen: c_uint,
10021002
) -> &Attribute;
1003+
pub fn LLVMGetStringAttributeAtIndex(
1004+
F: &Value,
1005+
Idx: c_uint,
1006+
K: *const c_char,
1007+
KLen: c_uint,
1008+
) -> Option<&Attribute>;
1009+
pub fn LLVMGetStringAttributeValue(A: &Attribute, Length: &mut c_uint) -> *const c_char;
10031010

10041011
// Operations on functions
10051012
pub fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);

compiler/rustc_target/src/asm/riscv.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ impl RiscVInlineAsmRegClass {
4040
match self {
4141
Self::reg => {
4242
if arch == InlineAsmArch::RiscV64 {
43-
types! { _: I8, I16, I32, I64, F32, F64; }
43+
types! { _: I8, I16, I32, I64, F16, F32, F64; }
4444
} else {
45-
types! { _: I8, I16, I32, F32; }
45+
types! { _: I8, I16, I32, F16, F32; }
4646
}
4747
}
48-
Self::freg => types! { f: F32; d: F64; },
48+
// FIXME(f16_f128): Add `q: F128;` once LLVM support the `Q` extension.
49+
Self::freg => types! { f: F16, F32; d: F64; },
4950
Self::vreg => &[],
5051
}
5152
}

tests/assembly/asm/riscv-types.rs

+48-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
1-
//@ revisions: riscv64 riscv32
1+
//@ revisions: riscv64 riscv32 riscv64-zfhmin riscv32-zfhmin riscv64-zfh riscv32-zfh
22
//@ assembly-output: emit-asm
3+
34
//@[riscv64] compile-flags: --target riscv64imac-unknown-none-elf
45
//@[riscv64] needs-llvm-components: riscv
6+
57
//@[riscv32] compile-flags: --target riscv32imac-unknown-none-elf
68
//@[riscv32] needs-llvm-components: riscv
9+
10+
//@[riscv64-zfhmin] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
11+
//@[riscv64-zfhmin] needs-llvm-components: riscv
12+
//@[riscv64-zfhmin] compile-flags: -C target-feature=+zfhmin
13+
//@[riscv64-zfhmin] filecheck-flags: --check-prefix riscv64
14+
15+
//@[riscv32-zfhmin] compile-flags: --target riscv32imac-unknown-none-elf
16+
//@[riscv32-zfhmin] needs-llvm-components: riscv
17+
//@[riscv32-zfhmin] compile-flags: -C target-feature=+zfhmin
18+
19+
//@[riscv64-zfh] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
20+
//@[riscv64-zfh] needs-llvm-components: riscv
21+
//@[riscv64-zfh] compile-flags: -C target-feature=+zfh
22+
//@[riscv64-zfh] filecheck-flags: --check-prefix riscv64
23+
24+
//@[riscv32-zfh] compile-flags: --target riscv32imac-unknown-none-elf
25+
//@[riscv32-zfh] needs-llvm-components: riscv
26+
//@[riscv32-zfh] compile-flags: -C target-feature=+zfh
27+
728
//@ compile-flags: -C target-feature=+d
829

9-
#![feature(no_core, lang_items, rustc_attrs)]
30+
#![feature(no_core, lang_items, rustc_attrs, f16)]
1031
#![crate_type = "rlib"]
1132
#![no_core]
1233
#![allow(asm_sub_register)]
@@ -33,6 +54,7 @@ type ptr = *mut u8;
3354

3455
impl Copy for i8 {}
3556
impl Copy for i16 {}
57+
impl Copy for f16 {}
3658
impl Copy for i32 {}
3759
impl Copy for f32 {}
3860
impl Copy for i64 {}
@@ -103,6 +125,12 @@ macro_rules! check_reg {
103125
// CHECK: #NO_APP
104126
check!(reg_i8 i8 reg "mv");
105127

128+
// CHECK-LABEL: reg_f16:
129+
// CHECK: #APP
130+
// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
131+
// CHECK: #NO_APP
132+
check!(reg_f16 f16 reg "mv");
133+
106134
// CHECK-LABEL: reg_i16:
107135
// CHECK: #APP
108136
// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
@@ -141,6 +169,12 @@ check!(reg_f64 f64 reg "mv");
141169
// CHECK: #NO_APP
142170
check!(reg_ptr ptr reg "mv");
143171

172+
// CHECK-LABEL: freg_f16:
173+
// CHECK: #APP
174+
// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
175+
// CHECK: #NO_APP
176+
check!(freg_f16 f16 freg "fmv.s");
177+
144178
// CHECK-LABEL: freg_f32:
145179
// CHECK: #APP
146180
// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
@@ -165,6 +199,12 @@ check_reg!(a0_i8 i8 "a0" "mv");
165199
// CHECK: #NO_APP
166200
check_reg!(a0_i16 i16 "a0" "mv");
167201

202+
// CHECK-LABEL: a0_f16:
203+
// CHECK: #APP
204+
// CHECK: mv a0, a0
205+
// CHECK: #NO_APP
206+
check_reg!(a0_f16 f16 "a0" "mv");
207+
168208
// CHECK-LABEL: a0_i32:
169209
// CHECK: #APP
170210
// CHECK: mv a0, a0
@@ -197,6 +237,12 @@ check_reg!(a0_f64 f64 "a0" "mv");
197237
// CHECK: #NO_APP
198238
check_reg!(a0_ptr ptr "a0" "mv");
199239

240+
// CHECK-LABEL: fa0_f16:
241+
// CHECK: #APP
242+
// CHECK: fmv.s fa0, fa0
243+
// CHECK: #NO_APP
244+
check_reg!(fa0_f16 f16 "fa0" "fmv.s");
245+
200246
// CHECK-LABEL: fa0_f32:
201247
// CHECK: #APP
202248
// CHECK: fmv.s fa0, fa0

0 commit comments

Comments
 (0)