@@ -112,15 +112,113 @@ end
112
112
function ChainRulesCore. rrule (s:: Sinus , x:: AbstractVector , y:: AbstractVector )
113
113
d = x - y
114
114
sind = sinpi .(d)
115
- abs2_sind_r = abs2 .(sind) ./ s. r
115
+ abs2_sind_r = abs2 .(sind) ./ s. r .^ 2
116
116
val = sum (abs2_sind_r)
117
- gradx = twoπ .* cospi .(d) .* sind ./ ( s. r .^ 2 )
117
+ gradx = twoπ .* cospi .(d) .* sind ./ s. r .^ 2
118
118
function evaluate_pullback (Δ:: Any )
119
- return (r= - 2 Δ .* abs2_sind_r,), Δ * gradx, - Δ * gradx
119
+ r̄ = - 2 Δ .* abs2_sind_r ./ s. r
120
+ s̄ = ChainRulesCore. Tangent {typeof(s)} (; r= r̄)
121
+ return s̄, Δ * gradx, - Δ * gradx
120
122
end
121
123
return val, evaluate_pullback
122
124
end
123
125
126
+ function ChainRulesCore. rrule (
127
+ :: typeof (Distances. pairwise), d:: Sinus , x:: AbstractMatrix ; dims= 2
128
+ )
129
+ project_x = ProjectTo (x)
130
+ function pairwise_pullback (z̄)
131
+ Δ = unthunk (z̄)
132
+ n = size (x, dims)
133
+ x̄ = collect (zero (x))
134
+ r̄ = zero (d. r)
135
+ if dims == 1
136
+ for j in 1 : n, i in 1 : n
137
+ xi = view (x, i, :)
138
+ xj = view (x, j, :)
139
+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi .(xi .- xj) ./ d. r .^ 2
140
+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
141
+ x̄[i, :] += ds
142
+ x̄[j, :] -= ds
143
+ end
144
+ elseif dims == 2
145
+ for j in 1 : n, i in 1 : n
146
+ xi = view (x, :, i)
147
+ xj = view (x, :, j)
148
+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi .(xi .- xj) ./ d. r .^ 2
149
+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
150
+ x̄[:, i] += ds
151
+ x̄[:, j] -= ds
152
+ end
153
+ end
154
+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
155
+ return NoTangent (), d̄, @thunk (project_x (x̄))
156
+ end
157
+ return Distances. pairwise (d, x; dims), pairwise_pullback
158
+ end
159
+
160
+ function ChainRulesCore. rrule (
161
+ :: typeof (Distances. pairwise), d:: Sinus , x:: AbstractMatrix , y:: AbstractMatrix ; dims= 2
162
+ )
163
+ project_x = ProjectTo (x)
164
+ project_y = ProjectTo (y)
165
+ function pairwise_pullback (z̄)
166
+ Δ = unthunk (z̄)
167
+ n = size (x, dims)
168
+ m = size (y, dims)
169
+ x̄ = collect (zero (x))
170
+ ȳ = collect (zero (y))
171
+ r̄ = zero (d. r)
172
+ if dims == 1
173
+ for j in 1 : m, i in 1 : n
174
+ xi = view (x, i, :)
175
+ yj = view (y, j, :)
176
+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- yj) .* cospi .(xi .- yj) ./ d. r .^ 2
177
+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
178
+ x̄[i, :] += ds
179
+ ȳ[j, :] -= ds
180
+ end
181
+ elseif dims == 2
182
+ for j in 1 : m, i in 1 : n
183
+ xi = view (x, :, i)
184
+ yj = view (y, :, j)
185
+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- yj) .* cospi .(xi .- yj) ./ d. r .^ 2
186
+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
187
+ x̄[:, i] += ds
188
+ ȳ[:, j] -= ds
189
+ end
190
+ end
191
+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
192
+ return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
193
+ end
194
+ return Distances. pairwise (d, x, y; dims), pairwise_pullback
195
+ end
196
+
197
+ function ChainRulesCore. rrule (
198
+ :: typeof (Distances. colwise), d:: Sinus , x:: AbstractMatrix , y:: AbstractMatrix
199
+ )
200
+ project_x = ProjectTo (x)
201
+ project_y = ProjectTo (y)
202
+ function colwise_pullback (z̄)
203
+ Δ = unthunk (z̄)
204
+ n = size (x, 2 )
205
+ x̄ = collect (zero (x))
206
+ ȳ = collect (zero (y))
207
+ r̄ = zero (d. r)
208
+ for i in 1 : n
209
+ xi = view (x, :, i)
210
+ yi = view (y, :, i)
211
+ ds = twoπ .* Δ[i] .* sinpi .(xi .- yi) .* cospi .(xi .- yi) ./ d. r .^ 2
212
+ r̄ .- = 2 .* Δ[i] .* sinpi .(xi .- yi) .^ 2 ./ d. r .^ 3
213
+ x̄[:, i] += ds
214
+ ȳ[:, i] -= ds
215
+ end
216
+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
217
+ return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
218
+ end
219
+ return Distances. colwise (d, x, y), colwise_pullback
220
+ end
221
+
124
222
# # Reverse Rules for matrix wrappers
125
223
126
224
function ChainRulesCore. rrule (:: Type{<:ColVecs} , X:: AbstractMatrix )
0 commit comments