diff --git a/src/lib/array.jl b/src/lib/array.jl index 06ce860d2..6d914d272 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -137,8 +137,9 @@ end struct StaticGetter{i} end (::StaticGetter{i})(v) where {i} = v[i] (::StaticGetter{i})(::Nothing) where {i} = nothing -@generated function _unzip(tuples, ::Val{N}) where {N} - Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...) +function _unzip(tuples, ::Val{N}) where {N} + getters = ntuple(n -> StaticGetter{n}(), N) + map(g -> map(g, tuples), getters) end function unzip(tuples) N = length(first(tuples)) @@ -169,8 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # So we keep axes(x) to restore gradient dx to its full length & correct shape. _tryaxes(x) = axes(x) _tryaxes(x::Tuple) = Val(length(x)) -_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax) +_tryaxes(x::Number) = x +_restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax) +_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(map(length, ax)) - length(dx))), ax) _restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N) +_restore(dx, ::Number) = only(dx) # Sometimes a pullback doesn't return a Tuple, but rather returns only a # single nothing to say "all arguments have zero cotangent". This function is needed to @@ -268,32 +272,51 @@ end _ndims(::Base.HasShape{d}) where {d} = d _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1 +function productfunc(xs, dy) + @assert length(first(dy)) == length(xs) + ndim = map(Zygote._ndims, xs) + cdim = cumsum((1, ndim[begin:end-1]...)) + getters = ntuple(n -> StaticGetter{n}(), length(xs)) + map(first(dy), xs, cdim, getters) do dyn, x, cd, getter + dyn === nothing && return nothing + nd = _ndims(x) + dims = nd == 0 ? (:) : ntuple(i -> i i StaticGetter{n}(), length(xs)) + map(xs, getters) do x, getter + dx = map(getter, dy) + _project(x, _restore(dx, _tryaxes(x))) end - Iterators.product(xs...), back end -@adjoint function Iterators.Zip(xs) - axs = map(_tryaxes, xs) # same function used for map - back(dy::NamedTuple{(:is,)}) = tuple(dy.is) - back(dy::AbstractArray) = ntuple(length(xs)) do d - dx = map(StaticGetter{d}(), dy) - _project(xs[d], _restore(dx, axs[d])) - end |> tuple - Iterators.Zip(xs), back +@adjoint function Iterators.zip(xs...) + zip_pullback(::AbstractArray{Nothing}) = nothing + zip_pullback(dy::NamedTuple{(:is,)}) = dy.is + zip_pullback(dy::AbstractArray) = zipfunc(xs, dy) + Iterators.zip(xs...), zip_pullback +end + +@adjoint function Base.collect(z::Base.Iterators.Zip) + collect_zip_pullback(dy::AbstractArray) = ((is=zipfunc(z.is, dy),),) + collect(z), collect_zip_pullback end # Reductions diff --git a/test/lib/array.jl b/test/lib/array.jl index a3b73aff9..8016c9541 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -18,6 +18,36 @@ test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_ # This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170 @test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] ≈ [320, 320, 320, 320] @test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] ≈ [320, 320, 320, 320] + + # Numbers failed before https://github.com/FluxML/Zygote.jl/pull/1489 + for p in (1.0, fill(1.0), [1.0]) + @test gradient(p -> sum([x*q for q in p, x in 1:3]), p) == (6p,) + @test gradient(p -> sum(x*q for (q, x) in Iterators.product(p, 1:3)), p) == (6p,) + end + + # inference would also fail before #1489 + y, back = _pullback(Iterators.product, 1:5, fill(1)) + @test @inferred back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 4.0, 5.0], fill(5.0)) +end + +@testset "adjoints of Iterators.zip" begin + y, back = _pullback(Iterators.zip, 1:5, 1:3, 1:2) + @test back(collect(y)) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [1.0, 2.0]) + @test back([(nothing, j, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, nothing, [1.0, 2.0, 0.0], [1.0, 2.0]) + @test back([(i, nothing, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], nothing, [1.0, 2.0]) + @test back([(i, j, nothing) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], nothing) + + + @test gradient(x -> sum([y[2] * y[3] for y in Iterators.zip(x, x, x, x)]), [1,2,3,4])[1] ≈ [2, 4, 6, 8] + @test gradient(x -> sum(y[2] * y[3] for y in Iterators.zip(x, x, x, x)), [1,2,3,4])[1] ≈ [2, 4, 6, 8] + + for p in (1.0, fill(1.0), [1.0]) + @test gradient(p_ -> sum(map(prod, Iterators.zip(p_, p))), p) == (p,) + @test gradient(p_ -> sum(x*q for (q, x) in Iterators.zip(p_, p)), p) == (p,) + end + + y, back = _pullback(Iterators.zip, 1:5, fill(1)) + @test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0)) end @testset "collect" begin @@ -45,6 +75,28 @@ end g = gradient(d -> sum(x^2 for x in collect(d)), t)[1] @test g === (2.0, 4.0) end + + @testset "Iterators.ProductIterator" begin + p = Iterators.product(1:3, 1:2) + g = gradient(p -> sum(prod, collect(p)), p)[1] + @test g == (iterators=(3ones(3), 6ones(2)),) + + @test gradient(x -> sum(broadcast(prod, Iterators.product(x,x))), ones(4)) == (2*4ones(4),) + @test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x))), ones(4)) == (3*4ones(4),) + @test gradient(x -> sum(broadcast(prod, Iterators.product(x, x .^ 2))), ones(4)) == (3*4ones(4),) + @test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x .^ 2))), ones(4)) == (4*4ones(4),) + end + + @testset "Iterators.Zip" begin + z = Iterators.zip(1:3, 1:2) + g = gradient(z -> sum(prod, collect(z)), z)[1] + @test g == (is=([1.0, 2.0, 0.0], [1.0, 2.0]),) + + @test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x))), ones(4)) == (2ones(4),) + @test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x))), ones(4)) == (3ones(4),) + @test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),) + @test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),) + end end @testset "dictionary comprehension" begin