Skip to content

Commit e5e26d0

Browse files
authored
bump Enzyme version in v0.2 (#132)
* bump Enzyme version, update Enzyme interface * drop testing on Julia 1.6 for v0.2 * fix disable testing on Enzyme
1 parent 4af5f82 commit e5e26d0

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

.github/workflows/CI.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
- os: macOS-latest
2626
arch: x86
2727
include:
28-
- version: '1.6'
28+
- version: '1.10'
2929
os: ubuntu-latest
3030
arch: x64
3131
- os: ubuntu-latest

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ DiffResults = "1"
3737
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
3838
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
3939
DocStringExtensions = "0.8, 0.9"
40-
Enzyme = "0.11"
40+
Enzyme = "0.13"
4141
LinearAlgebra = "1.6"
4242
ForwardDiff = "0.10.3"
4343
Flux = "0.14"

ext/AdvancedVIEnzymeExt.jl

+12-9
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,22 @@ function AdvancedVI.grad!(
2323
out::DiffResults.MutableDiffResult,
2424
args...
2525
)
26-
f(θ) =
27-
if (q isa Distributions.Distribution)
28-
-vo(alg, AdvancedVI.update(q, θ), model, args...)
29-
else
30-
-vo(alg, q(θ), model, args...)
31-
end
32-
# Use `Enzyme.ReverseWithPrimal` once it is released:
33-
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
26+
f(θ) = if (q isa Distributions.Distribution)
27+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
28+
else
29+
-vo(alg, q(θ), model, args...)
30+
end
31+
3432
y = f(θ)
3533
DiffResults.value!(out, y)
3634
dy = DiffResults.gradient(out)
3735
fill!(dy, 0)
38-
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
36+
Enzyme.autodiff(
37+
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
38+
Enzyme.Const(f),
39+
Enzyme.Active,
40+
Enzyme.Duplicated(θ, dy)
41+
)
3942
return out
4043
end
4144

test/runtests.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ using ReverseDiff: ReverseDiff
66
using Tracker: Tracker
77
using Zygote: Zygote
88
using Enzyme: Enzyme
9-
Enzyme.API.runtimeActivity!(true);
10-
Enzyme.API.typeWarning!(false);
119

1210
using AdvancedVI
1311

@@ -22,7 +20,7 @@ include("optimisers.jl")
2220
AutoReverseDiff(),
2321
AutoTracker(),
2422
AutoZygote(),
25-
# AutoEnzyme() # results in incorrect result
23+
# AutoEnzyme()
2624
]
2725
target = MvNormal(ones(2))
2826
logπ(z) = logpdf(target, z)
@@ -42,5 +40,4 @@ include("optimisers.jl")
4240

4341
xs = rand(target, 10)
4442
@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) 0.05
45-
4643
end

0 commit comments

Comments
 (0)