-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote AD failure workarounds & test cleanup #414
Changes from all commits
7ccfbef
313535e
11b2bcd
feb5b5c
919e445
7d38977
a4ad008
51c9250
821503c
8b1cddc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args) | |
return first(FiniteDifferences.grad(FDM, f, args)) | ||
end | ||
|
||
function compare_gradient(f, ::Val{:FiniteDiff}, args) | ||
@test_nowarn gradient(f, :FiniteDiff, args) | ||
end | ||
|
||
function compare_gradient(f, AD::Symbol, args) | ||
grad_AD = gradient(f, AD, args) | ||
grad_FD = gradient(f, :FiniteDiff, args) | ||
|
@@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) | |
function test_ADs( | ||
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] | ||
) | ||
test_fd = test_FiniteDiff(kernelfunction, args, dims) | ||
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) | ||
if !test_fd.anynonpass | ||
for AD in ADs | ||
test_AD(AD, kernelfunction, args, dims) | ||
|
@@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) | |
@inferred f(args...) | ||
@inferred Zygote._pullback(ctx, f, args...) | ||
out, pb = Zygote._pullback(ctx, f, args...) | ||
@test_throws ErrorException @inferred pb(out) | ||
@inferred pb(out) | ||
end | ||
|
||
function test_ADs( | ||
|
@@ -114,70 +118,6 @@ function test_ADs( | |
end | ||
end | ||
|
||
function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3]) | ||
# Init arguments : | ||
k = if args === nothing | ||
kernelfunction() | ||
else | ||
kernelfunction(args) | ||
end | ||
rng = MersenneTwister(42) | ||
@testset "FiniteDifferences" begin | ||
if k isa SimpleKernel | ||
for d in log.([eps(), rand(rng)]) | ||
@test_nowarn gradient(:FiniteDiff, [d]) do x | ||
kappa(k, exp(first(x))) | ||
end | ||
end | ||
end | ||
## Testing Kernel Functions | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
@test_nowarn gradient(:FiniteDiff, x) do x | ||
k(x, y) | ||
end | ||
if !(args === nothing) | ||
@test_nowarn gradient(:FiniteDiff, args) do p | ||
kernelfunction(p)(x, y) | ||
end | ||
end | ||
## Testing Kernel Matrices | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
for dim in 1:2 | ||
@test_nowarn gradient(:FiniteDiff, A) do a | ||
testfunction(k, a, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff, A) do a | ||
testfunction(k, a, B, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff, B) do b | ||
testfunction(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
@test_nowarn gradient(:FiniteDiff, args) do p | ||
testfunction(kernelfunction(p), A, B, dim) | ||
end | ||
end | ||
|
||
@test_nowarn gradient(:FiniteDiff, A) do a | ||
testdiagfunction(k, a, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff, A) do a | ||
testdiagfunction(k, a, B, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff, B) do b | ||
testdiagfunction(k, A, b, dim) | ||
end | ||
if args !== nothing | ||
@test_nowarn gradient(:FiniteDiff, args) do p | ||
testdiagfunction(kernelfunction(p), A, B, dim) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
function test_FiniteDiff(k::MOKernel, dims=(in=3, out=2, obs=3)) | ||
rng = MersenneTwister(42) | ||
@testset "FiniteDifferences" begin | ||
|
@@ -224,68 +164,68 @@ end | |
|
||
function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3]) | ||
@testset "$(AD)" begin | ||
# Test kappa function | ||
k = if args === nothing | ||
kernelfunction() | ||
else | ||
kernelfunction(args) | ||
end | ||
rng = MersenneTwister(42) | ||
|
||
if k isa SimpleKernel | ||
for d in log.([eps(), rand(rng)]) | ||
compare_gradient(AD, [d]) do x | ||
kappa(k, exp(x[1])) | ||
@testset "kappa function" begin | ||
for d in log.([eps(), rand(rng)]) | ||
compare_gradient(AD, [d]) do x | ||
kappa(k, exp(x[1])) | ||
end | ||
end | ||
end | ||
end | ||
# Testing kernel evaluations | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
compare_gradient(AD, x) do x | ||
k(x, y) | ||
end | ||
compare_gradient(AD, y) do y | ||
k(x, y) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
kernelfunction(p)(x, y) | ||
end | ||
end | ||
# Testing kernel matrices | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
for dim in 1:2 | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, B, dim) | ||
|
||
@testset "kernel evaluations" begin | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
compare_gradient(AD, x) do x | ||
k(x, y) | ||
end | ||
compare_gradient(AD, B) do b | ||
testfunction(k, A, b, dim) | ||
compare_gradient(AD, y) do y | ||
k(x, y) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
testfunction(kernelfunction(p), A, dim) | ||
@testset "hyperparameters" begin | ||
compare_gradient(AD, args) do p | ||
kernelfunction(p)(x, y) | ||
end | ||
end | ||
end | ||
end | ||
|
||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
testdiagfunction(k, A, b, dim) | ||
end | ||
if args !== nothing | ||
compare_gradient(AD, args) do p | ||
testdiagfunction(kernelfunction(p), A, dim) | ||
Comment on lines
-274
to
-285
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was doing exactly the same for |
||
@testset "kernel matrices" begin | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
@testset "$(_testfn)" for _testfn in (testfunction, testdiagfunction) | ||
for dim in 1:2 | ||
compare_gradient(AD, A) do a | ||
_testfn(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
_testfn(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
_testfn(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
@testset "hyperparameters" begin | ||
compare_gradient(AD, args) do p | ||
_testfn(kernelfunction(p), A, dim) | ||
end | ||
compare_gradient(AD, args) do p | ||
_testfn(kernelfunction(p), A, B, dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added missing case |
||
end | ||
end | ||
end | ||
end | ||
end | ||
end | ||
end # kernel matrices | ||
end | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was doing exactly the same as
test_AD
(except fortestdiagfunction(kernelfunction(p), A, dim)
vstestdiagfunction(kernelfunction(p), A, B, dim)
, which I've now added totest_AD
). The only difference was callingcompare_gradient
vs the@test_warn
, so I simply added acompare_gradient(f, ::Val{:FiniteDiff}, args)
so we can avoid all this code duplication.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice.