111
111
112
112
function ChainRulesCore. rrule (s:: Sinus , x:: AbstractVector , y:: AbstractVector )
113
113
d = x - y
114
- sind = sinpi .(d)
115
- abs2_sind_r = abs2 .(sind) ./ s. r .^ 2
114
+ abs2_sind_r = (sinpi .(d) ./ s. r) .^ 2
116
115
val = sum (abs2_sind_r)
117
- gradx = twoπ .* cospi .(d) .* sind ./ s. r .^ 2
116
+ gradx = π .* sinpi .( 2 .* d) ./ s. r .^ 2
118
117
function evaluate_pullback (Δ:: Any )
119
118
r̄ = - 2 Δ .* abs2_sind_r ./ s. r
120
119
s̄ = ChainRulesCore. Tangent {typeof(s)} (; r= r̄)
@@ -136,7 +135,7 @@ function ChainRulesCore.rrule(
136
135
for j in 1 : n, i in 1 : n
137
136
xi = view (x, i, :)
138
137
xj = view (x, j, :)
139
- ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi . (xi .- xj) ./ d. r .^ 2
138
+ ds = π .* Δ[i, j] .* sinpi .(2 .* (xi .- xj) ) ./ d. r .^ 2
140
139
r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
141
140
x̄[i, :] += ds
142
141
x̄[j, :] -= ds
@@ -147,8 +146,8 @@ function ChainRulesCore.rrule(
147
146
xj = view (x, :, j)
148
147
ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi .(xi .- xj) ./ d. r .^ 2
149
148
r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
150
- x̄[:, i] += ds
151
- x̄[:, j] -= ds
149
+ x̄[:, i] . += ds
150
+ x̄[:, j] . -= ds
152
151
end
153
152
end
154
153
d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
@@ -173,19 +172,19 @@ function ChainRulesCore.rrule(
173
172
for j in 1 : m, i in 1 : n
174
173
xi = view (x, i, :)
175
174
yj = view (y, j, :)
176
- ds = twoπ .* Δ[i, j] .* sinpi .(xi .- yj) .* cospi . (xi .- yj) ./ d. r .^ 2
175
+ ds = π .* Δ[i, j] .* sinpi .(2 .* (xi .- yj) ) ./ d. r .^ 2
177
176
r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
178
- x̄[i, :] += ds
179
- ȳ[j, :] -= ds
177
+ x̄[i, :] . += ds
178
+ ȳ[j, :] . -= ds
180
179
end
181
180
elseif dims == 2
182
181
for j in 1 : m, i in 1 : n
183
182
xi = view (x, :, i)
184
183
yj = view (y, :, j)
185
- ds = twoπ .* Δ[i, j] .* sinpi .(xi .- yj) .* cospi . (xi .- yj) ./ d. r .^ 2
184
+ ds = π .* Δ[i, j] .* sinpi .(2 .* (xi .- yj) ) ./ d. r .^ 2
186
185
r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
187
- x̄[:, i] += ds
188
- ȳ[:, j] -= ds
186
+ x̄[:, i] . += ds
187
+ ȳ[:, j] . -= ds
189
188
end
190
189
end
191
190
d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
@@ -208,10 +207,10 @@ function ChainRulesCore.rrule(
208
207
for i in 1 : n
209
208
xi = view (x, :, i)
210
209
yi = view (y, :, i)
211
- ds = twoπ .* Δ[i] .* sinpi .(xi .- yi) .* cospi . (xi .- yi) ./ d. r .^ 2
210
+ ds = π .* Δ[i] .* sinpi .(2 .* (xi .- yi) ) ./ d. r .^ 2
212
211
r̄ .- = 2 .* Δ[i] .* sinpi .(xi .- yi) .^ 2 ./ d. r .^ 3
213
- x̄[:, i] += ds
214
- ȳ[:, i] -= ds
212
+ x̄[:, i] . += ds
213
+ ȳ[:, i] . -= ds
215
214
end
216
215
d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
217
216
return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
0 commit comments