Skip to content

Commit 7131806

Browse files
authored
Rollup merge of rust-lang#90999 - RalfJung:miri_simd, r=oli-obk
fix CTFE/Miri simd_insert/extract on array-style repr(simd) types The changed test would previously fail since `place_index` would just return the only field of `f32x4`, i.e., the array -- rather than *indexing into* the array which is what we have to do. The new helper methods will also be needed for rust-lang/miri#1912. r? `@oli-obk`
2 parents d1662ee + 0304e16 commit 7131806

File tree

5 files changed

+76
-38
lines changed

5 files changed

+76
-38
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+16-31
Original file line numberDiff line numberDiff line change
@@ -413,48 +413,33 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
413413
sym::simd_insert => {
414414
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
415415
let elem = &args[2];
416-
let input = &args[0];
417-
let (len, e_ty) = input.layout.ty.simd_size_and_type(*self.tcx);
416+
let (input, input_len) = self.operand_to_simd(&args[0])?;
417+
let (dest, dest_len) = self.place_to_simd(dest)?;
418+
assert_eq!(input_len, dest_len, "Return vector length must match input length");
418419
assert!(
419-
index < len,
420-
"Index `{}` must be in bounds of vector type `{}`: `[0, {})`",
420+
index < dest_len,
421+
"Index `{}` must be in bounds of vector with length {}`",
421422
index,
422-
e_ty,
423-
len
424-
);
425-
assert_eq!(
426-
input.layout, dest.layout,
427-
"Return type `{}` must match vector type `{}`",
428-
dest.layout.ty, input.layout.ty
429-
);
430-
assert_eq!(
431-
elem.layout.ty, e_ty,
432-
"Scalar element type `{}` must match vector element type `{}`",
433-
elem.layout.ty, e_ty
423+
dest_len
434424
);
435425

436-
for i in 0..len {
437-
let place = self.place_index(dest, i)?;
438-
let value = if i == index { *elem } else { self.operand_index(input, i)? };
439-
self.copy_op(&value, &place)?;
426+
for i in 0..dest_len {
427+
let place = self.mplace_index(&dest, i)?;
428+
let value =
429+
if i == index { *elem } else { self.mplace_index(&input, i)?.into() };
430+
self.copy_op(&value, &place.into())?;
440431
}
441432
}
442433
sym::simd_extract => {
443434
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
444-
let (len, e_ty) = args[0].layout.ty.simd_size_and_type(*self.tcx);
435+
let (input, input_len) = self.operand_to_simd(&args[0])?;
445436
assert!(
446-
index < len,
447-
"index `{}` is out-of-bounds of vector type `{}` with length `{}`",
437+
index < input_len,
438+
"index `{}` must be in bounds of vector with length `{}`",
448439
index,
449-
e_ty,
450-
len
451-
);
452-
assert_eq!(
453-
e_ty, dest.layout.ty,
454-
"Return type `{}` must match vector element type `{}`",
455-
dest.layout.ty, e_ty
440+
input_len
456441
);
457-
self.copy_op(&self.operand_index(&args[0], index)?, dest)?;
442+
self.copy_op(&self.mplace_index(&input, index)?.into(), dest)?;
458443
}
459444
sym::likely | sym::unlikely | sym::black_box => {
460445
// These just return their argument

compiler/rustc_const_eval/src/interpret/operand.rs

+12
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,18 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
437437
})
438438
}
439439

440+
/// Converts a repr(simd) operand into an operand where `place_index` accesses the SIMD elements.
441+
/// Also returns the number of elements.
442+
pub fn operand_to_simd(
443+
&self,
444+
base: &OpTy<'tcx, M::PointerTag>,
445+
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> {
446+
// Basically we just transmute this place into an array following simd_size_and_type.
447+
// This only works in memory, but repr(simd) types should never be immediates anyway.
448+
assert!(base.layout.ty.is_simd());
449+
self.mplace_to_simd(&base.assert_mem_place())
450+
}
451+
440452
/// Read from a local. Will not actually access the local if reading from a ZST.
441453
/// Will not access memory, instead an indirect `Operand` is returned.
442454
///

compiler/rustc_const_eval/src/interpret/place.rs

+27-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ impl<'tcx, Tag: Provenance> MPlaceTy<'tcx, Tag> {
200200
}
201201
} else {
202202
// Go through the layout. There are lots of types that support a length,
203-
// e.g., SIMD types.
203+
// e.g., SIMD types. (But not all repr(simd) types even have FieldsShape::Array!)
204204
match self.layout.fields {
205205
FieldsShape::Array { count, .. } => Ok(count),
206206
_ => bug!("len not supported on sized type {:?}", self.layout.ty),
@@ -533,6 +533,22 @@ where
533533
})
534534
}
535535

536+
/// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
537+
/// Also returns the number of elements.
538+
pub fn mplace_to_simd(
539+
&self,
540+
base: &MPlaceTy<'tcx, M::PointerTag>,
541+
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> {
542+
// Basically we just transmute this place into an array following simd_size_and_type.
543+
// (Transmuting is okay since this is an in-memory place. We also double-check the size
544+
// stays the same.)
545+
let (len, e_ty) = base.layout.ty.simd_size_and_type(*self.tcx);
546+
let array = self.tcx.mk_array(e_ty, len);
547+
let layout = self.layout_of(array)?;
548+
assert_eq!(layout.size, base.layout.size);
549+
Ok((MPlaceTy { layout, ..*base }, len))
550+
}
551+
536552
/// Gets the place of a field inside the place, and also the field's type.
537553
/// Just a convenience function, but used quite a bit.
538554
/// This is the only projection that might have a side-effect: We cannot project
@@ -594,6 +610,16 @@ where
594610
})
595611
}
596612

613+
/// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
614+
/// Also returns the number of elements.
615+
pub fn place_to_simd(
616+
&mut self,
617+
base: &PlaceTy<'tcx, M::PointerTag>,
618+
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> {
619+
let mplace = self.force_allocation(base)?;
620+
self.mplace_to_simd(&mplace)
621+
}
622+
597623
/// Computes a place. You should only use this if you intend to write into this
598624
/// place; for reading, a more efficient alternative is `eval_place_for_read`.
599625
pub fn eval_place(

compiler/rustc_middle/src/ty/sty.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1805,17 +1805,22 @@ impl<'tcx> TyS<'tcx> {
18051805
pub fn simd_size_and_type(&self, tcx: TyCtxt<'tcx>) -> (u64, Ty<'tcx>) {
18061806
match self.kind() {
18071807
Adt(def, substs) => {
1808+
assert!(def.repr.simd(), "`simd_size_and_type` called on non-SIMD type");
18081809
let variant = def.non_enum_variant();
18091810
let f0_ty = variant.fields[0].ty(tcx, substs);
18101811

18111812
match f0_ty.kind() {
1813+
// If the first field is an array, we assume it is the only field and its
1814+
// elements are the SIMD components.
18121815
Array(f0_elem_ty, f0_len) => {
18131816
// FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112
18141817
// The way we evaluate the `N` in `[T; N]` here only works since we use
18151818
// `simd_size_and_type` post-monomorphization. It will probably start to ICE
18161819
// if we use it in generic code. See the `simd-array-trait` ui test.
18171820
(f0_len.eval_usize(tcx, ParamEnv::empty()) as u64, f0_elem_ty)
18181821
}
1822+
// Otherwise, the fields of this Adt are the SIMD components (and we assume they
1823+
// all have the same type).
18191824
_ => (variant.fields.len() as u64, f0_ty),
18201825
}
18211826
}

src/test/ui/consts/const-eval/simd/insert_extract.rs

+16-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
#[repr(simd)] struct i8x1(i8);
99
#[repr(simd)] struct u16x2(u16, u16);
10-
#[repr(simd)] struct f32x4(f32, f32, f32, f32);
10+
// Make some of them array types to ensure those also work.
11+
#[repr(simd)] struct i8x1_arr([i8; 1]);
12+
#[repr(simd)] struct f32x4([f32; 4]);
1113

1214
extern "platform-intrinsic" {
1315
#[rustc_const_stable(feature = "foo", since = "1.3.37")]
@@ -25,6 +27,14 @@ fn main() {
2527
assert_eq!(X0, 42);
2628
assert_eq!(Y0, 42);
2729
}
30+
{
31+
const U: i8x1_arr = i8x1_arr([13]);
32+
const V: i8x1_arr = unsafe { simd_insert(U, 0_u32, 42_i8) };
33+
const X0: i8 = V.0[0];
34+
const Y0: i8 = unsafe { simd_extract(V, 0) };
35+
assert_eq!(X0, 42);
36+
assert_eq!(Y0, 42);
37+
}
2838
{
2939
const U: u16x2 = u16x2(13, 14);
3040
const V: u16x2 = unsafe { simd_insert(U, 1_u32, 42_u16) };
@@ -38,12 +48,12 @@ fn main() {
3848
assert_eq!(Y1, 42);
3949
}
4050
{
41-
const U: f32x4 = f32x4(13., 14., 15., 16.);
51+
const U: f32x4 = f32x4([13., 14., 15., 16.]);
4252
const V: f32x4 = unsafe { simd_insert(U, 1_u32, 42_f32) };
43-
const X0: f32 = V.0;
44-
const X1: f32 = V.1;
45-
const X2: f32 = V.2;
46-
const X3: f32 = V.3;
53+
const X0: f32 = V.0[0];
54+
const X1: f32 = V.0[1];
55+
const X2: f32 = V.0[2];
56+
const X3: f32 = V.0[3];
4757
const Y0: f32 = unsafe { simd_extract(V, 0) };
4858
const Y1: f32 = unsafe { simd_extract(V, 1) };
4959
const Y2: f32 = unsafe { simd_extract(V, 2) };

0 commit comments

Comments
 (0)