Skip to content

Commit e7c7aa7

Browse files
committed
Auto merge of #98332 - oli-obk:assume, r=wesleywiser
Lower the assume intrinsic to a MIR statement This makes #96862 (comment) easier and will generally allow us to cheaply insert assume intrinsic calls in mir building. r? rust-lang/wg-mir-opt
2 parents 0568b0a + a0130e6 commit e7c7aa7

37 files changed

+338
-570
lines changed

compiler/rustc_borrowck/src/dataflow.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl<'tcx> rustc_mir_dataflow::GenKillAnalysis<'tcx> for Borrows<'_, 'tcx> {
391391
| mir::StatementKind::Retag { .. }
392392
| mir::StatementKind::AscribeUserType(..)
393393
| mir::StatementKind::Coverage(..)
394-
| mir::StatementKind::CopyNonOverlapping(..)
394+
| mir::StatementKind::Intrinsic(..)
395395
| mir::StatementKind::Nop => {}
396396
}
397397
}

compiler/rustc_borrowck/src/invalidation.rs

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rustc_data_structures::graph::dominators::Dominators;
22
use rustc_middle::mir::visit::Visitor;
3-
use rustc_middle::mir::{BasicBlock, Body, Location, Place, Rvalue};
3+
use rustc_middle::mir::{self, BasicBlock, Body, Location, NonDivergingIntrinsic, Place, Rvalue};
44
use rustc_middle::mir::{BorrowKind, Mutability, Operand};
55
use rustc_middle::mir::{InlineAsmOperand, Terminator, TerminatorKind};
66
use rustc_middle::mir::{Statement, StatementKind};
@@ -63,23 +63,24 @@ impl<'cx, 'tcx> Visitor<'tcx> for InvalidationGenerator<'cx, 'tcx> {
6363
StatementKind::FakeRead(box (_, _)) => {
6464
// Only relevant for initialized/liveness/safety checks.
6565
}
66-
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
66+
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => {
67+
self.consume_operand(location, op);
68+
}
69+
StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(mir::CopyNonOverlapping {
6770
ref src,
6871
ref dst,
6972
ref count,
70-
}) => {
73+
})) => {
7174
self.consume_operand(location, src);
7275
self.consume_operand(location, dst);
7376
self.consume_operand(location, count);
7477
}
75-
StatementKind::Nop
78+
// Only relevant for mir typeck
79+
StatementKind::AscribeUserType(..)
80+
// Doesn't have any language semantics
7681
| StatementKind::Coverage(..)
77-
| StatementKind::AscribeUserType(..)
78-
| StatementKind::Retag { .. }
79-
| StatementKind::StorageLive(..) => {
80-
// `Nop`, `AscribeUserType`, `Retag`, and `StorageLive` are irrelevant
81-
// to borrow check.
82-
}
82+
// Does not actually affect borrowck
83+
| StatementKind::StorageLive(..) => {}
8384
StatementKind::StorageDead(local) => {
8485
self.access_place(
8586
location,
@@ -88,7 +89,10 @@ impl<'cx, 'tcx> Visitor<'tcx> for InvalidationGenerator<'cx, 'tcx> {
8889
LocalMutationIsAllowed::Yes,
8990
);
9091
}
91-
StatementKind::Deinit(..) | StatementKind::SetDiscriminant { .. } => {
92+
StatementKind::Nop
93+
| StatementKind::Retag { .. }
94+
| StatementKind::Deinit(..)
95+
| StatementKind::SetDiscriminant { .. } => {
9296
bug!("Statement not allowed in this MIR phase")
9397
}
9498
}

compiler/rustc_borrowck/src/lib.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ use rustc_index::bit_set::ChunkedBitSet;
2626
use rustc_index::vec::IndexVec;
2727
use rustc_infer::infer::{DefiningAnchor, InferCtxt, TyCtxtInferExt};
2828
use rustc_middle::mir::{
29-
traversal, Body, ClearCrossCrate, Local, Location, Mutability, Operand, Place, PlaceElem,
30-
PlaceRef, VarDebugInfoContents,
29+
traversal, Body, ClearCrossCrate, Local, Location, Mutability, NonDivergingIntrinsic, Operand,
30+
Place, PlaceElem, PlaceRef, VarDebugInfoContents,
3131
};
3232
use rustc_middle::mir::{AggregateKind, BasicBlock, BorrowCheckResult, BorrowKind};
3333
use rustc_middle::mir::{Field, ProjectionElem, Promoted, Rvalue, Statement, StatementKind};
@@ -591,22 +591,19 @@ impl<'cx, 'tcx> rustc_mir_dataflow::ResultsVisitor<'cx, 'tcx> for MirBorrowckCtx
591591
flow_state,
592592
);
593593
}
594-
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
595-
..
596-
}) => {
597-
span_bug!(
594+
StatementKind::Intrinsic(box ref kind) => match kind {
595+
NonDivergingIntrinsic::Assume(op) => self.consume_operand(location, (op, span), flow_state),
596+
NonDivergingIntrinsic::CopyNonOverlapping(..) => span_bug!(
598597
span,
599598
"Unexpected CopyNonOverlapping, should only appear after lower_intrinsics",
600599
)
601600
}
602-
StatementKind::Nop
601+
// Only relevant for mir typeck
602+
StatementKind::AscribeUserType(..)
603+
// Doesn't have any language semantics
603604
| StatementKind::Coverage(..)
604-
| StatementKind::AscribeUserType(..)
605-
| StatementKind::Retag { .. }
606-
| StatementKind::StorageLive(..) => {
607-
// `Nop`, `AscribeUserType`, `Retag`, and `StorageLive` are irrelevant
608-
// to borrow check.
609-
}
605+
// Does not actually affect borrowck
606+
| StatementKind::StorageLive(..) => {}
610607
StatementKind::StorageDead(local) => {
611608
self.access_place(
612609
location,
@@ -616,7 +613,10 @@ impl<'cx, 'tcx> rustc_mir_dataflow::ResultsVisitor<'cx, 'tcx> for MirBorrowckCtx
616613
flow_state,
617614
);
618615
}
619-
StatementKind::Deinit(..) | StatementKind::SetDiscriminant { .. } => {
616+
StatementKind::Nop
617+
| StatementKind::Retag { .. }
618+
| StatementKind::Deinit(..)
619+
| StatementKind::SetDiscriminant { .. } => {
620620
bug!("Statement not allowed in this MIR phase")
621621
}
622622
}

compiler/rustc_borrowck/src/type_check/mod.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -1302,12 +1302,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
13021302
);
13031303
}
13041304
}
1305-
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
1306-
..
1307-
}) => span_bug!(
1308-
stmt.source_info.span,
1309-
"Unexpected StatementKind::CopyNonOverlapping, should only appear after lowering_intrinsics",
1310-
),
1305+
StatementKind::Intrinsic(box ref kind) => match kind {
1306+
NonDivergingIntrinsic::Assume(op) => self.check_operand(op, location),
1307+
NonDivergingIntrinsic::CopyNonOverlapping(..) => span_bug!(
1308+
stmt.source_info.span,
1309+
"Unexpected NonDivergingIntrinsic::CopyNonOverlapping, should only appear after lowering_intrinsics",
1310+
),
1311+
},
13111312
StatementKind::FakeRead(..)
13121313
| StatementKind::StorageLive(..)
13131314
| StatementKind::StorageDead(..)

compiler/rustc_codegen_cranelift/src/base.rs

+25-14
Original file line numberDiff line numberDiff line change
@@ -794,20 +794,31 @@ fn codegen_stmt<'tcx>(
794794
| StatementKind::AscribeUserType(..) => {}
795795

796796
StatementKind::Coverage { .. } => fx.tcx.sess.fatal("-Zcoverage is unimplemented"),
797-
StatementKind::CopyNonOverlapping(inner) => {
798-
let dst = codegen_operand(fx, &inner.dst);
799-
let pointee = dst
800-
.layout()
801-
.pointee_info_at(fx, rustc_target::abi::Size::ZERO)
802-
.expect("Expected pointer");
803-
let dst = dst.load_scalar(fx);
804-
let src = codegen_operand(fx, &inner.src).load_scalar(fx);
805-
let count = codegen_operand(fx, &inner.count).load_scalar(fx);
806-
let elem_size: u64 = pointee.size.bytes();
807-
let bytes =
808-
if elem_size != 1 { fx.bcx.ins().imul_imm(count, elem_size as i64) } else { count };
809-
fx.bcx.call_memcpy(fx.target_config, dst, src, bytes);
810-
}
797+
StatementKind::Intrinsic(ref intrinsic) => match &**intrinsic {
798+
// We ignore `assume` intrinsics, they are only useful for optimizations
799+
NonDivergingIntrinsic::Assume(_) => {}
800+
NonDivergingIntrinsic::CopyNonOverlapping(mir::CopyNonOverlapping {
801+
src,
802+
dst,
803+
count,
804+
}) => {
805+
let dst = codegen_operand(fx, dst);
806+
let pointee = dst
807+
.layout()
808+
.pointee_info_at(fx, rustc_target::abi::Size::ZERO)
809+
.expect("Expected pointer");
810+
let dst = dst.load_scalar(fx);
811+
let src = codegen_operand(fx, src).load_scalar(fx);
812+
let count = codegen_operand(fx, count).load_scalar(fx);
813+
let elem_size: u64 = pointee.size.bytes();
814+
let bytes = if elem_size != 1 {
815+
fx.bcx.ins().imul_imm(count, elem_size as i64)
816+
} else {
817+
count
818+
};
819+
fx.bcx.call_memcpy(fx.target_config, dst, src, bytes);
820+
}
821+
},
811822
}
812823
}
813824

compiler/rustc_codegen_cranelift/src/constant.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,11 @@ pub(crate) fn mir_operand_get_const_val<'tcx>(
536536
{
537537
return None;
538538
}
539-
StatementKind::CopyNonOverlapping(_) => {
540-
return None;
541-
} // conservative handling
539+
StatementKind::Intrinsic(ref intrinsic) => match **intrinsic {
540+
NonDivergingIntrinsic::CopyNonOverlapping(..) => return None,
541+
NonDivergingIntrinsic::Assume(..) => {}
542+
},
543+
// conservative handling
542544
StatementKind::Assign(_)
543545
| StatementKind::FakeRead(_)
544546
| StatementKind::SetDiscriminant { .. }

compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs

-3
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@ fn codegen_regular_intrinsic_call<'tcx>(
357357
let usize_layout = fx.layout_of(fx.tcx.types.usize);
358358

359359
match intrinsic {
360-
sym::assume => {
361-
intrinsic_args!(fx, args => (_a); intrinsic);
362-
}
363360
sym::likely | sym::unlikely => {
364361
intrinsic_args!(fx, args => (a); intrinsic);
365362

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

-4
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
7777
let result = PlaceRef::new_sized(llresult, fn_abi.ret.layout);
7878

7979
let llval = match name {
80-
sym::assume => {
81-
bx.assume(args[0].immediate());
82-
return;
83-
}
8480
sym::abort => {
8581
bx.abort();
8682
return;

compiler/rustc_codegen_ssa/src/mir/statement.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use rustc_middle::mir;
2+
use rustc_middle::mir::NonDivergingIntrinsic;
23

34
use super::FunctionCx;
45
use super::LocalRef;
@@ -73,11 +74,14 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
7374
self.codegen_coverage(&mut bx, coverage.clone(), statement.source_info.scope);
7475
bx
7576
}
76-
mir::StatementKind::CopyNonOverlapping(box mir::CopyNonOverlapping {
77-
ref src,
78-
ref dst,
79-
ref count,
80-
}) => {
77+
mir::StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(ref op)) => {
78+
let op_val = self.codegen_operand(&mut bx, op);
79+
bx.assume(op_val.immediate());
80+
bx
81+
}
82+
mir::StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(
83+
mir::CopyNonOverlapping { ref count, ref src, ref dst },
84+
)) => {
8185
let dst_val = self.codegen_operand(&mut bx, dst);
8286
let src_val = self.codegen_operand(&mut bx, src);
8387
let count = self.codegen_operand(&mut bx, count).immediate();

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+27-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_hir::def_id::DefId;
88
use rustc_middle::mir::{
99
self,
1010
interpret::{ConstValue, GlobalId, InterpResult, PointerArithmetic, Scalar},
11-
BinOp,
11+
BinOp, NonDivergingIntrinsic,
1212
};
1313
use rustc_middle::ty;
1414
use rustc_middle::ty::layout::LayoutOf as _;
@@ -506,12 +506,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
506506
// These just return their argument
507507
self.copy_op(&args[0], dest, /*allow_transmute*/ false)?;
508508
}
509-
sym::assume => {
510-
let cond = self.read_scalar(&args[0])?.to_bool()?;
511-
if !cond {
512-
throw_ub_format!("`assume` intrinsic called with `false`");
513-
}
514-
}
515509
sym::raw_eq => {
516510
let result = self.raw_eq_intrinsic(&args[0], &args[1])?;
517511
self.write_scalar(result, dest)?;
@@ -536,6 +530,32 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
536530
Ok(true)
537531
}
538532

533+
pub(super) fn emulate_nondiverging_intrinsic(
534+
&mut self,
535+
intrinsic: &NonDivergingIntrinsic<'tcx>,
536+
) -> InterpResult<'tcx> {
537+
match intrinsic {
538+
NonDivergingIntrinsic::Assume(op) => {
539+
let op = self.eval_operand(op, None)?;
540+
let cond = self.read_scalar(&op)?.to_bool()?;
541+
if !cond {
542+
throw_ub_format!("`assume` called with `false`");
543+
}
544+
Ok(())
545+
}
546+
NonDivergingIntrinsic::CopyNonOverlapping(mir::CopyNonOverlapping {
547+
count,
548+
src,
549+
dst,
550+
}) => {
551+
let src = self.eval_operand(src, None)?;
552+
let dst = self.eval_operand(dst, None)?;
553+
let count = self.eval_operand(count, None)?;
554+
self.copy_intrinsic(&src, &dst, &count, /* nonoverlapping */ true)
555+
}
556+
}
557+
}
558+
539559
pub fn exact_div(
540560
&mut self,
541561
a: &ImmTy<'tcx, M::Provenance>,

compiler/rustc_const_eval/src/interpret/step.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
114114
M::retag(self, *kind, &dest)?;
115115
}
116116

117-
// Call CopyNonOverlapping
118-
CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping { src, dst, count }) => {
119-
let src = self.eval_operand(src, None)?;
120-
let dst = self.eval_operand(dst, None)?;
121-
let count = self.eval_operand(count, None)?;
122-
self.copy_intrinsic(&src, &dst, &count, /* nonoverlapping */ true)?;
123-
}
117+
Intrinsic(box ref intrinsic) => self.emulate_nondiverging_intrinsic(intrinsic)?,
124118

125119
// Statements we do not track.
126120
AscribeUserType(..) => {}

compiler/rustc_const_eval/src/transform/check_consts/check.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
678678
| StatementKind::Retag { .. }
679679
| StatementKind::AscribeUserType(..)
680680
| StatementKind::Coverage(..)
681-
| StatementKind::CopyNonOverlapping(..)
681+
| StatementKind::Intrinsic(..)
682682
| StatementKind::Nop => {}
683683
}
684684
}

compiler/rustc_const_eval/src/transform/validate.rs

+16-8
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ use rustc_middle::mir::interpret::Scalar;
77
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
88
use rustc_middle::mir::visit::{PlaceContext, Visitor};
99
use rustc_middle::mir::{
10-
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, Local, Location,
11-
MirPass, MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, RuntimePhase, Rvalue,
12-
SourceScope, Statement, StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
10+
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, CopyNonOverlapping,
11+
Local, Location, MirPass, MirPhase, NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef,
12+
ProjectionElem, RuntimePhase, Rvalue, SourceScope, Statement, StatementKind, Terminator,
13+
TerminatorKind, UnOp, START_BLOCK,
1314
};
1415
use rustc_middle::ty::fold::BottomUpFolder;
1516
use rustc_middle::ty::subst::Subst;
@@ -636,11 +637,18 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
636637
);
637638
}
638639
}
639-
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
640-
ref src,
641-
ref dst,
642-
ref count,
643-
}) => {
640+
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => {
641+
let ty = op.ty(&self.body.local_decls, self.tcx);
642+
if !ty.is_bool() {
643+
self.fail(
644+
location,
645+
format!("`assume` argument must be `bool`, but got: `{}`", ty),
646+
);
647+
}
648+
}
649+
StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(
650+
CopyNonOverlapping { src, dst, count },
651+
)) => {
644652
let src_ty = src.ty(&self.body.local_decls, self.tcx);
645653
let op_src_ty = if let Some(src_deref) = src_ty.builtin_deref(true) {
646654
src_deref.ty

compiler/rustc_middle/src/mir/mod.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -1362,13 +1362,7 @@ impl Debug for Statement<'_> {
13621362
write!(fmt, "Coverage::{:?} for {:?}", kind, rgn)
13631363
}
13641364
Coverage(box ref coverage) => write!(fmt, "Coverage::{:?}", coverage.kind),
1365-
CopyNonOverlapping(box crate::mir::CopyNonOverlapping {
1366-
ref src,
1367-
ref dst,
1368-
ref count,
1369-
}) => {
1370-
write!(fmt, "copy_nonoverlapping(src={:?}, dst={:?}, count={:?})", src, dst, count)
1371-
}
1365+
Intrinsic(box ref intrinsic) => write!(fmt, "{intrinsic}"),
13721366
Nop => write!(fmt, "nop"),
13731367
}
13741368
}

compiler/rustc_middle/src/mir/spanview.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ pub fn statement_kind_name(statement: &Statement<'_>) -> &'static str {
249249
Retag(..) => "Retag",
250250
AscribeUserType(..) => "AscribeUserType",
251251
Coverage(..) => "Coverage",
252-
CopyNonOverlapping(..) => "CopyNonOverlapping",
252+
Intrinsic(..) => "Intrinsic",
253253
Nop => "Nop",
254254
}
255255
}

0 commit comments

Comments
 (0)