diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index e879af3f8..74b0e802d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -106,6 +106,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing @inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing +@inline wrap_chainrules_output(x::ChainRules.NoTangent) = nothing for T_outer in (:Tuple, :NamedTuple) # we create separate methods rather than using a `Union` + an `if` so that we avoid a # branch that changes output type, because nested AD on that kinda thing makes Zygote less