Skip to content

Commit 680e756

Browse files
fix: view adjoints
1 parent cc55ee6 commit 680e756

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

+13-7
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,21 @@ end
118118
Array(VA), adj
119119
end
120120

121+
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
122+
function adjoint(y)
123+
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
124+
end
125+
return view(A, I...), adjoint
126+
end
127+
121128
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
122-
adj = let A = A, I = I
123-
function view_adjoint(y)
124-
A = zero(A)
125-
view(A, I...) .= y
126-
return (A, map(_ -> nothing, I)...)
127-
end
129+
function view_adjoint(y)
130+
A = recursivecopy(parent(y))
131+
recursivefill!(A, zero(eltype(A)))
132+
A[I...] .= y
133+
return (A, map(_ -> nothing, I)...)
128134
end
129-
view(A, I...), adj
135+
view(A, I...), view_adjoint
130136
end
131137

132138
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

0 commit comments

Comments
 (0)