Skip to content

Commit 5babe9f

Browse files
Instantiate closure-like bounds with placeholders to deal with binders correctly
1 parent fd27e87 commit 5babe9f

File tree

3 files changed

+139
-82
lines changed

3 files changed

+139
-82
lines changed

compiler/rustc_trait_selection/src/traits/select/confirmation.rs

+82-68
Original file line numberDiff line numberDiff line change
@@ -676,17 +676,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
676676
fn_host_effect: ty::Const<'tcx>,
677677
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
678678
debug!(?obligation, "confirm_fn_pointer_candidate");
679+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
680+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
679681

680682
let tcx = self.tcx();
681-
682-
let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
683-
// FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
684-
// but we do not currently. Luckily, such a bound is not
685-
// particularly useful, so we don't expect users to write
686-
// them often.
687-
return Err(SelectionError::Unimplemented);
688-
};
689-
690683
let sig = self_ty.fn_sig(tcx);
691684
let trait_ref = closure_trait_ref_and_return_type(
692685
tcx,
@@ -698,7 +691,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
698691
)
699692
.map_bound(|(trait_ref, _)| trait_ref);
700693

701-
let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
694+
let mut nested = self.equate_trait_refs(
695+
&obligation.cause,
696+
obligation.param_env,
697+
placeholder_predicate.trait_ref,
698+
trait_ref,
699+
)?;
702700
let cause = obligation.derived_cause(BuiltinDerivedObligation);
703701

704702
// Confirm the `type Output: Sized;` bound that is present on `FnOnce`
@@ -746,10 +744,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
746744
&mut self,
747745
obligation: &PolyTraitObligation<'tcx>,
748746
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
749-
// Okay to skip binder because the args on coroutine types never
750-
// touch bound regions, they just capture the in-scope
751-
// type/region parameters.
752-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
747+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
748+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
753749
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
754750
bug!("closure candidate for non-closure {:?}", obligation);
755751
};
@@ -758,23 +754,19 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
758754

759755
let coroutine_sig = args.as_coroutine().sig();
760756

761-
// NOTE: The self-type is a coroutine type and hence is
762-
// in fact unparameterized (or at least does not reference any
763-
// regions bound in the obligation).
764-
let self_ty = obligation
765-
.predicate
766-
.self_ty()
767-
.no_bound_vars()
768-
.expect("unboxed closure type should not capture bound vars from the predicate");
769-
770757
let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
771758
self.tcx(),
772759
obligation.predicate.def_id(),
773760
self_ty,
774761
coroutine_sig,
775762
);
776763

777-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
764+
let nested = self.equate_trait_refs(
765+
&obligation.cause,
766+
obligation.param_env,
767+
placeholder_predicate.trait_ref,
768+
ty::Binder::dummy(trait_ref),
769+
)?;
778770
debug!(?trait_ref, ?nested, "coroutine candidate obligations");
779771

780772
Ok(nested)
@@ -784,10 +776,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
784776
&mut self,
785777
obligation: &PolyTraitObligation<'tcx>,
786778
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
787-
// Okay to skip binder because the args on coroutine types never
788-
// touch bound regions, they just capture the in-scope
789-
// type/region parameters.
790-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
779+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
780+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
791781
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
792782
bug!("closure candidate for non-closure {:?}", obligation);
793783
};
@@ -799,11 +789,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
799789
let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
800790
self.tcx(),
801791
obligation.predicate.def_id(),
802-
obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
792+
self_ty,
803793
coroutine_sig,
804794
);
805795

806-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
796+
let nested = self.equate_trait_refs(
797+
&obligation.cause,
798+
obligation.param_env,
799+
placeholder_predicate.trait_ref,
800+
ty::Binder::dummy(trait_ref),
801+
)?;
807802
debug!(?trait_ref, ?nested, "future candidate obligations");
808803

809804
Ok(nested)
@@ -813,10 +808,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
813808
&mut self,
814809
obligation: &PolyTraitObligation<'tcx>,
815810
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
816-
// Okay to skip binder because the args on coroutine types never
817-
// touch bound regions, they just capture the in-scope
818-
// type/region parameters.
819-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
811+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
812+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
820813
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
821814
bug!("closure candidate for non-closure {:?}", obligation);
822815
};
@@ -828,11 +821,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
828821
let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
829822
self.tcx(),
830823
obligation.predicate.def_id(),
831-
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
824+
self_ty,
832825
gen_sig,
833826
);
834827

835-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
828+
let nested = self.equate_trait_refs(
829+
&obligation.cause,
830+
obligation.param_env,
831+
placeholder_predicate.trait_ref,
832+
ty::Binder::dummy(trait_ref),
833+
)?;
836834
debug!(?trait_ref, ?nested, "iterator candidate obligations");
837835

838836
Ok(nested)
@@ -842,10 +840,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
842840
&mut self,
843841
obligation: &PolyTraitObligation<'tcx>,
844842
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
845-
// Okay to skip binder because the args on coroutine types never
846-
// touch bound regions, they just capture the in-scope
847-
// type/region parameters.
848-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
843+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
844+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
849845
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
850846
bug!("closure candidate for non-closure {:?}", obligation);
851847
};
@@ -857,11 +853,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
857853
let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
858854
self.tcx(),
859855
obligation.predicate.def_id(),
860-
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
856+
self_ty,
861857
gen_sig,
862858
);
863859

864-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
860+
let nested = self.equate_trait_refs(
861+
&obligation.cause,
862+
obligation.param_env,
863+
placeholder_predicate.trait_ref,
864+
ty::Binder::dummy(trait_ref),
865+
)?;
865866
debug!(?trait_ref, ?nested, "iterator candidate obligations");
866867

867868
Ok(nested)
@@ -872,14 +873,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
872873
&mut self,
873874
obligation: &PolyTraitObligation<'tcx>,
874875
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
875-
// Okay to skip binder because the args on closure types never
876-
// touch bound regions, they just capture the in-scope
877-
// type/region parameters.
878-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
876+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
877+
let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
878+
879879
let trait_ref = match *self_ty.kind() {
880-
ty::Closure(_, args) => {
881-
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
882-
}
880+
ty::Closure(..) => self.closure_trait_ref_unnormalized(
881+
self_ty,
882+
obligation.predicate.def_id(),
883+
self.tcx().consts.true_,
884+
),
883885
ty::CoroutineClosure(_, args) => {
884886
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
885887
ty::TraitRef::new(
@@ -894,16 +896,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
894896
}
895897
};
896898

897-
self.confirm_poly_trait_refs(obligation, trait_ref)
899+
self.equate_trait_refs(
900+
&obligation.cause,
901+
obligation.param_env,
902+
placeholder_predicate.trait_ref,
903+
trait_ref,
904+
)
898905
}
899906

900907
#[instrument(skip(self), level = "debug")]
901908
fn confirm_async_closure_candidate(
902909
&mut self,
903910
obligation: &PolyTraitObligation<'tcx>,
904911
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
912+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
913+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
914+
905915
let tcx = self.tcx();
906-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
907916

908917
let mut nested = vec![];
909918
let (trait_ref, kind_ty) = match *self_ty.kind() {
@@ -970,7 +979,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
970979
_ => bug!("expected callable type for AsyncFn candidate"),
971980
};
972981

973-
nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
982+
nested.extend(self.equate_trait_refs(
983+
&obligation.cause,
984+
obligation.param_env,
985+
placeholder_predicate.trait_ref,
986+
trait_ref,
987+
)?);
974988

975989
let goal_kind =
976990
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
@@ -1023,42 +1037,42 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
10231037
/// selection of the impl. Therefore, if there is a mismatch, we
10241038
/// report an error to the user.
10251039
#[instrument(skip(self), level = "trace")]
1026-
fn confirm_poly_trait_refs(
1040+
fn equate_trait_refs(
10271041
&mut self,
1028-
obligation: &PolyTraitObligation<'tcx>,
1029-
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
1042+
cause: &ObligationCause<'tcx>,
1043+
param_env: ty::ParamEnv<'tcx>,
1044+
obligation_trait_ref: ty::TraitRef<'tcx>,
1045+
found_trait_ref: ty::PolyTraitRef<'tcx>,
10301046
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
1031-
let obligation_trait_ref =
1032-
self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
1033-
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
1034-
obligation.cause.span,
1047+
let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
1048+
cause.span,
10351049
HigherRankedType,
1036-
self_ty_trait_ref,
1050+
found_trait_ref,
10371051
);
10381052
// Normalize the obligation and expected trait refs together, because why not
1039-
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
1053+
let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
10401054
ensure_sufficient_stack(|| {
10411055
normalize_with_depth(
10421056
self,
1043-
obligation.param_env,
1044-
obligation.cause.clone(),
1045-
obligation.recursion_depth + 1,
1046-
(obligation_trait_ref, self_ty_trait_ref),
1057+
param_env,
1058+
cause.clone(),
1059+
0,
1060+
(obligation_trait_ref, found_trait_ref),
10471061
)
10481062
});
10491063

10501064
// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
10511065
self.infcx
1052-
.at(&obligation.cause, obligation.param_env)
1053-
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
1066+
.at(&cause, param_env)
1067+
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
10541068
.map(|InferOk { mut obligations, .. }| {
10551069
obligations.extend(nested);
10561070
obligations
10571071
})
10581072
.map_err(|terr| {
10591073
SignatureMismatch(Box::new(SignatureMismatchData {
10601074
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
1061-
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
1075+
found_trait_ref: ty::Binder::dummy(found_trait_ref),
10621076
terr,
10631077
}))
10641078
})

compiler/rustc_trait_selection/src/traits/select/mod.rs

+6-14
Original file line numberDiff line numberDiff line change
@@ -2667,26 +2667,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
26672667
#[instrument(skip(self), level = "debug")]
26682668
fn closure_trait_ref_unnormalized(
26692669
&mut self,
2670-
obligation: &PolyTraitObligation<'tcx>,
2671-
args: GenericArgsRef<'tcx>,
2670+
self_ty: Ty<'tcx>,
2671+
fn_trait_def_id: DefId,
26722672
fn_host_effect: ty::Const<'tcx>,
26732673
) -> ty::PolyTraitRef<'tcx> {
2674+
let ty::Closure(_, args) = *self_ty.kind() else {
2675+
bug!("expected closure, found {self_ty}");
2676+
};
26742677
let closure_sig = args.as_closure().sig();
26752678

2676-
debug!(?closure_sig);
2677-
2678-
// NOTE: The self-type is an unboxed closure type and hence is
2679-
// in fact unparameterized (or at least does not reference any
2680-
// regions bound in the obligation).
2681-
let self_ty = obligation
2682-
.predicate
2683-
.self_ty()
2684-
.no_bound_vars()
2685-
.expect("unboxed closure type should not capture bound vars from the predicate");
2686-
26872679
closure_trait_ref_and_return_type(
26882680
self.tcx(),
2689-
obligation.predicate.def_id(),
2681+
fn_trait_def_id,
26902682
self_ty,
26912683
closure_sig,
26922684
util::TupleArgumentsFlag::No,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//@ edition:2024
2+
//@ compile-flags: -Zunstable-options
3+
//@ revisions: current next
4+
//@[next] compile-flags: -Znext-solver
5+
6+
#![feature(unboxed_closures, gen_blocks)]
7+
8+
trait Dispatch {
9+
fn dispatch(self);
10+
}
11+
12+
struct Fut<T>(T);
13+
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Fut<T>
14+
where
15+
for<'a> <T as FnOnce<(&'a (),)>>::Output: Future,
16+
{
17+
fn dispatch(self) {
18+
(self.0)(&());
19+
}
20+
}
21+
22+
struct Gen<T>(T);
23+
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Gen<T>
24+
where
25+
for<'a> <T as FnOnce<(&'a (),)>>::Output: Iterator,
26+
{
27+
fn dispatch(self) {
28+
(self.0)(&());
29+
}
30+
}
31+
32+
struct Closure<T>(T);
33+
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Closure<T>
34+
where
35+
for<'a> <T as FnOnce<(&'a (),)>>::Output: Fn<(&'a (),)>,
36+
{
37+
fn dispatch(self) {
38+
(self.0)(&())(&());
39+
}
40+
}
41+
42+
fn main() {
43+
async fn foo(_: &()) {}
44+
Fut(foo).dispatch();
45+
46+
gen fn bar(_: &()) {}
47+
Gen(bar).dispatch();
48+
49+
fn uwu<'a>(x: &'a ()) -> impl Fn(&'a ()) { |_| {} }
50+
Closure(uwu).dispatch();
51+
}

0 commit comments

Comments
 (0)