Skip to content

Commit 16a9828

Browse files
simsuracedevmotion
andauthored
Simplify rule according to feedback in #531 (#549)
Co-authored-by: David Widmann <[email protected]>
1 parent 49049b1 commit 16a9828

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

src/chainrules.jl

+14-15
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,9 @@ end
111111

112112
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
113113
d = x - y
114-
sind = sinpi.(d)
115-
abs2_sind_r = abs2.(sind) ./ s.r .^ 2
114+
abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2
116115
val = sum(abs2_sind_r)
117-
gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2
116+
gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2
118117
function evaluate_pullback::Any)
119118
= -2Δ .* abs2_sind_r ./ s.r
120119
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
@@ -136,7 +135,7 @@ function ChainRulesCore.rrule(
136135
for j in 1:n, i in 1:n
137136
xi = view(x, i, :)
138137
xj = view(x, j, :)
139-
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
138+
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- xj)) ./ d.r .^ 2
140139
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
141140
x̄[i, :] += ds
142141
x̄[j, :] -= ds
@@ -147,8 +146,8 @@ function ChainRulesCore.rrule(
147146
xj = view(x, :, j)
148147
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
149148
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
150-
x̄[:, i] += ds
151-
x̄[:, j] -= ds
149+
x̄[:, i] .+= ds
150+
x̄[:, j] .-= ds
152151
end
153152
end
154153
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
@@ -173,19 +172,19 @@ function ChainRulesCore.rrule(
173172
for j in 1:m, i in 1:n
174173
xi = view(x, i, :)
175174
yj = view(y, j, :)
176-
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
175+
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
177176
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
178-
x̄[i, :] += ds
179-
ȳ[j, :] -= ds
177+
x̄[i, :] .+= ds
178+
ȳ[j, :] .-= ds
180179
end
181180
elseif dims == 2
182181
for j in 1:m, i in 1:n
183182
xi = view(x, :, i)
184183
yj = view(y, :, j)
185-
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
184+
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
186185
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
187-
x̄[:, i] += ds
188-
ȳ[:, j] -= ds
186+
x̄[:, i] .+= ds
187+
ȳ[:, j] .-= ds
189188
end
190189
end
191190
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
@@ -208,10 +207,10 @@ function ChainRulesCore.rrule(
208207
for i in 1:n
209208
xi = view(x, :, i)
210209
yi = view(y, :, i)
211-
ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2
210+
ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2
212211
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
213-
x̄[:, i] += ds
214-
ȳ[:, i] -= ds
212+
x̄[:, i] .+= ds
213+
ȳ[:, i] .-= ds
215214
end
216215
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
217216
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))

0 commit comments

Comments
 (0)