From 7f6af69c61e4f4ac5c5d2d247ffdf6318fb64ef9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 02:03:26 -0500 Subject: [PATCH 01/15] Add EnzymeRule for conv --- ext/NNlibEnzymeExt.jl | 88 +++++++++++++++++++++++++++++++++++++++++++ src/NNlib.jl | 8 ++++ 2 files changed, 96 insertions(+) create mode 100644 ext/NNlibEnzymeExt.jl diff --git a/ext/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt.jl new file mode 100644 index 000000000..16d3dd809 --- /dev/null +++ b/ext/NNlibEnzymeExt.jl @@ -0,0 +1,88 @@ +module NNlibEnzymeExt + +using NNlib +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + +using Enzyme + +using EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims::Const; kwargs...) +) where {OutType, RT} + + @assert !(OutType <: Const) + if OutType <: Duplicated || OutType <: DuplicatedNoNeed + func.val(y.val, x.val, y.val, cdims.val; kwargs...) + end + + dres = if EnzymeRules.width(config) == 1 + func.val(prob.dval, alg.val; kwargs...) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + func.val(prob.dval[i], alg.val; kwargs...) + end + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(y.val, y.dval, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims::Const; kwargs...) where {RT} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: Const) + if !EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: Const) ? nothing : x.dval + dws = (typeof(w) <: Const) ? nothing : w.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in (dys, dxs, dws) + if !(typeof(x) <: Const) + # dx += grad wrt x + NNlib.∇conv_data!(dx, dy, cache_w, cdims; alpha=1, beta=1, kwargs...) + end + if !(typeof(y) <: Const) + # dw += grad wrt w + NNlib.∇conv_filter!(dw, cache_x, dy, cdims; alpha=1, beta=1, kwargs...) + end + end + + return (nothing, nothing, nothing, nothing) +end \ No newline at end of file diff --git a/src/NNlib.jl b/src/NNlib.jl index 8450a0261..2e39f9448 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,4 +123,12 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") +function __init__() + @static if !isdefined(Base, :get_extension) + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/NNlibEnzymeExt.jl") + end + end +end + end # module NNlib From 113c47b0c2337e04e61abd774ab3611e78237316 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 02:52:38 -0500 Subject: [PATCH 02/15] Fix --- ext/NNlibEnzymeExt.jl | 88 ------------------------------------------- src/NNlib.jl | 8 ++-- 2 files changed, 3 insertions(+), 93 deletions(-) delete mode 100644 ext/NNlibEnzymeExt.jl diff --git a/ext/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt.jl deleted file mode 100644 index 16d3dd809..000000000 --- a/ext/NNlibEnzymeExt.jl +++ /dev/null @@ -1,88 +0,0 @@ -module NNlibEnzymeExt - -using NNlib -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - -using Enzyme - -using EnzymeCore - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims::Const; kwargs...) -) where {OutType, RT} - - @assert !(OutType <: Const) - if OutType <: Duplicated || OutType <: DuplicatedNoNeed - func.val(y.val, x.val, y.val, cdims.val; kwargs...) - end - - dres = if EnzymeRules.width(config) == 1 - func.val(prob.dval, alg.val; kwargs...) - else - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - func.val(prob.dval[i], alg.val; kwargs...) - end - end - - primal = if EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(y.val, y.dval, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims::Const; kwargs...) where {RT} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: Const) - if !EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: Const) - if !EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: Const) ? nothing : x.dval - dws = (typeof(w) <: Const) ? nothing : w.dval - - if EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in (dys, dxs, dws) - if !(typeof(x) <: Const) - # dx += grad wrt x - NNlib.∇conv_data!(dx, dy, cache_w, cdims; alpha=1, beta=1, kwargs...) - end - if !(typeof(y) <: Const) - # dw += grad wrt w - NNlib.∇conv_filter!(dw, cache_x, dy, cdims; alpha=1, beta=1, kwargs...) - end - end - - return (nothing, nothing, nothing, nothing) -end \ No newline at end of file diff --git a/src/NNlib.jl b/src/NNlib.jl index 2e39f9448..fd9249176 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,11 +123,9 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -function __init__() - @static if !isdefined(Base, :get_extension) - @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/NNlibEnzymeExt.jl") - end +@init @static if !isdefined(Base, :get_extension) + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/NNlibEnzymeExt/NNlibEnzymeExt.jl") end end From 50552324a61e2b0a62c2f5293c8f31dca008ce0b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 02:53:59 -0500 Subject: [PATCH 03/15] Add missing file --- ext/NNlibEnzymeExt/NNlibEnzymeExt.jl | 79 ++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 ext/NNlibEnzymeExt/NNlibEnzymeExt.jl diff --git a/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl new file mode 100644 index 000000000..874764374 --- /dev/null +++ b/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl @@ -0,0 +1,79 @@ +module NNlibEnzymeExt + +using NNlib +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + +using EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} + + @assert !(OutType <: Const) + if OutType <: Duplicated || OutType <: DuplicatedNoNeed + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: Const) ? dys : x.dval + dws = (typeof(w) <: Const) ? dys : w.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(x) <: Const) && dx !== x + # dx += grad wrt x + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + if !(typeof(w) <: Const) && dw !== w + # dw += grad wrt w + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + dy .= 0 + end + + return (nothing, nothing, nothing, nothing) +end + +end \ No newline at end of file From 8bc816dbc9e33402e1d484ac8248fd79c5bb9178 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 11:00:56 -0500 Subject: [PATCH 04/15] Change to enzymecore ext --- ext/NNlibEnzymeExt/NNlibEnzymeExt.jl | 79 ---------------------------- src/NNlib.jl | 4 +- 2 files changed, 2 insertions(+), 81 deletions(-) delete mode 100644 ext/NNlibEnzymeExt/NNlibEnzymeExt.jl diff --git a/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl deleted file mode 100644 index 874764374..000000000 --- a/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl +++ /dev/null @@ -1,79 +0,0 @@ -module NNlibEnzymeExt - -using NNlib -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - -using EnzymeCore - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - - @assert !(OutType <: Const) - if OutType <: Duplicated || OutType <: DuplicatedNoNeed - func.val(y.val, x.val, w.val, cdims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: Const) ? dys : x.dval - dws = (typeof(w) <: Const) ? dys : w.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: Const) && dx !== x - # dx += grad wrt x - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - if !(typeof(w) <: Const) && dw !== w - # dw += grad wrt w - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - dy .= 0 - end - - return (nothing, nothing, nothing, nothing) -end - -end \ No newline at end of file diff --git a/src/NNlib.jl b/src/NNlib.jl index fd9249176..c4ad18750 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -124,8 +124,8 @@ include("impl/pooling_direct.jl") include("deprecations.jl") @init @static if !isdefined(Base, :get_extension) - @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/NNlibEnzymeExt/NNlibEnzymeExt.jl") + @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin + include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl") end end From 330c2b12334b4d9849c956f612dc1e2d9ccd3ee2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:15:49 -0500 Subject: [PATCH 05/15] attempt fix --- src/NNlib.jl | 6 +----- test/conv.jl | 2 +- test/test_utils.jl | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/NNlib.jl b/src/NNlib.jl index c4ad18750..14ba2c70f 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,10 +123,6 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -@init @static if !isdefined(Base, :get_extension) - @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin - include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl") - end -end +include("enzyme.jl") end # module NNlib diff --git a/test/conv.jl b/test/conv.jl index dc3fc57f5..404d1a63e 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -870,7 +870,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) diff --git a/test/test_utils.jl b/test/test_utils.jl index 16b3998dc..598f0f66b 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -12,7 +12,7 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly """ function gradtest( f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), - check_rrule = false, fdm = :central, check_broadcast = false, + check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false, skip = false, broken = false, ) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 @@ -20,6 +20,22 @@ function gradtest( if check_rrule test_rrule(f, xs...; fkwargs = fkwargs) end + if check_enzyme_rrule + if len(xs) == 2 + for Tret in (Const, Active), + Tx in (Const, Duplicated, BatchDuplicated), + Ty in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx, Ty) || continue + + test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) + end + else + throw(AssertionError("Unsupported arg count for testing")) + end + + EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs) + end if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args") From 97991337c99388ec614fe931e6066d0396ce9c64 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:32:28 -0500 Subject: [PATCH 06/15] Add missing file --- src/enzyme.jl | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/enzyme.jl diff --git a/src/enzyme.jl b/src/enzyme.jl new file mode 100644 index 000000000..8aaf38704 --- /dev/null +++ b/src/enzyme.jl @@ -0,0 +1,72 @@ +import EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(x) <: EnzymeCore.Const) && dx !== x + # dx += grad wrt x + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w + # dw += grad wrt w + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + dy .= 0 + end + + return (nothing, nothing, nothing, nothing) +end \ No newline at end of file From b8f44933bfec84c5cbf7bd119711092e7516979e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:55:23 -0500 Subject: [PATCH 07/15] Also add gather --- src/enzyme.jl | 63 +++++++++++++++++++++++++++++++++++++++++++--- test/conv.jl | 2 +- test/gather.jl | 11 ++++++++ test/test_utils.jl | 10 ++++---- 4 files changed, 76 insertions(+), 10 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 8aaf38704..7bf15f997 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -57,16 +57,71 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: EnzymeCore.Const) && dx !== x - # dx += grad wrt x + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w - # dw += grad wrt w + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end dy .= 0 end + return (nothing, nothing, nothing, nothing) +end + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + func.val(dst.val, src.val, idx.val) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val + src_size = size(src.val) + NNlib.∇gather_src(ddst, src_size, cache_idx) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w + ddst .= 0 + end + end + return (nothing, nothing, nothing, nothing) end \ No newline at end of file diff --git a/test/conv.jl b/test/conv.jl index 404d1a63e..2400595b4 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -870,7 +870,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) diff --git a/test/gather.jl b/test/gather.jl index 92e3bfb7d..359d772b7 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -152,6 +152,17 @@ function gather_testsuite(Backend) Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) + + if Backend == CPU + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end end @static if Test_Enzyme diff --git a/test/test_utils.jl b/test/test_utils.jl index 598f0f66b..2e4f19d5f 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -22,13 +22,13 @@ function gradtest( end if check_enzyme_rrule if len(xs) == 2 - for Tret in (Const, Active), - Tx in (Const, Duplicated, BatchDuplicated), - Ty in (Const, Duplicated, BatchDuplicated) + for Tret in (EnzymeCore.Const, EnzymeCore.Active), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - are_activities_compatible(Tret, Tx, Ty) || continue + EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue - test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) + EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) end else throw(AssertionError("Unsupported arg count for testing")) From e6e98b4a9dfc7b6141cd6414b5469e4864f96d1f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 16:48:54 -0500 Subject: [PATCH 08/15] Additional functions, tests, and fixes --- src/enzyme.jl | 85 +++++++++++++++++++++++++++++++++++++++++----- test/conv.jl | 2 +- test/gather.jl | 21 ++++++++---- test/test_utils.jl | 18 +--------- 4 files changed, 92 insertions(+), 34 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 7bf15f997..d466f05c0 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -1,9 +1,12 @@ import EnzymeCore -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} +for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) + @eval begin + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -29,7 +32,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) @@ -71,11 +74,13 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN return (nothing, nothing, nothing, nothing) end +end +end function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(dst.val, src.val, idx.val) end @@ -114,14 +119,76 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val - src_size = size(src.val) - NNlib.∇gather_src(ddst, src_size, cache_idx) + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && + !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val ddst .= 0 end end + return (nothing, nothing, nothing) +end + + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(op.val, dst.val, src.val, idx.val) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && + !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end + end + end + return (nothing, nothing, nothing, nothing) -end \ No newline at end of file +end + + + diff --git a/test/conv.jl b/test/conv.jl index 2400595b4..dc3fc57f5 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -870,7 +870,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) diff --git a/test/gather.jl b/test/gather.jl index 359d772b7..d143220e9 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -152,16 +152,23 @@ function gather_testsuite(Backend) Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) + end - if Backend == CPU - for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated), - Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + @testset "EnzymeRules: gather! gradient for scalar index" begin + src = device(Float64[3, 4, 5, 6, 7]) + idx = device([ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + dst = gather(src, idx) + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) - end + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index 2e4f19d5f..16b3998dc 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -12,7 +12,7 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly """ function gradtest( f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), - check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false, + check_rrule = false, fdm = :central, check_broadcast = false, skip = false, broken = false, ) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 @@ -20,22 +20,6 @@ function gradtest( if check_rrule test_rrule(f, xs...; fkwargs = fkwargs) end - if check_enzyme_rrule - if len(xs) == 2 - for Tret in (EnzymeCore.Const, EnzymeCore.Active), - Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - - EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue - - EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) - end - else - throw(AssertionError("Unsupported arg count for testing")) - end - - EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs) - end if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args") From c32de740c53b016eb2440152cf247c5eb1fe2389 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 19:51:24 -0500 Subject: [PATCH 09/15] Add pooling --- src/enzyme.jl | 161 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 131 insertions(+), 30 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index d466f05c0..a80f0f932 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -5,7 +5,6 @@ for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - @assert !(OutType <: EnzymeCore.Const) if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -22,10 +21,16 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ end # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + && !(typeof(w) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(w.val) : nothing cache = (cache_x, cache_w) @@ -36,14 +41,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, : cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: EnzymeCore.Const) + if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[3] cache_x = x.val end end # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_w = w.val end @@ -60,15 +65,19 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, : end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - # dx += grad wrt x.val - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val - # dw += grad wrt w.val - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + + dy .= 0 end - dy .= 0 end return (nothing, nothing, nothing, nothing) @@ -79,7 +88,6 @@ end function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - @assert !(OutType <: EnzymeCore.Const) if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(dst.val, src.val, idx.val) end @@ -96,7 +104,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ end # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end @@ -104,7 +115,7 @@ end function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_idx = idx.val end @@ -119,11 +130,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && - !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - NNlib.scatter!(+, dsrc, ddst, cache_idx) - end if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) + end + ddst .= 0 end end @@ -152,7 +164,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ end # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end @@ -160,7 +175,7 @@ end function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_idx = idx.val end @@ -175,15 +190,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && - !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if eltype(typeof(op)) == typeof(+) - dsrc .+= NNlib.gather(ddst, cache_idx) - else - @assert eltype(typeof(op)) == typeof(-) - dsrc .-= NNlib.gather(ddst, cache_idx) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end end + + ddst .= 0 + end end @@ -192,3 +212,84 @@ end +for pool in [:maxpool, :meanpool, :lpnormpool] + pool! = Symbol(pool, :!) + ∇pool = Symbol(:∇, pool) + + @eval begin + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, dims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(y.val) : nothing + + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + cache = (cache_y, cache_x) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} + cache_y, cache_x = cache + + # Don't cache y if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[2] + cache_y = y.val + end + end + + # Don't cache x if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + end + + for (dy, dx, dw) in zip(dys, dxs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing) +end + +end +end + + From 0ccb6c97c0dece05935cb1ab39e5396625937564 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 21:16:29 -0500 Subject: [PATCH 10/15] Add dropout --- src/enzyme.jl | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/src/enzyme.jl b/src/enzyme.jl index a80f0f932..9eb391616 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -292,4 +292,75 @@ end end end +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + keep = if dims.val isa Colon + similar(dst.val, T, size(dst.val)) + else + similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) + end + rand!(rng.val, keep) + + keep = keep .> p.val + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + dst.val .= (keep .* val) .* src.val + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const + keep = nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + + ddsts = dst.dval + dsrcs = src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + dsrc .+= (keep .* val) .* ddst + end + + ddst .= 0 + end + end + + dp = if typeof(p) <: EnzymeCore.Active + typeof(p.val)(0) + else + nothing + end + + return (nothing, nothing, nothing, dp, nothing) +end From 72a996371a0ec3b1605ea74fa68ba8a86d9f61d2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 21:45:03 -0500 Subject: [PATCH 11/15] Fix scatter bug --- src/enzyme.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 9eb391616..773066cff 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -202,8 +202,6 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end end - ddst .= 0 - end end From 6e64553ec993178d38f512ce291a8f83c9f3b77c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 22:00:00 -0500 Subject: [PATCH 12/15] fix pool --- src/enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 773066cff..0e1b75d8a 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -273,7 +273,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($p dxs = (dxs,) end - for (dy, dx, dw) in zip(dys, dxs) + for (dy, dx) in zip(dys, dxs) if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val From 256a4fbc75b9fd04f6aab307b7c039a6039ca0c5 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 22:44:15 -0500 Subject: [PATCH 13/15] More fixups --- src/enzyme.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 0e1b75d8a..1d4d0c20c 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ keep = nothing end - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) end @@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN val = convert(T, 1/(1-p.val)) ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) From 597bcd730334bdceefb753979157e3f4579f966c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 22:45:06 -0500 Subject: [PATCH 14/15] Add batchnorm derivatives --- ext/NNlibCUDACUDNNExt/batchnorm.jl | 147 +++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/ext/NNlibCUDACUDNNExt/batchnorm.jl b/ext/NNlibCUDACUDNNExt/batchnorm.jl index 2c38f009e..2c83d92c2 100644 --- a/ext/NNlibCUDACUDNNExt/batchnorm.jl +++ b/ext/NNlibCUDACUDNNExt/batchnorm.jl @@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardTraining import NNlib: batchnorm, ∇batchnorm +using EnzymeCore + # TODO: replace with new cudnn normalization interface # https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl @@ -153,3 +155,148 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) end + + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT}, + y::OutType, + g, + b, + x, + running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, b.val, x.val, running_mean.val, running_var.val, momentum.val; kws...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_g = nothing + cache_x = nothing + cache_running_mean = nothing + cache_running_var = nothing + + if !(typeof(y) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) + || !(typeof(g) <: EnzymeCore.Const) + || !(typeof(b) <: EnzymeCore.Const) + + if EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_g = copy(g.val) + end + if EnzymeCore.EnzymeRules.overwritten(config)[5] + cache_x = copy(x.val) + end + if EnzymeCore.EnzymeRules.overwritten(config)[6] + cache_running_mean = copy(running_mean.val) + end + if EnzymeCore.EnzymeRules.overwritten(config)[7] + cache_running_var = copy(running_var.val) + end + + end + end + + cache = (cache_g, cache_x, cache_running_mean, cache_running_var) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT}, + cache, + y::OutType, g, b, x, running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT} + + cache_g, cache_x, cache_running_mean, cache_running_var = cache + + if !(typeof(y) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) + || !(typeof(g) <: EnzymeCore.Const) + || !(typeof(b) <: EnzymeCore.Const) + + if EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_g = g.val + end + if EnzymeCore.EnzymeRules.overwritten(config)[5] + cache_x = x.val + end + if EnzymeCore.EnzymeRules.overwritten(config)[6] + cache_running_mean = running_mean.val + end + if EnzymeCore.EnzymeRules.overwritten(config)[7] + cache_running_var = running_var.val + end + + end + end + + dys = y.dval + dgs = (typeof(g) <: EnzymeCore.Const) ? dys : g.dval + dbs = (typeof(b) <: EnzymeCore.Const) ? dbs : b.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dxs : x.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dgs = (dgs,) + dbs = (dbs,) + end + + for (dy, dx, dg, db) in zip(dys, dxs, dgs, dbs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !((typeof(x) <: EnzymeCore.Const) || dx === x.val) + || !((typeof(g) <: EnzymeCore.Const) || dg === g.val) + || !((typeof(b) <: EnzymeCore.Const) || db === b.val) + + # dx values + alpha = T(1) + beta = T(1) + + # dx = alpha * newVal + beta old(dx) + # if x is constant, we can use zero for both + # otherwise we want to do dx += newVal, aka alpha=beta=1 + if x <: EnzymeCore.Const + alpha = T(0) + beta = T(0) + dx = similar(x.val) + end + + # dg / db values + alpha = T(1) + beta = T(1) + + if g <: EnzymeCore.Const && b <: EnzymeCore.Const + dalpha = T(0) + dbeta = T(0) + end + + if g <: EnzymeCore.Const + dg = similar(g.val) + end + + if b <: EnzymeCore.Const + db = similar(b.val) + end + + cudnnBNBackward!(dg, cache_g, db, dx, cache_x, dy, + cache_running_mean, cache_running_var, + momentum.val; alpha, beta, dalpha, dbeta; kw...) + + end + + dy .= 0 + + end + end + + return (nothing, nothing, nothing, nothing, nothing, nothing, nothing) +end From 1933a9afe32ff020ba550f34ee882c1a557f5f33 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 8 Oct 2023 13:38:00 -0700 Subject: [PATCH 15/15] rebase --- ext/NNlibCUDACUDNNExt/batchnorm.jl | 8 +- src/NNlib.jl | 2 - src/enzyme.jl | 358 ----------------------------- test/gather.jl | 18 -- 4 files changed, 2 insertions(+), 384 deletions(-) delete mode 100644 src/enzyme.jl diff --git a/ext/NNlibCUDACUDNNExt/batchnorm.jl b/ext/NNlibCUDACUDNNExt/batchnorm.jl index 2c83d92c2..4b7793b91 100644 --- a/ext/NNlibCUDACUDNNExt/batchnorm.jl +++ b/ext/NNlibCUDACUDNNExt/batchnorm.jl @@ -186,9 +186,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ cache_running_var = nothing if !(typeof(y) <: EnzymeCore.Const) - if !(typeof(x) <: EnzymeCore.Const) - || !(typeof(g) <: EnzymeCore.Const) - || !(typeof(b) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const) if EnzymeCore.EnzymeRules.overwritten(config)[3] cache_g = copy(g.val) @@ -218,9 +216,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cu cache_g, cache_x, cache_running_mean, cache_running_var = cache if !(typeof(y) <: EnzymeCore.Const) - if !(typeof(x) <: EnzymeCore.Const) - || !(typeof(g) <: EnzymeCore.Const) - || !(typeof(b) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const) if EnzymeCore.EnzymeRules.overwritten(config)[3] cache_g = g.val diff --git a/src/NNlib.jl b/src/NNlib.jl index 14ba2c70f..8450a0261 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,6 +123,4 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -include("enzyme.jl") - end # module NNlib diff --git a/src/enzyme.jl b/src/enzyme.jl deleted file mode 100644 index 1d4d0c20c..000000000 --- a/src/enzyme.jl +++ /dev/null @@ -1,358 +0,0 @@ -import EnzymeCore - -for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) - @eval begin - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(y.val, x.val, w.val, cdims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] - && !(typeof(w) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val - - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - # dx += grad wrt x.val - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val - # dw += grad wrt w.val - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - - dy .= 0 - end - end - - return (nothing, nothing, nothing, nothing) -end - -end -end - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(dst.val, src.val, idx.val) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - dst.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - dst.dval - else - nothing - end - - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_idx = idx.val - end - end - - ddsts = dst.dval - dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - ddsts = (ddsts,) - dsrcs = (dsrcs,) - end - - for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val - NNlib.scatter!(+, dsrc, ddst, cache_idx) - end - - ddst .= 0 - end - end - - return (nothing, nothing, nothing) -end - - - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(op.val, dst.val, src.val, idx.val) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - dst.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - dst.dval - else - nothing - end - - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_idx = idx.val - end - end - - ddsts = dst.dval - dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - ddsts = (ddsts,) - dsrcs = (dsrcs,) - end - - for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val - - if eltype(typeof(op)) == typeof(+) - dsrc .+= NNlib.gather(ddst, cache_idx) - else - @assert eltype(typeof(op)) == typeof(-) - dsrc .-= NNlib.gather(ddst, cache_idx) - end - end - - end - end - - return (nothing, nothing, nothing, nothing) -end - - - -for pool in [:maxpool, :meanpool, :lpnormpool] - pool! = Symbol(pool, :!) - ∇pool = Symbol(:∇, pool) - - @eval begin - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(y.val, x.val, dims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(y.val) : nothing - - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(x.val) : nothing - - cache = (cache_y, cache_x) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} - cache_y, cache_x = cache - - # Don't cache y if not overwritten - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[2] - cache_y = y.val - end - end - - # Don't cache x if not overwritten - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - end - - for (dy, dx) in zip(dys, dxs) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) - end - - dy .= 0 - end - end - - return (nothing, nothing, nothing) -end - -end -end - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} - - T = float(real(eltype(dst.val))) - val = convert(T, 1/(1-p.val)) - keep = if dims.val isa Colon - similar(dst.val, T, size(dst.val)) - else - similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) - end - rand!(rng.val, keep) - - keep = keep .> p.val - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - dst.val .= (keep .* val) .* src.val - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - dst.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - dst.dval - else - nothing - end - - if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const - keep = nothing - end - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} - T = float(real(eltype(dst.val))) - val = convert(T, 1/(1-p.val)) - - ddsts = dst.dval - dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - ddsts = (ddsts,) - dsrcs = (dsrcs,) - end - - for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val - dsrc .+= (keep .* val) .* ddst - end - - ddst .= 0 - end - end - - dp = if typeof(p) <: EnzymeCore.Active - typeof(p.val)(0) - else - nothing - end - - return (nothing, nothing, nothing, dp, nothing) -end diff --git a/test/gather.jl b/test/gather.jl index d143220e9..92e3bfb7d 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -154,24 +154,6 @@ function gather_testsuite(Backend) gradtest_fn((s, i) -> gather(s, i), src, idx) end - - @testset "EnzymeRules: gather! gradient for scalar index" begin - src = device(Float64[3, 4, 5, 6, 7]) - idx = device([ - 1 2 3 4; - 4 2 1 3; - 3 5 5 3]) - dst = gather(src, idx) - for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue - - EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) - end - end - @static if Test_Enzyme @testset "EnzymeRules: gather! gradient for scalar index" begin