Skip to content

Commit 25600ca

Browse files
authored
Add function_annotation to AutoEnzyme (#77)
* Add function_annotation to AutoEnzyme * Fix printing * Test warning
1 parent 1b5cad0 commit 25600ca

File tree

4 files changed

+35
-17
lines changed

4 files changed

+35
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.6.2"
6+
version = "1.7.0"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,36 +39,46 @@ struct AutoDiffractor <: AbstractADType end
3939
mode(::AutoDiffractor) = ForwardOrReverseMode()
4040

4141
"""
42-
AutoEnzyme{M}
42+
AutoEnzyme{M,A}
4343
4444
Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.
4545
4646
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4747
4848
# Constructors
4949
50-
AutoEnzyme(; mode=nothing)
50+
AutoEnzyme(; mode::M=nothing, function_annotation::Type{A}=Nothing)
51+
52+
# Type parameters
53+
54+
- `A` determines how the function `f` to differentiate is passed to Enzyme. It can be:
55+
56+
+ a subtype of `EnzymeCore.Annotation` (like `EnzymeCore.Const` or `EnzymeCore.Duplicated`) to enforce a given annotation
57+
+ `Nothing` to simply pass `f` and let Enzyme choose the most appropriate annotation
5158
5259
# Fields
5360
54-
- `mode::M`: can be either
61+
- `mode::M` determines the autodiff mode (forward or reverse). It can be:
5562
5663
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
5764
+ `nothing` to choose the best mode automatically
5865
"""
59-
struct AutoEnzyme{M} <: AbstractADType
66+
struct AutoEnzyme{M, A} <: AbstractADType
6067
mode::M
6168
end
6269

63-
function AutoEnzyme(; mode::M = nothing) where {M}
64-
return AutoEnzyme{M}(mode)
70+
function AutoEnzyme(;
71+
mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A}
72+
return AutoEnzyme{M, A}(mode)
6573
end
6674

6775
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
6876

69-
function Base.show(io::IO, backend::AutoEnzyme)
77+
function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A}
7078
print(io, AutoEnzyme, "(")
7179
!isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io))
80+
!isnothing(backend.mode) && !(A <: Nothing) && print(io, ", ")
81+
!(A <: Nothing) && print(io, "function_annotation=", repr(A; context = io))
7282
print(io, ")")
7383
end
7484

test/dense.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,26 @@ end
2828
@testset "AutoEnzyme" begin
2929
ad = AutoEnzyme()
3030
@test ad isa AbstractADType
31-
@test ad isa AutoEnzyme{Nothing}
31+
@test ad isa AutoEnzyme{Nothing, Nothing}
3232
@test mode(ad) isa ForwardOrReverseMode
3333
@test ad.mode === nothing
3434

35-
ad = AutoEnzyme(EnzymeCore.Forward)
35+
ad = AutoEnzyme(; mode = EnzymeCore.Forward)
3636
@test ad isa AbstractADType
37-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
37+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), Nothing}
3838
@test mode(ad) isa ForwardMode
3939
@test ad.mode == EnzymeCore.Forward
4040

41-
ad = AutoEnzyme(; mode = EnzymeCore.Forward)
41+
ad = AutoEnzyme(; function_annotation = EnzymeCore.Const)
4242
@test ad isa AbstractADType
43-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
44-
@test mode(ad) isa ForwardMode
45-
@test ad.mode == EnzymeCore.Forward
43+
@test ad isa AutoEnzyme{Nothing, EnzymeCore.Const}
44+
@test mode(ad) isa ForwardOrReverseMode
45+
@test ad.mode === nothing
4646

47-
ad = AutoEnzyme(; mode = EnzymeCore.Reverse)
47+
ad = AutoEnzyme(;
48+
mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated)
4849
@test ad isa AbstractADType
49-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)}
50+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated}
5051
@test mode(ad) isa ReverseMode
5152
@test ad.mode == EnzymeCore.Reverse
5253
end

test/misc.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,19 @@ end
2121
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
2222
end
2323

24+
#=
25+
The following tests are only for visual assessment of the printing behavior.
26+
They do not correspond to proper use of ADTypes constructors.
27+
Please refer to the docstrings for that.
28+
=#
2429
for backend in [
2530
# dense
2631
ADTypes.AutoChainRules(; ruleconfig = :rc),
2732
ADTypes.AutoDiffractor(),
2833
ADTypes.AutoEnzyme(),
2934
ADTypes.AutoEnzyme(mode = :forward),
35+
ADTypes.AutoEnzyme(function_annotation = Val{:forward}),
36+
ADTypes.AutoEnzyme(mode = :reverse, function_annotation = Val{:duplicated}),
3037
ADTypes.AutoFastDifferentiation(),
3138
ADTypes.AutoFiniteDiff(),
3239
ADTypes.AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh),

0 commit comments

Comments
 (0)