-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathchainrules.jl
39 lines (37 loc) · 1.26 KB
/
chainrules.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@testset "Chain Rules" begin
rng = MersenneTwister(123456)
x = rand(rng, 5)
y = rand(rng, 5)
r = rand(rng, 5)
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
@assert isposdef(Q)
compare_gradient(:Zygote, [x, y]) do xy
Euclidean()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
SqEuclidean()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.DotProduct()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Delta()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Sinus(r)(xy[1], xy[2])
end
if VERSION < v"1.6"
@test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6"
else
compare_gradient(:Zygote, [Q, x, y]) do Qxy
SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3])
end
end
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
dist = KernelFunctions.Sinus(r)
test_rrule(dist, rand(3), rand(3))
test_rrule(Distances.pairwise, dist, rand(3, 2); fkwargs=(dims=2,))
test_rrule(Distances.pairwise, dist, rand(3, 2), rand(3, 3); fkwargs=(dims=2,))
test_rrule(Distances.colwise, dist, rand(3, 2), rand(3, 2))
end
end