Skip to content

Commit b87a9c9

Browse files
committed
fix handling of NaNs in simd max/min
1 parent 2f97eb6 commit b87a9c9

File tree

3 files changed

+79
-32
lines changed

3 files changed

+79
-32
lines changed

rust-version

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8876ca3dd46b99fe7e6ad937f11493d37996231e
1+
297273c45b205820a4c055082c71677197a40b55

src/shims/intrinsics.rs

+52-31
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
345345
bug!("simd_fabs operand is not a float")
346346
};
347347
let op = op.to_scalar()?;
348-
// FIXME: Using host floats.
349348
match float_ty {
350349
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
351350
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
@@ -438,12 +437,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
438437
}
439438
}
440439
Op::FMax => {
441-
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
442-
this.max_op(&left, &right)?.to_scalar()?
440+
fmax_op(&left, &right)?
443441
}
444442
Op::FMin => {
445-
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
446-
this.min_op(&left, &right)?.to_scalar()?
443+
fmin_op(&left, &right)?
447444
}
448445
};
449446
this.write_scalar(val, &dest.into())?;
@@ -499,10 +496,28 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
499496
this.binary_op(mir_op, &res, &op)?
500497
}
501498
Op::Max => {
502-
this.max_op(&res, &op)?
499+
if matches!(res.layout.ty.kind(), ty::Float(_)) {
500+
ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
501+
} else {
502+
// Just boring integers, so NaNs to worry about
503+
if this.binary_op(BinOp::Ge, &res, &op)?.to_scalar()?.to_bool()? {
504+
res
505+
} else {
506+
op
507+
}
508+
}
503509
}
504510
Op::Min => {
505-
this.min_op(&res, &op)?
511+
if matches!(res.layout.ty.kind(), ty::Float(_)) {
512+
ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
513+
} else {
514+
// Just boring integers, so NaNs to worry about
515+
if this.binary_op(BinOp::Le, &res, &op)?.to_scalar()?.to_bool()? {
516+
res
517+
} else {
518+
op
519+
}
520+
}
506521
}
507522
};
508523
}
@@ -1078,30 +1093,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
10781093
_ => bug!("`float_to_int_unchecked` called with non-int output type {:?}", dest_ty),
10791094
})
10801095
}
1096+
}
10811097

1082-
fn max_op(
1083-
&self,
1084-
left: &ImmTy<'tcx, Tag>,
1085-
right: &ImmTy<'tcx, Tag>,
1086-
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
1087-
let this = self.eval_context_ref();
1088-
Ok(if this.binary_op(BinOp::Gt, left, right)?.to_scalar()?.to_bool()? {
1089-
*left
1090-
} else {
1091-
*right
1092-
})
1093-
}
1098+
fn fmax_op<'tcx>(
1099+
left: &ImmTy<'tcx, Tag>,
1100+
right: &ImmTy<'tcx, Tag>,
1101+
) -> InterpResult<'tcx, Scalar<Tag>> {
1102+
assert_eq!(left.layout.ty, right.layout.ty);
1103+
let ty::Float(float_ty) = left.layout.ty.kind() else {
1104+
bug!("fmax operand is not a float")
1105+
};
1106+
let left = left.to_scalar()?;
1107+
let right = right.to_scalar()?;
1108+
Ok(match float_ty {
1109+
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
1110+
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
1111+
})
1112+
}
10941113

1095-
fn min_op(
1096-
&self,
1097-
left: &ImmTy<'tcx, Tag>,
1098-
right: &ImmTy<'tcx, Tag>,
1099-
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
1100-
let this = self.eval_context_ref();
1101-
Ok(if this.binary_op(BinOp::Lt, left, right)?.to_scalar()?.to_bool()? {
1102-
*left
1103-
} else {
1104-
*right
1105-
})
1106-
}
1114+
fn fmin_op<'tcx>(
1115+
left: &ImmTy<'tcx, Tag>,
1116+
right: &ImmTy<'tcx, Tag>,
1117+
) -> InterpResult<'tcx, Scalar<Tag>> {
1118+
assert_eq!(left.layout.ty, right.layout.ty);
1119+
let ty::Float(float_ty) = left.layout.ty.kind() else {
1120+
bug!("fmin operand is not a float")
1121+
};
1122+
let left = left.to_scalar()?;
1123+
let right = right.to_scalar()?;
1124+
Ok(match float_ty {
1125+
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
1126+
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),
1127+
})
11071128
}

tests/run-pass/portable-simd.rs

+26
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ fn simd_ops_f32() {
3030
assert_eq!(b.horizontal_max(), 3.0);
3131
assert_eq!(a.horizontal_min(), 10.0);
3232
assert_eq!(b.horizontal_min(), -4.0);
33+
34+
assert_eq!(
35+
f32x2::from_array([0.0, f32::NAN]).max(f32x2::from_array([f32::NAN, 0.0])),
36+
f32x2::from_array([0.0, 0.0])
37+
);
38+
assert_eq!(f32x2::from_array([0.0, f32::NAN]).horizontal_max(), 0.0);
39+
assert_eq!(f32x2::from_array([f32::NAN, 0.0]).horizontal_max(), 0.0);
40+
assert_eq!(
41+
f32x2::from_array([0.0, f32::NAN]).min(f32x2::from_array([f32::NAN, 0.0])),
42+
f32x2::from_array([0.0, 0.0])
43+
);
44+
assert_eq!(f32x2::from_array([0.0, f32::NAN]).horizontal_min(), 0.0);
45+
assert_eq!(f32x2::from_array([f32::NAN, 0.0]).horizontal_min(), 0.0);
3346
}
3447

3548
fn simd_ops_f64() {
@@ -61,6 +74,19 @@ fn simd_ops_f64() {
6174
assert_eq!(b.horizontal_max(), 3.0);
6275
assert_eq!(a.horizontal_min(), 10.0);
6376
assert_eq!(b.horizontal_min(), -4.0);
77+
78+
assert_eq!(
79+
f64x2::from_array([0.0, f64::NAN]).max(f64x2::from_array([f64::NAN, 0.0])),
80+
f64x2::from_array([0.0, 0.0])
81+
);
82+
assert_eq!(f64x2::from_array([0.0, f64::NAN]).horizontal_max(), 0.0);
83+
assert_eq!(f64x2::from_array([f64::NAN, 0.0]).horizontal_max(), 0.0);
84+
assert_eq!(
85+
f64x2::from_array([0.0, f64::NAN]).min(f64x2::from_array([f64::NAN, 0.0])),
86+
f64x2::from_array([0.0, 0.0])
87+
);
88+
assert_eq!(f64x2::from_array([0.0, f64::NAN]).horizontal_min(), 0.0);
89+
assert_eq!(f64x2::from_array([f64::NAN, 0.0]).horizontal_min(), 0.0);
6490
}
6591

6692
fn simd_ops_i32() {

0 commit comments

Comments
 (0)