@@ -486,12 +486,13 @@ include("llvm/transforms.jl")
486486include (" llvm/passes.jl" )
487487include (" typeutils/make_zero.jl" )
488488
489- function nested_codegen! (mode:: API.CDerivativeMode , mod:: LLVM.Module , @nospecialize (f), @nospecialize (tt:: Type ), world :: UInt )
490- funcspec = my_methodinstance (mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (f), tt, world)
491- nested_codegen! (mode, mod, funcspec, world )
489+ function nested_codegen! (ctx :: EnzymeContext , mode:: API.CDerivativeMode , mod:: LLVM.Module , @nospecialize (f), @nospecialize (tt:: Type ))
490+ funcspec = my_methodinstance (mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (f), tt, ctx . world)
491+ nested_codegen! (ctx, mode, mod, funcspec)
492492end
493493
494494function prepare_llvm (interp, mod:: LLVM.Module , job, meta)
495+ # TODO : remove enzymejl_world
495496 for f in functions (mod)
496497 attributes = function_attributes (f)
497498 push! (attributes, StringAttribute (" enzymejl_world" , string (job. world)))
@@ -1234,11 +1235,12 @@ const DumpPreNestedOpt = Ref(false)
12341235const DumpPostNestedOpt = Ref (false )
12351236
12361237function nested_codegen! (
1238+ ctx:: EnzymeContext ,
12371239 mode:: API.CDerivativeMode ,
12381240 mod:: LLVM.Module ,
12391241 funcspec:: Core.MethodInstance ,
1240- world:: UInt ,
12411242)
1243+ world = ctx. world
12421244 # TODO : Put a cache here index on `mod` and f->tt
12431245
12441246
@@ -1254,6 +1256,7 @@ function nested_codegen!(
12541256 GPUCompiler. prepare_job! (job)
12551257 otherMod, meta = GPUCompiler. emit_llvm (job)
12561258
1259+ # TODO : interp should be cached since it contains internal caches
12571260 interp = GPUCompiler. get_interpreter (job)
12581261 prepare_llvm (interp, otherMod, job, meta)
12591262
@@ -2395,6 +2398,7 @@ const DumpPostEnzyme = Ref(false)
23952398const DumpPostWrap = Ref (false )
23962399
23972400function enzyme! (
2401+ enzyme_context:: EnzymeContext ,
23982402 job:: CompilerJob ,
23992403 interp,
24002404 mod:: LLVM.Module ,
@@ -2510,7 +2514,6 @@ function enzyme!(
25102514 convert (API. CDIFFE_TYPE, rt)
25112515 end
25122516
2513- enzyme_context = EnzymeContext (job. world)
25142517 GC. @preserve enzyme_context begin
25152518 LLVM. @dispose logic = Logic (enzyme_context) begin
25162519
@@ -2580,6 +2583,7 @@ function enzyme!(
25802583
25812584 if wrap
25822585 augmented_primalf = create_abi_wrapper (
2586+ enzyme_context,
25832587 augmented_primalf,
25842588 TT,
25852589 rt,
@@ -2589,7 +2593,6 @@ function enzyme!(
25892593 width,
25902594 returnPrimal,
25912595 shadow_init,
2592- world,
25932596 interp,
25942597 runtimeActivity,
25952598 )
@@ -2622,6 +2625,7 @@ function enzyme!(
26222625 ) #= atomicAdd=#
26232626 if wrap
26242627 adjointf = create_abi_wrapper (
2628+ enzyme_context,
26252629 adjointf,
26262630 TT,
26272631 rt,
@@ -2631,7 +2635,6 @@ function enzyme!(
26312635 width,
26322636 false ,
26332637 shadow_init,
2634- world,
26352638 interp,
26362639 runtimeActivity
26372640 ) #= returnPrimal=#
@@ -2663,6 +2666,7 @@ function enzyme!(
26632666 augmented_primalf = nothing
26642667 if wrap
26652668 adjointf = create_abi_wrapper (
2669+ enzyme_context,
26662670 adjointf,
26672671 TT,
26682672 rt,
@@ -2672,7 +2676,6 @@ function enzyme!(
26722676 width,
26732677 returnPrimal,
26742678 shadow_init,
2675- world,
26762679 interp,
26772680 runtimeActivity
26782681 )
@@ -2708,6 +2711,7 @@ function enzyme!(
27082711 if wrap
27092712 pf = adjointf
27102713 adjointf = create_abi_wrapper (
2714+ enzyme_context,
27112715 adjointf,
27122716 TT,
27132717 rt,
@@ -2717,7 +2721,6 @@ function enzyme!(
27172721 width,
27182722 returnPrimal,
27192723 shadow_init,
2720- world,
27212724 interp,
27222725 runtimeActivity
27232726 )
@@ -2792,6 +2795,7 @@ function set_subprogram!(f::LLVM.Function, sp)
27922795end
27932796
27942797function create_abi_wrapper (
2798+ ctx:: EnzymeContext ,
27952799 enzymefn:: LLVM.Function ,
27962800 @nospecialize (TT:: Type ),
27972801 @nospecialize (rettype:: Type ),
@@ -2801,10 +2805,10 @@ function create_abi_wrapper(
28012805 width:: Int ,
28022806 returnPrimal:: Bool ,
28032807 shadow_init:: Bool ,
2804- world:: UInt ,
28052808 interp,
28062809 runtime_activity:: Bool
28072810)
2811+ world = ctx. world
28082812 is_adjoint = Mode == API. DEM_ReverseModeGradient || Mode == API. DEM_ReverseModeCombined
28092813 is_split = Mode == API. DEM_ReverseModeGradient || Mode == API. DEM_ReverseModePrimal
28102814 needs_tape = Mode == API. DEM_ReverseModeGradient
@@ -3087,6 +3091,7 @@ function create_abi_wrapper(
30873091 realparms = LLVM. Value[]
30883092 i = 1
30893093
3094+ # TODO (vchuravy): remove
30903095 for attr in collect (function_attributes (enzymefn))
30913096 if kind (attr) == " enzymejl_world"
30923097 push! (function_attributes (llvm_f), attr)
@@ -3231,7 +3236,7 @@ function create_abi_wrapper(
32313236 elseif T <: BatchDuplicatedFunc
32323237 Func = get_func (T)
32333238 funcspec = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
3234- llvmf = nested_codegen! (Mode, mod, funcspec, world )
3239+ llvmf = nested_codegen! (ctx, Mode, mod, funcspec)
32353240 push! (function_attributes (llvmf), EnumAttribute (" alwaysinline" , 0 ))
32363241 Func_RT = return_type (interp, funcspec)
32373242 @assert Func_RT == NTuple{width,T′}
@@ -5099,15 +5104,17 @@ end
50995104 end
51005105 end
51015106
5107+ ctx = EnzymeContext (job. world)
51025108 if params. run_enzyme
51035109 # Generate the adjoint
51045110 memcpy_alloca_to_loadstore (mod)
51055111 force_recompute! (mod)
51065112 API. EnzymeDetectReadonlyOrThrow (mod)
51075113
51085114 adjointf, augmented_primalf, TapeType = enzyme! (
5115+ ctx,
51095116 job,
5110- interp,
5117+ interp,
51115118 mod,
51125119 primalf,
51135120 TT,
@@ -5205,7 +5212,7 @@ end
52055212 fname = String (name) * pf
52065213 if haskey (functions (mod), fname)
52075214 funcspec = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job. world)
5208- llvmf = nested_codegen! (mode, mod, funcspec, job . world )
5215+ llvmf = nested_codegen! (ctx, mode, mod, funcspec)
52095216 push! (function_attributes (llvmf), StringAttribute (" implements" , fname))
52105217 end
52115218 end
0 commit comments