From 8fb20af6cb81a9493b4832466a4de5214d23fd16 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 25 Aug 2021 21:34:10 +0100 Subject: [PATCH 1/2] Use zygote2differential to wrap chainrules inputs --- src/compiler/chainrules.jl | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index aaec3951f..3191e6b1f 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -117,27 +117,16 @@ for T_outer in (:Tuple, :NamedTuple) end """ - wrap_chainrules_input(x) - -Convert `x` from the format Zygote uses internally to differentials types ChainRules uses. -""" -@inline wrap_chainrules_input(x) = x -@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() -@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) - xp = map(wrap_chainrules_input, xs) - ChainRules.Tangent{Any, typeof(xp)}(xp) -end - -""" - ZBack{F}(back) <: Function + ZBack{F,P}(back::F, primals::P) <: Function Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions. (A functor here is used rather than a closure to avoid boxing issues); """ -struct ZBack{F} <: Function +struct ZBack{F,P} <: Function + primals::P back::F end -@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) +@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(zygote2differential(dy, s.primals))) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) @inline (s::ZBack)(::Nothing) = nothing @@ -150,7 +139,7 @@ The pullback is appropriately wrapped up to follow Zygote conventions. """ @inline function chain_rrule(config, f, args...) y, back = rrule(config, f, args...) - return y, ZBack(back) + return y, ZBack(back, args...) end From 9ca44028b6503d03e5c34c04cfa0c8b7c68a4199 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 25 Aug 2021 21:42:30 +0100 Subject: [PATCH 2/2] fix typos and mised bits --- src/compiler/chainrules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 3191e6b1f..ab9acea44 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -123,8 +123,8 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven (A functor here is used rather than a closure to avoid boxing issues); """ struct ZBack{F,P} <: Function - primals::P back::F + primals::P end @inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(zygote2differential(dy, s.primals))) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 @@ -139,7 +139,7 @@ The pullback is appropriately wrapped up to follow Zygote conventions. """ @inline function chain_rrule(config, f, args...) y, back = rrule(config, f, args...) - return y, ZBack(back, args...) + return y, ZBack(back, args) end @@ -152,7 +152,7 @@ As per [`chain_rrule`](@ref) but with support for kwargs. @inline function chain_rrule_kw(config, kwf, kwargs, f, args...) y, back = rrule(config, f, args...; kwargs...) function kw_zpullback(dy) - dxs = ZBack(back)(dy) + dxs = ZBack(back, args)(dy) if dxs === nothing # if dxs is nothing, then all partiaols are nothing # Zygote convention is a single nothing no mather how partials, if all are nothing return nothing