Skip to content

Commit fc037f9

Browse files
fix: fix view adjoints
1 parent b75de6e commit fc037f9

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

Diff for: ext/RecursiveArrayToolsZygoteExt.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888

8989
@adjoint function Base.copy(u::VectorOfArray)
9090
copy(u),
91-
y -> (copy(y),)
91+
tuple copy
9292
end
9393

9494
@adjoint function DiffEqArray(u, t)
@@ -123,18 +123,24 @@ end
123123
end
124124

125125
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
126-
function adjoint(y)
127-
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
126+
view_adjoint = let A = A, I = I
127+
function (y)
128+
A = recursivecopy(A)
129+
copyto!(A, y)
130+
(A, map(_ -> nothing, I)...)
131+
end
128132
end
129-
return view(A, I...), adjoint
133+
return view(A, I...), view_adjoint
130134
end
131135

132136
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
133-
function view_adjoint(y)
134-
A = recursivecopy(parent(y))
135-
recursivefill!(A, zero(eltype(A)))
136-
A[I...] .= y
137-
return (A, map(_ -> nothing, I)...)
137+
view_adjoint = let A = A, I = I
138+
function (y)
139+
A = recursivecopy(A)
140+
recursivefill!(A, zero(eltype(A)))
141+
A[I...] .= y
142+
return (A, map(_ -> nothing, I)...)
143+
end
138144
end
139145
view(A, I...), view_adjoint
140146
end

Diff for: test/adjoints.jl

+12
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ function loss9(x)
6666
return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5])
6767
end
6868

69+
function loss10(x)
70+
voa = VectorOfArray([i * x for i in 1:5])
71+
return sum(view(voa, 2:4, 3:5))
72+
end
73+
74+
function loss11(x)
75+
voa = VectorOfArray([i * x for i in 1:5])
76+
return sum(view(voa, :, :))
77+
end
78+
6979
x = float.(6:10)
7080
loss(x)
7181
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
@@ -78,3 +88,5 @@ loss(x)
7888
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
7989
@test ForwardDiff.derivative(loss9, 0.0) ==
8090
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
91+
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
92+
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)

0 commit comments

Comments
 (0)