Skip to content

Commit 9a2f7bb

Browse files
authored
Remove chainrule and test for SqMahalanobis (#539)
1 parent ec19a94 commit 9a2f7bb

File tree

3 files changed

+1
-32
lines changed

3 files changed

+1
-32
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.59"
3+
version = "0.10.60"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/chainrules.jl

-22
Original file line numberDiff line numberDiff line change
@@ -121,28 +121,6 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
121121
return val, evaluate_pullback
122122
end
123123

124-
## Reverse Rules SqMahalanobis
125-
126-
function ChainRulesCore.rrule(
127-
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
128-
)
129-
d = dist(a, b)
130-
function SqMahalanobis_pullback::Real)
131-
a_b = a - b
132-
∂qmat = InplaceableThunk(
133-
-> mul!(X̄, a_b, a_b', true, Δ), @thunk((a_b * a_b') * Δ)
134-
)
135-
∂a = InplaceableThunk(
136-
-> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), @thunk((2 * Δ) * dist.qmat * a_b)
137-
)
138-
∂b = InplaceableThunk(
139-
-> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), @thunk((-2 * Δ) * dist.qmat * a_b)
140-
)
141-
return Tangent{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
142-
end
143-
return d, SqMahalanobis_pullback
144-
end
145-
146124
## Reverse Rules for matrix wrappers
147125

148126
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)

test/chainrules.jl

-9
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
x = rand(rng, 5)
44
y = rand(rng, 5)
55
r = rand(rng, 5)
6-
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
7-
@assert isposdef(Q)
86

97
compare_gradient(:Zygote, [x, y]) do xy
108
Euclidean()(xy[1], xy[2])
@@ -21,11 +19,4 @@
2119
compare_gradient(:Zygote, [x, y]) do xy
2220
KernelFunctions.Sinus(r)(xy[1], xy[2])
2321
end
24-
if VERSION < v"1.6"
25-
@test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6"
26-
else
27-
compare_gradient(:Zygote, [Q, x, y]) do Qxy
28-
SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3])
29-
end
30-
end
3122
end

0 commit comments

Comments
 (0)