Skip to content

Commit e0b18dc

Browse files
Merge pull request #382 from AayushSabharwal/as/fix-adjoints
fix: fix `Base.view` adjoints
2 parents 5e9e8f8 + 97edb58 commit e0b18dc

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

+17-10
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888

8989
@adjoint function Base.copy(u::VectorOfArray)
9090
copy(u),
91-
y -> (copy(y),)
91+
tuple copy
9292
end
9393

9494
@adjoint function DiffEqArray(u, t)
@@ -115,26 +115,33 @@ end
115115
adj = let VA = VA
116116
function Array_adjoint(y)
117117
VA = recursivecopy(VA)
118-
copyto!(VA, y)
118+
VA .= y
119119
return (VA,)
120120
end
121121
end
122122
Array(VA), adj
123123
end
124124

125125
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
126-
function adjoint(y)
127-
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
126+
view_adjoint = let A = A, I = I
127+
function (y)
128+
A = recursivecopy(A)
129+
A .= y
130+
(A, map(_ -> nothing, I)...)
131+
end
128132
end
129-
return view(A, I...), adjoint
133+
return view(A, I...), view_adjoint
130134
end
131135

132136
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
133-
function view_adjoint(y)
134-
A = recursivecopy(parent(y))
135-
recursivefill!(A, zero(eltype(A)))
136-
A[I...] .= y
137-
return (A, map(_ -> nothing, I)...)
137+
view_adjoint = let A = A, I = I
138+
function (y)
139+
A = recursivecopy(A)
140+
recursivefill!(A, zero(eltype(A)))
141+
v = view(A, I...)
142+
v .= y
143+
return (A, map(_ -> nothing, I)...)
144+
end
138145
end
139146
view(A, I...), view_adjoint
140147
end

test/adjoints.jl

+12
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ function loss9(x)
6666
return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5])
6767
end
6868

69+
function loss10(x)
70+
voa = VectorOfArray([i * x for i in 1:5])
71+
return sum(view(voa, 2:4, 3:5))
72+
end
73+
74+
function loss11(x)
75+
voa = VectorOfArray([i * x for i in 1:5])
76+
return sum(view(voa, :, :))
77+
end
78+
6979
x = float.(6:10)
7080
loss(x)
7181
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
@@ -78,3 +88,5 @@ loss(x)
7888
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
7989
@test ForwardDiff.derivative(loss9, 0.0) ==
8090
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
91+
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
92+
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)

0 commit comments

Comments
 (0)