Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve adjoint for product and zip #1489

Merged
merged 16 commits into from
Jan 19, 2024
73 changes: 48 additions & 25 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ end
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
function _unzip(tuples, ::Val{N}) where {N}
getters = ntuple(n -> StaticGetter{n}(), N)
map(g -> map(g, tuples), getters)
end
function unzip(tuples)
N = length(first(tuples))
Expand Down Expand Up @@ -169,8 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
_tryaxes(x::Number) = x
_restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax)
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(map(length, ax)) - length(dx))), ax)
_restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N)
_restore(dx, ::Number) = only(dx)

# Sometimes a pullback doesn't return a Tuple, but rather returns only a
# single nothing to say "all arguments have zero cotangent". This function is needed to
Expand Down Expand Up @@ -268,32 +272,51 @@ end
_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1

function productfunc(xs, dy)
@assert length(first(dy)) == length(xs)
ndim = map(Zygote._ndims, xs)
cdim = cumsum((1, ndim[begin:end-1]...))
getters = ntuple(n -> StaticGetter{n}(), length(xs))
map(first(dy), xs, cdim, getters) do dyn, x, cd, getter
dyn === nothing && return nothing
nd = _ndims(x)
dims = nd == 0 ? (:) : ntuple(i -> i<cd ? i : i+nd, Val(ndims(dy)-nd))
init = map(zero, dyn) # allows for tuples, which accum can add:
red = mapreduce(getter, accum, dy; dims, init)
return _project(x, nd == 0 ? red : reshape(red, axes(x)))
end
end

@adjoint function Iterators.product(xs...)
back(::AbstractArray{Nothing}) = nothing
back(dy::NamedTuple{(:iterators,)}) = dy.iterators
function back(dy::AbstractArray)
d = 1
ntuple(length(xs)) do n
nd = _ndims(xs[n])
dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
d += nd
first(dy)[n] === nothing && return nothing
init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
return _project(xs[n], reshape(red, axes(xs[n])))
end
product_pullback(::AbstractArray{Nothing}) = nothing
product_pullback(dy::NamedTuple{(:iterators,)}) = dy.iterators
product_pullback(dy::AbstractArray) = productfunc(xs, dy)
Iterators.product(xs...), product_pullback
end

@adjoint function Base.collect(p::Base.Iterators.ProductIterator)
collect_product_pullback(dy) = ((iterators=productfunc(p.iterators, dy),),)
return collect(p), collect_product_pullback
end

function zipfunc(xs, dy)
getters = ntuple(n -> StaticGetter{n}(), length(xs))
map(xs, getters) do x, getter
dx = map(getter, dy)
_project(x, _restore(dx, _tryaxes(x)))
end
Iterators.product(xs...), back
end

@adjoint function Iterators.Zip(xs)
axs = map(_tryaxes, xs) # same function used for map
back(dy::NamedTuple{(:is,)}) = tuple(dy.is)
back(dy::AbstractArray) = ntuple(length(xs)) do d
dx = map(StaticGetter{d}(), dy)
_project(xs[d], _restore(dx, axs[d]))
end |> tuple
Iterators.Zip(xs), back
@adjoint function Iterators.zip(xs...)
zip_pullback(::AbstractArray{Nothing}) = nothing
zip_pullback(dy::NamedTuple{(:is,)}) = dy.is
zip_pullback(dy::AbstractArray) = zipfunc(xs, dy)
Iterators.zip(xs...), zip_pullback
end

@adjoint function Base.collect(z::Base.Iterators.Zip)
collect_zip_pullback(dy::AbstractArray) = ((is=zipfunc(z.is, dy),),)
collect(z), collect_zip_pullback
end

# Reductions
Expand Down
52 changes: 52 additions & 0 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@ test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_
# This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170
@test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] ≈ [320, 320, 320, 320]
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] ≈ [320, 320, 320, 320]

# Numbers failed before https://github.com/FluxML/Zygote.jl/pull/1489
for p in (1.0, fill(1.0), [1.0])
@test gradient(p -> sum([x*q for q in p, x in 1:3]), p) == (6p,)
@test gradient(p -> sum(x*q for (q, x) in Iterators.product(p, 1:3)), p) == (6p,)
end

# inference would also fail before #1489
y, back = _pullback(Iterators.product, 1:5, fill(1))
@test @inferred back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 4.0, 5.0], fill(5.0))
end

@testset "adjoints of Iterators.zip" begin
y, back = _pullback(Iterators.zip, 1:5, 1:3, 1:2)
@test back(collect(y)) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [1.0, 2.0])
@test back([(nothing, j, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, nothing, [1.0, 2.0, 0.0], [1.0, 2.0])
@test back([(i, nothing, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], nothing, [1.0, 2.0])
@test back([(i, j, nothing) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], nothing)


@test gradient(x -> sum([y[2] * y[3] for y in Iterators.zip(x, x, x, x)]), [1,2,3,4])[1] ≈ [2, 4, 6, 8]
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.zip(x, x, x, x)), [1,2,3,4])[1] ≈ [2, 4, 6, 8]

for p in (1.0, fill(1.0), [1.0])
@test gradient(p_ -> sum(map(prod, Iterators.zip(p_, p))), p) == (p,)
@test gradient(p_ -> sum(x*q for (q, x) in Iterators.zip(p_, p)), p) == (p,)
end

y, back = _pullback(Iterators.zip, 1:5, fill(1))
@test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0))
end

@testset "collect" begin
Expand Down Expand Up @@ -45,6 +75,28 @@ end
g = gradient(d -> sum(x^2 for x in collect(d)), t)[1]
@test g === (2.0, 4.0)
end

@testset "Iterators.ProductIterator" begin
p = Iterators.product(1:3, 1:2)
g = gradient(p -> sum(prod, collect(p)), p)[1]
@test g == (iterators=(3ones(3), 6ones(2)),)

@test gradient(x -> sum(broadcast(prod, Iterators.product(x,x))), ones(4)) == (2*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x))), ones(4)) == (3*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x, x .^ 2))), ones(4)) == (3*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x .^ 2))), ones(4)) == (4*4ones(4),)
end

@testset "Iterators.Zip" begin
z = Iterators.zip(1:3, 1:2)
g = gradient(z -> sum(prod, collect(z)), z)[1]
@test g == (is=([1.0, 2.0, 0.0], [1.0, 2.0]),)

@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x))), ones(4)) == (2ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),)
end
end

@testset "dictionary comprehension" begin
Expand Down