|
11 | 11 | end
|
12 | 12 |
|
13 | 13 | # Define a new species of projection operator for this type:
|
14 |
| -ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() |
| 14 | +# ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() |
15 | 15 |
|
16 | 16 | function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
|
17 | 17 | xs::AbstractVectorOfArray)
|
|
117 | 117 | A.x, literal_ArrayPartition_x_adjoint
|
118 | 118 | end
|
119 | 119 |
|
| 120 | +@adjoint function Array(VA::AbstractVectorOfArray) |
| 121 | + Array(VA), |
| 122 | + y -> (Array(y),) |
120 | 123 | end
|
| 124 | + |
| 125 | + |
| 126 | +ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) |
| 127 | + |
| 128 | +function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x) |
| 129 | + arr = reshape(x, p.sz) |
| 130 | + return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) |
| 131 | +end |
| 132 | + |
| 133 | +function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) |
| 134 | + N = ndims(x̄) |
| 135 | + if length(x) == length(x̄) |
| 136 | + Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors |
| 137 | + else |
| 138 | + dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) |
| 139 | + Zygote._project(x, Zygote.accum_sum(x̄; dims = dims)) |
| 140 | + end |
| 141 | +end |
| 142 | + |
| 143 | +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic(__context__, f, a, b) |
| 144 | +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b) |
| 145 | +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b) |
| 146 | + |
| 147 | +@inline function _broadcast_generic(__context__, f::F, args...) where {F} |
| 148 | + T = Broadcast.combine_eltypes(f, args) |
| 149 | + # Avoid generic broadcasting in two easy cases: |
| 150 | + if T == Bool |
| 151 | + return (f.(args...), _ -> nothing) |
| 152 | + elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && all(Zygote._dual_safearg, args) && !Zygote.isderiving() |
| 153 | + return Zygote.broadcast_forward(f, args...) |
| 154 | + end |
| 155 | + len = Zygote.inclen(args) |
| 156 | + y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...) |
| 157 | + y = broadcast(first, y∂b) |
| 158 | + function ∇broadcasted(ȳ) |
| 159 | + y∂b = y∂b isa AbstractVectorOfArray ? Iterators.flatten(y∂b.u) : y∂b |
| 160 | + ȳ = ȳ isa AbstractVectorOfArray ? Iterators.flatten(ȳ.u) : ȳ |
| 161 | + dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ) |
| 162 | + getters = ntuple(i -> Zygote.StaticGetter{i}(), len) |
| 163 | + dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters) |
| 164 | + (nothing, Zygote.accum_sum(dxs[1]), map(Zygote.unbroadcast, args, Base.tail(dxs))...) |
| 165 | + end |
| 166 | + return y, ∇broadcasted |
| 167 | +end |
| 168 | + |
| 169 | +end # module |
0 commit comments