From c53df992fa37faee417db245ca1d7c569d59b201 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 23 Sep 2021 20:27:34 +0530 Subject: [PATCH 1/3] NoTangent -> nothing --- src/compiler/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index e879af3f8..79bf6b3dc 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -145,7 +145,7 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x) # Piracy: # wrap_chainrules_input doesn't handle array of Union{Int,Nothing} -(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent() +(::ChainRulesCore.ProjectTo)(::Nothing) = nothing # ChainRulesCore.NoTangent() # CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} (project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) From 4034bedaa412794066cc8c4bb787f9bc66cd5b2e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 25 Sep 2021 03:11:02 +0530 Subject: [PATCH 2/3] add wrap_chainrules_output rule --- src/compiler/chainrules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 79bf6b3dc..4605e55cf 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -106,6 +106,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing @inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing +@inlint wrap_chainrules_output(x::ChainRules.NoTangent) = nothing for T_outer in (:Tuple, :NamedTuple) # we create separate methods rather than using a `Union` + an `if` so that we avoid a # branch that changes output type, because nested AD on that kinda thing makes Zygote less @@ -145,7 +146,7 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x) # Piracy: # wrap_chainrules_input doesn't handle array of Union{Int,Nothing} -(::ChainRulesCore.ProjectTo)(::Nothing) = nothing # ChainRulesCore.NoTangent() +(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent() # CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} (project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) From 7ad6e8bd9886d967987b9cefe0f935c4bf35add3 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 25 Sep 2021 03:12:31 +0530 Subject: [PATCH 3/3] typo --- src/compiler/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 4605e55cf..74b0e802d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -106,7 +106,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing @inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing -@inlint wrap_chainrules_output(x::ChainRules.NoTangent) = nothing +@inline wrap_chainrules_output(x::ChainRules.NoTangent) = nothing for T_outer in (:Tuple, :NamedTuple) # we create separate methods rather than using a `Union` + an `if` so that we avoid a # branch that changes output type, because nested AD on that kinda thing makes Zygote less