Skip to content

Commit

Permalink
Pipeline for nested enzyme differentiation (#452)
Browse files Browse the repository at this point in the history
* Pipeline for enzyme

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Nested AD

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Compiler.jl

* Update src/Compiler.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Project.toml

* Update autodiff.jl

* Update test/autodiff.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update autodiff.jl

* fixbug

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Jan 2, 2025
1 parent 327d252 commit f7e361e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.34"
Reactant_jll = "0.0.36"
Scratch = "1.2"
SpecialFunctions = "2"
Statistics = "1.10"
Expand Down
41 changes: 31 additions & 10 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,9 @@ function create_result(
return Meta.quot(tocopy)
end

const opt_passes::String = join(
# Optimization passes via transform dialect
const transform_passes::String = join(
[
"inline{default-pipeline=canonicalize max-iterations=4}",
"canonicalize,cse",
"canonicalize",
"enzyme-hlo-generate-td{" *
join(
[
Expand Down Expand Up @@ -273,9 +271,22 @@ const opt_passes::String = join(
"transform-interpreter",
"enzyme-hlo-remove-transform",
],
',',
",",
)

# Optimization passes which apply to an individual function
const func_passes::String = join(
["canonicalize,cse", "canonicalize", transform_passes], ","
)

const opt_passes::String = join(
["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ','
)

# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"

function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
pm = MLIR.IR.PassManager()
MLIR.IR.enable_verifier!(pm, enable_verifier)
Expand Down Expand Up @@ -335,7 +346,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand All @@ -351,7 +364,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :before_kernel
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand Down Expand Up @@ -381,7 +396,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :only_enzyme
run_pass_pipeline!(mod, "enzyme-batch")
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand All @@ -391,7 +408,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :after_enzyme
run_pass_pipeline!(mod, "enzyme-batch")
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand All @@ -407,7 +426,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :before_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
)
Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
if haskey(seen, prev)
return seen[prev]
end
Expand Down
11 changes: 11 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,14 @@ end
@test stret.st2 x .+ 1
@test stret.st1 === stret.st2
end

@testset "Nested AD" begin
x = ConcreteRNumber(3.1)
f(x) = x * x * x * x
df(x) = Enzyme.gradient(Reverse, f, x)[1]
res1 = @jit df(x)
@test res1 4 * 3.1^3
ddf(x) = Enzyme.gradient(Reverse, df, x)[1]
res2 = @jit ddf(x)
@test res2 4 * 3 * 3.1^2
end

2 comments on commit f7e361e

@wsmoses
Copy link
Member Author

@wsmoses wsmoses commented on f7e361e Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122287

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.14 -m "<description of version>" f7e361ed31515c58eb0e0df139c87940d8e52492
git push origin v0.2.14

Please sign in to comment.