Skip to content

Commit eae488e

Browse files
Merge pull request #319 from AayushSabharwal/as/broadcast-ad
feat: add support for autodiff through broadcasting of VoA
2 parents cb9d12d + 035e20a commit eae488e

File tree

4 files changed

+68
-8
lines changed

4 files changed

+68
-8
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

+50-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ else
1111
end
1212

1313
# Define a new species of projection operator for this type:
14-
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
14+
# ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
1515

1616
function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
1717
xs::AbstractVectorOfArray)
@@ -117,4 +117,53 @@ end
117117
A.x, literal_ArrayPartition_x_adjoint
118118
end
119119

120+
@adjoint function Array(VA::AbstractVectorOfArray)
121+
Array(VA),
122+
y -> (Array(y),)
120123
end
124+
125+
126+
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
127+
128+
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
129+
arr = reshape(x, p.sz)
130+
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
131+
end
132+
133+
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
134+
N = ndims(x̄)
135+
if length(x) == length(x̄)
136+
Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
137+
else
138+
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
139+
Zygote._project(x, Zygote.accum_sum(x̄; dims = dims))
140+
end
141+
end
142+
143+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic(__context__, f, a, b)
144+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b)
145+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b)
146+
147+
@inline function _broadcast_generic(__context__, f::F, args...) where {F}
148+
T = Broadcast.combine_eltypes(f, args)
149+
# Avoid generic broadcasting in two easy cases:
150+
if T == Bool
151+
return (f.(args...), _ -> nothing)
152+
elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && all(Zygote._dual_safearg, args) && !Zygote.isderiving()
153+
return Zygote.broadcast_forward(f, args...)
154+
end
155+
len = Zygote.inclen(args)
156+
y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...)
157+
y = broadcast(first, y∂b)
158+
function ∇broadcasted(ȳ)
159+
y∂b = y∂b isa AbstractVectorOfArray ? Iterators.flatten(y∂b.u) : y∂b
160+
ȳ = ȳ isa AbstractVectorOfArray ? Iterators.flatten.u) : ȳ
161+
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
162+
getters = ntuple(i -> Zygote.StaticGetter{i}(), len)
163+
dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters)
164+
(nothing, Zygote.accum_sum(dxs[1]), map(Zygote.unbroadcast, args, Base.tail(dxs))...)
165+
end
166+
return y, ∇broadcasted
167+
end
168+
169+
end # module

src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ Calls `f` on each element of a vecvec `v`.
177177
function vecvecapply(f, v)
178178
sol = Vector{eltype(eltype(v))}()
179179
for i in eachindex(v)
180-
for j in eachindex(v[i])
181-
push!(sol, v[i][j])
180+
for j in eachindex(v[:, i])
181+
push!(sol, v[:, i][j])
182182
end
183183
end
184184
f(sol)

test/adjoints.jl

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ function loss6(x)
3737
sum(abs2, Array(_prob.u0))
3838
end
3939

40+
function loss7(x)
41+
_x = VectorOfArray([x .* i for i in 1:5])
42+
return sum(abs2, x .- 1)
43+
end
44+
4045
x = float.(6:10)
4146
loss(x)
4247
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
@@ -45,3 +50,4 @@ loss(x)
4550
@test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x)
4651
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
4752
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
53+
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)

test/downstream/symbol_indexing.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@ sol_new = DiffEqArray(sol.u[1:10],
2222
@test sol_new[t] sol_new.t
2323
@test sol_new[t, 1:5] sol_new.t[1:5]
2424
@test getp(sol, τ)(sol) == getp(sol_new, τ)(sol_new) == 3.0
25-
@test variable_symbols(sol) == variable_symbols(sol_new) == [x]
26-
@test all_variable_symbols(sol) == all_variable_symbols(sol_new) == [x, RHS]
27-
@test all_symbols(sol) == all_symbols(sol_new) == [x, RHS, τ, t]
28-
@test sol[solvedvariables, 1:10] == sol_new[solvedvariables] == sol_new[[x]]
29-
@test sol[allvariables, 1:10] == sol_new[allvariables] == sol_new[[x, RHS]]
25+
@test all(isequal.(variable_symbols(sol), variable_symbols(sol_new)))
26+
@test all(isequal.(variable_symbols(sol), [x]))
27+
@test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new)))
28+
@test all(isequal.(all_variable_symbols(sol), [x, RHS]))
29+
@test all(isequal.(all_symbols(sol), all_symbols(sol_new)))
30+
@test all(isequal.(all_symbols(sol), [x, RHS, τ, t]))
31+
@test sol[solvedvariables] == sol[[x]]
32+
@test sol_new[solvedvariables] == sol_new[[x]]
33+
@test sol[allvariables] == sol[[x, RHS]]
34+
@test sol_new[allvariables] == sol_new[[x, RHS]]
3035
@test_throws Exception sol[τ]
3136
@test_throws Exception sol_new[τ]
3237

0 commit comments

Comments
 (0)