Skip to content

Commit f2384d1

Browse files
committed
ProjectTo after OneElement and disable some excessive inferred requirments
1 parent 049b287 commit f2384d1

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

src/rulesets/Base/indexing.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N}
8585
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
8686
# leaving just Int / AbstractVector of Int
8787
plain_inds = Base.to_indices(x, inds)
88-
if plain_inds isa NTuple{N, Int} && T<:Number
88+
dx = if plain_inds isa NTuple{N, Int} && T<:Number
8989
# scalar indexing
90-
return OneElement(dy, plain_inds, axes(x))
90+
OneElement(dy, plain_inds, axes(x))
9191
else # some from slicing (potentially noncontigous)
9292
dx = _setindex_zero(x, dy, plain_inds...)
9393
∇getindex!(dx, dy, plain_inds...)
94-
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
9594
end
95+
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
9696
end
9797
∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z
9898

@@ -110,7 +110,7 @@ end
110110
Base.size(A::OneElement) = map(length, A.axes)
111111
Base.axes(A::OneElement) = A.axes
112112
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
113-
# TODO: should we teach ProjectTo that OneElement is more structurally sparse than anything it intersects nonstructural zeros with?
113+
114114

115115
"""
116116
_setindex_zero(x, dy, inds...)

test/rulesets/Base/array.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,15 @@ end
358358
@test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent()))
359359
@test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,))
360360
# These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
361+
# or by https://github.com/JuliaLang/julia/pull/48404
361362

362363
# Reverse
363364
test_rrule(findmin, rand(10), output_tangent = (rand(), false))
364365
test_rrule(findmax, rand(10), output_tangent = (rand(), false))
365366
test_rrule(findmin, rand(5,3))
366367
test_rrule(findmax, rand(5,3))
367-
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
368-
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
368+
@test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
369+
@test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
369370

370371
# Reverse with dims:
371372
@test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2])
@@ -398,8 +399,8 @@ end
398399
@test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2])
399400

400401
# Structured matrix -- NB the minimum is a structral zero here
401-
@test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Union{Diagonal, OneElement}
402-
@test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa Union{UpperTriangular{Float64}, ChainRules.OneElement{Float64}} # must be at least as structured
402+
@test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal
403+
@test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64}
403404
@test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool)
404405
end
405406

0 commit comments

Comments
 (0)