Skip to content

Commit 805a1cd

Browse files
Merge pull request #383 from AayushSabharwal/as/fix-adjoints
fix: fix view adjoints
2 parents 75d9322 + 036cbf7 commit 805a1cd

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

ext/RecursiveArrayToolsReverseDiffExt.jl

+12
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,16 @@ end
2323
end
2424
return Array(VA), Array_adjoint
2525
end
26+
27+
@adjoint function Base.view(A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
28+
view_adjoint = let A = A, I = I
29+
function (y)
30+
A = recursivecopy(A)
31+
trackedarraycopyto!(A, y)
32+
(A, map(_ -> nothing, I)...)
33+
end
34+
end
35+
return view(A, I...), view_adjoint
36+
end
37+
2638
end # module

ext/RecursiveArrayToolsZygoteExt.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115
adj = let VA = VA
116116
function Array_adjoint(y)
117117
VA = recursivecopy(VA)
118-
VA .= y
118+
copyto!(VA, y)
119119
return (VA,)
120120
end
121121
end
@@ -126,8 +126,8 @@ end
126126
view_adjoint = let A = A, I = I
127127
function (y)
128128
A = recursivecopy(A)
129-
A .= y
130-
(A, map(_ -> nothing, I)...)
129+
copyto!(A, y)
130+
return (A, map(_ -> nothing, I)...)
131131
end
132132
end
133133
return view(A, I...), view_adjoint
@@ -139,7 +139,7 @@ end
139139
A = recursivecopy(A)
140140
recursivefill!(A, zero(eltype(A)))
141141
v = view(A, I...)
142-
v .= y
142+
copyto!(v, y)
143143
return (A, map(_ -> nothing, I)...)
144144
end
145145
end

0 commit comments

Comments
 (0)