@@ -26,6 +26,16 @@ mod llvm_enzyme {
26
26
27
27
use crate :: errors;
28
28
29
+ pub ( crate ) fn outer_normal_attr (
30
+ kind : & P < rustc_ast:: NormalAttr > ,
31
+ id : rustc_ast:: AttrId ,
32
+ span : Span ,
33
+ ) -> rustc_ast:: Attribute {
34
+ let style = rustc_ast:: AttrStyle :: Outer ;
35
+ let kind = rustc_ast:: AttrKind :: Normal ( kind. clone ( ) ) ;
36
+ rustc_ast:: Attribute { kind, id, style, span }
37
+ }
38
+
29
39
// If we have a default `()` return type or explicitley `()` return type,
30
40
// then we often can skip doing some work.
31
41
fn has_ret ( ty : & FnRetTy ) -> bool {
@@ -224,20 +234,8 @@ mod llvm_enzyme {
224
234
. filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
225
235
. count ( ) as u32 ;
226
236
let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
227
- let new_decl_span = d_sig. span ;
228
237
let d_body = gen_enzyme_body (
229
- ecx,
230
- & x,
231
- n_active,
232
- & sig,
233
- & d_sig,
234
- primal,
235
- & new_args,
236
- span,
237
- sig_span,
238
- new_decl_span,
239
- idents,
240
- errored,
238
+ ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
241
239
) ;
242
240
let d_ident = first_ident ( & meta_item_vec[ 0 ] ) ;
243
241
@@ -270,36 +268,39 @@ mod llvm_enzyme {
270
268
} ;
271
269
let inline_never_attr = P ( ast:: NormalAttr { item : inline_item, tokens : None } ) ;
272
270
let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
273
- let attr: ast:: Attribute = ast:: Attribute {
274
- kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
275
- id : new_id,
276
- style : ast:: AttrStyle :: Outer ,
277
- span,
278
- } ;
271
+ let attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
279
272
let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
280
- let inline_never: ast:: Attribute = ast:: Attribute {
281
- kind : ast:: AttrKind :: Normal ( inline_never_attr) ,
282
- id : new_id,
283
- style : ast:: AttrStyle :: Outer ,
284
- span,
285
- } ;
273
+ let inline_never = outer_normal_attr ( & inline_never_attr, new_id, span) ;
274
+
275
+ // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
276
+ fn same_attribute ( attr : & ast:: AttrKind , item : & ast:: AttrKind ) -> bool {
277
+ match ( attr, item) {
278
+ ( ast:: AttrKind :: Normal ( a) , ast:: AttrKind :: Normal ( b) ) => {
279
+ let a = & a. item . path ;
280
+ let b = & b. item . path ;
281
+ a. segments . len ( ) == b. segments . len ( )
282
+ && a. segments . iter ( ) . zip ( b. segments . iter ( ) ) . all ( |( a, b) | a. ident == b. ident )
283
+ }
284
+ _ => false ,
285
+ }
286
+ }
286
287
287
288
// Don't add it multiple times:
288
289
let orig_annotatable: Annotatable = match item {
289
290
Annotatable :: Item ( ref mut iitem) => {
290
- if !iitem. attrs . iter ( ) . any ( |a| a . id == attr. id ) {
291
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & attr. kind ) ) {
291
292
iitem. attrs . push ( attr) ;
292
293
}
293
- if !iitem. attrs . iter ( ) . any ( |a| a . id == inline_never. id ) {
294
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & inline_never. kind ) ) {
294
295
iitem. attrs . push ( inline_never. clone ( ) ) ;
295
296
}
296
297
Annotatable :: Item ( iitem. clone ( ) )
297
298
}
298
299
Annotatable :: AssocItem ( ref mut assoc_item, i @ Impl ) => {
299
- if !assoc_item. attrs . iter ( ) . any ( |a| a . id == attr. id ) {
300
+ if !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & attr. kind ) ) {
300
301
assoc_item. attrs . push ( attr) ;
301
302
}
302
- if !assoc_item. attrs . iter ( ) . any ( |a| a . id == inline_never. id ) {
303
+ if !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & inline_never. kind ) ) {
303
304
assoc_item. attrs . push ( inline_never. clone ( ) ) ;
304
305
}
305
306
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
@@ -314,13 +315,7 @@ mod llvm_enzyme {
314
315
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
315
316
tokens : ts,
316
317
} ) ;
317
- let d_attr: ast:: Attribute = ast:: Attribute {
318
- kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
319
- id : new_id,
320
- style : ast:: AttrStyle :: Outer ,
321
- span,
322
- } ;
323
-
318
+ let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
324
319
let d_annotatable = if is_impl {
325
320
let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
326
321
let d_fn = P ( ast:: AssocItem {
@@ -361,30 +356,27 @@ mod llvm_enzyme {
361
356
ty
362
357
}
363
358
364
- /// We only want this function to type-check, since we will replace the body
365
- /// later on llvm level. Using `loop {}` does not cover all return types anymore,
366
- /// so instead we build something that should pass. We also add a inline_asm
367
- /// line, as one more barrier for rustc to prevent inlining of this function.
368
- /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
369
- /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
370
- /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
371
- /// this function (which should never happen, since it is only a placeholder).
372
- /// Finally, we also add back_box usages of all input arguments, to prevent rustc
373
- /// from optimizing any arguments away.
374
- fn gen_enzyme_body (
359
+ // Will generate a body of the type:
360
+ // ```
361
+ // {
362
+ // unsafe {
363
+ // asm!("NOP");
364
+ // }
365
+ // ::core::hint::black_box(primal(args));
366
+ // ::core::hint::black_box((args, ret));
367
+ // <This part remains to be done by following function>
368
+ // }
369
+ // ```
370
+ fn init_body_helper (
375
371
ecx : & ExtCtxt < ' _ > ,
376
- x : & AutoDiffAttrs ,
377
- n_active : u32 ,
378
- sig : & ast:: FnSig ,
379
- d_sig : & ast:: FnSig ,
372
+ span : Span ,
380
373
primal : Ident ,
381
374
new_names : & [ String ] ,
382
- span : Span ,
383
375
sig_span : Span ,
384
376
new_decl_span : Span ,
385
- idents : Vec < Ident > ,
377
+ idents : & [ Ident ] ,
386
378
errored : bool ,
387
- ) -> P < ast:: Block > {
379
+ ) -> ( P < ast:: Block > , P < ast :: Expr > , P < ast :: Expr > , P < ast :: Expr > ) {
388
380
let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
389
381
let noop = ast:: InlineAsm {
390
382
asm_macro : ast:: AsmMacro :: Asm ,
@@ -433,6 +425,51 @@ mod llvm_enzyme {
433
425
}
434
426
body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
435
427
428
+ ( body, primal_call, black_box_primal_call, blackbox_call_expr)
429
+ }
430
+
431
+ /// We only want this function to type-check, since we will replace the body
432
+ /// later on llvm level. Using `loop {}` does not cover all return types anymore,
433
+ /// so instead we manually build something that should pass the type checker.
434
+ /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
435
+ /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
436
+ /// bug would ever try to accidentially differentiate this placeholder function body.
437
+ /// Finally, we also add back_box usages of all input arguments, to prevent rustc
438
+ /// from optimizing any arguments away.
439
+ fn gen_enzyme_body (
440
+ ecx : & ExtCtxt < ' _ > ,
441
+ x : & AutoDiffAttrs ,
442
+ n_active : u32 ,
443
+ sig : & ast:: FnSig ,
444
+ d_sig : & ast:: FnSig ,
445
+ primal : Ident ,
446
+ new_names : & [ String ] ,
447
+ span : Span ,
448
+ sig_span : Span ,
449
+ idents : Vec < Ident > ,
450
+ errored : bool ,
451
+ ) -> P < ast:: Block > {
452
+ let new_decl_span = d_sig. span ;
453
+
454
+ // Just adding some default inline-asm and black_box usages to prevent early inlining
455
+ // and optimizations which alter the function signature.
456
+ //
457
+ // The bb_primal_call is the black_box call of the primal function. We keep it around,
458
+ // since it has the convenient property of returning the type of the primal function,
459
+ // Remember, we only care to match types here.
460
+ // No matter which return we pick, we always wrap it into a std::hint::black_box call,
461
+ // to prevent rustc from propagating it into the caller.
462
+ let ( mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper (
463
+ ecx,
464
+ span,
465
+ primal,
466
+ new_names,
467
+ sig_span,
468
+ new_decl_span,
469
+ & idents,
470
+ errored,
471
+ ) ;
472
+
436
473
if !has_ret ( & d_sig. decl . output ) {
437
474
// there is no return type that we have to match, () works fine.
438
475
return body;
@@ -444,7 +481,7 @@ mod llvm_enzyme {
444
481
445
482
if primal_ret && n_active == 0 && x. mode . is_rev ( ) {
446
483
// We only have the primal ret.
447
- body. stmts . push ( ecx. stmt_expr ( black_box_primal_call ) ) ;
484
+ body. stmts . push ( ecx. stmt_expr ( bb_primal_call ) ) ;
448
485
return body;
449
486
}
450
487
@@ -536,11 +573,11 @@ mod llvm_enzyme {
536
573
return body;
537
574
}
538
575
[ arg] => {
539
- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
576
+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
540
577
}
541
578
args => {
542
579
let ret_tuple: P < ast:: Expr > = ecx. expr_tuple ( span, args. into ( ) ) ;
543
- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ ret_tuple] ) ;
580
+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ ret_tuple] ) ;
544
581
}
545
582
}
546
583
assert ! ( has_ret( & d_sig. decl. output) ) ;
@@ -553,7 +590,7 @@ mod llvm_enzyme {
553
590
ecx : & ExtCtxt < ' _ > ,
554
591
span : Span ,
555
592
primal : Ident ,
556
- idents : Vec < Ident > ,
593
+ idents : & [ Ident ] ,
557
594
) -> P < ast:: Expr > {
558
595
let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
559
596
if has_self {
0 commit comments