Skip to content

Commit 49729b5

Browse files
committed
Auto merge of #2028 - RalfJung:simd-round, r=RalfJung
implement SIMD float rounding functions Cc #1912
2 parents 39c72db + 1f237b3 commit 49729b5

File tree

2 files changed

+121
-10
lines changed

2 files changed

+121
-10
lines changed

Diff for: src/shims/intrinsics.rs

+60-6
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
9090
}
9191

9292
// Floating-point operations
93+
"fabsf32" => {
94+
let &[ref f] = check_arg_count(args)?;
95+
let f = this.read_scalar(f)?.to_f32()?;
96+
// Can be implemented in soft-floats.
97+
this.write_scalar(Scalar::from_f32(f.abs()), dest)?;
98+
}
99+
"fabsf64" => {
100+
let &[ref f] = check_arg_count(args)?;
101+
let f = this.read_scalar(f)?.to_f64()?;
102+
// Can be implemented in soft-floats.
103+
this.write_scalar(Scalar::from_f64(f.abs()), dest)?;
104+
}
93105
#[rustfmt::skip]
94106
| "sinf32"
95-
| "fabsf32"
96107
| "cosf32"
97108
| "sqrtf32"
98109
| "expf32"
@@ -110,7 +121,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
110121
let f = f32::from_bits(this.read_scalar(f)?.to_u32()?);
111122
let f = match intrinsic_name {
112123
"sinf32" => f.sin(),
113-
"fabsf32" => f.abs(),
114124
"cosf32" => f.cos(),
115125
"sqrtf32" => f.sqrt(),
116126
"expf32" => f.exp(),
@@ -129,7 +139,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
129139

130140
#[rustfmt::skip]
131141
| "sinf64"
132-
| "fabsf64"
133142
| "cosf64"
134143
| "sqrtf64"
135144
| "expf64"
@@ -147,7 +156,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
147156
let f = f64::from_bits(this.read_scalar(f)?.to_u64()?);
148157
let f = match intrinsic_name {
149158
"sinf64" => f.sin(),
150-
"fabsf64" => f.abs(),
151159
"cosf64" => f.cos(),
152160
"sqrtf64" => f.sqrt(),
153161
"expf64" => f.exp(),
@@ -317,20 +325,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
317325
// SIMD operations
318326
#[rustfmt::skip]
319327
| "simd_neg"
320-
| "simd_fabs" => {
328+
| "simd_fabs"
329+
| "simd_ceil"
330+
| "simd_floor"
331+
| "simd_round"
332+
| "simd_trunc" => {
321333
let &[ref op] = check_arg_count(args)?;
322334
let (op, op_len) = this.operand_to_simd(op)?;
323335
let (dest, dest_len) = this.place_to_simd(dest)?;
324336

325337
assert_eq!(dest_len, op_len);
326338

339+
#[derive(Copy, Clone)]
340+
enum HostFloatOp {
341+
Ceil,
342+
Floor,
343+
Round,
344+
Trunc,
345+
}
346+
#[derive(Copy, Clone)]
327347
enum Op {
328348
MirOp(mir::UnOp),
329349
Abs,
350+
HostOp(HostFloatOp),
330351
}
331352
let which = match intrinsic_name {
332353
"simd_neg" => Op::MirOp(mir::UnOp::Neg),
333354
"simd_fabs" => Op::Abs,
355+
"simd_ceil" => Op::HostOp(HostFloatOp::Ceil),
356+
"simd_floor" => Op::HostOp(HostFloatOp::Floor),
357+
"simd_round" => Op::HostOp(HostFloatOp::Round),
358+
"simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
334359
_ => unreachable!(),
335360
};
336361

@@ -342,14 +367,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
342367
Op::Abs => {
343368
// Works for f32 and f64.
344369
let ty::Float(float_ty) = op.layout.ty.kind() else {
345-
bug!("simd_fabs operand is not a float")
370+
bug!("{} operand is not a float", intrinsic_name)
346371
};
347372
let op = op.to_scalar()?;
348373
match float_ty {
349374
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
350375
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
351376
}
352377
}
378+
Op::HostOp(host_op) => {
379+
let ty::Float(float_ty) = op.layout.ty.kind() else {
380+
bug!("{} operand is not a float", intrinsic_name)
381+
};
382+
// FIXME using host floats
383+
match float_ty {
384+
FloatTy::F32 => {
385+
let f = f32::from_bits(op.to_scalar()?.to_u32()?);
386+
let res = match host_op {
387+
HostFloatOp::Ceil => f.ceil(),
388+
HostFloatOp::Floor => f.floor(),
389+
HostFloatOp::Round => f.round(),
390+
HostFloatOp::Trunc => f.trunc(),
391+
};
392+
Scalar::from_u32(res.to_bits())
393+
}
394+
FloatTy::F64 => {
395+
let f = f64::from_bits(op.to_scalar()?.to_u64()?);
396+
let res = match host_op {
397+
HostFloatOp::Ceil => f.ceil(),
398+
HostFloatOp::Floor => f.floor(),
399+
HostFloatOp::Round => f.round(),
400+
HostFloatOp::Trunc => f.trunc(),
401+
};
402+
Scalar::from_u64(res.to_bits())
403+
}
404+
}
405+
406+
}
353407
};
354408
this.write_scalar(val, &dest.into())?;
355409
}

Diff for: tests/run-pass/portable-simd.rs

+61-4
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,39 @@ fn simd_ops_i32() {
106106
assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16]));
107107

108108
assert_eq!(
109-
i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([1, i8::MIN, i8::MAX, 28])),
109+
i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([
110+
1,
111+
i8::MIN,
112+
i8::MAX,
113+
28
114+
])),
110115
i8x4::from_array([i8::MAX, i8::MIN, i8::MAX, -100])
111116
);
112117
assert_eq!(
113-
i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([1, i8::MAX, i8::MAX, -80])),
118+
i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([
119+
1,
120+
i8::MAX,
121+
i8::MAX,
122+
-80
123+
])),
114124
i8x4::from_array([126, i8::MIN, -100, 122])
115125
);
116126
assert_eq!(
117-
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([1, 1, u8::MAX, 200])),
127+
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([
128+
1,
129+
1,
130+
u8::MAX,
131+
200
132+
])),
118133
u8x4::from_array([u8::MAX, 1, u8::MAX, 242])
119134
);
120135
assert_eq!(
121-
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([1, 1, u8::MAX, 200])),
136+
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([
137+
1,
138+
1,
139+
u8::MAX,
140+
200
141+
])),
122142
u8x4::from_array([254, 0, 0, 0])
123143
);
124144

@@ -259,6 +279,42 @@ fn simd_gather_scatter() {
259279
assert_eq!(vec, vec![124, 11, 12, 82, 14, 15, 16, 17, 18]);
260280
}
261281

282+
fn simd_round() {
283+
assert_eq!(
284+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
285+
f32x4::from_array([1.0, 2.0, 2.0, -4.0])
286+
);
287+
assert_eq!(
288+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
289+
f32x4::from_array([0.0, 1.0, 2.0, -5.0])
290+
);
291+
assert_eq!(
292+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
293+
f32x4::from_array([1.0, 1.0, 2.0, -5.0])
294+
);
295+
assert_eq!(
296+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
297+
f32x4::from_array([0.0, 1.0, 2.0, -4.0])
298+
);
299+
300+
assert_eq!(
301+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
302+
f64x4::from_array([1.0, 2.0, 2.0, -4.0])
303+
);
304+
assert_eq!(
305+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
306+
f64x4::from_array([0.0, 1.0, 2.0, -5.0])
307+
);
308+
assert_eq!(
309+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
310+
f64x4::from_array([1.0, 1.0, 2.0, -5.0])
311+
);
312+
assert_eq!(
313+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
314+
f64x4::from_array([0.0, 1.0, 2.0, -4.0])
315+
);
316+
}
317+
262318
fn simd_intrinsics() {
263319
extern "platform-intrinsic" {
264320
fn simd_eq<T, U>(x: T, y: T) -> U;
@@ -299,5 +355,6 @@ fn main() {
299355
simd_cast();
300356
simd_swizzle();
301357
simd_gather_scatter();
358+
simd_round();
302359
simd_intrinsics();
303360
}

0 commit comments

Comments
 (0)