Skip to content

Commit f49e5b0

Browse files
committed
pass context to enzyme!
1 parent d184159 commit f49e5b0

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

src/compiler.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,13 @@ include("llvm/transforms.jl")
486486
include("llvm/passes.jl")
487487
include("typeutils/make_zero.jl")
488488

489-
function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
490-
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
491-
nested_codegen!(mode, mod, funcspec, world)
489+
function nested_codegen!(ctx::EnzymeContext, mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type))
490+
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, ctx.world)
491+
nested_codegen!(ctx, mode, mod, funcspec)
492492
end
493493

494494
function prepare_llvm(interp, mod::LLVM.Module, job, meta)
495+
# TODO: remove enzymejl_world
495496
for f in functions(mod)
496497
attributes = function_attributes(f)
497498
push!(attributes, StringAttribute("enzymejl_world", string(job.world)))
@@ -1234,11 +1235,12 @@ const DumpPreNestedOpt = Ref(false)
12341235
const DumpPostNestedOpt = Ref(false)
12351236

12361237
function nested_codegen!(
1238+
ctx::EnzymeContext,
12371239
mode::API.CDerivativeMode,
12381240
mod::LLVM.Module,
12391241
funcspec::Core.MethodInstance,
1240-
world::UInt,
12411242
)
1243+
world = ctx.world
12421244
# TODO: Put a cache here index on `mod` and f->tt
12431245

12441246

@@ -1254,6 +1256,7 @@ function nested_codegen!(
12541256
GPUCompiler.prepare_job!(job)
12551257
otherMod, meta = GPUCompiler.emit_llvm(job)
12561258

1259+
# TODO: interp should be cached since it contains internal caches
12571260
interp = GPUCompiler.get_interpreter(job)
12581261
prepare_llvm(interp, otherMod, job, meta)
12591262

@@ -2395,6 +2398,7 @@ const DumpPostEnzyme = Ref(false)
23952398
const DumpPostWrap = Ref(false)
23962399

23972400
function enzyme!(
2401+
enzyme_context::EnzymeContext,
23982402
job::CompilerJob,
23992403
interp,
24002404
mod::LLVM.Module,
@@ -2510,7 +2514,6 @@ function enzyme!(
25102514
convert(API.CDIFFE_TYPE, rt)
25112515
end
25122516

2513-
enzyme_context = EnzymeContext(job.world)
25142517
GC.@preserve enzyme_context begin
25152518
LLVM.@dispose logic = Logic(enzyme_context) begin
25162519

@@ -2580,6 +2583,7 @@ function enzyme!(
25802583

25812584
if wrap
25822585
augmented_primalf = create_abi_wrapper(
2586+
enzyme_context,
25832587
augmented_primalf,
25842588
TT,
25852589
rt,
@@ -2589,7 +2593,6 @@ function enzyme!(
25892593
width,
25902594
returnPrimal,
25912595
shadow_init,
2592-
world,
25932596
interp,
25942597
runtimeActivity,
25952598
)
@@ -2622,6 +2625,7 @@ function enzyme!(
26222625
) #=atomicAdd=#
26232626
if wrap
26242627
adjointf = create_abi_wrapper(
2628+
enzyme_context,
26252629
adjointf,
26262630
TT,
26272631
rt,
@@ -2631,7 +2635,6 @@ function enzyme!(
26312635
width,
26322636
false,
26332637
shadow_init,
2634-
world,
26352638
interp,
26362639
runtimeActivity
26372640
) #=returnPrimal=#
@@ -2663,6 +2666,7 @@ function enzyme!(
26632666
augmented_primalf = nothing
26642667
if wrap
26652668
adjointf = create_abi_wrapper(
2669+
enzyme_context,
26662670
adjointf,
26672671
TT,
26682672
rt,
@@ -2672,7 +2676,6 @@ function enzyme!(
26722676
width,
26732677
returnPrimal,
26742678
shadow_init,
2675-
world,
26762679
interp,
26772680
runtimeActivity
26782681
)
@@ -2708,6 +2711,7 @@ function enzyme!(
27082711
if wrap
27092712
pf = adjointf
27102713
adjointf = create_abi_wrapper(
2714+
enzyme_context,
27112715
adjointf,
27122716
TT,
27132717
rt,
@@ -2717,7 +2721,6 @@ function enzyme!(
27172721
width,
27182722
returnPrimal,
27192723
shadow_init,
2720-
world,
27212724
interp,
27222725
runtimeActivity
27232726
)
@@ -2792,6 +2795,7 @@ function set_subprogram!(f::LLVM.Function, sp)
27922795
end
27932796

27942797
function create_abi_wrapper(
2798+
ctx::EnzymeContext,
27952799
enzymefn::LLVM.Function,
27962800
@nospecialize(TT::Type),
27972801
@nospecialize(rettype::Type),
@@ -2801,10 +2805,10 @@ function create_abi_wrapper(
28012805
width::Int,
28022806
returnPrimal::Bool,
28032807
shadow_init::Bool,
2804-
world::UInt,
28052808
interp,
28062809
runtime_activity::Bool
28072810
)
2811+
world = ctx.world
28082812
is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined
28092813
is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal
28102814
needs_tape = Mode == API.DEM_ReverseModeGradient
@@ -3087,6 +3091,7 @@ function create_abi_wrapper(
30873091
realparms = LLVM.Value[]
30883092
i = 1
30893093

3094+
# TODO(vchuravy): remove
30903095
for attr in collect(function_attributes(enzymefn))
30913096
if kind(attr) == "enzymejl_world"
30923097
push!(function_attributes(llvm_f), attr)
@@ -3231,7 +3236,7 @@ function create_abi_wrapper(
32313236
elseif T <: BatchDuplicatedFunc
32323237
Func = get_func(T)
32333238
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
3234-
llvmf = nested_codegen!(Mode, mod, funcspec, world)
3239+
llvmf = nested_codegen!(ctx, Mode, mod, funcspec)
32353240
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
32363241
Func_RT = return_type(interp, funcspec)
32373242
@assert Func_RT == NTuple{width,T′}
@@ -5099,15 +5104,17 @@ end
50995104
end
51005105
end
51015106

5107+
ctx = EnzymeContext(job.world)
51025108
if params.run_enzyme
51035109
# Generate the adjoint
51045110
memcpy_alloca_to_loadstore(mod)
51055111
force_recompute!(mod)
51065112
API.EnzymeDetectReadonlyOrThrow(mod)
51075113

51085114
adjointf, augmented_primalf, TapeType = enzyme!(
5115+
ctx,
51095116
job,
5110-
interp,
5117+
interp,
51115118
mod,
51125119
primalf,
51135120
TT,
@@ -5205,7 +5212,7 @@ end
52055212
fname = String(name) * pf
52065213
if haskey(functions(mod), fname)
52075214
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job.world)
5208-
llvmf = nested_codegen!(mode, mod, funcspec, job.world)
5215+
llvmf = nested_codegen!(ctx, mode, mod, funcspec)
52095216
push!(function_attributes(llvmf), StringAttribute("implements", fname))
52105217
end
52115218
end

src/rules/customrules.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,8 @@ end
578578

579579
curent_bb = position(B)
580580
fn = LLVM.parent(curent_bb)
581-
world = enzyme_extract_world(fn)
582-
@assert world == enzyme_context(gutils).world
583581

584-
llvmf = nested_codegen!(mode, mod, fmi, world)
582+
llvmf = nested_codegen!(enzyme_context(gutils), mode, mod, fmi)
585583

586584
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
587585

@@ -1039,8 +1037,8 @@ function enzyme_custom_common_rev(
10391037

10401038
curent_bb = position(B)
10411039
fn = LLVM.parent(curent_bb)
1042-
world = enzyme_extract_world(fn)
1043-
@assert world == enzyme_context(gutils).world
1040+
ctx = enzyme_context(gutils)
1041+
world = ctx.world
10441042

10451043
mode = get_mode(gutils)
10461044

@@ -1101,7 +1099,7 @@ function enzyme_custom_common_rev(
11011099
applicablefn = true
11021100

11031101
if forward
1104-
llvmf = nested_codegen!(mode, mod, ami, world)
1102+
llvmf = nested_codegen!(ctx, mode, mod, ami)
11051103
@assert llvmf !== nothing
11061104
rev_RT = nothing
11071105
else
@@ -1143,7 +1141,7 @@ function enzyme_custom_common_rev(
11431141

11441142
rmi = rmi::Core.MethodInstance
11451143
rev_RT = rev_RT::Type
1146-
llvmf = nested_codegen!(mode, mod, rmi, world)
1144+
llvmf = nested_codegen!(ctx, mode, mod, rmi)
11471145
end
11481146

11491147
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

src/rules/parallelrules.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,8 @@ end
550550

551551
tt = Tuple{thunkTy,dfuncT,Bool}
552552
mode = get_mode(gutils)
553-
world = enzyme_extract_world(LLVM.parent(position(B)))
554-
@assert world == enzyme_context(gutils).world
555-
entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world)
553+
ctx = enzyme_context(gutils)
554+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_fwd, tt)
556555
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
557556

558557
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))
@@ -595,9 +594,8 @@ end
595594
Bool,
596595
}
597596
mode = get_mode(gutils)
598-
world = enzyme_extract_world(LLVM.parent(position(B)))
599-
@assert world == enzyme_context(gutils).world
600-
entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world)
597+
ctx = enzyme_context(gutils)
598+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_augfwd, tt)
601599
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
602600

603601
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))
@@ -629,8 +627,6 @@ end
629627

630628
@register_rev function threadsfor_rev(B, orig, gutils, tape)
631629
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
632-
world = enzyme_extract_world(LLVM.parent(position(B)))
633-
@assert world == enzyme_context(gutils).world
634630
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
635631
return
636632
end
@@ -653,7 +649,8 @@ end
653649
Bool,
654650
}
655651
mode = get_mode(gutils)
656-
entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world)
652+
ctx = enzyme_context(gutils)
653+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_rev, tt)
657654
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
658655

659656
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))

src/rules/typeunstablerules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ end
10361036
if legal
10371037
@assert legal
10381038
world = enzyme_extract_world(LLVM.parent(position(B)))
1039-
@assert world == enzyme_context(gutils).world
1039+
@assert world == enzyme_context(gutils).world
10401040
torun = !guaranteed_nonactive(TT, world)
10411041
else
10421042
torun = true

0 commit comments

Comments
 (0)