@@ -465,30 +465,30 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
465465 let ocx = ObligationCtxt :: new ( infcx) ;
466466
467467 let norm_cause = ObligationCause :: misc ( return_span, impl_m_hir_id) ;
468- let impl_return_ty = ocx. normalize (
468+ let impl_sig = ocx. normalize (
469469 norm_cause. clone ( ) ,
470470 param_env,
471- infcx
472- . replace_bound_vars_with_fresh_vars (
473- return_span,
474- infer:: HigherRankedType ,
475- tcx. fn_sig ( impl_m. def_id ) ,
476- )
477- . output ( ) ,
471+ infcx. replace_bound_vars_with_fresh_vars (
472+ return_span,
473+ infer:: HigherRankedType ,
474+ tcx. fn_sig ( impl_m. def_id ) ,
475+ ) ,
478476 ) ;
477+ let impl_return_ty = impl_sig. output ( ) ;
479478
480479 let mut collector = ImplTraitInTraitCollector :: new ( & ocx, return_span, param_env, impl_m_hir_id) ;
481- let unnormalized_trait_return_ty = tcx
480+ let unnormalized_trait_sig = tcx
482481 . liberate_late_bound_regions (
483482 impl_m. def_id ,
484483 tcx. bound_fn_sig ( trait_m. def_id ) . subst ( tcx, trait_to_placeholder_substs) ,
485484 )
486- . output ( )
487485 . fold_with ( & mut collector) ;
488- let trait_return_ty =
489- ocx . normalize ( norm_cause . clone ( ) , param_env , unnormalized_trait_return_ty ) ;
486+ let trait_sig = ocx . normalize ( norm_cause . clone ( ) , param_env , unnormalized_trait_sig ) ;
487+ let trait_return_ty = trait_sig . output ( ) ;
490488
491- let wf_tys = FxHashSet :: from_iter ( [ unnormalized_trait_return_ty, trait_return_ty] ) ;
489+ let wf_tys = FxHashSet :: from_iter (
490+ unnormalized_trait_sig. inputs_and_output . iter ( ) . chain ( trait_sig. inputs_and_output . iter ( ) ) ,
491+ ) ;
492492
493493 match infcx. at ( & cause, param_env) . eq ( trait_return_ty, impl_return_ty) {
494494 Ok ( infer:: InferOk { value : ( ) , obligations } ) => {
@@ -521,6 +521,26 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
521521 }
522522 }
523523
524+ // Unify the whole function signature. We need to do this to fully infer
525+ // the lifetimes of the return type, but do this after unifying just the
526+ // return types, since we want to avoid duplicating errors from
527+ // `compare_predicate_entailment`.
528+ match infcx
529+ . at ( & cause, param_env)
530+ . eq ( tcx. mk_fn_ptr ( ty:: Binder :: dummy ( trait_sig) ) , tcx. mk_fn_ptr ( ty:: Binder :: dummy ( impl_sig) ) )
531+ {
532+ Ok ( infer:: InferOk { value : ( ) , obligations } ) => {
533+ ocx. register_obligations ( obligations) ;
534+ }
535+ Err ( terr) => {
536+ let guar = tcx. sess . delay_span_bug (
537+ return_span,
538+ format ! ( "could not unify `{trait_sig}` and `{impl_sig}`: {terr:?}" ) ,
539+ ) ;
540+ return Err ( guar) ;
541+ }
542+ }
543+
524544 // Check that all obligations are satisfied by the implementation's
525545 // RPITs.
526546 let errors = ocx. select_all_or_error ( ) ;
@@ -551,15 +571,40 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
551571 let id_substs = InternalSubsts :: identity_for_item ( tcx, def_id) ;
552572 debug ! ( ?id_substs, ?substs) ;
553573 let map: FxHashMap < ty:: GenericArg < ' tcx > , ty:: GenericArg < ' tcx > > =
554- substs . iter ( ) . enumerate ( ) . map ( | ( index , arg ) | ( arg , id_substs[ index ] ) ) . collect ( ) ;
574+ std :: iter:: zip ( substs , id_substs) . collect ( ) ;
555575 debug ! ( ?map) ;
556576
577+ // NOTE(compiler-errors): RPITITs, like all other RPITs, have early-bound
578+ // region substs that are synthesized during AST lowering. These are substs
579+ // that are appended to the parent substs (trait and trait method). However,
580+ // we're trying to infer the unsubstituted type value of the RPITIT inside
581+ // the *impl*, so we can later use the impl's method substs to normalize
582+ // an RPITIT to a concrete type (`confirm_impl_trait_in_trait_candidate`).
583+ //
584+ // Due to the design of RPITITs, during AST lowering, we have no idea that
585+ // an impl method corresponds to a trait method with RPITITs in it. Therefore,
586+ // we don't have a list of early-bound region substs for the RPITIT in the impl.
587+ // Since early region parameters are index-based, we can't just rebase these
588+ // (trait method) early-bound region substs onto the impl, and there's no
589+ // guarantee that the indices from the trait substs and impl substs line up.
590+ // So to fix this, we subtract the number of trait substs and add the number of
591+ // impl substs to *renumber* these early-bound regions to their corresponding
592+ // indices in the impl's substitutions list.
593+ //
594+ // Also, we only need to account for a difference in trait and impl substs,
595+ // since we previously enforce that the trait method and impl method have the
596+ // same generics.
597+ let num_trait_substs = trait_to_impl_substs. len ( ) ;
598+ let num_impl_substs = tcx. generics_of ( impl_m. container_id ( tcx) ) . params . len ( ) ;
557599 let ty = tcx. fold_regions ( ty, |region, _| {
558- if let ty:: ReFree ( _) = region. kind ( ) {
559- map[ & region. into ( ) ] . expect_region ( )
560- } else {
561- region
562- }
600+ let ty:: ReFree ( _) = region. kind ( ) else { return region; } ;
601+ let ty:: ReEarlyBound ( e) = map[ & region. into ( ) ] . expect_region ( ) . kind ( )
602+ else { bug ! ( "expected ReFree to map to ReEarlyBound" ) ; } ;
603+ tcx. mk_region ( ty:: ReEarlyBound ( ty:: EarlyBoundRegion {
604+ def_id : e. def_id ,
605+ name : e. name ,
606+ index : ( e. index as usize - num_trait_substs + num_impl_substs) as u32 ,
607+ } ) )
563608 } ) ;
564609 debug ! ( %ty) ;
565610 collected_tys. insert ( def_id, ty) ;
0 commit comments