diff --git a/Project.toml b/Project.toml index ecf5418d..3a356377 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "2.17.2" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 59ad81e4..88fb4240 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -11,6 +11,9 @@ using Requires, RecipesBase, StaticArrays, Statistics, import ChainRulesCore import ChainRulesCore: NoTangent import ZygoteRules + +using FillArrays + abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end diff --git a/src/zygote.jl b/src/zygote.jl index d1b5a09e..fa7bd019 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] - (NoTangent(),Δ′,NoTangent()) + (NoTangent(),VectorOfArray(Δ′),NoTangent()) end VA[i],AbstractVectorOfArray_getindex_adjoint end @@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[indices...] = Δ - (NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...) + (NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...) end VA[indices...],AbstractVectorOfArray_getindex_adjoint end @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0,cumsum(reduce(vcat,length.(x)))) - NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent() + NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent() end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint @@ -43,23 +43,33 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end +# Define a new species of projection operator for this type: +ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() + +# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix +#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) +# Gradient from broadcasting will be another AbstractArray +#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx + +# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` +# definition first, and finds its own before finding those. + ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] - (Δ′,nothing) + Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))] + (VectorOfArray(Δ′),nothing) end VA[i],AbstractVectorOfArray_getindex_adjoint end ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = zero(VA) - Δ′[i,j...] = Δ - (Δ′, i,map(_ -> nothing, j)...) + Δ′ = [(i == j ? zero(x) : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))] + Δ′[i][j...] = Δ + (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) end VA[i,j...],AbstractVectorOfArray_getindex_adjoint end - ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) @@ -71,11 +81,11 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{fal end ZygoteRules.@adjoint function VectorOfArray(u) - VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],) + VectorOfArray(u),y -> (VectorOfArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]]),) end ZygoteRules.@adjoint function DiffEqArray(u,t) - DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing) + DiffEqArray(u,t),y -> (DiffEqArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],t),nothing) end ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})