@@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args)
69
69
return first (FiniteDifferences. grad (FDM, f, args))
70
70
end
71
71
72
+ function compare_gradient (f, :: Val{:FiniteDiff} , args)
73
+ @test_nowarn gradient (f, :FiniteDiff , args)
74
+ end
75
+
72
76
function compare_gradient (f, AD:: Symbol , args)
73
77
grad_AD = gradient (f, AD, args)
74
78
grad_FD = gradient (f, :FiniteDiff , args)
@@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))
88
92
function test_ADs (
89
93
kernelfunction, args= nothing ; ADs= [:Zygote , :ForwardDiff , :ReverseDiff ], dims= [3 , 3 ]
90
94
)
91
- test_fd = test_FiniteDiff ( kernelfunction, args, dims)
95
+ test_fd = test_AD ( :FiniteDiff , kernelfunction, args, dims)
92
96
if ! test_fd. anynonpass
93
97
for AD in ADs
94
98
test_AD (AD, kernelfunction, args, dims)
@@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
100
104
@inferred f (args... )
101
105
@inferred Zygote. _pullback (ctx, f, args... )
102
106
out, pb = Zygote. _pullback (ctx, f, args... )
103
- @test_throws ErrorException @ inferred pb (out)
107
+ @inferred pb (out)
104
108
end
105
109
106
110
function test_ADs (
@@ -114,70 +118,6 @@ function test_ADs(
114
118
end
115
119
end
116
120
117
- function test_FiniteDiff (kernelfunction, args= nothing , dims= [3 , 3 ])
118
- # Init arguments :
119
- k = if args === nothing
120
- kernelfunction ()
121
- else
122
- kernelfunction (args)
123
- end
124
- rng = MersenneTwister (42 )
125
- @testset " FiniteDifferences" begin
126
- if k isa SimpleKernel
127
- for d in log .([eps (), rand (rng)])
128
- @test_nowarn gradient (:FiniteDiff , [d]) do x
129
- kappa (k, exp (first (x)))
130
- end
131
- end
132
- end
133
- # # Testing Kernel Functions
134
- x = rand (rng, dims[1 ])
135
- y = rand (rng, dims[1 ])
136
- @test_nowarn gradient (:FiniteDiff , x) do x
137
- k (x, y)
138
- end
139
- if ! (args === nothing )
140
- @test_nowarn gradient (:FiniteDiff , args) do p
141
- kernelfunction (p)(x, y)
142
- end
143
- end
144
- # # Testing Kernel Matrices
145
- A = rand (rng, dims... )
146
- B = rand (rng, dims... )
147
- for dim in 1 : 2
148
- @test_nowarn gradient (:FiniteDiff , A) do a
149
- testfunction (k, a, dim)
150
- end
151
- @test_nowarn gradient (:FiniteDiff , A) do a
152
- testfunction (k, a, B, dim)
153
- end
154
- @test_nowarn gradient (:FiniteDiff , B) do b
155
- testfunction (k, A, b, dim)
156
- end
157
- if ! (args === nothing )
158
- @test_nowarn gradient (:FiniteDiff , args) do p
159
- testfunction (kernelfunction (p), A, B, dim)
160
- end
161
- end
162
-
163
- @test_nowarn gradient (:FiniteDiff , A) do a
164
- testdiagfunction (k, a, dim)
165
- end
166
- @test_nowarn gradient (:FiniteDiff , A) do a
167
- testdiagfunction (k, a, B, dim)
168
- end
169
- @test_nowarn gradient (:FiniteDiff , B) do b
170
- testdiagfunction (k, A, b, dim)
171
- end
172
- if args != = nothing
173
- @test_nowarn gradient (:FiniteDiff , args) do p
174
- testdiagfunction (kernelfunction (p), A, B, dim)
175
- end
176
- end
177
- end
178
- end
179
- end
180
-
181
121
function test_FiniteDiff (k:: MOKernel , dims= (in= 3 , out= 2 , obs= 3 ))
182
122
rng = MersenneTwister (42 )
183
123
@testset " FiniteDifferences" begin
@@ -224,68 +164,68 @@ end
224
164
225
165
function test_AD (AD:: Symbol , kernelfunction, args= nothing , dims= [3 , 3 ])
226
166
@testset " $(AD) " begin
227
- # Test kappa function
228
167
k = if args === nothing
229
168
kernelfunction ()
230
169
else
231
170
kernelfunction (args)
232
171
end
233
172
rng = MersenneTwister (42 )
173
+
234
174
if k isa SimpleKernel
235
- for d in log .([eps (), rand (rng)])
236
- compare_gradient (AD, [d]) do x
237
- kappa (k, exp (x[1 ]))
175
+ @testset " kappa function" begin
176
+ for d in log .([eps (), rand (rng)])
177
+ compare_gradient (AD, [d]) do x
178
+ kappa (k, exp (x[1 ]))
179
+ end
238
180
end
239
181
end
240
182
end
241
- # Testing kernel evaluations
242
- x = rand (rng, dims[1 ])
243
- y = rand (rng, dims[1 ])
244
- compare_gradient (AD, x) do x
245
- k (x, y)
246
- end
247
- compare_gradient (AD, y) do y
248
- k (x, y)
249
- end
250
- if ! (args === nothing )
251
- compare_gradient (AD, args) do p
252
- kernelfunction (p)(x, y)
253
- end
254
- end
255
- # Testing kernel matrices
256
- A = rand (rng, dims... )
257
- B = rand (rng, dims... )
258
- for dim in 1 : 2
259
- compare_gradient (AD, A) do a
260
- testfunction (k, a, dim)
261
- end
262
- compare_gradient (AD, A) do a
263
- testfunction (k, a, B, dim)
183
+
184
+ @testset " kernel evaluations" begin
185
+ x = rand (rng, dims[1 ])
186
+ y = rand (rng, dims[1 ])
187
+ compare_gradient (AD, x) do x
188
+ k (x, y)
264
189
end
265
- compare_gradient (AD, B ) do b
266
- testfunction (k, A, b, dim )
190
+ compare_gradient (AD, y ) do y
191
+ k (x, y )
267
192
end
268
193
if ! (args === nothing )
269
- compare_gradient (AD, args) do p
270
- testfunction (kernelfunction (p), A, dim)
194
+ @testset " hyperparameters" begin
195
+ compare_gradient (AD, args) do p
196
+ kernelfunction (p)(x, y)
197
+ end
271
198
end
272
199
end
200
+ end
273
201
274
- compare_gradient (AD, A) do a
275
- testdiagfunction (k, a, dim)
276
- end
277
- compare_gradient (AD, A) do a
278
- testdiagfunction (k, a, B, dim)
279
- end
280
- compare_gradient (AD, B) do b
281
- testdiagfunction (k, A, b, dim)
282
- end
283
- if args != = nothing
284
- compare_gradient (AD, args) do p
285
- testdiagfunction (kernelfunction (p), A, dim)
202
+ @testset " kernel matrices" begin
203
+ A = rand (rng, dims... )
204
+ B = rand (rng, dims... )
205
+ @testset " $(_testfn) " for _testfn in (testfunction, testdiagfunction)
206
+ for dim in 1 : 2
207
+ compare_gradient (AD, A) do a
208
+ _testfn (k, a, dim)
209
+ end
210
+ compare_gradient (AD, A) do a
211
+ _testfn (k, a, B, dim)
212
+ end
213
+ compare_gradient (AD, B) do b
214
+ _testfn (k, A, b, dim)
215
+ end
216
+ if ! (args === nothing )
217
+ @testset " hyperparameters" begin
218
+ compare_gradient (AD, args) do p
219
+ _testfn (kernelfunction (p), A, dim)
220
+ end
221
+ compare_gradient (AD, args) do p
222
+ _testfn (kernelfunction (p), A, B, dim)
223
+ end
224
+ end
225
+ end
286
226
end
287
227
end
288
- end
228
+ end # kernel matrices
289
229
end
290
230
end
291
231
0 commit comments