Skip to content

Commit 2f97eb6

Browse files
committed
implement simd_fmax/fmin
1 parent 9851b74 commit 2f97eb6

File tree

2 files changed

+103
-59
lines changed

2 files changed

+103
-59
lines changed

src/shims/intrinsics.rs

+86-53
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
371371
| "simd_lt"
372372
| "simd_le"
373373
| "simd_gt"
374-
| "simd_ge" => {
374+
| "simd_ge"
375+
| "simd_fmax"
376+
| "simd_fmin" => {
375377
use mir::BinOp;
376378

377379
let &[ref left, ref right] = check_arg_count(args)?;
@@ -382,50 +384,69 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
382384
assert_eq!(dest_len, left_len);
383385
assert_eq!(dest_len, right_len);
384386

385-
let mir_op = match intrinsic_name {
386-
"simd_add" => BinOp::Add,
387-
"simd_sub" => BinOp::Sub,
388-
"simd_mul" => BinOp::Mul,
389-
"simd_div" => BinOp::Div,
390-
"simd_rem" => BinOp::Rem,
391-
"simd_shl" => BinOp::Shl,
392-
"simd_shr" => BinOp::Shr,
393-
"simd_and" => BinOp::BitAnd,
394-
"simd_or" => BinOp::BitOr,
395-
"simd_xor" => BinOp::BitXor,
396-
"simd_eq" => BinOp::Eq,
397-
"simd_ne" => BinOp::Ne,
398-
"simd_lt" => BinOp::Lt,
399-
"simd_le" => BinOp::Le,
400-
"simd_gt" => BinOp::Gt,
401-
"simd_ge" => BinOp::Ge,
387+
enum Op {
388+
MirOp(BinOp),
389+
FMax,
390+
FMin,
391+
}
392+
let which = match intrinsic_name {
393+
"simd_add" => Op::MirOp(BinOp::Add),
394+
"simd_sub" => Op::MirOp(BinOp::Sub),
395+
"simd_mul" => Op::MirOp(BinOp::Mul),
396+
"simd_div" => Op::MirOp(BinOp::Div),
397+
"simd_rem" => Op::MirOp(BinOp::Rem),
398+
"simd_shl" => Op::MirOp(BinOp::Shl),
399+
"simd_shr" => Op::MirOp(BinOp::Shr),
400+
"simd_and" => Op::MirOp(BinOp::BitAnd),
401+
"simd_or" => Op::MirOp(BinOp::BitOr),
402+
"simd_xor" => Op::MirOp(BinOp::BitXor),
403+
"simd_eq" => Op::MirOp(BinOp::Eq),
404+
"simd_ne" => Op::MirOp(BinOp::Ne),
405+
"simd_lt" => Op::MirOp(BinOp::Lt),
406+
"simd_le" => Op::MirOp(BinOp::Le),
407+
"simd_gt" => Op::MirOp(BinOp::Gt),
408+
"simd_ge" => Op::MirOp(BinOp::Ge),
409+
"simd_fmax" => Op::FMax,
410+
"simd_fmin" => Op::FMin,
402411
_ => unreachable!(),
403412
};
404413

405414
for i in 0..dest_len {
406415
let left = this.read_immediate(&this.mplace_index(&left, i)?.into())?;
407416
let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?;
408417
let dest = this.mplace_index(&dest, i)?;
409-
let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
410-
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
411-
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
412-
// See <https://github.com/rust-lang/rust/issues/91237>.
413-
if overflowed {
414-
let r_val = right.to_scalar()?.to_bits(right.layout.size)?;
415-
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
418+
let val = match which {
419+
Op::MirOp(mir_op) => {
420+
let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
421+
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
422+
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
423+
// See <https://github.com/rust-lang/rust/issues/91237>.
424+
if overflowed {
425+
let r_val = right.to_scalar()?.to_bits(right.layout.size)?;
426+
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
427+
}
428+
}
429+
if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
430+
// Special handling for boolean-returning operations
431+
assert_eq!(ty, this.tcx.types.bool);
432+
let val = val.to_bool().unwrap();
433+
bool_to_simd_element(val, dest.layout.size)
434+
} else {
435+
assert_ne!(ty, this.tcx.types.bool);
436+
assert_eq!(ty, dest.layout.ty);
437+
val
438+
}
439+
}
440+
Op::FMax => {
441+
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
442+
this.max_op(&left, &right)?.to_scalar()?
416443
}
417-
}
418-
if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
419-
// Special handling for boolean-returning operations
420-
assert_eq!(ty, this.tcx.types.bool);
421-
let val = val.to_bool().unwrap();
422-
let val = bool_to_simd_element(val, dest.layout.size);
423-
this.write_scalar(val, &dest.into())?;
424-
} else {
425-
assert_ne!(ty, this.tcx.types.bool);
426-
assert_eq!(ty, dest.layout.ty);
427-
this.write_scalar(val, &dest.into())?;
428-
}
444+
Op::FMin => {
445+
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
446+
this.min_op(&left, &right)?.to_scalar()?
447+
}
448+
};
449+
this.write_scalar(val, &dest.into())?;
429450
}
430451
}
431452
#[rustfmt::skip]
@@ -478,24 +499,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
478499
this.binary_op(mir_op, &res, &op)?
479500
}
480501
Op::Max => {
481-
// if `op > res`...
482-
if this.binary_op(BinOp::Gt, &op, &res)?.to_scalar()?.to_bool()? {
483-
// update accumulator
484-
op
485-
} else {
486-
// no change
487-
res
488-
}
502+
this.max_op(&res, &op)?
489503
}
490504
Op::Min => {
491-
// if `op < res`...
492-
if this.binary_op(BinOp::Lt, &op, &res)?.to_scalar()?.to_bool()? {
493-
// update accumulator
494-
op
495-
} else {
496-
// no change
497-
res
498-
}
505+
this.min_op(&res, &op)?
499506
}
500507
};
501508
}
@@ -1071,4 +1078,30 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
10711078
_ => bug!("`float_to_int_unchecked` called with non-int output type {:?}", dest_ty),
10721079
})
10731080
}
1081+
1082+
fn max_op(
1083+
&self,
1084+
left: &ImmTy<'tcx, Tag>,
1085+
right: &ImmTy<'tcx, Tag>,
1086+
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
1087+
let this = self.eval_context_ref();
1088+
Ok(if this.binary_op(BinOp::Gt, left, right)?.to_scalar()?.to_bool()? {
1089+
*left
1090+
} else {
1091+
*right
1092+
})
1093+
}
1094+
1095+
fn min_op(
1096+
&self,
1097+
left: &ImmTy<'tcx, Tag>,
1098+
right: &ImmTy<'tcx, Tag>,
1099+
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
1100+
let this = self.eval_context_ref();
1101+
Ok(if this.binary_op(BinOp::Lt, left, right)?.to_scalar()?.to_bool()? {
1102+
*left
1103+
} else {
1104+
*right
1105+
})
1106+
}
10741107
}

tests/run-pass/portable-simd.rs

+17-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ fn simd_ops_f32() {
1212
assert_eq!(a / f32x4::splat(2.0), f32x4::splat(5.0));
1313
assert_eq!(a % b, f32x4::from_array([0.0, 0.0, 1.0, 2.0]));
1414
assert_eq!(b.abs(), f32x4::from_array([1.0, 2.0, 3.0, 4.0]));
15+
assert_eq!(a.max(b * f32x4::splat(4.0)), f32x4::from_array([10.0, 10.0, 12.0, 10.0]));
16+
assert_eq!(a.min(b * f32x4::splat(4.0)), f32x4::from_array([4.0, 8.0, 10.0, -16.0]));
1517

1618
assert_eq!(a.lanes_eq(f32x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
1719
assert_eq!(a.lanes_ne(f32x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
@@ -41,6 +43,8 @@ fn simd_ops_f64() {
4143
assert_eq!(a / f64x4::splat(2.0), f64x4::splat(5.0));
4244
assert_eq!(a % b, f64x4::from_array([0.0, 0.0, 1.0, 2.0]));
4345
assert_eq!(b.abs(), f64x4::from_array([1.0, 2.0, 3.0, 4.0]));
46+
assert_eq!(a.max(b * f64x4::splat(4.0)), f64x4::from_array([10.0, 10.0, 12.0, 10.0]));
47+
assert_eq!(a.min(b * f64x4::splat(4.0)), f64x4::from_array([4.0, 8.0, 10.0, -16.0]));
4448

4549
assert_eq!(a.lanes_eq(f64x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
4650
assert_eq!(a.lanes_ne(f64x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
@@ -71,6 +75,12 @@ fn simd_ops_i32() {
7175
assert_eq!(i32x2::splat(i32::MIN) / i32x2::splat(-1), i32x2::splat(i32::MIN));
7276
assert_eq!(a % b, i32x4::from_array([0, 0, 1, 2]));
7377
assert_eq!(i32x2::splat(i32::MIN) % i32x2::splat(-1), i32x2::splat(0));
78+
assert_eq!(b.abs(), i32x4::from_array([1, 2, 3, 4]));
79+
// FIXME not a per-lane method (https://github.com/rust-lang/rust/issues/94682)
80+
// assert_eq!(a.max(b * i32x4::splat(4)), i32x4::from_array([10, 10, 12, 10]));
81+
// assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16]));
82+
83+
assert_eq!(!b, i32x4::from_array([!1, !2, !3, !-4]));
7484
assert_eq!(b << i32x4::splat(2), i32x4::from_array([4, 8, 12, -16]));
7585
assert_eq!(b >> i32x4::splat(1), i32x4::from_array([0, 1, 1, -2]));
7686
assert_eq!(b & i32x4::splat(2), i32x4::from_array([0, 2, 2, 0]));
@@ -84,12 +94,6 @@ fn simd_ops_i32() {
8494
assert_eq!(a.lanes_ge(i32x4::splat(5) * b), Mask::from_array([true, true, false, true]));
8595
assert_eq!(a.lanes_gt(i32x4::splat(5) * b), Mask::from_array([true, false, false, true]));
8696

87-
assert_eq!(a.horizontal_and(), 10);
88-
assert_eq!(b.horizontal_and(), 0);
89-
assert_eq!(a.horizontal_or(), 10);
90-
assert_eq!(b.horizontal_or(), -1);
91-
assert_eq!(a.horizontal_xor(), 0);
92-
assert_eq!(b.horizontal_xor(), -4);
9397
assert_eq!(a.horizontal_sum(), 40);
9498
assert_eq!(b.horizontal_sum(), 2);
9599
assert_eq!(a.horizontal_product(), 100 * 100);
@@ -98,6 +102,13 @@ fn simd_ops_i32() {
98102
assert_eq!(b.horizontal_max(), 3);
99103
assert_eq!(a.horizontal_min(), 10);
100104
assert_eq!(b.horizontal_min(), -4);
105+
106+
assert_eq!(a.horizontal_and(), 10);
107+
assert_eq!(b.horizontal_and(), 0);
108+
assert_eq!(a.horizontal_or(), 10);
109+
assert_eq!(b.horizontal_or(), -1);
110+
assert_eq!(a.horizontal_xor(), 0);
111+
assert_eq!(b.horizontal_xor(), -4);
101112
}
102113

103114
fn simd_mask() {

0 commit comments

Comments
 (0)