Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollup of 10 pull requests #138792

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c58c06a
Return blocks from DropTree::build_mir
bjorn3 Mar 11, 2025
8dc0c0e
Simplify lit_to_mir_constant a bit
bjorn3 Mar 12, 2025
eca391f
Cleanup `LangString::parse`
yotamofek Mar 15, 2025
aa2c24b
uefi: Add OwnedEvent abstraction
Ayush1325 Mar 8, 2025
03ece26
update tests
ZuseZ4 Mar 17, 2025
f5c37c3
[NFC] split up gen_body_helper
ZuseZ4 Mar 17, 2025
f9d0a14
resolve repeated attribute fixme
ZuseZ4 Mar 17, 2025
5f7ff88
[NFC] use outer_normal_attr helper
ZuseZ4 Mar 17, 2025
f4c2978
[NFC] extract autodiff call lowering in cg_llvm into own function
ZuseZ4 Mar 17, 2025
47c07ed
[NFC] simplify matching
ZuseZ4 Mar 17, 2025
bcac931
Update test for SGX now implementing read_buf
thaliaarchi Mar 18, 2025
69a3ad0
Add `MutMirVisitor`
makai410 Mar 18, 2025
ad315f6
Add test for `MutMirVisitor`
makai410 Mar 18, 2025
81b2d55
addressing feedback, removing unused arg
ZuseZ4 Mar 18, 2025
526a689
std: fs: uefi: Implement canonicalize
Ayush1325 Mar 18, 2025
456eedc
std: fs: uefi: Make lstat call stat
Ayush1325 Mar 18, 2025
80229a2
std: fs: uefi: Implement OpenOptions
Ayush1325 Mar 18, 2025
a3669b8
updated compiler tests for rustc_intrinsic'
BLANKatGITHUB Mar 12, 2025
9b88fd0
Update GCC submodule
Kobzol Mar 19, 2025
485c14f
Make `crate_hash` not iterate over `hir_crate` owners anymore
oli-obk Mar 19, 2025
6d0c404
Rollup merge of #138236 - Ayush1325:uefi-event, r=petrochenkov
Kobzol Mar 21, 2025
0ea049b
Rollup merge of #138364 - BLANKatGITHUB:compiler, r=RalfJung
Kobzol Mar 21, 2025
bec1b11
Rollup merge of #138410 - bjorn3:misc_cleanups, r=compiler-errors
Kobzol Mar 21, 2025
0e34454
Rollup merge of #138535 - yotamofek:pr/rustdoc/lang-string-parse-clea…
Kobzol Mar 21, 2025
ee2615f
Rollup merge of #138536 - makai410:mut-mir-visitor, r=celinval
Kobzol Mar 21, 2025
3d6105b
Rollup merge of #138627 - EnzymeAD:autodiff-cleanups, r=oli-obk
Kobzol Mar 21, 2025
4787eed
Rollup merge of #138631 - thaliaarchi:sgx-read-buf-test, r=workingjub…
Kobzol Mar 21, 2025
79f5a8c
Rollup merge of #138662 - Ayush1325:uefi-fs-1, r=nicholasbishop,petro…
Kobzol Mar 21, 2025
b88c7f9
Rollup merge of #138709 - Kobzol:update-gcc, r=GuillaumeGomez
Kobzol Mar 21, 2025
d5b4a5f
Rollup merge of #138750 - oli-obk:decouple-hir-queries, r=fee1-dead
Kobzol Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 95 additions & 58 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ mod llvm_enzyme {

use crate::errors;

pub(crate) fn outer_normal_attr(
kind: &P<rustc_ast::NormalAttr>,
id: rustc_ast::AttrId,
span: Span,
) -> rustc_ast::Attribute {
let style = rustc_ast::AttrStyle::Outer;
let kind = rustc_ast::AttrKind::Normal(kind.clone());
rustc_ast::Attribute { kind, id, style, span }
}

// If we have a default `()` return type or explicitley `()` return type,
// then we often can skip doing some work.
fn has_ret(ty: &FnRetTy) -> bool {
Expand Down Expand Up @@ -224,20 +234,8 @@ mod llvm_enzyme {
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let new_decl_span = d_sig.span;
let d_body = gen_enzyme_body(
ecx,
&x,
n_active,
&sig,
&d_sig,
primal,
&new_args,
span,
sig_span,
new_decl_span,
idents,
errored,
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
);
let d_ident = first_ident(&meta_item_vec[0]);

Expand Down Expand Up @@ -270,36 +268,39 @@ mod llvm_enzyme {
};
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let attr: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(inline_never_attr),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);

// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
match (attr, item) {
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
let a = &a.item.path;
let b = &b.item.path;
a.segments.len() == b.segments.len()
&& a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
}
_ => false,
}
}

// Don't add it multiple times:
let orig_annotatable: Annotatable = match item {
Annotatable::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| a.id == attr.id) {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
iitem.attrs.push(inline_never.clone());
}
Annotatable::Item(iitem.clone())
}
Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
assoc_item.attrs.push(attr);
}
if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
assoc_item.attrs.push(inline_never.clone());
}
Annotatable::AssocItem(assoc_item.clone(), i)
Expand All @@ -314,13 +315,7 @@ mod llvm_enzyme {
delim: rustc_ast::token::Delimiter::Parenthesis,
tokens: ts,
});
let d_attr: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};

let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let d_annotatable = if is_impl {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
Expand Down Expand Up @@ -361,30 +356,27 @@ mod llvm_enzyme {
ty
}

/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we build something that should pass. We also add a inline_asm
/// line, as one more barrier for rustc to prevent inlining of this function.
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
/// this function (which should never happen, since it is only a placeholder).
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
// Will generate a body of the type:
// ```
// {
// unsafe {
// asm!("NOP");
// }
// ::core::hint::black_box(primal(args));
// ::core::hint::black_box((args, ret));
// <This part remains to be done by following function>
// }
// ```
fn init_body_helper(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
span: Span,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
new_decl_span: Span,
idents: Vec<Ident>,
idents: &[Ident],
errored: bool,
) -> P<ast::Block> {
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
asm_macro: ast::AsmMacro::Asm,
Expand Down Expand Up @@ -433,6 +425,51 @@ mod llvm_enzyme {
}
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));

(body, primal_call, black_box_primal_call, blackbox_call_expr)
}

/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we manually build something that should pass the type checker.
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
/// bug would ever try to accidentially differentiate this placeholder function body.
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
idents: Vec<Ident>,
errored: bool,
) -> P<ast::Block> {
let new_decl_span = d_sig.span;

// Just adding some default inline-asm and black_box usages to prevent early inlining
// and optimizations which alter the function signature.
//
// The bb_primal_call is the black_box call of the primal function. We keep it around,
// since it has the convenient property of returning the type of the primal function,
// Remember, we only care to match types here.
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
// to prevent rustc from propagating it into the caller.
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
ecx,
span,
primal,
new_names,
sig_span,
new_decl_span,
&idents,
errored,
);

if !has_ret(&d_sig.decl.output) {
// there is no return type that we have to match, () works fine.
return body;
Expand All @@ -444,7 +481,7 @@ mod llvm_enzyme {

if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(black_box_primal_call));
body.stmts.push(ecx.stmt_expr(bb_primal_call));
return body;
}

Expand Down Expand Up @@ -536,11 +573,11 @@ mod llvm_enzyme {
return body;
}
[arg] => {
ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![arg.clone()]);
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
}
args => {
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![ret_tuple]);
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
}
}
assert!(has_ret(&d_sig.decl.output));
Expand All @@ -553,7 +590,7 @@ mod llvm_enzyme {
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
idents: Vec<Ident>,
idents: &[Ident],
) -> P<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
Expand Down
Loading
Loading