Skip to content

rustup; ptr atomics #2341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
049308cf8b48e9d67e54d6d0b01c10c79d1efc3a
7665c3543079ebc3710b676d0fd6951bedfd4b29
4 changes: 1 addition & 3 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx

/// Test if this pointer equals 0.
fn ptr_is_null(&self, ptr: Pointer<Option<Tag>>) -> InterpResult<'tcx, bool> {
let this = self.eval_context_ref();
let null = Scalar::null_ptr(this);
this.ptr_eq(Scalar::from_maybe_pointer(ptr, this), null)
Ok(ptr.addr().bytes() == 0)
}

/// Get the `Place` for a local
Expand Down
66 changes: 32 additions & 34 deletions src/operator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use log::trace;

use rustc_middle::{mir, ty::Ty};
use rustc_target::abi::Size;

use crate::*;

Expand All @@ -11,8 +12,6 @@ pub trait EvalContextExt<'tcx> {
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, (Scalar<Tag>, bool, Ty<'tcx>)>;

fn ptr_eq(&self, left: Scalar<Tag>, right: Scalar<Tag>) -> InterpResult<'tcx, bool>;
}

impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriEvalContext<'mir, 'tcx> {
Expand All @@ -27,23 +26,8 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriEvalContext<'mir, 'tcx> {
trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right);

Ok(match bin_op {
Eq | Ne => {
// This supports fat pointers.
#[rustfmt::skip]
let eq = match (**left, **right) {
(Immediate::Scalar(left), Immediate::Scalar(right)) => {
self.ptr_eq(left.check_init()?, right.check_init()?)?
}
(Immediate::ScalarPair(left1, left2), Immediate::ScalarPair(right1, right2)) => {
self.ptr_eq(left1.check_init()?, right1.check_init()?)?
&& self.ptr_eq(left2.check_init()?, right2.check_init()?)?
}
_ => bug!("Type system should not allow comparing Scalar with ScalarPair"),
};
(Scalar::from_bool(if bin_op == Eq { eq } else { !eq }), false, self.tcx.types.bool)
}

Lt | Le | Gt | Ge => {
Eq | Ne | Lt | Le | Gt | Ge => {
assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for`
let size = self.pointer_size();
// Just compare the bits. ScalarPairs are compared lexicographically.
// We thus always compare pairs and simply fill scalars up with 0.
Expand All @@ -58,35 +42,49 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriEvalContext<'mir, 'tcx> {
(r1.check_init()?.to_bits(size)?, r2.check_init()?.to_bits(size)?),
};
let res = match bin_op {
Eq => left == right,
Ne => left != right,
Lt => left < right,
Le => left <= right,
Gt => left > right,
Ge => left >= right,
_ => bug!("We already established it has to be one of these operators."),
_ => bug!(),
};
(Scalar::from_bool(res), false, self.tcx.types.bool)
}

Offset => {
assert!(left.layout.ty.is_unsafe_ptr());
let ptr = self.scalar_to_ptr(left.to_scalar()?)?;
let offset = right.to_scalar()?.to_machine_isize(self)?;

let pointee_ty =
left.layout.ty.builtin_deref(true).expect("Offset called on non-ptr type").ty;
let ptr = self.ptr_offset_inbounds(
self.scalar_to_ptr(left.to_scalar()?)?,
pointee_ty,
right.to_scalar()?.to_machine_isize(self)?,
)?;
let ptr = self.ptr_offset_inbounds(ptr, pointee_ty, offset)?;
(Scalar::from_maybe_pointer(ptr, self), false, left.layout.ty)
}

_ => bug!("Invalid operator on pointers: {:?}", bin_op),
})
}
// Some more operations are possible with atomics.
// The return value always has the provenance of the *left* operand.
Add | Sub | BitOr | BitAnd | BitXor => {
assert!(left.layout.ty.is_unsafe_ptr());
assert!(right.layout.ty.is_unsafe_ptr());
let ptr = self.scalar_to_ptr(left.to_scalar()?)?;
// We do the actual operation with usize-typed scalars.
let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize);
let right = ImmTy::from_uint(
right.to_scalar()?.to_machine_usize(self)?,
self.machine.layouts.usize,
);
let (result, overflowing, _ty) =
self.overflowing_binary_op(bin_op, &left, &right)?;
// Construct a new pointer with the provenance of `ptr` (the LHS).
let result_ptr =
Pointer::new(ptr.provenance, Size::from_bytes(result.to_machine_usize(self)?));
(Scalar::from_maybe_pointer(result_ptr, self), overflowing, left.layout.ty)
}

fn ptr_eq(&self, left: Scalar<Tag>, right: Scalar<Tag>) -> InterpResult<'tcx, bool> {
let size = self.pointer_size();
// Just compare the integers.
let left = left.to_bits(size)?;
let right = right.to_bits(size)?;
Ok(left == right)
_ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
})
}
}
31 changes: 19 additions & 12 deletions src/shims/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
ty::Float(FloatTy::F64) =>
this.float_to_int_unchecked(val.to_scalar()?.to_f64()?, dest.layout.ty)?,
_ =>
bug!(
span_bug!(
this.cur_span(),
"`float_to_int_unchecked` called with non-float input type {:?}",
val.layout.ty
),
Expand Down Expand Up @@ -371,7 +372,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
Op::Abs => {
// Works for f32 and f64.
let ty::Float(float_ty) = op.layout.ty.kind() else {
bug!("{} operand is not a float", intrinsic_name)
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
let op = op.to_scalar()?;
match float_ty {
Expand All @@ -381,7 +382,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}
Op::HostOp(host_op) => {
let ty::Float(float_ty) = op.layout.ty.kind() else {
bug!("{} operand is not a float", intrinsic_name)
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
// FIXME using host floats
match float_ty {
Expand Down Expand Up @@ -546,7 +547,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx

// Works for f32 and f64.
let ty::Float(float_ty) = dest.layout.ty.kind() else {
bug!("{} operand is not a float", intrinsic_name)
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
let val = match float_ty {
FloatTy::F32 =>
Expand Down Expand Up @@ -763,7 +764,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx

// `index` is an array, not a SIMD type
let ty::Array(_, index_len) = index.layout.ty.kind() else {
bug!("simd_shuffle index argument has non-array type {}", index.layout.ty)
span_bug!(this.cur_span(), "simd_shuffle index argument has non-array type {}", index.layout.ty)
};
let index_len = index_len.eval_usize(*this.tcx, this.param_env());

Expand All @@ -785,10 +786,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
&this.mplace_index(&right, src_index - left_len)?.into(),
)?
} else {
bug!(
"simd_shuffle index {} is out of bounds for 2 vectors of size {}",
src_index,
left_len
span_bug!(
this.cur_span(),
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
);
};
this.write_immediate(*val, &dest.into())?;
Expand Down Expand Up @@ -1187,8 +1187,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
let [place, rhs] = check_arg_count(args)?;
let place = this.deref_operand(place)?;

if !place.layout.ty.is_integral() {
bug!("Atomic arithmetic operations only work on integer types");
if !place.layout.ty.is_integral() && !place.layout.ty.is_unsafe_ptr() {
span_bug!(
this.cur_span(),
"atomic arithmetic operations only work on integer and raw pointer types",
);
}
let rhs = this.read_immediate(rhs)?;

Expand Down Expand Up @@ -1355,7 +1358,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}
}
// Nothing else
_ => bug!("`float_to_int_unchecked` called with non-int output type {dest_ty:?}"),
_ =>
span_bug!(
this.cur_span(),
"`float_to_int_unchecked` called with non-int output type {dest_ty:?}"
),
})
}
}
Expand Down
55 changes: 54 additions & 1 deletion tests/pass/atomic.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::sync::atomic::{compiler_fence, fence, AtomicBool, AtomicIsize, AtomicU64, Ordering::*};
// compile-flags: -Zmiri-strict-provenance
#![feature(strict_provenance, strict_provenance_atomic_ptr)]
use std::sync::atomic::{
compiler_fence, fence, AtomicBool, AtomicIsize, AtomicPtr, AtomicU64, Ordering::*,
};

fn main() {
atomic_bool();
atomic_all_ops();
atomic_u64();
atomic_fences();
atomic_ptr();
weak_sometimes_fails();
}

Expand Down Expand Up @@ -130,6 +135,54 @@ fn atomic_fences() {
compiler_fence(AcqRel);
}

fn atomic_ptr() {
use std::ptr;
let array: Vec<i32> = (0..100).into_iter().collect(); // a target to point to, to test provenance things
let x = array.as_ptr() as *mut i32;

let ptr = AtomicPtr::<i32>::new(ptr::null_mut());
assert!(ptr.load(Relaxed).addr() == 0);
ptr.store(ptr::invalid_mut(13), SeqCst);
assert!(ptr.swap(x, Relaxed).addr() == 13);
unsafe { assert!(*ptr.load(Acquire) == 0) };

// comparison ignores provenance
assert_eq!(
ptr.compare_exchange(
(&mut 0 as *mut i32).with_addr(x.addr()),
ptr::invalid_mut(0),
SeqCst,
SeqCst
)
.unwrap()
.addr(),
x.addr(),
);
assert_eq!(
ptr.compare_exchange(
(&mut 0 as *mut i32).with_addr(x.addr()),
ptr::invalid_mut(0),
SeqCst,
SeqCst
)
.unwrap_err()
.addr(),
0,
);
ptr.store(x, Relaxed);

assert_eq!(ptr.fetch_ptr_add(13, AcqRel).addr(), x.addr());
unsafe { assert_eq!(*ptr.load(SeqCst), 13) }; // points to index 13 now
assert_eq!(ptr.fetch_ptr_sub(4, AcqRel).addr(), x.addr() + 13 * 4);
unsafe { assert_eq!(*ptr.load(SeqCst), 9) };
assert_eq!(ptr.fetch_or(3, AcqRel).addr(), x.addr() + 9 * 4); // ptr is 4-aligned, so set the last 2 bits
assert_eq!(ptr.fetch_and(!3, AcqRel).addr(), (x.addr() + 9 * 4) | 3); // and unset them again
unsafe { assert_eq!(*ptr.load(SeqCst), 9) };
assert_eq!(ptr.fetch_xor(0xdeadbeef, AcqRel).addr(), x.addr() + 9 * 4);
assert_eq!(ptr.fetch_xor(0xdeadbeef, AcqRel).addr(), (x.addr() + 9 * 4) ^ 0xdeadbeef);
unsafe { assert_eq!(*ptr.load(SeqCst), 9) }; // after XORing twice with the same thing, we get our ptr back
}

fn weak_sometimes_fails() {
let atomic = AtomicBool::new(false);
let tries = 100;
Expand Down