diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 168b7e71..ead05de1 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -11,7 +11,7 @@ else end # Define a new species of projection operator for this type: -ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() +# ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray}, xs::AbstractVectorOfArray) @@ -117,4 +117,53 @@ end A.x, literal_ArrayPartition_x_adjoint end +@adjoint function Array(VA::AbstractVectorOfArray) + Array(VA), + y -> (Array(y),) end + + +ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) + +function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x) + arr = reshape(x, p.sz) + return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) +end + +function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) + N = ndims(x̄) + if length(x) == length(x̄) + Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors + else + dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) + Zygote._project(x, Zygote.accum_sum(x̄; dims = dims)) + end +end + +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic(__context__, f, a, b) +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b) +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b) + +@inline function _broadcast_generic(__context__, f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + # Avoid generic broadcasting in two easy cases: + if T == Bool + return (f.(args...), _ -> nothing) + elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && all(Zygote._dual_safearg, args) && !Zygote.isderiving() + return Zygote.broadcast_forward(f, args...) + end + len = Zygote.inclen(args) + y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...) + y = broadcast(first, y∂b) + function ∇broadcasted(ȳ) + y∂b = y∂b isa AbstractVectorOfArray ? Iterators.flatten(y∂b.u) : y∂b + ȳ = ȳ isa AbstractVectorOfArray ? Iterators.flatten(ȳ.u) : ȳ + dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ) + getters = ntuple(i -> Zygote.StaticGetter{i}(), len) + dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters) + (nothing, Zygote.accum_sum(dxs[1]), map(Zygote.unbroadcast, args, Base.tail(dxs))...) + end + return y, ∇broadcasted +end + +end # module diff --git a/src/utils.jl b/src/utils.jl index c59a9710..d9c1585a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -177,8 +177,8 @@ Calls `f` on each element of a vecvec `v`. function vecvecapply(f, v) sol = Vector{eltype(eltype(v))}() for i in eachindex(v) - for j in eachindex(v[i]) - push!(sol, v[i][j]) + for j in eachindex(v[:, i]) + push!(sol, v[:, i][j]) end end f(sol) diff --git a/test/adjoints.jl b/test/adjoints.jl index a6abb445..e06035af 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -37,6 +37,11 @@ function loss6(x) sum(abs2, Array(_prob.u0)) end +function loss7(x) + _x = VectorOfArray([x .* i for i in 1:5]) + return sum(abs2, x .- 1) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) @@ -45,3 +50,4 @@ loss(x) @test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x) @test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x) @test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) +@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 1669032b..a437fa46 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -22,11 +22,16 @@ sol_new = DiffEqArray(sol.u[1:10], @test sol_new[t] ≈ sol_new.t @test sol_new[t, 1:5] ≈ sol_new.t[1:5] @test getp(sol, τ)(sol) == getp(sol_new, τ)(sol_new) == 3.0 -@test variable_symbols(sol) == variable_symbols(sol_new) == [x] -@test all_variable_symbols(sol) == all_variable_symbols(sol_new) == [x, RHS] -@test all_symbols(sol) == all_symbols(sol_new) == [x, RHS, τ, t] -@test sol[solvedvariables, 1:10] == sol_new[solvedvariables] == sol_new[[x]] -@test sol[allvariables, 1:10] == sol_new[allvariables] == sol_new[[x, RHS]] +@test all(isequal.(variable_symbols(sol), variable_symbols(sol_new))) +@test all(isequal.(variable_symbols(sol), [x])) +@test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new))) +@test all(isequal.(all_variable_symbols(sol), [x, RHS])) +@test all(isequal.(all_symbols(sol), all_symbols(sol_new))) +@test all(isequal.(all_symbols(sol), [x, RHS, τ, t])) +@test sol[solvedvariables] == sol[[x]] +@test sol_new[solvedvariables] == sol_new[[x]] +@test sol[allvariables] == sol[[x, RHS]] +@test sol_new[allvariables] == sol_new[[x, RHS]] @test_throws Exception sol[τ] @test_throws Exception sol_new[τ]