Skip to content

Commit 814d252

Browse files
committed
Auto merge of #75562 - oli-obk:const_prop_no_aggregates, r=wesleywiser
Check that we don't use `Rvalue::Aggregate` after the deaggregator fixes #75481 r? @wesleywiser cc @RalfJung (modified the validator)
2 parents 5fff382 + dcc2027 commit 814d252

20 files changed

+274
-83
lines changed

src/librustc_middle/mir/mod.rs

+23-3
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,35 @@ impl<'tcx> HasLocalDecls<'tcx> for Body<'tcx> {
7373

7474
/// The various "big phases" that MIR goes through.
7575
///
76+
/// These phases all describe dialects of MIR. Since all MIR uses the same datastructures, the
77+
/// dialects forbid certain variants or values in certain phases.
78+
///
79+
/// Note: Each phase's validation checks all invariants of the *previous* phases' dialects. A phase
80+
/// that changes the dialect documents what invariants must be upheld *after* that phase finishes.
81+
///
7682
/// Warning: ordering of variants is significant.
7783
#[derive(Copy, Clone, TyEncodable, TyDecodable, Debug, PartialEq, Eq, PartialOrd, Ord)]
7884
#[derive(HashStable)]
7985
pub enum MirPhase {
8086
Build = 0,
87+
// FIXME(oli-obk): it's unclear whether we still need this phase (and its corresponding query).
88+
// We used to have this for pre-miri MIR based const eval.
8189
Const = 1,
82-
Validated = 2,
83-
DropElab = 3,
84-
Optimized = 4,
90+
/// This phase checks the MIR for promotable elements and takes them out of the main MIR body
91+
/// by creating a new MIR body per promoted element. After this phase (and thus the termination
92+
/// of the `mir_promoted` query), these promoted elements are available in the `promoted_mir`
93+
/// query.
94+
ConstPromotion = 2,
95+
/// After this phase
96+
/// * the only `AggregateKind`s allowed are `Array` and `Generator`,
97+
/// * `DropAndReplace` is gone for good
98+
/// * `Drop` now uses explicit drop flags visible in the MIR and reaching a `Drop` terminator
99+
/// means that the auto-generated drop glue will be invoked.
100+
DropLowering = 3,
101+
/// After this phase, generators are explicit state machines (no more `Yield`).
102+
/// `AggregateKind::Generator` is gone for good.
103+
GeneratorLowering = 4,
104+
Optimization = 5,
85105
}
86106

87107
impl MirPhase {

src/librustc_middle/query/mod.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ rustc_queries! {
247247
desc { |tcx| "elaborating drops for `{}`", tcx.def_path_str(key.did.to_def_id()) }
248248
}
249249

250-
query mir_validated(key: ty::WithOptConstParam<LocalDefId>) ->
250+
query mir_promoted(key: ty::WithOptConstParam<LocalDefId>) ->
251251
(
252252
&'tcx Steal<mir::Body<'tcx>>,
253253
&'tcx Steal<IndexVec<mir::Promoted, mir::Body<'tcx>>>
@@ -281,6 +281,11 @@ rustc_queries! {
281281
cache_on_disk_if { key.is_local() }
282282
}
283283

284+
/// The `DefId` is the `DefId` of the containing MIR body. Promoteds do not have their own
285+
/// `DefId`. This function returns all promoteds in the specified body. The body references
286+
/// promoteds by the `DefId` and the `mir::Promoted` index. This is necessary, because
287+
/// after inlining a body may refer to promoteds from other bodies. In that case you still
288+
/// need to use the `DefId` of the original body.
284289
query promoted_mir(key: DefId) -> &'tcx IndexVec<mir::Promoted, mir::Body<'tcx>> {
285290
desc { |tcx| "optimizing promoted MIR for `{}`", tcx.def_path_str(key) }
286291
cache_on_disk_if { key.is_local() }

src/librustc_middle/ty/query/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ rustc_query_append! { [define_queries!][<'tcx>] }
133133
/// `DefPathHash` in the current codebase to the corresponding `DefId`, we have
134134
/// everything we need to re-run the query.
135135
///
136-
/// Take the `mir_validated` query as an example. Like many other queries, it
136+
/// Take the `mir_promoted` query as an example. Like many other queries, it
137137
/// just has a single parameter: the `DefId` of the item it will compute the
138138
/// validated MIR for. Now, when we call `force_from_dep_node()` on a `DepNode`
139139
/// with kind `MirValidated`, we know that the GUID/fingerprint of the `DepNode`

src/librustc_mir/borrow_check/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ fn mir_borrowck<'tcx>(
106106
tcx: TyCtxt<'tcx>,
107107
def: ty::WithOptConstParam<LocalDefId>,
108108
) -> &'tcx BorrowCheckResult<'tcx> {
109-
let (input_body, promoted) = tcx.mir_validated(def);
109+
let (input_body, promoted) = tcx.mir_promoted(def);
110110
debug!("run query mir_borrowck: {}", tcx.def_path_str(def.did.to_def_id()));
111111

112112
let opt_closure_req = tcx.infer_ctxt().enter(|infcx| {

src/librustc_mir/interpret/intern.rs

+24-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use super::validity::RefTracking;
77
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
88
use rustc_hir as hir;
99
use rustc_middle::mir::interpret::InterpResult;
10-
use rustc_middle::ty::{self, query::TyCtxtAt, Ty};
10+
use rustc_middle::ty::{self, layout::TyAndLayout, query::TyCtxtAt, Ty};
11+
use rustc_target::abi::Size;
1112

1213
use rustc_ast::Mutability;
1314

@@ -430,3 +431,25 @@ pub fn intern_const_alloc_recursive<M: CompileTimeMachine<'mir, 'tcx>>(
430431
}
431432
}
432433
}
434+
435+
impl<'mir, 'tcx: 'mir, M: super::intern::CompileTimeMachine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
436+
/// A helper function that allocates memory for the layout given and gives you access to mutate
437+
/// it. Once your own mutation code is done, the backing `Allocation` is removed from the
438+
/// current `Memory` and returned.
439+
pub(crate) fn intern_with_temp_alloc(
440+
&mut self,
441+
layout: TyAndLayout<'tcx>,
442+
f: impl FnOnce(
443+
&mut InterpCx<'mir, 'tcx, M>,
444+
MPlaceTy<'tcx, M::PointerTag>,
445+
) -> InterpResult<'tcx, ()>,
446+
) -> InterpResult<'tcx, &'tcx Allocation> {
447+
let dest = self.allocate(layout, MemoryKind::Stack);
448+
f(self, dest)?;
449+
let ptr = dest.ptr.assert_ptr();
450+
assert_eq!(ptr.offset, Size::ZERO);
451+
let mut alloc = self.memory.alloc_map.remove(&ptr.alloc_id).unwrap().1;
452+
alloc.mutability = Mutability::Not;
453+
Ok(self.tcx.intern_const_alloc(alloc))
454+
}
455+
}

src/librustc_mir/transform/const_prop.rs

+37-24
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ use rustc_middle::mir::visit::{
1414
MutVisitor, MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor,
1515
};
1616
use rustc_middle::mir::{
17-
AggregateKind, AssertKind, BasicBlock, BinOp, Body, ClearCrossCrate, Constant, Local,
18-
LocalDecl, LocalKind, Location, Operand, Place, Rvalue, SourceInfo, SourceScope,
19-
SourceScopeData, Statement, StatementKind, Terminator, TerminatorKind, UnOp, RETURN_PLACE,
17+
AssertKind, BasicBlock, BinOp, Body, ClearCrossCrate, Constant, Local, LocalDecl, LocalKind,
18+
Location, Operand, Place, Rvalue, SourceInfo, SourceScope, SourceScopeData, Statement,
19+
StatementKind, Terminator, TerminatorKind, UnOp, RETURN_PLACE,
2020
};
2121
use rustc_middle::ty::layout::{HasTyCtxt, LayoutError, TyAndLayout};
2222
use rustc_middle::ty::subst::{InternalSubsts, Subst};
@@ -28,9 +28,9 @@ use rustc_trait_selection::traits;
2828

2929
use crate::const_eval::ConstEvalErr;
3030
use crate::interpret::{
31-
self, compile_time_machine, truncate, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx,
32-
LocalState, LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy,
33-
Pointer, ScalarMaybeUninit, StackPopCleanup,
31+
self, compile_time_machine, truncate, AllocId, Allocation, ConstValue, Frame, ImmTy, Immediate,
32+
InterpCx, LocalState, LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand,
33+
PlaceTy, Pointer, ScalarMaybeUninit, StackPopCleanup,
3434
};
3535
use crate::transform::{MirPass, MirSource};
3636

@@ -824,44 +824,57 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
824824
));
825825
}
826826
Immediate::ScalarPair(
827-
ScalarMaybeUninit::Scalar(one),
828-
ScalarMaybeUninit::Scalar(two),
827+
ScalarMaybeUninit::Scalar(_),
828+
ScalarMaybeUninit::Scalar(_),
829829
) => {
830-
// Found a value represented as a pair. For now only do cont-prop if type of
831-
// Rvalue is also a pair with two scalars. The more general case is more
832-
// complicated to implement so we'll do it later.
833-
// FIXME: implement the general case stated above ^.
834-
let ty = &value.layout.ty.kind;
830+
// Found a value represented as a pair. For now only do const-prop if the type
831+
// of `rvalue` is also a tuple with two scalars.
832+
// FIXME: enable the general case stated above ^.
833+
let ty = &value.layout.ty;
835834
// Only do it for tuples
836-
if let ty::Tuple(substs) = ty {
835+
if let ty::Tuple(substs) = ty.kind {
837836
// Only do it if tuple is also a pair with two scalars
838837
if substs.len() == 2 {
839-
let opt_ty1_ty2 = self.use_ecx(|this| {
838+
let alloc = self.use_ecx(|this| {
840839
let ty1 = substs[0].expect_ty();
841840
let ty2 = substs[1].expect_ty();
842841
let ty_is_scalar = |ty| {
843842
this.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar())
844843
== Some(true)
845844
};
846845
if ty_is_scalar(ty1) && ty_is_scalar(ty2) {
847-
Ok(Some((ty1, ty2)))
846+
let alloc = this
847+
.ecx
848+
.intern_with_temp_alloc(value.layout, |ecx, dest| {
849+
ecx.write_immediate_to_mplace(*imm, dest)
850+
})
851+
.unwrap();
852+
Ok(Some(alloc))
848853
} else {
849854
Ok(None)
850855
}
851856
});
852857

853-
if let Some(Some((ty1, ty2))) = opt_ty1_ty2 {
854-
*rval = Rvalue::Aggregate(
855-
Box::new(AggregateKind::Tuple),
856-
vec![
857-
self.operand_from_scalar(one, ty1, source_info.span),
858-
self.operand_from_scalar(two, ty2, source_info.span),
859-
],
860-
);
858+
if let Some(Some(alloc)) = alloc {
859+
// Assign entire constant in a single statement.
860+
// We can't use aggregates, as we run after the aggregate-lowering `MirPhase`.
861+
*rval = Rvalue::Use(Operand::Constant(Box::new(Constant {
862+
span: source_info.span,
863+
user_ty: None,
864+
literal: self.ecx.tcx.mk_const(ty::Const {
865+
ty,
866+
val: ty::ConstKind::Value(ConstValue::ByRef {
867+
alloc,
868+
offset: Size::ZERO,
869+
}),
870+
}),
871+
})));
861872
}
862873
}
863874
}
864875
}
876+
// Scalars or scalar pairs that contain undef values are assumed to not have
877+
// successfully evaluated and are thus not propagated.
865878
_ => {}
866879
}
867880
}

src/librustc_mir/transform/generator.rs

+25-12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ use crate::transform::no_landing_pads::no_landing_pads;
5757
use crate::transform::simplify;
5858
use crate::transform::{MirPass, MirSource};
5959
use crate::util::dump_mir;
60+
use crate::util::expand_aggregate;
6061
use crate::util::storage;
6162
use rustc_data_structures::fx::FxHashMap;
6263
use rustc_hir as hir;
@@ -66,7 +67,7 @@ use rustc_index::bit_set::{BitMatrix, BitSet};
6667
use rustc_index::vec::{Idx, IndexVec};
6768
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
6869
use rustc_middle::mir::*;
69-
use rustc_middle::ty::subst::SubstsRef;
70+
use rustc_middle::ty::subst::{Subst, SubstsRef};
7071
use rustc_middle::ty::GeneratorSubsts;
7172
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
7273
use rustc_target::abi::VariantIdx;
@@ -236,10 +237,28 @@ struct TransformVisitor<'tcx> {
236237
}
237238

238239
impl TransformVisitor<'tcx> {
239-
// Make a GeneratorState rvalue
240-
fn make_state(&self, idx: VariantIdx, val: Operand<'tcx>) -> Rvalue<'tcx> {
241-
let adt = AggregateKind::Adt(self.state_adt_ref, idx, self.state_substs, None, None);
242-
Rvalue::Aggregate(box adt, vec![val])
240+
// Make a GeneratorState variant assignment. `core::ops::GeneratorState` only has single
241+
// element tuple variants, so we can just write to the downcasted first field and then set the
242+
// discriminant to the appropriate variant.
243+
fn make_state(
244+
&self,
245+
idx: VariantIdx,
246+
val: Operand<'tcx>,
247+
source_info: SourceInfo,
248+
) -> impl Iterator<Item = Statement<'tcx>> {
249+
let kind = AggregateKind::Adt(self.state_adt_ref, idx, self.state_substs, None, None);
250+
assert_eq!(self.state_adt_ref.variants[idx].fields.len(), 1);
251+
let ty = self
252+
.tcx
253+
.type_of(self.state_adt_ref.variants[idx].fields[0].did)
254+
.subst(self.tcx, self.state_substs);
255+
expand_aggregate(
256+
Place::return_place(),
257+
std::iter::once((val, ty)),
258+
kind,
259+
source_info,
260+
self.tcx,
261+
)
243262
}
244263

245264
// Create a Place referencing a generator struct field
@@ -325,13 +344,7 @@ impl MutVisitor<'tcx> for TransformVisitor<'tcx> {
325344
if let Some((state_idx, resume, v, drop)) = ret_val {
326345
let source_info = data.terminator().source_info;
327346
// We must assign the value first in case it gets declared dead below
328-
data.statements.push(Statement {
329-
source_info,
330-
kind: StatementKind::Assign(box (
331-
Place::return_place(),
332-
self.make_state(state_idx, v),
333-
)),
334-
});
347+
data.statements.extend(self.make_state(state_idx, v, source_info));
335348
let state = if let Some((resume, resume_arg)) = resume {
336349
// Yield
337350
let state = 3 + self.suspension_points.len();

0 commit comments

Comments
 (0)