Skip to content

Commit 5c6d7fd

Browse files
committed
take 1
1 parent 196ecd7 commit 5c6d7fd

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/zygote.jl

+23-3
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,28 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343
A.x,literal_ArrayPartition_x_adjoint
4444
end
4545

46+
#=
47+
48+
# Define a new species of projection operator for this type:
49+
ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}()
50+
51+
# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
52+
(::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
53+
# Gradient from broadcasting will be another AbstractArray
54+
(::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
55+
56+
But this may not be necessary?
57+
58+
=#
59+
60+
61+
# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
62+
# definition first, and finds its own before finding those.
63+
4664
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
4765
function AbstractVectorOfArray_getindex_adjoint(Δ)
4866
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
49-
(Δ′,nothing)
67+
(VectorOfArray(Δ′),nothing)
5068
end
5169
VA[i],AbstractVectorOfArray_getindex_adjoint
5270
end
@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
5573
function AbstractVectorOfArray_getindex_adjoint(Δ)
5674
Δ′ = zero(VA)
5775
Δ′[i,j...] = Δ
58-
(Δ′, i,map(_ -> nothing, j)...)
76+
@show Δ′
77+
# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
78+
(Δ′, nothing, map(_ -> nothing, j)...)
79+
# (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
5980
end
6081
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
6182
end
62-
6383
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
6484
function ArrayPartition_adjoint(_y)
6585
y = Array(_y)

0 commit comments

Comments
 (0)