Skip to content

Commit f5fca8b

Browse files
committed
Fix generator inlining by checking for rust-call abi and spread arg
1 parent 4d6afcb commit f5fca8b

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

compiler/rustc_mir/src/transform/inline.rs

+14-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use rustc_index::vec::Idx;
77
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
88
use rustc_middle::mir::visit::*;
99
use rustc_middle::mir::*;
10+
use rustc_middle::ty::subst::Subst;
1011
use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
1112
use rustc_span::{hygiene::ExpnKind, ExpnData, Span};
1213
use rustc_target::spec::abi::Abi;
@@ -28,6 +29,7 @@ pub struct Inline;
2829
#[derive(Copy, Clone, Debug)]
2930
struct CallSite<'tcx> {
3031
callee: Instance<'tcx>,
32+
fn_sig: ty::PolyFnSig<'tcx>,
3133
block: BasicBlock,
3234
target: Option<BasicBlock>,
3335
source_info: SourceInfo,
@@ -173,22 +175,23 @@ impl Inliner<'tcx> {
173175

174176
// Only consider direct calls to functions
175177
let terminator = bb_data.terminator();
176-
if let TerminatorKind::Call { func: ref op, ref destination, .. } = terminator.kind {
177-
if let ty::FnDef(callee_def_id, substs) = *op.ty(caller_body, self.tcx).kind() {
178-
// To resolve an instance its substs have to be fully normalized, so
179-
// we do this here.
180-
let normalized_substs = self.tcx.normalize_erasing_regions(self.param_env, substs);
178+
if let TerminatorKind::Call { ref func, ref destination, .. } = terminator.kind {
179+
let func_ty = func.ty(caller_body, self.tcx);
180+
if let ty::FnDef(def_id, substs) = *func_ty.kind() {
181+
// To resolve an instance its substs have to be fully normalized.
182+
let substs = self.tcx.normalize_erasing_regions(self.param_env, substs);
181183
let callee =
182-
Instance::resolve(self.tcx, self.param_env, callee_def_id, normalized_substs)
183-
.ok()
184-
.flatten()?;
184+
Instance::resolve(self.tcx, self.param_env, def_id, substs).ok().flatten()?;
185185

186186
if let InstanceDef::Virtual(..) | InstanceDef::Intrinsic(_) = callee.def {
187187
return None;
188188
}
189189

190+
let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, substs);
191+
190192
return Some(CallSite {
191193
callee,
194+
fn_sig,
192195
block: bb,
193196
target: destination.map(|(_, target)| target),
194197
source_info: terminator.source_info,
@@ -437,7 +440,7 @@ impl Inliner<'tcx> {
437440
};
438441

439442
// Copy the arguments if needed.
440-
let args: Vec<_> = self.make_call_args(args, &callsite, caller_body);
443+
let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body);
441444

442445
let mut integrator = Integrator {
443446
args: &args,
@@ -518,6 +521,7 @@ impl Inliner<'tcx> {
518521
args: Vec<Operand<'tcx>>,
519522
callsite: &CallSite<'tcx>,
520523
caller_body: &mut Body<'tcx>,
524+
callee_body: &Body<'tcx>,
521525
) -> Vec<Local> {
522526
let tcx = self.tcx;
523527

@@ -544,9 +548,7 @@ impl Inliner<'tcx> {
544548
// tmp2 = tuple_tmp.2
545549
//
546550
// and the vector is `[closure_ref, tmp0, tmp1, tmp2]`.
547-
// FIXME(eddyb) make this check for `"rust-call"` ABI combined with
548-
// `callee_body.spread_arg == None`, instead of special-casing closures.
549-
if tcx.is_closure(callsite.callee.def_id()) {
551+
if callsite.fn_sig.abi() == Abi::RustCall && callee_body.spread_arg.is_none() {
550552
let mut args = args.into_iter();
551553
let self_ = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
552554
let tuple = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#![feature(generators, generator_trait)]
2+
3+
use std::ops::Generator;
4+
use std::pin::Pin;
5+
6+
// EMIT_MIR inline_generator.main.Inline.diff
7+
fn main() {
8+
let _r = Pin::new(&mut g()).resume(false);
9+
}
10+
11+
#[inline(always)]
12+
pub fn g() -> impl Generator<bool> {
13+
#[inline(always)]
14+
|a| { yield if a { 7 } else { 13 } }
15+
}

0 commit comments

Comments
 (0)