Skip to content

Commit 6548485

Browse files
committed
Lower return types for gen fn to impl Iterator
1 parent 869a494 commit 6548485

File tree

7 files changed

+167
-80
lines changed

7 files changed

+167
-80
lines changed

compiler/rustc_ast_lowering/src/item.rs

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::FnReturnTransformation;
2+
13
use super::errors::{InvalidAbi, InvalidAbiReason, InvalidAbiSuggestion, MisplacedRelaxTraitBound};
24
use super::ResolverAstLoweringExt;
35
use super::{AstOwner, ImplTraitContext, ImplTraitPosition};
@@ -208,13 +210,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
208210
// only cares about the input argument patterns in the function
209211
// declaration (decl), not the return types.
210212
let asyncness = header.asyncness;
211-
let body_id =
212-
this.lower_maybe_async_body(span, hir_id, decl, asyncness, body.as_deref());
213+
let genness = header.genness;
214+
let body_id = this.lower_maybe_coroutine_body(
215+
span,
216+
hir_id,
217+
decl,
218+
asyncness,
219+
genness,
220+
body.as_deref(),
221+
);
213222

214223
let itctx = ImplTraitContext::Universal;
215224
let (generics, decl) =
216225
this.lower_generics(generics, header.constness, id, &itctx, |this| {
217-
let ret_id = asyncness.opt_return_id();
226+
let ret_id = asyncness
227+
.opt_return_id()
228+
.map(|(node_id, span)| {
229+
crate::FnReturnTransformation::Async(node_id, span)
230+
})
231+
.or_else(|| match genness {
232+
Gen::Yes { span, closure_id: _, return_impl_trait_id } => {
233+
Some(crate::FnReturnTransformation::Iterator(
234+
return_impl_trait_id,
235+
span,
236+
))
237+
}
238+
_ => None,
239+
});
218240
this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, ret_id)
219241
});
220242
let sig = hir::FnSig {
@@ -733,20 +755,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
733755
sig,
734756
i.id,
735757
FnDeclKind::Trait,
736-
asyncness.opt_return_id(),
758+
asyncness
759+
.opt_return_id()
760+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
737761
);
738762
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false)
739763
}
740764
AssocItemKind::Fn(box Fn { sig, generics, body: Some(body), .. }) => {
741765
let asyncness = sig.header.asyncness;
742-
let body_id =
743-
self.lower_maybe_async_body(i.span, hir_id, &sig.decl, asyncness, Some(body));
766+
let genness = sig.header.genness;
767+
let body_id = self.lower_maybe_coroutine_body(
768+
i.span,
769+
hir_id,
770+
&sig.decl,
771+
asyncness,
772+
genness,
773+
Some(body),
774+
);
744775
let (generics, sig) = self.lower_method_sig(
745776
generics,
746777
sig,
747778
i.id,
748779
FnDeclKind::Trait,
749-
asyncness.opt_return_id(),
780+
asyncness
781+
.opt_return_id()
782+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
750783
);
751784
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true)
752785
}
@@ -836,19 +869,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
836869
),
837870
AssocItemKind::Fn(box Fn { sig, generics, body, .. }) => {
838871
let asyncness = sig.header.asyncness;
839-
let body_id = self.lower_maybe_async_body(
872+
let genness = sig.header.genness;
873+
let body_id = self.lower_maybe_coroutine_body(
840874
i.span,
841875
hir_id,
842876
&sig.decl,
843877
asyncness,
878+
genness,
844879
body.as_deref(),
845880
);
846881
let (generics, sig) = self.lower_method_sig(
847882
generics,
848883
sig,
849884
i.id,
850885
if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent },
851-
asyncness.opt_return_id(),
886+
asyncness
887+
.opt_return_id()
888+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
852889
);
853890

854891
(generics, hir::ImplItemKind::Fn(sig, body_id))
@@ -1012,16 +1049,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
10121049
})
10131050
}
10141051

1015-
fn lower_maybe_async_body(
1052+
/// Takes what may be the body of an `async fn` or a `gen fn` and wraps it in an `async {}` or
1053+
/// `gen {}` block as appropriate.
1054+
fn lower_maybe_coroutine_body(
10161055
&mut self,
10171056
span: Span,
10181057
fn_id: hir::HirId,
10191058
decl: &FnDecl,
10201059
asyncness: Async,
1060+
genness: Gen,
10211061
body: Option<&Block>,
10221062
) -> hir::BodyId {
1023-
let (closure_id, body) = match (asyncness, body) {
1024-
(Async::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
1063+
let (closure_id, body) = match (asyncness, genness, body) {
1064+
// FIXME(eholk): do something reasonable for `async gen fn`. Probably that's an error
1065+
// for now since it's not supported.
1066+
(Async::Yes { closure_id, .. }, _, Some(body))
1067+
| (_, Gen::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
10251068
_ => return self.lower_fn_body_block(span, decl, body),
10261069
};
10271070

@@ -1164,44 +1207,55 @@ impl<'hir> LoweringContext<'_, 'hir> {
11641207
parameters.push(new_parameter);
11651208
}
11661209

1167-
let async_expr = this.make_async_expr(
1168-
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1169-
closure_id,
1170-
None,
1171-
body.span,
1172-
hir::CoroutineSource::Fn,
1173-
|this| {
1174-
// Create a block from the user's function body:
1175-
let user_body = this.lower_block_expr(body);
1210+
let mkbody = |this: &mut LoweringContext<'_, 'hir>| {
1211+
// Create a block from the user's function body:
1212+
let user_body = this.lower_block_expr(body);
11761213

1177-
// Transform into `drop-temps { <user-body> }`, an expression:
1178-
let desugared_span =
1179-
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
1180-
let user_body =
1181-
this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
1214+
// Transform into `drop-temps { <user-body> }`, an expression:
1215+
let desugared_span =
1216+
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
1217+
let user_body = this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
11821218

1183-
// As noted above, create the final block like
1184-
//
1185-
// ```
1186-
// {
1187-
// let $param_pattern = $raw_param;
1188-
// ...
1189-
// drop-temps { <user-body> }
1190-
// }
1191-
// ```
1192-
let body = this.block_all(
1193-
desugared_span,
1194-
this.arena.alloc_from_iter(statements),
1195-
Some(user_body),
1196-
);
1219+
// As noted above, create the final block like
1220+
//
1221+
// ```
1222+
// {
1223+
// let $param_pattern = $raw_param;
1224+
// ...
1225+
// drop-temps { <user-body> }
1226+
// }
1227+
// ```
1228+
let body = this.block_all(
1229+
desugared_span,
1230+
this.arena.alloc_from_iter(statements),
1231+
Some(user_body),
1232+
);
11971233

1198-
this.expr_block(body)
1199-
},
1200-
);
1234+
this.expr_block(body)
1235+
};
1236+
let coroutine_expr = match (asyncness, genness) {
1237+
(Async::Yes { .. }, _) => this.make_async_expr(
1238+
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1239+
closure_id,
1240+
None,
1241+
body.span,
1242+
hir::CoroutineSource::Fn,
1243+
mkbody,
1244+
),
1245+
(_, Gen::Yes { .. }) => this.make_gen_expr(
1246+
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1247+
closure_id,
1248+
None,
1249+
body.span,
1250+
hir::CoroutineSource::Fn,
1251+
mkbody,
1252+
),
1253+
_ => unreachable!("we must have either an async fn or a gen fn"),
1254+
};
12011255

12021256
let hir_id = this.lower_node_id(closure_id);
12031257
this.maybe_forward_track_caller(body.span, fn_id, hir_id);
1204-
let expr = hir::Expr { hir_id, kind: async_expr, span: this.lower_span(body.span) };
1258+
let expr = hir::Expr { hir_id, kind: coroutine_expr, span: this.lower_span(body.span) };
12051259

12061260
(this.arena.alloc_from_iter(parameters), expr)
12071261
})
@@ -1213,13 +1267,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
12131267
sig: &FnSig,
12141268
id: NodeId,
12151269
kind: FnDeclKind,
1216-
is_async: Option<(NodeId, Span)>,
1270+
transform_return_type: Option<FnReturnTransformation>,
12171271
) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) {
12181272
let header = self.lower_fn_header(sig.header);
12191273
let itctx = ImplTraitContext::Universal;
12201274
let (generics, decl) =
12211275
self.lower_generics(generics, sig.header.constness, id, &itctx, |this| {
1222-
this.lower_fn_decl(&sig.decl, id, sig.span, kind, is_async)
1276+
this.lower_fn_decl(&sig.decl, id, sig.span, kind, transform_return_type)
12231277
});
12241278
(generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) })
12251279
}

compiler/rustc_ast_lowering/src/lib.rs

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,21 @@ enum ParenthesizedGenericArgs {
494494
Err,
495495
}
496496

497+
/// Describes a return type transformation that can be performed by `LoweringContext::lower_fn_decl`
498+
#[derive(Debug)]
499+
enum FnReturnTransformation {
500+
/// Replaces a return type `T` with `impl Future<Output = T>`.
501+
///
502+
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
503+
/// `async` keyword.
504+
Async(NodeId, Span),
505+
/// Replaces a return type `T` with `impl Iterator<Item = T>`.
506+
///
507+
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
508+
/// `gen` keyword.
509+
Iterator(NodeId, Span),
510+
}
511+
497512
impl<'a, 'hir> LoweringContext<'a, 'hir> {
498513
fn create_def(
499514
&mut self,
@@ -1783,21 +1798,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
17831798
}))
17841799
}
17851800

1786-
// Lowers a function declaration.
1787-
//
1788-
// `decl`: the unlowered (AST) function declaration.
1789-
// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given `NodeId`.
1790-
// `make_ret_async`: if `Some`, converts `-> T` into `-> impl Future<Output = T>` in the
1791-
// return type. This is used for `async fn` declarations. The `NodeId` is the ID of the
1792-
// return type `impl Trait` item, and the `Span` points to the `async` keyword.
1801+
/// Lowers a function declaration.
1802+
///
1803+
/// `decl`: the unlowered (AST) function declaration.
1804+
///
1805+
/// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given
1806+
/// `NodeId`.
1807+
///
1808+
/// `transform_return_type`: if `Some`, applies some conversion to the return type, such as is
1809+
/// needed for `async fn` and `gen fn`. See [`FnReturnTransformation`] for more details.
17931810
#[instrument(level = "debug", skip(self))]
17941811
fn lower_fn_decl(
17951812
&mut self,
17961813
decl: &FnDecl,
17971814
fn_node_id: NodeId,
17981815
fn_span: Span,
17991816
kind: FnDeclKind,
1800-
make_ret_async: Option<(NodeId, Span)>,
1817+
transform_return_type: Option<FnReturnTransformation>,
18011818
) -> &'hir hir::FnDecl<'hir> {
18021819
let c_variadic = decl.c_variadic();
18031820

@@ -1826,11 +1843,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18261843
self.lower_ty_direct(&param.ty, &itctx)
18271844
}));
18281845

1829-
let output = if let Some((ret_id, _span)) = make_ret_async {
1830-
let fn_def_id = self.local_def_id(fn_node_id);
1831-
self.lower_async_fn_ret_ty(&decl.output, fn_def_id, ret_id, kind, fn_span)
1832-
} else {
1833-
match &decl.output {
1846+
let output = match transform_return_type {
1847+
Some(transform) => {
1848+
let fn_def_id = self.local_def_id(fn_node_id);
1849+
self.lower_coroutine_fn_ret_ty(&decl.output, fn_def_id, transform, kind, fn_span)
1850+
}
1851+
None => match &decl.output {
18341852
FnRetTy::Ty(ty) => {
18351853
let context = if kind.return_impl_trait_allowed() {
18361854
let fn_def_id = self.local_def_id(fn_node_id);
@@ -1854,7 +1872,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18541872
hir::FnRetTy::Return(self.lower_ty(ty, &context))
18551873
}
18561874
FnRetTy::Default(span) => hir::FnRetTy::DefaultReturn(self.lower_span(*span)),
1857-
}
1875+
},
18581876
};
18591877

18601878
self.arena.alloc(hir::FnDecl {
@@ -1893,17 +1911,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18931911
// `fn_node_id`: `NodeId` of the parent function (used to create child impl trait definition)
18941912
// `opaque_ty_node_id`: `NodeId` of the opaque `impl Trait` type that should be created
18951913
#[instrument(level = "debug", skip(self))]
1896-
fn lower_async_fn_ret_ty(
1914+
fn lower_coroutine_fn_ret_ty(
18971915
&mut self,
18981916
output: &FnRetTy,
18991917
fn_def_id: LocalDefId,
1900-
opaque_ty_node_id: NodeId,
1918+
transform: FnReturnTransformation,
19011919
fn_kind: FnDeclKind,
19021920
fn_span: Span,
19031921
) -> hir::FnRetTy<'hir> {
19041922
let span = self.lower_span(fn_span);
19051923
let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None);
19061924

1925+
let opaque_ty_node_id = match transform {
1926+
FnReturnTransformation::Async(opaque_ty_node_id, _)
1927+
| FnReturnTransformation::Iterator(opaque_ty_node_id, _) => opaque_ty_node_id,
1928+
};
1929+
19071930
let captured_lifetimes: Vec<_> = self
19081931
.resolver
19091932
.take_extra_lifetime_params(opaque_ty_node_id)
@@ -1919,8 +1942,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19191942
span,
19201943
opaque_ty_span,
19211944
|this| {
1922-
let future_bound = this.lower_async_fn_output_type_to_future_bound(
1945+
let future_bound = this.lower_coroutine_fn_output_type_to_future_bound(
19231946
output,
1947+
transform,
19241948
span,
19251949
ImplTraitContext::ReturnPositionOpaqueTy {
19261950
origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
@@ -1936,9 +1960,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19361960
}
19371961

19381962
/// Transforms `-> T` into `Future<Output = T>`.
1939-
fn lower_async_fn_output_type_to_future_bound(
1963+
fn lower_coroutine_fn_output_type_to_future_bound(
19401964
&mut self,
19411965
output: &FnRetTy,
1966+
transform: FnReturnTransformation,
19421967
span: Span,
19431968
nested_impl_trait_context: ImplTraitContext,
19441969
) -> hir::GenericBound<'hir> {
@@ -1953,17 +1978,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19531978
FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])),
19541979
};
19551980

1956-
// "<Output = T>"
1981+
// "<Output|Item = T>"
1982+
let (symbol, lang_item) = match transform {
1983+
FnReturnTransformation::Async(..) => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
1984+
FnReturnTransformation::Iterator(..) => {
1985+
(hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator)
1986+
}
1987+
};
1988+
19571989
let future_args = self.arena.alloc(hir::GenericArgs {
19581990
args: &[],
1959-
bindings: arena_vec![self; self.output_ty_binding(span, output_ty)],
1991+
bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)],
19601992
parenthesized: hir::GenericArgsParentheses::No,
19611993
span_ext: DUMMY_SP,
19621994
});
19631995

19641996
hir::GenericBound::LangItemTrait(
1965-
// ::std::future::Future<future_params>
1966-
hir::LangItem::Future,
1997+
lang_item,
19671998
self.lower_span(span),
19681999
self.next_id(),
19692000
future_args,

0 commit comments

Comments
 (0)