Skip to content

Commit 2af86b7

Browse files
fix: view adjoints
1 parent cc55ee6 commit 2af86b7

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,19 @@ end
118118
Array(VA), adj
119119
end
120120

121+
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
122+
function adjoint(y)
123+
(recursivecopy(A), map(_ -> nothing, I)...)
124+
end
125+
return view(A, I...), adjoint
126+
end
127+
121128
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
122129
adj = let A = A, I = I
123130
function view_adjoint(y)
124-
A = zero(A)
125-
view(A, I...) .= y
131+
A = recursivecopy(A)
132+
recursivefill!(A, zero(eltype(A)))
133+
A[I...] .= y
126134
return (A, map(_ -> nothing, I)...)
127135
end
128136
end

0 commit comments

Comments
 (0)