Skip to content

Commit 2d17212

Browse files
authored
Zygote AD failure workarounds & test cleanup (#414)
Zygote AD failures: * revert #409 (test_utils workaround for broken Zygote - now working again) * disable broken Zygote AD test for ChainTransform Improved tests: * finer-grained testsets * add missing test cases to test_AD * replace test_FiniteDiff with test_AD(..., :FiniteDiff, ...) * remove code duplication
1 parent 3c49949 commit 2d17212

File tree

2 files changed

+53
-111
lines changed

2 files changed

+53
-111
lines changed

test/test_utils.jl

+50-110
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args)
6969
return first(FiniteDifferences.grad(FDM, f, args))
7070
end
7171

72+
function compare_gradient(f, ::Val{:FiniteDiff}, args)
73+
@test_nowarn gradient(f, :FiniteDiff, args)
74+
end
75+
7276
function compare_gradient(f, AD::Symbol, args)
7377
grad_AD = gradient(f, AD, args)
7478
grad_FD = gradient(f, :FiniteDiff, args)
@@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))
8892
function test_ADs(
8993
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3]
9094
)
91-
test_fd = test_FiniteDiff(kernelfunction, args, dims)
95+
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims)
9296
if !test_fd.anynonpass
9397
for AD in ADs
9498
test_AD(AD, kernelfunction, args, dims)
@@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
100104
@inferred f(args...)
101105
@inferred Zygote._pullback(ctx, f, args...)
102106
out, pb = Zygote._pullback(ctx, f, args...)
103-
@test_throws ErrorException @inferred pb(out)
107+
@inferred pb(out)
104108
end
105109

106110
function test_ADs(
@@ -114,70 +118,6 @@ function test_ADs(
114118
end
115119
end
116120

117-
function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3])
118-
# Init arguments :
119-
k = if args === nothing
120-
kernelfunction()
121-
else
122-
kernelfunction(args)
123-
end
124-
rng = MersenneTwister(42)
125-
@testset "FiniteDifferences" begin
126-
if k isa SimpleKernel
127-
for d in log.([eps(), rand(rng)])
128-
@test_nowarn gradient(:FiniteDiff, [d]) do x
129-
kappa(k, exp(first(x)))
130-
end
131-
end
132-
end
133-
## Testing Kernel Functions
134-
x = rand(rng, dims[1])
135-
y = rand(rng, dims[1])
136-
@test_nowarn gradient(:FiniteDiff, x) do x
137-
k(x, y)
138-
end
139-
if !(args === nothing)
140-
@test_nowarn gradient(:FiniteDiff, args) do p
141-
kernelfunction(p)(x, y)
142-
end
143-
end
144-
## Testing Kernel Matrices
145-
A = rand(rng, dims...)
146-
B = rand(rng, dims...)
147-
for dim in 1:2
148-
@test_nowarn gradient(:FiniteDiff, A) do a
149-
testfunction(k, a, dim)
150-
end
151-
@test_nowarn gradient(:FiniteDiff, A) do a
152-
testfunction(k, a, B, dim)
153-
end
154-
@test_nowarn gradient(:FiniteDiff, B) do b
155-
testfunction(k, A, b, dim)
156-
end
157-
if !(args === nothing)
158-
@test_nowarn gradient(:FiniteDiff, args) do p
159-
testfunction(kernelfunction(p), A, B, dim)
160-
end
161-
end
162-
163-
@test_nowarn gradient(:FiniteDiff, A) do a
164-
testdiagfunction(k, a, dim)
165-
end
166-
@test_nowarn gradient(:FiniteDiff, A) do a
167-
testdiagfunction(k, a, B, dim)
168-
end
169-
@test_nowarn gradient(:FiniteDiff, B) do b
170-
testdiagfunction(k, A, b, dim)
171-
end
172-
if args !== nothing
173-
@test_nowarn gradient(:FiniteDiff, args) do p
174-
testdiagfunction(kernelfunction(p), A, B, dim)
175-
end
176-
end
177-
end
178-
end
179-
end
180-
181121
function test_FiniteDiff(k::MOKernel, dims=(in=3, out=2, obs=3))
182122
rng = MersenneTwister(42)
183123
@testset "FiniteDifferences" begin
@@ -224,68 +164,68 @@ end
224164

225165
function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3])
226166
@testset "$(AD)" begin
227-
# Test kappa function
228167
k = if args === nothing
229168
kernelfunction()
230169
else
231170
kernelfunction(args)
232171
end
233172
rng = MersenneTwister(42)
173+
234174
if k isa SimpleKernel
235-
for d in log.([eps(), rand(rng)])
236-
compare_gradient(AD, [d]) do x
237-
kappa(k, exp(x[1]))
175+
@testset "kappa function" begin
176+
for d in log.([eps(), rand(rng)])
177+
compare_gradient(AD, [d]) do x
178+
kappa(k, exp(x[1]))
179+
end
238180
end
239181
end
240182
end
241-
# Testing kernel evaluations
242-
x = rand(rng, dims[1])
243-
y = rand(rng, dims[1])
244-
compare_gradient(AD, x) do x
245-
k(x, y)
246-
end
247-
compare_gradient(AD, y) do y
248-
k(x, y)
249-
end
250-
if !(args === nothing)
251-
compare_gradient(AD, args) do p
252-
kernelfunction(p)(x, y)
253-
end
254-
end
255-
# Testing kernel matrices
256-
A = rand(rng, dims...)
257-
B = rand(rng, dims...)
258-
for dim in 1:2
259-
compare_gradient(AD, A) do a
260-
testfunction(k, a, dim)
261-
end
262-
compare_gradient(AD, A) do a
263-
testfunction(k, a, B, dim)
183+
184+
@testset "kernel evaluations" begin
185+
x = rand(rng, dims[1])
186+
y = rand(rng, dims[1])
187+
compare_gradient(AD, x) do x
188+
k(x, y)
264189
end
265-
compare_gradient(AD, B) do b
266-
testfunction(k, A, b, dim)
190+
compare_gradient(AD, y) do y
191+
k(x, y)
267192
end
268193
if !(args === nothing)
269-
compare_gradient(AD, args) do p
270-
testfunction(kernelfunction(p), A, dim)
194+
@testset "hyperparameters" begin
195+
compare_gradient(AD, args) do p
196+
kernelfunction(p)(x, y)
197+
end
271198
end
272199
end
200+
end
273201

274-
compare_gradient(AD, A) do a
275-
testdiagfunction(k, a, dim)
276-
end
277-
compare_gradient(AD, A) do a
278-
testdiagfunction(k, a, B, dim)
279-
end
280-
compare_gradient(AD, B) do b
281-
testdiagfunction(k, A, b, dim)
282-
end
283-
if args !== nothing
284-
compare_gradient(AD, args) do p
285-
testdiagfunction(kernelfunction(p), A, dim)
202+
@testset "kernel matrices" begin
203+
A = rand(rng, dims...)
204+
B = rand(rng, dims...)
205+
@testset "$(_testfn)" for _testfn in (testfunction, testdiagfunction)
206+
for dim in 1:2
207+
compare_gradient(AD, A) do a
208+
_testfn(k, a, dim)
209+
end
210+
compare_gradient(AD, A) do a
211+
_testfn(k, a, B, dim)
212+
end
213+
compare_gradient(AD, B) do b
214+
_testfn(k, A, b, dim)
215+
end
216+
if !(args === nothing)
217+
@testset "hyperparameters" begin
218+
compare_gradient(AD, args) do p
219+
_testfn(kernelfunction(p), A, dim)
220+
end
221+
compare_gradient(AD, args) do p
222+
_testfn(kernelfunction(p), A, B, dim)
223+
end
224+
end
225+
end
286226
end
287227
end
288-
end
228+
end # kernel matrices
289229
end
290230
end
291231

test/transform/chaintransform.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
@test repr(tp tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
2525
test_ADs(
2626
x -> SEKernel() (ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
27-
randn(rng, 4),
27+
randn(rng, 4);
28+
ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote
2829
)
30+
@test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263"
2931
end

0 commit comments

Comments
 (0)