Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote AD failure workarounds & test cleanup #414

Merged
merged 10 commits into from
Dec 18, 2021
160 changes: 50 additions & 110 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args)
return first(FiniteDifferences.grad(FDM, f, args))
end

function compare_gradient(f, ::Val{:FiniteDiff}, args)
@test_nowarn gradient(f, :FiniteDiff, args)
end

function compare_gradient(f, AD::Symbol, args)
grad_AD = gradient(f, AD, args)
grad_FD = gradient(f, :FiniteDiff, args)
Expand All @@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))
function test_ADs(
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3]
)
test_fd = test_FiniteDiff(kernelfunction, args, dims)
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims)
if !test_fd.anynonpass
for AD in ADs
test_AD(AD, kernelfunction, args, dims)
Expand All @@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
@inferred f(args...)
@inferred Zygote._pullback(ctx, f, args...)
out, pb = Zygote._pullback(ctx, f, args...)
@test_throws ErrorException @inferred pb(out)
@inferred pb(out)
end

function test_ADs(
Expand All @@ -114,70 +118,6 @@ function test_ADs(
end
end

function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was doing exactly the same as test_AD (except for testdiagfunction(kernelfunction(p), A, dim) vs testdiagfunction(kernelfunction(p), A, B, dim), which I've now added to test_AD). The only difference was calling compare_gradient vs the @test_warn, so I simply added a compare_gradient(f, ::Val{:FiniteDiff}, args) so we can avoid all this code duplication.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice.

# Init arguments :
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)
@testset "FiniteDifferences" begin
if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
@test_nowarn gradient(:FiniteDiff, [d]) do x
kappa(k, exp(first(x)))
end
end
end
## Testing Kernel Functions
x = rand(rng, dims[1])
y = rand(rng, dims[1])
@test_nowarn gradient(:FiniteDiff, x) do x
k(x, y)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
kernelfunction(p)(x, y)
end
end
## Testing Kernel Matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
@test_nowarn gradient(:FiniteDiff, A) do a
testfunction(k, a, dim)
end
@test_nowarn gradient(:FiniteDiff, A) do a
testfunction(k, a, B, dim)
end
@test_nowarn gradient(:FiniteDiff, B) do b
testfunction(k, A, b, dim)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
testfunction(kernelfunction(p), A, B, dim)
end
end

@test_nowarn gradient(:FiniteDiff, A) do a
testdiagfunction(k, a, dim)
end
@test_nowarn gradient(:FiniteDiff, A) do a
testdiagfunction(k, a, B, dim)
end
@test_nowarn gradient(:FiniteDiff, B) do b
testdiagfunction(k, A, b, dim)
end
if args !== nothing
@test_nowarn gradient(:FiniteDiff, args) do p
testdiagfunction(kernelfunction(p), A, B, dim)
end
end
end
end
end

function test_FiniteDiff(k::MOKernel, dims=(in=3, out=2, obs=3))
rng = MersenneTwister(42)
@testset "FiniteDifferences" begin
Expand Down Expand Up @@ -224,68 +164,68 @@ end

function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3])
@testset "$(AD)" begin
# Test kappa function
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)

if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
@testset "kappa function" begin
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
end
end
end
end
# Testing kernel evaluations
x = rand(rng, dims[1])
y = rand(rng, dims[1])
compare_gradient(AD, x) do x
k(x, y)
end
compare_gradient(AD, y) do y
k(x, y)
end
if !(args === nothing)
compare_gradient(AD, args) do p
kernelfunction(p)(x, y)
end
end
# Testing kernel matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
compare_gradient(AD, A) do a
testfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testfunction(k, a, B, dim)

@testset "kernel evaluations" begin
x = rand(rng, dims[1])
y = rand(rng, dims[1])
compare_gradient(AD, x) do x
k(x, y)
end
compare_gradient(AD, B) do b
testfunction(k, A, b, dim)
compare_gradient(AD, y) do y
k(x, y)
end
if !(args === nothing)
compare_gradient(AD, args) do p
testfunction(kernelfunction(p), A, dim)
@testset "hyperparameters" begin
compare_gradient(AD, args) do p
kernelfunction(p)(x, y)
end
end
end
end

compare_gradient(AD, A) do a
testdiagfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testdiagfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testdiagfunction(k, A, b, dim)
end
if args !== nothing
compare_gradient(AD, args) do p
testdiagfunction(kernelfunction(p), A, dim)
Comment on lines -274 to -285
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was doing exactly the same for testdiagfunction as the code above for testfunction, so I've unified it with a for loop over the two functions.

@testset "kernel matrices" begin
A = rand(rng, dims...)
B = rand(rng, dims...)
@testset "$(_testfn)" for _testfn in (testfunction, testdiagfunction)
for dim in 1:2
compare_gradient(AD, A) do a
_testfn(k, a, dim)
end
compare_gradient(AD, A) do a
_testfn(k, a, B, dim)
end
compare_gradient(AD, B) do b
_testfn(k, A, b, dim)
end
if !(args === nothing)
@testset "hyperparameters" begin
compare_gradient(AD, args) do p
_testfn(kernelfunction(p), A, dim)
end
compare_gradient(AD, args) do p
_testfn(kernelfunction(p), A, B, dim)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added missing case

end
end
end
end
end
end
end # kernel matrices
end
end

Expand Down
4 changes: 3 additions & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
@test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
randn(rng, 4);
ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote
)
@test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263"
end