diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 9b8b60552..fe6b9499d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -152,6 +152,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain @inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() @inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent() @inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent() +@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs) @inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, dxs) # This produces Tangent{Any} since it does not get to see the primal, `x`.