diff --git a/src/zygote.jl b/src/zygote.jl index d1b5a09e..c30ec23e 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -43,10 +43,28 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end +#= + +# Define a new species of projection operator for this type: +ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}() + +# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix +(::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) +# Gradient from broadcasting will be another AbstractArray +(::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx + +But this may not be necessary? + +=# + + +# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` +# definition first, and finds its own before finding those. + ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] - (Δ′,nothing) + (VectorOfArray(Δ′),nothing) end VA[i],AbstractVectorOfArray_getindex_adjoint end @@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ - (Δ′, i,map(_ -> nothing, j)...) + @show Δ′ + # (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug? + (Δ′, nothing, map(_ -> nothing, j)...) + # (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) end VA[i,j...],AbstractVectorOfArray_getindex_adjoint end - ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x} function ArrayPartition_adjoint(_y) y = Array(_y)