@@ -6,7 +6,7 @@ use std::{fs, slice, str};
6
6
7
7
use libc:: { c_char, c_int, c_uint, c_void, size_t} ;
8
8
use llvm:: {
9
- IntPredicate , LLVMRustLLVMHasZlibCompressionForDebugSymbols ,
9
+ LLVMRustLLVMHasZlibCompressionForDebugSymbols ,
10
10
LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
11
11
} ;
12
12
use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
@@ -44,19 +44,14 @@ use crate::errors::{
44
44
use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
45
45
use crate :: llvm:: {
46
46
self , AttributeKind , DiagnosticInfo ,
47
- LLVMAppendBasicBlockInContext , LLVMBuildCall2 ,
48
- LLVMBuildCondBr , LLVMBuildExtractValue , LLVMBuildICmp , LLVMBuildRet , LLVMBuildRetVoid ,
49
- LLVMCountParams , LLVMCountStructElementTypes , LLVMCreateBuilderInContext ,
50
- LLVMCreateStringAttribute , LLVMDisposeBuilder , LLVMDumpModule , LLVMGetFirstBasicBlock ,
51
- LLVMGetFirstFunction , LLVMGetNextFunction , LLVMGetParams , LLVMGetReturnType ,
52
- LLVMGetStringAttributeAtIndex , LLVMGlobalGetValueType , LLVMIsEnumAttribute ,
53
- LLVMIsStringAttribute , LLVMMetadataAsValue , LLVMPositionBuilderAtEnd ,
47
+ LLVMCreateStringAttribute , LLVMDumpModule ,
48
+ LLVMGetFirstFunction , LLVMGetNextFunction ,
49
+ LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute ,
50
+ LLVMIsStringAttribute ,
54
51
LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
55
- LLVMRustAddFunctionAttributes , LLVMRustDIGetInstMetadata , LLVMRustEraseInstBefore ,
56
- LLVMRustEraseInstFromParent , LLVMRustGetEnumAttributeAtIndex , LLVMRustGetFunctionType ,
57
- LLVMRustGetLastInstruction , LLVMRustGetTerminator , LLVMRustHasMetadata ,
58
- LLVMRustRemoveEnumAttributeAtIndex , LLVMVerifyFunction , LLVMVoidTypeInContext , PassManager ,
59
- Value ,
52
+ LLVMRustAddFunctionAttributes ,
53
+ LLVMRustGetEnumAttributeAtIndex ,
54
+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
60
55
} ;
61
56
use crate :: type_:: Type ;
62
57
use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
@@ -652,6 +647,14 @@ pub(crate) unsafe fn llvm_optimize(
652
647
result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
653
648
}
654
649
650
+ #[ allow( non_snake_case) ]
651
+ unsafe fn EnzymeSetCLBool ( ptr : * mut c_int , val : u8 ) {
652
+ unsafe {
653
+ //let ptr = ptr as *mut c_int;
654
+ * ptr = val as c_int ;
655
+ }
656
+ }
657
+
655
658
pub ( crate ) fn differentiate (
656
659
module : & ModuleCodegen < ModuleLlvm > ,
657
660
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -672,8 +675,13 @@ pub(crate) fn differentiate(
672
675
673
676
if ad. contains ( & AutoDiff :: LooseTypes ) {
674
677
dbg ! ( "Setting loose types to true" ) ;
675
- //llvm::set_loose_types(true);
676
- todo ! ( ) ;
678
+ //extern "C" {
679
+ #[ allow( non_upper_case_globals) ]
680
+ static mut looseTypeAnalysis: c_int = 0 ;
681
+ //}
682
+ unsafe {
683
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( looseTypeAnalysis) , true as u8 ) ;
684
+ }
677
685
}
678
686
679
687
// Before dumping the module, we want all the tt to become part of the module.
@@ -702,20 +710,13 @@ pub(crate) fn differentiate(
702
710
} ) ) ,
703
711
} ;
704
712
705
- // Before dumping the module, we also might want to add dummy functions, which will
706
- // trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary.
707
- // This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in
708
- // Enzyme's compiler explorer.
709
- if ad. contains ( & AutoDiff :: OPT2 ) {
710
- dbg ! ( "Enable extra debug helper to debug Enzyme through the opt plugin" ) ;
711
- crate :: builder:: add_opt_dbg_helper2 (
712
- llmod,
713
- llcx,
714
- fn_def,
715
- fn_target,
716
- item. attrs . clone ( ) ,
717
- ) ;
718
- }
713
+ crate :: builder:: add_opt_dbg_helper2 (
714
+ llmod,
715
+ llcx,
716
+ fn_def,
717
+ fn_target,
718
+ item. attrs . clone ( ) ,
719
+ ) ;
719
720
}
720
721
721
722
if ad. contains ( & AutoDiff :: PrintModBefore ) || ad. contains ( & AutoDiff :: OPT ) {
@@ -725,36 +726,17 @@ pub(crate) fn differentiate(
725
726
}
726
727
727
728
if ad. contains ( & AutoDiff :: Inline ) {
728
- dbg ! ( "Setting inline to true" ) ;
729
- //llvm::set_inline(true);
730
- todo ! ( ) ;
731
- }
732
-
733
- if ad. contains ( & AutoDiff :: RuntimeActivity ) {
734
- dbg ! ( "Setting runtime activity check to true" ) ;
735
- //llvm::set_runtime_activity_check(true);
736
- todo ! ( ) ;
737
- }
738
-
739
- for val in ad {
740
- match & val {
741
- AutoDiff :: TTDepth ( depth) => {
742
- assert ! ( * depth >= 1 ) ;
743
- //llvm::set_max_int_offset(*depth);
744
- todo ! ( ) ;
745
- }
746
- AutoDiff :: TTWidth ( width) => {
747
- assert ! ( * width >= 1 ) ;
748
- todo ! ( ) ;
749
- //llvm::set_max_type_offset(*width);
750
- }
751
- _ => { }
729
+ trace ! ( "Setting Enzyme inline to true" ) ;
730
+ //extern "C" {
731
+ // static mut EnzymeInline: c_void;
732
+ //}
733
+ #[ allow( non_upper_case_globals) ]
734
+ static mut EnzymeInline : c_int = 0 ;
735
+ unsafe {
736
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymeInline ) , true as u8 ) ;
752
737
}
753
738
}
754
739
755
- let differentiate = !diff_items. is_empty ( ) ;
756
- let _fnc_opt = ad. contains ( & AutoDiff :: EnableFncOpt ) ;
757
-
758
740
unsafe {
759
741
let mut f = LLVMGetFirstFunction ( llmod) ;
760
742
loop {
@@ -787,9 +769,6 @@ pub(crate) fn differentiate(
787
769
}
788
770
}
789
771
790
- //if ad.contains(&AutoDiff::NoModOptAfter) || !differentiate {
791
- // trace!("Skipping module optimization after automatic differentiation");
792
- //} else {
793
772
if let Some ( opt_level) = config. opt_level {
794
773
let opt_stage = match cgcx. lto {
795
774
Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
0 commit comments