diff --git a/Project.toml b/Project.toml index 91dda09..d325fcb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ZygoteRules" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.2" +version = "0.2.3" [deps] MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/src/adjoint.jl b/src/adjoint.jl index 47f628e..ed8755f 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -37,7 +37,11 @@ function unthunk_tangent end @inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x) -function gradm(ex, mut = false, keepthunks = false) +@inline maybe_final(cx::AContext, y) = nothing +@inline maybe_final(x) = nothing + + +function gradm(ex, mut = false, keepthunks = false, finalise = false) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {Ts__} = body_)) || error("Need a function definition") kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing @@ -59,18 +63,19 @@ function gradm(ex, mut = false, keepthunks = false) gradtuplekw = isclosure ? gradtuple2 : gradtuple3 adj = @q @inline ZygoteRules.adjoint($(fargs...)) where $(Ts...) = $(esc(body)) maybe_unthunked_Δ = keepthunks ? :Δ : :(unthunk_tangent(Δ)) + maybe_finalise_y = finalise ? :(maybe_final(__context__,y)) : nothing quote $adj @inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...)) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuple(_back($maybe_unthunked_Δ)) + back(Δ) = begin ∇s = $gradtuple(_back($maybe_unthunked_Δ)); $maybe_finalise_y; ∇s end return y, back end @inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...); kw...) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ)) + back(Δ) = begin ∇s = $gradtuplekw(_back($maybe_unthunked_Δ)); $maybe_finalise_y; ∇s end return y, back end nothing @@ -78,9 +83,31 @@ function gradm(ex, mut = false, keepthunks = false) end macro adjoint(ex) - gradm(ex, false, false) + gradm(ex, false, false, false) end macro adjoint!(ex) - gradm(ex, true, false) + gradm(ex, true, false, false) +end + +""" + @adjoint_final function f(x) ... + +This differs from `@adjoint` in that it may call `finalize(y)` on +the forward result `y = f(x)`, after the backward pass. +(This will never happen within `jacobian`, where the pullback is +called several times, hence `y` must be kept.) + +For correctness, by using this macro you guarantee that `y = f(x)` +is a new array, never `===` nor aliasing `x`. +(Gradients `dx` also may not be thunks closing over `y`, +but like `adjoint` this should automatically un-thunk.) + +Calling `finalize` is an optimisation to help free up GPU memory. +To free up intermediate steps by hand, use `maybe_final(__context__, z)` +inside the reverse pass (for `z` which must survive until then) +or `maybe_final(z)` to discard `z` immediately (even inside `jacobian`). +""" +macro adjoint_final(ex) + ZygoteRules.gradm(ex, false, false, true) end