Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly stall coroutine witnesses in new solver #138845

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,15 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// Resume type defaults to `()` if the coroutine has no argument.
let resume_ty = liberated_sig.inputs().get(0).copied().unwrap_or(tcx.types.unit);

let interior = self.next_ty_var(expr_span);
// In the new solver, we can just instantiate this eagerly
// with the witness. This will ensure that goals that don't need
// to stall on interior types will get processed eagerly.
let interior = if self.next_trait_solver() {
Ty::new_coroutine_witness(tcx, expr_def_id.to_def_id(), parent_args)
} else {
self.next_ty_var(expr_span)
};

self.deferred_coroutine_interiors.borrow_mut().push((expr_def_id, interior));

// Coroutines that come from coroutine closures have not yet determined
Expand Down
53 changes: 29 additions & 24 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,34 +635,39 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

let mut obligations = vec![];

for &(coroutine_def_id, interior) in coroutines.iter() {
debug!(?coroutine_def_id);
if !self.next_trait_solver() {
for &(coroutine_def_id, interior) in coroutines.iter() {
debug!(?coroutine_def_id);

// Create the `CoroutineWitness` type that we will unify with `interior`.
let args = ty::GenericArgs::identity_for_item(
self.tcx,
self.tcx.typeck_root_def_id(coroutine_def_id.to_def_id()),
);
let witness =
Ty::new_coroutine_witness(self.tcx, coroutine_def_id.to_def_id(), args);

// Create the `CoroutineWitness` type that we will unify with `interior`.
let args = ty::GenericArgs::identity_for_item(
self.tcx,
self.tcx.typeck_root_def_id(coroutine_def_id.to_def_id()),
);
let witness = Ty::new_coroutine_witness(self.tcx, coroutine_def_id.to_def_id(), args);

// Unify `interior` with `witness` and collect all the resulting obligations.
let span = self.tcx.hir_body_owned_by(coroutine_def_id).value.span;
let ty::Infer(ty::InferTy::TyVar(_)) = interior.kind() else {
span_bug!(span, "coroutine interior witness not infer: {:?}", interior.kind())
};
let ok = self
.at(&self.misc(span), self.param_env)
// Will never define opaque types, as all we do is instantiate a type variable.
.eq(DefineOpaqueTypes::Yes, interior, witness)
.expect("Failed to unify coroutine interior type");

obligations.extend(ok.obligations);
// Unify `interior` with `witness` and collect all the resulting obligations.
let span = self.tcx.hir_body_owned_by(coroutine_def_id).value.span;
let ty::Infer(ty::InferTy::TyVar(_)) = interior.kind() else {
span_bug!(span, "coroutine interior witness not infer: {:?}", interior.kind())
};
let ok = self
.at(&self.misc(span), self.param_env)
// Will never define opaque types, as all we do is instantiate a type variable.
.eq(DefineOpaqueTypes::Yes, interior, witness)
.expect("Failed to unify coroutine interior type");

obligations.extend(ok.obligations);
}
}

// FIXME: Use a real visitor for unstalled obligations in the new solver.
if !coroutines.is_empty() {
obligations
.extend(self.fulfillment_cx.borrow_mut().drain_unstalled_obligations(&self.infcx));
obligations.extend(
self.fulfillment_cx
.borrow_mut()
.drain_stalled_obligations_for_coroutines(&self.infcx),
);
}

self.typeck_results
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/typeck_root_ctxt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<'tcx> TypeckRootCtxt<'tcx> {
let hir_owner = tcx.local_def_id_to_hir_id(def_id).owner;

let infcx =
tcx.infer_ctxt().ignoring_regions().build(TypingMode::analysis_in_body(tcx, def_id));
tcx.infer_ctxt().ignoring_regions().build(TypingMode::typeck_for_body(tcx, def_id));
let typeck_results = RefCell::new(ty::TypeckResults::new(hir_owner));

TypeckRootCtxt {
Expand Down
53 changes: 46 additions & 7 deletions compiler/rustc_hir_typeck/src/writeback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use rustc_data_structures::unord::ExtendUnord;
use rustc_errors::ErrorGuaranteed;
use rustc_hir::intravisit::{self, InferKind, Visitor};
use rustc_hir::{self as hir, AmbigArg, HirId};
use rustc_infer::traits::solve::Goal;
use rustc_middle::span_bug;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::adjustment::{Adjust, Adjustment, PointerCoercion};
Expand Down Expand Up @@ -731,7 +732,32 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> {
T: TypeFoldable<TyCtxt<'tcx>>,
{
let value = self.fcx.resolve_vars_if_possible(value);
let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body, true));

let mut goals = vec![];
let value =
value.fold_with(&mut Resolver::new(self.fcx, span, self.body, true, &mut goals));

// Ensure that we resolve goals we get from normalizing coroutine interiors,
// but we shouldn't expect those goals to need normalizing (or else we'd get
// into a somewhat awkward fixpoint situation, and we don't need it anyways).
let mut unexpected_goals = vec![];
self.typeck_results.coroutine_stalled_predicates.extend(
goals
.into_iter()
.map(|pred| {
self.fcx.resolve_vars_if_possible(pred).fold_with(&mut Resolver::new(
self.fcx,
span,
self.body,
false,
&mut unexpected_goals,
))
})
// FIXME: throwing away the param-env :(
.map(|goal| (goal.predicate, self.fcx.misc(span.to_span(self.fcx.tcx)))),
);
assert_eq!(unexpected_goals, vec![]);

assert!(!value.has_infer());

// We may have introduced e.g. `ty::Error`, if inference failed, make sure
Expand All @@ -749,7 +775,12 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> {
T: TypeFoldable<TyCtxt<'tcx>>,
{
let value = self.fcx.resolve_vars_if_possible(value);
let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body, false));

let mut goals = vec![];
let value =
value.fold_with(&mut Resolver::new(self.fcx, span, self.body, false, &mut goals));
assert_eq!(goals, vec![]);

assert!(!value.has_infer());

// We may have introduced e.g. `ty::Error`, if inference failed, make sure
Expand Down Expand Up @@ -786,6 +817,7 @@ struct Resolver<'cx, 'tcx> {
/// Whether we should normalize using the new solver, disabled
/// both when using the old solver and when resolving predicates.
should_normalize: bool,
nested_goals: &'cx mut Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
}

impl<'cx, 'tcx> Resolver<'cx, 'tcx> {
Expand All @@ -794,8 +826,9 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> {
span: &'cx dyn Locatable,
body: &'tcx hir::Body<'tcx>,
should_normalize: bool,
nested_goals: &'cx mut Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
) -> Resolver<'cx, 'tcx> {
Resolver { fcx, span, body, should_normalize }
Resolver { fcx, span, body, nested_goals, should_normalize }
}

fn report_error(&self, p: impl Into<ty::GenericArg<'tcx>>) -> ErrorGuaranteed {
Expand Down Expand Up @@ -832,12 +865,18 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> {
let cause = ObligationCause::misc(self.span.to_span(tcx), body_id);
let at = self.fcx.at(&cause, self.fcx.param_env);
let universes = vec![None; outer_exclusive_binder(value).as_usize()];
solve::deeply_normalize_with_skipped_universes(at, value, universes).unwrap_or_else(
|errors| {
match solve::deeply_normalize_with_skipped_universes_and_ambiguous_goals(
at, value, universes,
) {
Ok((value, goals)) => {
self.nested_goals.extend(goals);
value
}
Err(errors) => {
let guar = self.fcx.err_ctxt().report_fulfillment_errors(errors);
new_err(tcx, guar)
},
)
}
}
} else {
value
};
Expand Down
9 changes: 5 additions & 4 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,10 @@ impl<'tcx> InferCtxt<'tcx> {
pub fn can_define_opaque_ty(&self, id: impl Into<DefId>) -> bool {
debug_assert!(!self.next_trait_solver());
match self.typing_mode() {
TypingMode::Analysis { defining_opaque_types } => {
id.into().as_local().is_some_and(|def_id| defining_opaque_types.contains(&def_id))
}
TypingMode::Analysis { defining_opaque_types_and_generators } => id
.into()
.as_local()
.is_some_and(|def_id| defining_opaque_types_and_generators.contains(&def_id)),
// FIXME(#132279): This function is quite weird in post-analysis
// and post-borrowck analysis mode. We may need to modify its uses
// to support PostBorrowckAnalysis in the old solver as well.
Expand Down Expand Up @@ -1260,7 +1261,7 @@ impl<'tcx> InferCtxt<'tcx> {
// to handle them without proper canonicalization. This means we may cause cycle
// errors and fail to reveal opaques while inside of bodies. We should rename this
// function and require explicit comments on all use-sites in the future.
ty::TypingMode::Analysis { defining_opaque_types: _ } => {
ty::TypingMode::Analysis { defining_opaque_types_and_generators: _ } => {
TypingMode::non_body_analysis()
}
mode @ (ty::TypingMode::Coherence
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/traits/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub trait TraitEngine<'tcx, E: 'tcx>: 'tcx {
/// Among all pending obligations, collect those are stalled on a inference variable which has
/// changed since the last call to `select_where_possible`. Those obligations are marked as
/// successful and returned.
fn drain_unstalled_obligations(
fn drain_stalled_obligations_for_coroutines(
&mut self,
infcx: &InferCtxt<'tcx>,
) -> PredicateObligations<'tcx>;
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,15 @@ rustc_queries! {
}
}

query stalled_generators_within(
key: LocalDefId
) -> &'tcx ty::List<LocalDefId> {
desc {
|tcx| "computing the opaque types defined by `{}`",
tcx.def_path_str(key.to_def_id())
}
}

/// Returns the explicitly user-written *bounds* on the associated or opaque type given by `DefId`
/// that must be proven true at definition site (and which can be assumed at usage sites).
///
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ macro_rules! define_callbacks {

pub type Storage<'tcx> = <$($K)* as keys::Key>::Cache<Erase<$V>>;

// Ensure that keys grow no larger than 80 bytes by accident.
// Ensure that keys grow no larger than 88 bytes by accident.
// Increase this limit if necessary, but do try to keep the size low if possible
#[cfg(target_pointer_width = "64")]
const _: () = {
Expand Down
23 changes: 19 additions & 4 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
) -> Self::PredefinedOpaques {
self.mk_predefined_opaques_in_body(data)
}
type DefiningOpaqueTypes = &'tcx ty::List<LocalDefId>;
type LocalDefIds = &'tcx ty::List<LocalDefId>;
type CanonicalVars = CanonicalVarInfos<'tcx>;
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
self.mk_canonical_var_infos(infos)
Expand Down Expand Up @@ -663,9 +663,24 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
self.anonymize_bound_vars(binder)
}

fn opaque_types_defined_by(self, defining_anchor: LocalDefId) -> Self::DefiningOpaqueTypes {
fn opaque_types_defined_by(self, defining_anchor: LocalDefId) -> Self::LocalDefIds {
self.opaque_types_defined_by(defining_anchor)
}

fn opaque_types_and_generators_defined_by(
self,
defining_anchor: Self::LocalDefId,
) -> Self::LocalDefIds {
if self.next_trait_solver_globally() {
self.mk_local_def_ids_from_iter(
self.opaque_types_defined_by(defining_anchor)
.iter()
.chain(self.stalled_generators_within(defining_anchor)),
)
} else {
self.opaque_types_defined_by(defining_anchor)
}
}
}

macro_rules! bidirectional_lang_item_map {
Expand Down Expand Up @@ -2871,11 +2886,11 @@ impl<'tcx> TyCtxt<'tcx> {
self.interners.intern_clauses(clauses)
}

pub fn mk_local_def_ids(self, clauses: &[LocalDefId]) -> &'tcx List<LocalDefId> {
pub fn mk_local_def_ids(self, def_ids: &[LocalDefId]) -> &'tcx List<LocalDefId> {
// FIXME consider asking the input slice to be sorted to avoid
// re-interning permutations, in which case that would be asserted
// here.
self.intern_local_def_ids(clauses)
self.intern_local_def_ids(def_ids)
}

pub fn mk_local_def_ids_from_iter<I, T>(self, iter: I) -> T::Output
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_next_trait_solver/src/solve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ where
TypingMode::Coherence | TypingMode::PostAnalysis => false,
// During analysis, opaques are rigid unless they may be defined by
// the current body.
TypingMode::Analysis { defining_opaque_types: non_rigid_opaques }
TypingMode::Analysis { defining_opaque_types_and_generators: non_rigid_opaques }
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: non_rigid_opaques } => {
!def_id.as_local().is_some_and(|def_id| non_rigid_opaques.contains(&def_id))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ where
);
self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS)
}
TypingMode::Analysis { defining_opaque_types } => {
TypingMode::Analysis { defining_opaque_types_and_generators } => {
let Some(def_id) = opaque_ty
.def_id
.as_local()
.filter(|&def_id| defining_opaque_types.contains(&def_id))
.filter(|&def_id| defining_opaque_types_and_generators.contains(&def_id))
else {
self.structurally_instantiate_normalizes_to_term(goal, goal.predicate.alias);
return self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes);
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,21 @@ where
debug_assert!(ecx.opaque_type_is_rigid(opaque_ty.def_id));
}

if let ty::CoroutineWitness(def_id, _) = goal.predicate.self_ty().kind() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should pull this out into a helper, b/c I think I need to also apply this hack to copy/clone. I think those are it tho.

match ecx.typing_mode() {
TypingMode::Analysis { defining_opaque_types_and_generators } => {
if def_id.as_local().is_some_and(|def_id| {
defining_opaque_types_and_generators.contains(&def_id)
}) {
return ecx.forced_ambiguity(MaybeCause::Ambiguity);
}
}
TypingMode::Coherence
| TypingMode::PostAnalysis
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ } => {}
}
}

ecx.probe_and_evaluate_goal_for_constituent_tys(
CandidateSource::BuiltinImpl(BuiltinImplSource::Misc),
goal,
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_trait_selection/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@ mod select;
pub(crate) use delegate::SolverDelegate;
pub use fulfill::{FulfillmentCtxt, NextSolverError};
pub(crate) use normalize::deeply_normalize_for_diagnostics;
pub use normalize::{deeply_normalize, deeply_normalize_with_skipped_universes};
pub use normalize::{
deeply_normalize, deeply_normalize_with_skipped_universes,
deeply_normalize_with_skipped_universes_and_ambiguous_goals,
};
pub use select::InferCtxtSelectExt;
Loading
Loading