Skip to content

Commit d73f176

Browse files
committed
Excise getindex adjoint
We have a better rule in Chainrules now
1 parent d39ab59 commit d73f176

File tree

1 file changed

+2
-40
lines changed

1 file changed

+2
-40
lines changed

src/lib/array.jl

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,48 +21,10 @@ end
2121
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
2222
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)
2323

24-
@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)
25-
26-
@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)
27-
28-
∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
29-
if inds isa NTuple{N,Int} && T <: Number
30-
dx = OneElement(dy, inds, axes(x))
31-
elseif inds isa NTuple{<:Any, Integer}
32-
dx = _zero(x, typeof(dy))
33-
dx[inds...] = dy
34-
else
35-
dx = _zero(x, eltype(dy))
36-
dxv = view(dx, inds...)
37-
dxv .= accum.(dxv, _droplike(dy, dxv))
38-
end
39-
return (_project(x, dx), map(_->nothing, inds)...)
40-
end
41-
42-
"""
43-
OneElement(val, ind, axes) <: AbstractArray
44-
45-
Extremely simple `struct` used for the gradient of scalar `getindex`.
46-
"""
47-
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
48-
val::T
49-
ind::I
50-
axes::A
51-
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
52-
end
53-
Base.size(A::OneElement) = map(length, A.axes)
54-
Base.axes(A::OneElement) = A.axes
55-
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
56-
57-
5824
_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
5925
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
6026
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)
6127

62-
_droplike(dy, dxv) = dy
63-
_droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::AbstractVector) =
64-
dropdims(dy; dims=2)
65-
6628
@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)
6729

6830
_throw_mutation_error(f, args...) = error("""
@@ -83,7 +45,7 @@ Possible fixes:
8345
_ -> _throw_mutation_error(copyto!, xs)
8446

8547
for f in [push!, pop!, pushfirst!, popfirst!]
86-
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
48+
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
8749
_ -> _throw_mutation_error($f, x)
8850
end
8951

@@ -310,7 +272,7 @@ end
310272
# =============
311273

312274
@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),)
313-
@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)
275+
@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)
314276

315277
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
316278
m1, n1 = size(mat1)

0 commit comments

Comments
 (0)