diff --git a/src/destructure.jl b/src/destructure.jl index 2b91983d..d000ff75 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -127,15 +127,22 @@ function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), base(dx)) off′, _ = functor(typeof(x), off) - foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) + for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′) + flat = _grad!(xᵢ, dxᵢ, oᵢ, flat) + end flat end -function _grad!(x, dx, off::Integer, flat::AbstractVector) - @views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes +function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T + dx_un = unthunk(dx) + T2 = promote_type(T, eltype(dx_un)) + if T != T2 # then we must widen the type + flat = copyto!(similar(flat, T2), flat) + end + @views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes flat end -_grad!(x, dx::Zero, off, flat::AbstractVector) = dx -_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity +_grad!(x, dx::Zero, off, flat::AbstractVector) = flat +_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity # These are only needed for 2nd derivatives: function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) diff --git a/test/destructure.jl b/test/destructure.jl index 043315b3..d20f4f30 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -180,3 +180,43 @@ end 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) end + +@testset "DiffEqFlux issue 699" begin + # The gradient of `re` is a vector into which we accumulate contributions, and the issue + # is that one contribution may have a wider type than `v`, especially for `Dual` numbers. + v, re = destructure((x=Float32[1,2], y=Float32[3,4,5])) + _, bk = Zygote.pullback(re, ones(Float32, 5)) + # Testing with `Complex` isn't ideal, but this was an error on 0.2.1. + # If some upgrade inserts ProjectTo, this will fail, and can be changed: + @test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],) + + @test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float32} # despite some ZeroTangent + @test bk((x=nothing, y=nothing)) == ([0,0,0,0,0],) + @test bk((x=nothing, y=@thunk [1,2,3] .* 10.0)) == ([0,0,10,20,30],) + @test bk((x=[1.2, 3.4], y=Float32[5,6,7])) == ([1.2, 3.4, 5, 6, 7],) +end + +#= + +# Adapted from https://github.com/SciML/DiffEqFlux.jl/pull/699#issuecomment-1092846657 +using ForwardDiff, Zygote, Flux, Optimisers, Test + +y = Float32[0.8564646, 0.21083355] +p = randn(Float32, 27); +t = 1.5f0 +λ = [ForwardDiff.Dual(0.87135935, 1, 0, 0, 0, 0, 0), ForwardDiff.Dual(1.5225363, 0, 1, 0, 0, 0, 0)] + +model = Chain(x -> x .^ 3, + Dense(2 => 5, tanh), + Dense(5 => 2)) + +p,re = Optimisers.destructure(model) +f(u, p, t) = re(p)(u) +_dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) +end +tmp1, tmp2 = back(λ); +tmp1 +@test tmp2 isa Vector{<:ForwardDiff.Dual} + +=#