Skip to content

Commit 92b4576

Browse files
committed
Merge remote-tracking branch 'origin/master' into fix-periodic
2 parents c15d095 + 9a2f7bb commit 92b4576

File tree

3 files changed

+1
-32
lines changed

3 files changed

+1
-32
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.59"
3+
version = "0.10.60"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/chainrules.jl

-21
Original file line numberDiff line numberDiff line change
@@ -219,27 +219,6 @@ function ChainRulesCore.rrule(
219219
return Distances.colwise(d, x, y), colwise_pullback
220220
end
221221

222-
## Reverse Rules SqMahalanobis
223-
224-
function ChainRulesCore.rrule(
225-
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
226-
)
227-
d = dist(a, b)
228-
function SqMahalanobis_pullback::Real)
229-
a_b = a - b
230-
∂qmat = InplaceableThunk(
231-
-> mul!(X̄, a_b, a_b', true, Δ), @thunk((a_b * a_b') * Δ)
232-
)
233-
∂a = InplaceableThunk(
234-
-> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), @thunk((2 * Δ) * dist.qmat * a_b)
235-
)
236-
∂b = InplaceableThunk(
237-
-> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), @thunk((-2 * Δ) * dist.qmat * a_b)
238-
)
239-
return Tangent{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
240-
end
241-
return d, SqMahalanobis_pullback
242-
end
243222

244223
## Reverse Rules for matrix wrappers
245224

test/chainrules.jl

-10
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
x = rand(rng, 5)
44
y = rand(rng, 5)
55
r = rand(rng, 5)
6-
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
7-
@assert isposdef(Q)
86

97
compare_gradient(:Zygote, [x, y]) do xy
108
Euclidean()(xy[1], xy[2])
@@ -21,14 +19,6 @@
2119
compare_gradient(:Zygote, [x, y]) do xy
2220
KernelFunctions.Sinus(r)(xy[1], xy[2])
2321
end
24-
if VERSION < v"1.6"
25-
@test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6"
26-
else
27-
compare_gradient(:Zygote, [Q, x, y]) do Qxy
28-
SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3])
29-
end
30-
end
31-
3222
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
3323
dist = KernelFunctions.Sinus(r)
3424
@testset "$type" for type in (Vector, SVector{3})

0 commit comments

Comments
 (0)