From 0cab7a5f28f559fdba3ecc5032189af172ef203f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 8 Feb 2024 22:19:09 +0100 Subject: [PATCH 1/5] Remove third argument to `similar` --- src/rulesets/Base/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 830571ecd..d19619c39 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -130,7 +130,7 @@ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't allow `eltype(dy)`, nor does it work for many structured matrices. """ _setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors From 2aa8d26d91d7a50e7706ca4e4b5e34dd37b09c79 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 8 Feb 2024 22:52:05 +0100 Subject: [PATCH 2/5] Add test, fix existing tests --- Project.toml | 2 ++ src/ChainRules.jl | 1 + src/rulesets/Base/indexing.jl | 3 ++- test/rulesets/Base/indexing.jl | 7 +++++++ test/runtests.jl | 1 + 5 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fedd2a600..c8e05e63b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.61.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -20,6 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" +AxisArrays = "0.4.7" ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..6eb6128c9 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,6 +1,7 @@ module ChainRules using Adapt: adapt +using AxisArrays: AxisArray, AxisArrays using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index d19619c39..681c25203 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -130,7 +130,8 @@ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't allow `eltype(dy)`, nor does it work for many structured matrices. """ _setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..423a7afe7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -128,6 +128,13 @@ end @test dx23[3] == dxfix[3] end + @testset "getindex(::AxisArray{<:Number})" begin + X = randn((2, 3)) + A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z]) + dA, back = rrule(getindex, A, [:a], [:x, :z]) + unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0] + end + @testset "second derivatives: ∇getindex" begin @eval using ChainRules: ∇getindex # Forward, scalar result diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..81bb4ee22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize using Adapt +using AxisArrays using Base.Broadcast: broadcastable using ChainRules using ChainRules: stack From 1d779d27e409658da883d6731be55809a85e0a31 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 8 Feb 2024 22:55:02 +0100 Subject: [PATCH 3/5] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c8e05e63b..11af953aa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.61.0" +version = "1.61.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 348fb86e6d0af0ddc7679e88f21a65513dedcc7f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 8 Feb 2024 23:49:25 +0100 Subject: [PATCH 4/5] Use two-arg `similar` in `_setindex_zero` --- Project.toml | 4 ++-- src/ChainRules.jl | 1 - src/rulesets/Base/indexing.jl | 12 ++++-------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 11af953aa..124a1fbaf 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "1.61.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -43,6 +42,7 @@ SuiteSparse = "1" julia = "1.6" [extras] +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -52,4 +52,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] +test = ["AxisArrays", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6eb6128c9..6d33a22e7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,7 +1,6 @@ module ChainRules using Adapt: adapt -using AxisArrays: AxisArray, AxisArrays using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 681c25203..f81ffda59 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -125,22 +125,18 @@ Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2) This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`, and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what `∇getindex` does next. - -It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't -allow `eltype(dy)`, nor does it work for many structured matrices. """ -_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) -_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors T = Union{typeof(dy), ZeroTangent} - return fill!(similar(x, T, axes(x)), ZeroTangent()) + return fill!(similar(x, T), ZeroTangent()) end function _setindex_zero(x::AbstractArray, dy, inds...) T = Union{eltype(dy), ZeroTangent} - return fill!(similar(x, T, axes(x)), ZeroTangent()) + return fill!(similar(x, T), ZeroTangent()) end ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) From 9ce63a54d39c1052fe95a6537f628106a10820ef Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 9 Feb 2024 00:06:48 +0100 Subject: [PATCH 5/5] Update src/rulesets/Base/indexing.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/Base/indexing.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index f81ffda59..ea081c99c 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -126,8 +126,11 @@ This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutabl and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what `∇getindex` does next. """ -_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = + fill!(similar(x, typeof(dy)), false) +function _setindex_zero(x::AbstractArray{<:Number}, dy, inds...) + return fill!(similar(x, eltype(dy)), false) +end function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors