Skip to content

Commit 49049b1

Browse files
authored
Fix periodic kernel AD (#531)
1 parent 3b7a2df commit 49049b1

File tree

6 files changed

+126
-7
lines changed

6 files changed

+126
-7
lines changed

src/KernelFunctions.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ export tensor, ⊗, compose
4848

4949
using Compat
5050
using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent
51-
using ChainRulesCore: @thunk, InplaceableThunk
51+
using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk
5252
using CompositionsBase
5353
using Distances
5454
using FillArrays

src/chainrules.jl

+101-3
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,113 @@ end
112112
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
113113
d = x - y
114114
sind = sinpi.(d)
115-
abs2_sind_r = abs2.(sind) ./ s.r
115+
abs2_sind_r = abs2.(sind) ./ s.r .^ 2
116116
val = sum(abs2_sind_r)
117-
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2)
117+
gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2
118118
function evaluate_pullback::Any)
119-
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx
119+
= -2Δ .* abs2_sind_r ./ s.r
120+
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
121+
return s̄, Δ * gradx, -Δ * gradx
120122
end
121123
return val, evaluate_pullback
122124
end
123125

126+
function ChainRulesCore.rrule(
127+
::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix; dims=2
128+
)
129+
project_x = ProjectTo(x)
130+
function pairwise_pullback(z̄)
131+
Δ = unthunk(z̄)
132+
n = size(x, dims)
133+
= collect(zero(x))
134+
= zero(d.r)
135+
if dims == 1
136+
for j in 1:n, i in 1:n
137+
xi = view(x, i, :)
138+
xj = view(x, j, :)
139+
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
140+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
141+
x̄[i, :] += ds
142+
x̄[j, :] -= ds
143+
end
144+
elseif dims == 2
145+
for j in 1:n, i in 1:n
146+
xi = view(x, :, i)
147+
xj = view(x, :, j)
148+
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
149+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
150+
x̄[:, i] += ds
151+
x̄[:, j] -= ds
152+
end
153+
end
154+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
155+
return NoTangent(), d̄, @thunk(project_x(x̄))
156+
end
157+
return Distances.pairwise(d, x; dims), pairwise_pullback
158+
end
159+
160+
function ChainRulesCore.rrule(
161+
::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix; dims=2
162+
)
163+
project_x = ProjectTo(x)
164+
project_y = ProjectTo(y)
165+
function pairwise_pullback(z̄)
166+
Δ = unthunk(z̄)
167+
n = size(x, dims)
168+
m = size(y, dims)
169+
= collect(zero(x))
170+
= collect(zero(y))
171+
= zero(d.r)
172+
if dims == 1
173+
for j in 1:m, i in 1:n
174+
xi = view(x, i, :)
175+
yj = view(y, j, :)
176+
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
177+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
178+
x̄[i, :] += ds
179+
ȳ[j, :] -= ds
180+
end
181+
elseif dims == 2
182+
for j in 1:m, i in 1:n
183+
xi = view(x, :, i)
184+
yj = view(y, :, j)
185+
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
186+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
187+
x̄[:, i] += ds
188+
ȳ[:, j] -= ds
189+
end
190+
end
191+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
192+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
193+
end
194+
return Distances.pairwise(d, x, y; dims), pairwise_pullback
195+
end
196+
197+
function ChainRulesCore.rrule(
198+
::typeof(Distances.colwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix
199+
)
200+
project_x = ProjectTo(x)
201+
project_y = ProjectTo(y)
202+
function colwise_pullback(z̄)
203+
Δ = unthunk(z̄)
204+
n = size(x, 2)
205+
= collect(zero(x))
206+
= collect(zero(y))
207+
= zero(d.r)
208+
for i in 1:n
209+
xi = view(x, :, i)
210+
yi = view(y, :, i)
211+
ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2
212+
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
213+
x̄[:, i] += ds
214+
ȳ[:, i] -= ds
215+
end
216+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
217+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
218+
end
219+
return Distances.colwise(d, x, y), colwise_pullback
220+
end
221+
124222
## Reverse Rules for matrix wrappers
125223

126224
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
35
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
46
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
57
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

test/basekernels/periodic.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64})
1616
TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64})
1717

18-
# test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff])
19-
# Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff
20-
@test_broken false
18+
test_ADs(r -> PeriodicKernel(; r=exp.(r)), log.(r))
2119
test_params(k, (r,))
2220
end

test/chainrules.jl

+19
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,23 @@
1919
compare_gradient(:Zygote, [x, y]) do xy
2020
KernelFunctions.Sinus(r)(xy[1], xy[2])
2121
end
22+
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
23+
dist = KernelFunctions.Sinus(r)
24+
@testset "$type" for type in (Vector, SVector{3})
25+
test_rrule(dist, type(rand(3)), type(rand(3)))
26+
end
27+
@testset "$type1, $type2" for type1 in (Matrix, SMatrix{3,2}),
28+
type2 in (Matrix, SMatrix{3,4})
29+
30+
test_rrule(Distances.pairwise, dist, type1(rand(3, 2)); fkwargs=(dims=2,))
31+
test_rrule(
32+
Distances.pairwise,
33+
dist,
34+
type1(rand(3, 2)),
35+
type2(rand(3, 4));
36+
fkwargs=(dims=2,),
37+
)
38+
test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2)))
39+
end
40+
end
2241
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using KernelFunctions
22
using AxisArrays
3+
using ChainRulesCore
4+
using ChainRulesTestUtils
35
using Distances
46
using Documenter
57
using Functors: functor

0 commit comments

Comments
 (0)