From 0f351036f601b2abe25eed9dda47dc95557970bc Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 18 May 2023 14:42:34 +0800 Subject: [PATCH 1/6] =Bring over OneElement for scalar getindex --- src/rulesets/Base/indexing.jl | 29 +++++++++++++++++++++++++---- test/rulesets/Base/array.jl | 4 ++-- test/rulesets/Base/indexing.jl | 2 +- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 0ca102143..d1ffcb136 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -81,16 +81,37 @@ For the `rrule` of `y = x[inds...]`, this function is roughly `setindex(zero(x), dy, inds...)`, returning the array `dx`. Differentiable. Includes `ProjectTo(x)(dx)`. """ -function ∇getindex(x::AbstractArray, dy, inds...) +function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N} # `to_indices` removes any logical indexing, colons, CartesianIndex etc, # leaving just Int / AbstractVector of Int plain_inds = Base.to_indices(x, inds) - dx = _setindex_zero(x, dy, plain_inds...) - ∇getindex!(dx, dy, plain_inds...) - return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules + if plain_inds isa NTuple{N, Int} && T<:Number + # scalar indexing + return OneElement(dy, plain_inds, axes(x)) + else # some from slicing (potentially noncontigous) + dx = _setindex_zero(x, dy, plain_inds...) + ∇getindex!(dx, dy, plain_inds...) + return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules + end end ∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z +""" + OneElement(val, ind, axes) <: AbstractArray + +Extremely simple `struct` used for the gradient of scalar `getindex`. +""" +struct OneElement{T,N,I,A} <: AbstractArray{T,N} + val::T + ind::I + axes::A + OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes) +end +Base.size(A::OneElement) = map(length, A.axes) +Base.axes(A::OneElement) = A.axes +Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) +# TODO: should we teach ProjectTo that OneElement is more structurally sparse than anything it intersects nonstructural zeros with? + """ _setindex_zero(x, dy, inds...) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index c50008430..2baf5952b 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -398,8 +398,8 @@ end @test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2]) # Structured matrix -- NB the minimum is a structral zero here - @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal - @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64} + @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Union{Diagonal, OneElement} + @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa Union{UpperTriangular{Float64}, ChainRules.OneElement{Float64}} # must be at least as structured @test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool) end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 3dbcd0bc9..d131dc876 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -89,7 +89,7 @@ test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2) sgrad = rrule(getindex, Symmetric(rand(3, 3)), 2, 3)[2](1.0)[2] - @test unthunk(sgrad) ≈ [0 0 0; 0 0 1/2; 0 1/2 0] + @test unthunk(sgrad) ≈ [0 0 0; 0 0 1/2; 0 1/2 0] # We are actually getting this wrong now end @testset "getindex(::Array{<:Array})" begin From add78390046a3e8da42393f949ffc23792c0a1d0 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 May 2023 18:35:01 +0800 Subject: [PATCH 2/6] ProjectTo after OneElement and disable some excessive inferred requirments --- src/rulesets/Base/indexing.jl | 8 ++++---- test/rulesets/Base/array.jl | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index d1ffcb136..156db42b9 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -85,14 +85,14 @@ function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N} # `to_indices` removes any logical indexing, colons, CartesianIndex etc, # leaving just Int / AbstractVector of Int plain_inds = Base.to_indices(x, inds) - if plain_inds isa NTuple{N, Int} && T<:Number + dx = if plain_inds isa NTuple{N, Int} && T<:Number # scalar indexing - return OneElement(dy, plain_inds, axes(x)) + OneElement(dy, plain_inds, axes(x)) else # some from slicing (potentially noncontigous) dx = _setindex_zero(x, dy, plain_inds...) ∇getindex!(dx, dy, plain_inds...) - return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end + return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end ∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z @@ -110,7 +110,7 @@ end Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) -# TODO: should we teach ProjectTo that OneElement is more structurally sparse than anything it intersects nonstructural zeros with? + """ _setindex_zero(x, dy, inds...) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 2baf5952b..fa5f0b808 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -358,14 +358,15 @@ end @test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent())) @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) # These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 + # or by https://github.com/JuliaLang/julia/pull/48404 # Reverse test_rrule(findmin, rand(10), output_tangent = (rand(), false)) test_rrule(findmax, rand(10), output_tangent = (rand(), false)) test_rrule(findmin, rand(5,3)) test_rrule(findmax, rand(5,3)) - @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) - @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2]) + @test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) + @test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2]) # Reverse with dims: @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) @@ -398,8 +399,8 @@ end @test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2]) # Structured matrix -- NB the minimum is a structral zero here - @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Union{Diagonal, OneElement} - @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa Union{UpperTriangular{Float64}, ChainRules.OneElement{Float64}} # must be at least as structured + @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal + @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64} @test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool) end From 470e2145defbd81849362038181a80bdd0153366 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 May 2023 20:53:17 +0800 Subject: [PATCH 3/6] Check a few less inference --- test/rulesets/Base/array.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index fa5f0b808..5dead0dba 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -363,8 +363,8 @@ end # Reverse test_rrule(findmin, rand(10), output_tangent = (rand(), false)) test_rrule(findmax, rand(10), output_tangent = (rand(), false)) - test_rrule(findmin, rand(5,3)) - test_rrule(findmax, rand(5,3)) + test_rrule(findmin, rand(5,3); check_inferred=false) + test_rrule(findmax, rand(5,3); check_inferred=false) @test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) @test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2]) @@ -386,7 +386,7 @@ end # Reverse test_rrule(imum, rand(10)) - test_rrule(imum, rand(3,4)) + test_rrule(imum, rand(3,4); check_inferred=false) @gpu test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),)) From 0094209d31145dc256e9e9f3ecac3ff60c253938 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 May 2023 21:44:12 +0800 Subject: [PATCH 4/6] Disable more inference tests --- test/rulesets/Base/indexing.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index d131dc876..8928c55e7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -34,10 +34,10 @@ @testset "single element" begin test_rrule(getindex, x, 2) - test_rrule(getindex, x, 2, 1) - test_rrule(getindex, x, 2, 2) + test_rrule(getindex, x, 2, 1; check_inferred=false) + test_rrule(getindex, x, 2, 2; check_inferred=false) - test_rrule(getindex, x, CartesianIndex(2, 3)) + test_rrule(getindex, x, CartesianIndex(2, 3); check_inferred=false) end @testset "slice/index positions" begin @@ -87,9 +87,9 @@ dgrad = rrule(getindex, Diagonal(rand(3)), 2, :)[2]([1,2,3])[2] @test unthunk(dgrad) ≈ Diagonal([0, 2, 0]) - test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2) + test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2; check_inferred=false) # Infers to Any sgrad = rrule(getindex, Symmetric(rand(3, 3)), 2, 3)[2](1.0)[2] - @test unthunk(sgrad) ≈ [0 0 0; 0 0 1/2; 0 1/2 0] # We are actually getting this wrong now + @test unthunk(sgrad) ≈ [0 0 0; 0 0 1/2; 0 1/2 0] end @testset "getindex(::Array{<:Array})" begin From 77e8b8404a22a176dd7a3f33c4017afe50219c75 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 22 May 2023 13:52:20 +0800 Subject: [PATCH 5/6] optimize + --- src/rulesets/Base/indexing.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 156db42b9..7e1befd14 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -111,6 +111,17 @@ Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) +function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N} + if !ChainRulesCore.is_inplaceable_destination(xs) + xs = collect(xs) + end + xs[oe.ind...] += oe.val + return xs +end + +Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe) +Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe) +Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2) """ _setindex_zero(x, dy, inds...) From 488dca634dfe3fd726f94cfda0b910dfa271faf2 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 12 Jun 2023 11:51:27 +0800 Subject: [PATCH 6/6] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 254f1ad7c..7a9746437 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.50.0" +version = "1.51.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"