From 692e4825eff9e7ed77f7de64ef087efdff225f16 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 20 Dec 2022 21:47:05 -0500 Subject: [PATCH 1/5] non_differentiable _denom --- src/rulesets/Statistics/statistics.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 08be133fd..6dc00faae 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -6,6 +6,8 @@ _denom(x, dims) = size(x, dims) _denom(x, dims::Colon) = length(x) _denom(x, dims::Union{Tuple, AbstractArray}) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) +@non_differentiable _denom(::Any, ::Any) # else Zygote tries to AD unique(::Tuple) + function rrule(::typeof(mean), x::AbstractArray{<:Union{Real,Complex,AbstractArray}}; dims=:) y_sum, sum_pullback = rrule(sum, x; dims) n = _denom(x, dims) From c5794107c0eb3f6051d905ae187c04248f62fdcc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 20 Dec 2022 21:49:12 -0500 Subject: [PATCH 2/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 51337c18f..b378a42bd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.46.0" +version = "1.46.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 7bb7d986ce03f1c6ed62121b1c26bb0dda9f46c8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 20 Dec 2022 22:04:31 -0500 Subject: [PATCH 3/5] =?UTF-8?q?=E2=88=87getindex(::AbstractZero)=20paths?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rulesets/Base/indexing.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 6136fa6a2..beb042cf3 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -61,10 +61,9 @@ function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) end function rrule(::typeof(getindex), x::AbstractArray, inds...) - function getindex_pullback(dy) - nots = map(Returns(NoTangent()), inds) - return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) - end + nots = map(Returns(NoTangent()), inds) + getindex_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) + getindex_pullback(z::AbstractZero) = (NoTangent(), z, nots...) return x[inds...], getindex_pullback end @@ -90,6 +89,7 @@ function ∇getindex(x::AbstractArray, dy, inds...) ∇getindex!(dx, dy, plain_inds...) return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end +∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z """ _setindex_zero(x, dy, inds...) @@ -191,10 +191,9 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) end function rrule(::typeof(view), x::AbstractArray, inds...) - function view_pullback(dy) - nots = map(Returns(NoTangent()), inds) - return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) - end + nots = map(Returns(NoTangent()), inds) + view_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) + view_pullback(z::AbstractZero) = (NoTangent(), z, nots...) return view(x, inds...), view_pullback end From 532f72443a4a70aa9ce43d4052e0c64c548d90dc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 21 Dec 2022 13:36:44 -0500 Subject: [PATCH 4/5] non_differentiable foreach(f, ()) --- src/rulesets/Base/nondiff.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 22aeb1748..8392f8076 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -189,6 +189,7 @@ @non_differentiable floatmax(::Any) @non_differentiable floatmin(::Any) @non_differentiable flush(::Any) +@non_differentiable foreach(::Any, ::Tuple{}) @non_differentiable gensym(::Symbol) @non_differentiable gensym(::String...) From f33353f6d2c5326a21a31aa78c1c56c1b74f0191 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 5 Jan 2023 23:45:51 -0500 Subject: [PATCH 5/5] also fix summary --- src/rulesets/Base/nondiff.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 8392f8076..2bcac37ca 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -423,6 +423,7 @@ end @non_differentiable supertype(::Any) @non_differentiable Symbol(::Any...) @non_differentiable symlink(::AbstractString, ::AbstractString) +@non_differentiable summary(::Any) @non_differentiable take!(::Base.GenericIOBuffer) @non_differentiable take!(::IOStream) @@ -473,6 +474,7 @@ elseif isdefined(Base, :cumulative_compile_time_ns) end @non_differentiable Base.time_print(::Any...) @non_differentiable Base.OneTo(::Any...) +@non_differentiable Base.array_summary(::Any) @non_differentiable Broadcast.combine_styles(::Any...) @non_differentiable Broadcast.result_style(::Any)