@@ -465,30 +465,30 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
465465 let ocx = ObligationCtxt :: new ( infcx) ;
466466
467467 let norm_cause = ObligationCause :: misc ( return_span, impl_m_hir_id) ;
468- let impl_return_ty = ocx. normalize (
468+ let impl_sig = ocx. normalize (
469469 norm_cause. clone ( ) ,
470470 param_env,
471- infcx
472- . replace_bound_vars_with_fresh_vars (
473- return_span,
474- infer:: HigherRankedType ,
475- tcx. fn_sig ( impl_m. def_id ) ,
476- )
477- . output ( ) ,
471+ infcx. replace_bound_vars_with_fresh_vars (
472+ return_span,
473+ infer:: HigherRankedType ,
474+ tcx. fn_sig ( impl_m. def_id ) ,
475+ ) ,
478476 ) ;
477+ let impl_return_ty = impl_sig. output ( ) ;
479478
480479 let mut collector = ImplTraitInTraitCollector :: new ( & ocx, return_span, param_env, impl_m_hir_id) ;
481- let unnormalized_trait_return_ty = tcx
480+ let unnormalized_trait_sig = tcx
482481 . liberate_late_bound_regions (
483482 impl_m. def_id ,
484483 tcx. bound_fn_sig ( trait_m. def_id ) . subst ( tcx, trait_to_placeholder_substs) ,
485484 )
486- . output ( )
487485 . fold_with ( & mut collector) ;
488- let trait_return_ty =
489- ocx . normalize ( norm_cause . clone ( ) , param_env , unnormalized_trait_return_ty ) ;
486+ let trait_sig = ocx . normalize ( norm_cause . clone ( ) , param_env , unnormalized_trait_sig ) ;
487+ let trait_return_ty = trait_sig . output ( ) ;
490488
491- let wf_tys = FxHashSet :: from_iter ( [ unnormalized_trait_return_ty, trait_return_ty] ) ;
489+ let wf_tys = FxHashSet :: from_iter (
490+ unnormalized_trait_sig. inputs_and_output . iter ( ) . chain ( trait_sig. inputs_and_output . iter ( ) ) ,
491+ ) ;
492492
493493 match infcx. at ( & cause, param_env) . eq ( trait_return_ty, impl_return_ty) {
494494 Ok ( infer:: InferOk { value : ( ) , obligations } ) => {
@@ -521,6 +521,26 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
521521 }
522522 }
523523
524+ // Unify the whole function signature. We need to do this to fully infer
525+ // the lifetimes of the return type, but do this after unifying just the
526+ // return types, since we want to avoid duplicating errors from
527+ // `compare_predicate_entailment`.
528+ match infcx
529+ . at ( & cause, param_env)
530+ . eq ( tcx. mk_fn_ptr ( ty:: Binder :: dummy ( trait_sig) ) , tcx. mk_fn_ptr ( ty:: Binder :: dummy ( impl_sig) ) )
531+ {
532+ Ok ( infer:: InferOk { value : ( ) , obligations } ) => {
533+ ocx. register_obligations ( obligations) ;
534+ }
535+ Err ( terr) => {
536+ let guar = tcx. sess . delay_span_bug (
537+ return_span,
538+ format ! ( "could not unify `{trait_sig}` and `{impl_sig}`: {terr:?}" ) ,
539+ ) ;
540+ return Err ( guar) ;
541+ }
542+ }
543+
524544 // Check that all obligations are satisfied by the implementation's
525545 // RPITs.
526546 let errors = ocx. select_all_or_error ( ) ;
0 commit comments