Skip to content

Commit 4669c0d

Browse files
committed
Override carrying_mul_add in cg_llvm
1 parent 2c0c912 commit 4669c0d

File tree

5 files changed

+181
-2
lines changed

5 files changed

+181
-2
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

+31
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,37 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
340340
self.const_i32(cache_type),
341341
])
342342
}
343+
sym::carrying_mul_add => {
344+
let (size, signed) = fn_args.type_at(0).int_size_and_signed(self.tcx);
345+
346+
let wide_llty = self.type_ix(size.bits() * 2);
347+
let args = args.as_array().unwrap();
348+
let [a, b, c, d] = args.map(|a| self.intcast(a.immediate(), wide_llty, signed));
349+
350+
let wide = if signed {
351+
let prod = self.unchecked_smul(a, b);
352+
let acc = self.unchecked_sadd(prod, c);
353+
self.unchecked_sadd(acc, d)
354+
} else {
355+
let prod = self.unchecked_umul(a, b);
356+
let acc = self.unchecked_uadd(prod, c);
357+
self.unchecked_uadd(acc, d)
358+
};
359+
360+
let narrow_llty = self.type_ix(size.bits());
361+
let low = self.trunc(wide, narrow_llty);
362+
let bits_const = self.const_uint(wide_llty, size.bits());
363+
// No need for ashr when signed; LLVM changes it to lshr anyway.
364+
let high = self.lshr(wide, bits_const);
365+
// FIXME: could be `trunc nuw`, even for signed.
366+
let high = self.trunc(high, narrow_llty);
367+
368+
let pair_llty = self.type_struct(&[narrow_llty, narrow_llty], false);
369+
let pair = self.const_poison(pair_llty);
370+
let pair = self.insert_value(pair, low, 0);
371+
let pair = self.insert_value(pair, high, 1);
372+
pair
373+
}
343374
sym::ctlz
344375
| sym::ctlz_nonzero
345376
| sym::cttz

compiler/rustc_codegen_llvm/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#![feature(iter_intersperse)]
1818
#![feature(let_chains)]
1919
#![feature(rustdoc_internals)]
20+
#![feature(slice_as_array)]
2021
#![feature(try_blocks)]
2122
#![warn(unreachable_pub)]
2223
// tidy-alphabetical-end

library/core/src/intrinsics/fallback.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ impl const CarryingMulAdd for i128 {
100100
fn carrying_mul_add(self, b: i128, c: i128, d: i128) -> (u128, i128) {
101101
let (low, high) = wide_mul_u128(self as u128, b as u128);
102102
let mut high = high as i128;
103-
high = high.wrapping_add((self >> 127) * b);
104-
high = high.wrapping_add(self * (b >> 127));
103+
high = high.wrapping_add(i128::wrapping_mul(self >> 127, b));
104+
high = high.wrapping_add(i128::wrapping_mul(self, b >> 127));
105105
let (low, carry) = u128::overflowing_add(low, c as u128);
106106
high = high.wrapping_add((carry as i128) + (c >> 127));
107107
let (low, carry) = u128::overflowing_add(low, d as u128);

library/core/tests/intrinsics.rs

+10
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ fn carrying_mul_add_fallback_i32() {
153153

154154
#[test]
155155
fn carrying_mul_add_fallback_u128() {
156+
assert_eq!(fallback_cma::<u128>(u128::MAX, u128::MAX, 0, 0), (1, u128::MAX - 1));
156157
assert_eq!(fallback_cma::<u128>(1, 1, 1, 1), (3, 0));
157158
assert_eq!(fallback_cma::<u128>(0, 0, u128::MAX, u128::MAX), (u128::MAX - 1, 1));
158159
assert_eq!(
@@ -178,8 +179,17 @@ fn carrying_mul_add_fallback_u128() {
178179

179180
#[test]
180181
fn carrying_mul_add_fallback_i128() {
182+
assert_eq!(fallback_cma::<i128>(-1, -1, 0, 0), (1, 0));
181183
let r = fallback_cma::<i128>(-1, -1, -1, -1);
182184
assert_eq!(r, (u128::MAX, -1));
183185
let r = fallback_cma::<i128>(1, -1, 1, 1);
184186
assert_eq!(r, (1, 0));
187+
assert_eq!(
188+
fallback_cma::<i128>(i128::MAX, i128::MAX, i128::MAX, i128::MAX),
189+
(u128::MAX, i128::MAX / 2),
190+
);
191+
assert_eq!(
192+
fallback_cma::<i128>(i128::MIN, i128::MIN, i128::MAX, i128::MAX),
193+
(u128::MAX - 1, -(i128::MIN / 2)),
194+
);
185195
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//@ revisions: RAW OPT
2+
//@ compile-flags: -C opt-level=1
3+
//@[RAW] compile-flags: -C no-prepopulate-passes
4+
//@[OPT] min-llvm-version: 19
5+
6+
#![crate_type = "lib"]
7+
#![feature(core_intrinsics)]
8+
#![feature(core_intrinsics_fallbacks)]
9+
10+
// Note that LLVM seems to sometimes permute the order of arguments to mul and add,
11+
// so these tests don't check the arguments in the optimized revision.
12+
13+
use std::intrinsics::{carrying_mul_add, fallback};
14+
15+
// The fallbacks are emitted even when they're never used, but optimize out.
16+
17+
// RAW: wide_mul_u128
18+
// OPT-NOT: wide_mul_u128
19+
20+
// CHECK-LABEL: @cma_u8
21+
#[no_mangle]
22+
pub unsafe fn cma_u8(a: u8, b: u8, c: u8, d: u8) -> (u8, u8) {
23+
// CHECK: [[A:%.+]] = zext i8 %a to i16
24+
// CHECK: [[B:%.+]] = zext i8 %b to i16
25+
// CHECK: [[C:%.+]] = zext i8 %c to i16
26+
// CHECK: [[D:%.+]] = zext i8 %d to i16
27+
// CHECK: [[AB:%.+]] = mul nuw i16
28+
// RAW-SAME: [[A]], [[B]]
29+
// CHECK: [[ABC:%.+]] = add nuw i16
30+
// RAW-SAME: [[AB]], [[C]]
31+
// CHECK: [[ABCD:%.+]] = add nuw i16
32+
// RAW-SAME: [[ABC]], [[D]]
33+
// CHECK: [[LOW:%.+]] = trunc i16 [[ABCD]] to i8
34+
// CHECK: [[HIGHW:%.+]] = lshr i16 [[ABCD]], 8
35+
// RAW: [[HIGH:%.+]] = trunc i16 [[HIGHW]] to i8
36+
// OPT: [[HIGH:%.+]] = trunc nuw i16 [[HIGHW]] to i8
37+
// CHECK: [[PAIR0:%.+]] = insertvalue { i8, i8 } poison, i8 [[LOW]], 0
38+
// CHECK: [[PAIR1:%.+]] = insertvalue { i8, i8 } [[PAIR0]], i8 [[HIGH]], 1
39+
// OPT: ret { i8, i8 } [[PAIR1]]
40+
carrying_mul_add(a, b, c, d)
41+
}
42+
43+
// CHECK-LABEL: @cma_u32
44+
#[no_mangle]
45+
pub unsafe fn cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) {
46+
// CHECK: [[A:%.+]] = zext i32 %a to i64
47+
// CHECK: [[B:%.+]] = zext i32 %b to i64
48+
// CHECK: [[C:%.+]] = zext i32 %c to i64
49+
// CHECK: [[D:%.+]] = zext i32 %d to i64
50+
// CHECK: [[AB:%.+]] = mul nuw i64
51+
// RAW-SAME: [[A]], [[B]]
52+
// CHECK: [[ABC:%.+]] = add nuw i64
53+
// RAW-SAME: [[AB]], [[C]]
54+
// CHECK: [[ABCD:%.+]] = add nuw i64
55+
// RAW-SAME: [[ABC]], [[D]]
56+
// CHECK: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32
57+
// CHECK: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32
58+
// RAW: [[HIGH:%.+]] = trunc i64 [[HIGHW]] to i32
59+
// OPT: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32
60+
// CHECK: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0
61+
// CHECK: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1
62+
// OPT: ret { i32, i32 } [[PAIR1]]
63+
carrying_mul_add(a, b, c, d)
64+
}
65+
66+
// CHECK-LABEL: @cma_u128
67+
// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d
68+
#[no_mangle]
69+
pub unsafe fn cma_u128(a: u128, b: u128, c: u128, d: u128) -> (u128, u128) {
70+
// CHECK: [[A:%.+]] = zext i128 %a to i256
71+
// CHECK: [[B:%.+]] = zext i128 %b to i256
72+
// CHECK: [[C:%.+]] = zext i128 %c to i256
73+
// CHECK: [[D:%.+]] = zext i128 %d to i256
74+
// CHECK: [[AB:%.+]] = mul nuw i256
75+
// RAW-SAME: [[A]], [[B]]
76+
// CHECK: [[ABC:%.+]] = add nuw i256
77+
// RAW-SAME: [[AB]], [[C]]
78+
// CHECK: [[ABCD:%.+]] = add nuw i256
79+
// RAW-SAME: [[ABC]], [[D]]
80+
// CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128
81+
// CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128
82+
// RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128
83+
// OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128
84+
// RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0
85+
// RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1
86+
// OPT: store i128 [[LOW]], ptr %_0
87+
// OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, {{i32|i64}} 16
88+
// OPT: store i128 [[HIGH]], ptr [[P1]]
89+
// CHECK: ret void
90+
carrying_mul_add(a, b, c, d)
91+
}
92+
93+
// CHECK-LABEL: @cma_i128
94+
// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d
95+
#[no_mangle]
96+
pub unsafe fn cma_i128(a: i128, b: i128, c: i128, d: i128) -> (u128, i128) {
97+
// CHECK: [[A:%.+]] = sext i128 %a to i256
98+
// CHECK: [[B:%.+]] = sext i128 %b to i256
99+
// CHECK: [[C:%.+]] = sext i128 %c to i256
100+
// CHECK: [[D:%.+]] = sext i128 %d to i256
101+
// CHECK: [[AB:%.+]] = mul nsw i256
102+
// RAW-SAME: [[A]], [[B]]
103+
// CHECK: [[ABC:%.+]] = add nsw i256
104+
// RAW-SAME: [[AB]], [[C]]
105+
// CHECK: [[ABCD:%.+]] = add nsw i256
106+
// RAW-SAME: [[ABC]], [[D]]
107+
// CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128
108+
// CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128
109+
// RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128
110+
// OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128
111+
// RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0
112+
// RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1
113+
// OPT: store i128 [[LOW]], ptr %_0
114+
// OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, {{i32|i64}} 16
115+
// OPT: store i128 [[HIGH]], ptr [[P1]]
116+
// CHECK: ret void
117+
carrying_mul_add(a, b, c, d)
118+
}
119+
120+
// CHECK-LABEL: @fallback_cma_u32
121+
#[no_mangle]
122+
pub unsafe fn fallback_cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) {
123+
// OPT-DAG: [[A:%.+]] = zext i32 %a to i64
124+
// OPT-DAG: [[B:%.+]] = zext i32 %b to i64
125+
// OPT-DAG: [[AB:%.+]] = mul nuw i64
126+
// OPT-DAG: [[C:%.+]] = zext i32 %c to i64
127+
// OPT-DAG: [[ABC:%.+]] = add nuw i64{{.+}}[[C]]
128+
// OPT-DAG: [[D:%.+]] = zext i32 %d to i64
129+
// OPT-DAG: [[ABCD:%.+]] = add nuw i64{{.+}}[[D]]
130+
// OPT-DAG: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32
131+
// OPT-DAG: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32
132+
// OPT-DAG: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32
133+
// OPT-DAG: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0
134+
// OPT-DAG: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1
135+
// OPT-DAG: ret { i32, i32 } [[PAIR1]]
136+
fallback::CarryingMulAdd::carrying_mul_add(a, b, c, d)
137+
}

0 commit comments

Comments
 (0)