Skip to content

Commit 24482c8

Browse files
committedMar 26, 2024
Instantiate closure-like bounds with placeholders to deal with binders correctly
1 parent 519d892 commit 24482c8

File tree

3 files changed

+139
-82
lines changed

3 files changed

+139
-82
lines changed
 

‎compiler/rustc_trait_selection/src/traits/select/confirmation.rs

+82-68
Original file line numberDiff line numberDiff line change
@@ -678,17 +678,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
678678
fn_host_effect: ty::Const<'tcx>,
679679
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
680680
debug!(?obligation, "confirm_fn_pointer_candidate");
681+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
682+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
681683

682684
let tcx = self.tcx();
683-
684-
let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
685-
// FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
686-
// but we do not currently. Luckily, such a bound is not
687-
// particularly useful, so we don't expect users to write
688-
// them often.
689-
return Err(SelectionError::Unimplemented);
690-
};
691-
692685
let sig = self_ty.fn_sig(tcx);
693686
let trait_ref = closure_trait_ref_and_return_type(
694687
tcx,
@@ -700,7 +693,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
700693
)
701694
.map_bound(|(trait_ref, _)| trait_ref);
702695

703-
let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
696+
let mut nested = self.equate_trait_refs(
697+
&obligation.cause,
698+
obligation.param_env,
699+
placeholder_predicate.trait_ref,
700+
trait_ref,
701+
)?;
704702
let cause = obligation.derived_cause(BuiltinDerivedObligation);
705703

706704
// Confirm the `type Output: Sized;` bound that is present on `FnOnce`
@@ -748,10 +746,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
748746
&mut self,
749747
obligation: &PolyTraitObligation<'tcx>,
750748
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
751-
// Okay to skip binder because the args on coroutine types never
752-
// touch bound regions, they just capture the in-scope
753-
// type/region parameters.
754-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
749+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
750+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
755751
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
756752
bug!("closure candidate for non-closure {:?}", obligation);
757753
};
@@ -760,23 +756,19 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
760756

761757
let coroutine_sig = args.as_coroutine().sig();
762758

763-
// NOTE: The self-type is a coroutine type and hence is
764-
// in fact unparameterized (or at least does not reference any
765-
// regions bound in the obligation).
766-
let self_ty = obligation
767-
.predicate
768-
.self_ty()
769-
.no_bound_vars()
770-
.expect("unboxed closure type should not capture bound vars from the predicate");
771-
772759
let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
773760
self.tcx(),
774761
obligation.predicate.def_id(),
775762
self_ty,
776763
coroutine_sig,
777764
);
778765

779-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
766+
let nested = self.equate_trait_refs(
767+
&obligation.cause,
768+
obligation.param_env,
769+
placeholder_predicate.trait_ref,
770+
ty::Binder::dummy(trait_ref),
771+
)?;
780772
debug!(?trait_ref, ?nested, "coroutine candidate obligations");
781773

782774
Ok(nested)
@@ -786,10 +778,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
786778
&mut self,
787779
obligation: &PolyTraitObligation<'tcx>,
788780
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
789-
// Okay to skip binder because the args on coroutine types never
790-
// touch bound regions, they just capture the in-scope
791-
// type/region parameters.
792-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
781+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
782+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
793783
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
794784
bug!("closure candidate for non-closure {:?}", obligation);
795785
};
@@ -801,11 +791,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
801791
let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
802792
self.tcx(),
803793
obligation.predicate.def_id(),
804-
obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
794+
self_ty,
805795
coroutine_sig,
806796
);
807797

808-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
798+
let nested = self.equate_trait_refs(
799+
&obligation.cause,
800+
obligation.param_env,
801+
placeholder_predicate.trait_ref,
802+
ty::Binder::dummy(trait_ref),
803+
)?;
809804
debug!(?trait_ref, ?nested, "future candidate obligations");
810805

811806
Ok(nested)
@@ -815,10 +810,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
815810
&mut self,
816811
obligation: &PolyTraitObligation<'tcx>,
817812
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
818-
// Okay to skip binder because the args on coroutine types never
819-
// touch bound regions, they just capture the in-scope
820-
// type/region parameters.
821-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
813+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
814+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
822815
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
823816
bug!("closure candidate for non-closure {:?}", obligation);
824817
};
@@ -830,11 +823,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
830823
let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
831824
self.tcx(),
832825
obligation.predicate.def_id(),
833-
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
826+
self_ty,
834827
gen_sig,
835828
);
836829

837-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
830+
let nested = self.equate_trait_refs(
831+
&obligation.cause,
832+
obligation.param_env,
833+
placeholder_predicate.trait_ref,
834+
ty::Binder::dummy(trait_ref),
835+
)?;
838836
debug!(?trait_ref, ?nested, "iterator candidate obligations");
839837

840838
Ok(nested)
@@ -844,10 +842,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
844842
&mut self,
845843
obligation: &PolyTraitObligation<'tcx>,
846844
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
847-
// Okay to skip binder because the args on coroutine types never
848-
// touch bound regions, they just capture the in-scope
849-
// type/region parameters.
850-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
845+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
846+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
851847
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
852848
bug!("closure candidate for non-closure {:?}", obligation);
853849
};
@@ -859,11 +855,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
859855
let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
860856
self.tcx(),
861857
obligation.predicate.def_id(),
862-
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
858+
self_ty,
863859
gen_sig,
864860
);
865861

866-
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
862+
let nested = self.equate_trait_refs(
863+
&obligation.cause,
864+
obligation.param_env,
865+
placeholder_predicate.trait_ref,
866+
ty::Binder::dummy(trait_ref),
867+
)?;
867868
debug!(?trait_ref, ?nested, "iterator candidate obligations");
868869

869870
Ok(nested)
@@ -874,14 +875,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
874875
&mut self,
875876
obligation: &PolyTraitObligation<'tcx>,
876877
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
877-
// Okay to skip binder because the args on closure types never
878-
// touch bound regions, they just capture the in-scope
879-
// type/region parameters.
880-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
878+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
879+
let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
880+
881881
let trait_ref = match *self_ty.kind() {
882-
ty::Closure(_, args) => {
883-
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
884-
}
882+
ty::Closure(..) => self.closure_trait_ref_unnormalized(
883+
self_ty,
884+
obligation.predicate.def_id(),
885+
self.tcx().consts.true_,
886+
),
885887
ty::CoroutineClosure(_, args) => {
886888
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
887889
ty::TraitRef::new(
@@ -896,16 +898,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
896898
}
897899
};
898900

899-
self.confirm_poly_trait_refs(obligation, trait_ref)
901+
self.equate_trait_refs(
902+
&obligation.cause,
903+
obligation.param_env,
904+
placeholder_predicate.trait_ref,
905+
trait_ref,
906+
)
900907
}
901908

902909
#[instrument(skip(self), level = "debug")]
903910
fn confirm_async_closure_candidate(
904911
&mut self,
905912
obligation: &PolyTraitObligation<'tcx>,
906913
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
914+
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
915+
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
916+
907917
let tcx = self.tcx();
908-
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
909918

910919
let mut nested = vec![];
911920
let (trait_ref, kind_ty) = match *self_ty.kind() {
@@ -972,7 +981,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
972981
_ => bug!("expected callable type for AsyncFn candidate"),
973982
};
974983

975-
nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
984+
nested.extend(self.equate_trait_refs(
985+
&obligation.cause,
986+
obligation.param_env,
987+
placeholder_predicate.trait_ref,
988+
trait_ref,
989+
)?);
976990

977991
let goal_kind =
978992
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
@@ -1025,42 +1039,42 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
10251039
/// selection of the impl. Therefore, if there is a mismatch, we
10261040
/// report an error to the user.
10271041
#[instrument(skip(self), level = "trace")]
1028-
fn confirm_poly_trait_refs(
1042+
fn equate_trait_refs(
10291043
&mut self,
1030-
obligation: &PolyTraitObligation<'tcx>,
1031-
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
1044+
cause: &ObligationCause<'tcx>,
1045+
param_env: ty::ParamEnv<'tcx>,
1046+
obligation_trait_ref: ty::TraitRef<'tcx>,
1047+
found_trait_ref: ty::PolyTraitRef<'tcx>,
10321048
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
1033-
let obligation_trait_ref =
1034-
self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
1035-
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
1036-
obligation.cause.span,
1049+
let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
1050+
cause.span,
10371051
HigherRankedType,
1038-
self_ty_trait_ref,
1052+
found_trait_ref,
10391053
);
10401054
// Normalize the obligation and expected trait refs together, because why not
1041-
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
1055+
let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
10421056
ensure_sufficient_stack(|| {
10431057
normalize_with_depth(
10441058
self,
1045-
obligation.param_env,
1046-
obligation.cause.clone(),
1047-
obligation.recursion_depth + 1,
1048-
(obligation_trait_ref, self_ty_trait_ref),
1059+
param_env,
1060+
cause.clone(),
1061+
0,
1062+
(obligation_trait_ref, found_trait_ref),
10491063
)
10501064
});
10511065

10521066
// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
10531067
self.infcx
1054-
.at(&obligation.cause, obligation.param_env)
1055-
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
1068+
.at(&cause, param_env)
1069+
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
10561070
.map(|InferOk { mut obligations, .. }| {
10571071
obligations.extend(nested);
10581072
obligations
10591073
})
10601074
.map_err(|terr| {
10611075
SignatureMismatch(Box::new(SignatureMismatchData {
10621076
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
1063-
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
1077+
found_trait_ref: ty::Binder::dummy(found_trait_ref),
10641078
terr,
10651079
}))
10661080
})

‎compiler/rustc_trait_selection/src/traits/select/mod.rs

+6-14
Original file line numberDiff line numberDiff line change
@@ -2679,26 +2679,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
26792679
#[instrument(skip(self), level = "debug")]
26802680
fn closure_trait_ref_unnormalized(
26812681
&mut self,
2682-
obligation: &PolyTraitObligation<'tcx>,
2683-
args: GenericArgsRef<'tcx>,
2682+
self_ty: Ty<'tcx>,
2683+
fn_trait_def_id: DefId,
26842684
fn_host_effect: ty::Const<'tcx>,
26852685
) -> ty::PolyTraitRef<'tcx> {
2686+
let ty::Closure(_, args) = *self_ty.kind() else {
2687+
bug!("expected closure, found {self_ty}");
2688+
};
26862689
let closure_sig = args.as_closure().sig();
26872690

2688-
debug!(?closure_sig);
2689-
2690-
// NOTE: The self-type is an unboxed closure type and hence is
2691-
// in fact unparameterized (or at least does not reference any
2692-
// regions bound in the obligation).
2693-
let self_ty = obligation
2694-
.predicate
2695-
.self_ty()
2696-
.no_bound_vars()
2697-
.expect("unboxed closure type should not capture bound vars from the predicate");
2698-
26992691
closure_trait_ref_and_return_type(
27002692
self.tcx(),
2701-
obligation.predicate.def_id(),
2693+
fn_trait_def_id,
27022694
self_ty,
27032695
closure_sig,
27042696
util::TupleArgumentsFlag::No,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//@ edition:2024
2+
//@ compile-flags: -Zunstable-options
3+
//@ revisions: current next
4+
//@[next] compile-flags: -Znext-solver
5+
6+
#![feature(unboxed_closures, gen_blocks)]
7+
8+
trait Dispatch {
9+
fn dispatch(self);
10+
}
11+
12+
struct Fut<T>(T);
13+
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Fut<T>
14+
where
15+
for<'a> <T as FnOnce<(&'a (),)>>::Output: Future,
16+
{
17+
fn dispatch(self) {
18+
(self.0)(&());
19+
}
20+
}
21+
22+
struct Gen<T>(T);
23+
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Gen<T>
24+
where
25+
for<'a> <T as FnOnce<(&'a (),)>>::Output: Iterator,
26+
{
27+
fn dispatch(self) {
28+
(self.0)(&());
29+
}
30+
}
31+
32+
struct Closure<T>(T);
33+
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Closure<T>
34+
where
35+
for<'a> <T as FnOnce<(&'a (),)>>::Output: Fn<(&'a (),)>,
36+
{
37+
fn dispatch(self) {
38+
(self.0)(&())(&());
39+
}
40+
}
41+
42+
fn main() {
43+
async fn foo(_: &()) {}
44+
Fut(foo).dispatch();
45+
46+
gen fn bar(_: &()) {}
47+
Gen(bar).dispatch();
48+
49+
fn uwu<'a>(x: &'a ()) -> impl Fn(&'a ()) { |_| {} }
50+
Closure(uwu).dispatch();
51+
}

0 commit comments

Comments
 (0)