21
21
@adjoint (:: Type{T} )(sz) where {T<: Zeros } = T (sz), Δ-> (nothing ,)
22
22
@adjoint (:: Type{T} )(sz) where {T<: Ones } = T (sz), Δ-> (nothing ,)
23
23
24
- @adjoint getindex (x:: AbstractArray , inds... ) = x[inds... ], ∇getindex (x, inds)
25
-
26
- @adjoint view (x:: AbstractArray , inds... ) = view (x, inds... ), ∇getindex (x, inds)
27
-
28
- ∇getindex (x:: AbstractArray{T,N} , inds) where {T,N} = dy -> begin
29
- if inds isa NTuple{N,Int} && T <: Number
30
- dx = OneElement (dy, inds, axes (x))
31
- elseif inds isa NTuple{<: Any , Integer}
32
- dx = _zero (x, typeof (dy))
33
- dx[inds... ] = dy
34
- else
35
- dx = _zero (x, eltype (dy))
36
- dxv = view (dx, inds... )
37
- dxv .= accum .(dxv, _droplike (dy, dxv))
38
- end
39
- return (_project (x, dx), map (_-> nothing , inds)... )
40
- end
41
-
42
- """
43
- OneElement(val, ind, axes) <: AbstractArray
44
-
45
- Extremely simple `struct` used for the gradient of scalar `getindex`.
46
- """
47
- struct OneElement{T,N,I,A} <: AbstractArray{T,N}
48
- val:: T
49
- ind:: I
50
- axes:: A
51
- OneElement (val:: T , ind:: I , axes:: A ) where {T<: Number , I<: NTuple{N,Int} , A<: NTuple{N,AbstractUnitRange} } where {N} = new {T,N,I,A} (val, ind, axes)
52
- end
53
- Base. size (A:: OneElement ) = map (length, A. axes)
54
- Base. axes (A:: OneElement ) = A. axes
55
- Base. getindex (A:: OneElement{T,N} , i:: Vararg{Int,N} ) where {T,N} = ifelse (i== A. ind, A. val, zero (T))
56
-
57
-
58
24
_zero (xs:: AbstractArray{<:Number} , T:: Type{Nothing} ) = fill! (similar (xs), zero (eltype (xs)))
59
25
_zero (xs:: AbstractArray{<:Number} , T) = fill! (similar (xs, T), false )
60
26
_zero (xs:: AbstractArray , T) = fill! (similar (xs, Union{Nothing, T}), nothing )
61
27
62
- _droplike (dy, dxv) = dy
63
- _droplike (dy:: Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose} , dxv:: AbstractVector ) =
64
- dropdims (dy; dims= 2 )
65
-
66
28
@adjoint getindex (:: Type{T} , xs... ) where {T} = T[xs... ], dy -> (nothing , dy... )
67
29
68
30
_throw_mutation_error (f, args... ) = error ("""
@@ -83,7 +45,7 @@ Possible fixes:
83
45
_ -> _throw_mutation_error (copyto!, xs)
84
46
85
47
for f in [push!, pop!, pushfirst!, popfirst!]
86
- @eval @adjoint! $ f (x:: AbstractVector , ys... ) = $ f (x, ys... ),
48
+ @eval @adjoint! $ f (x:: AbstractVector , ys... ) = $ f (x, ys... ),
87
49
_ -> _throw_mutation_error ($ f, x)
88
50
end
89
51
310
272
# =============
311
273
312
274
@adjoint parent (x:: LinearAlgebra.Adjoint ) = parent (x), ȳ -> (LinearAlgebra. Adjoint (ȳ),)
313
- @adjoint parent (x:: LinearAlgebra.Transpose ) = parent (x), ȳ -> (LinearAlgebra. Transpose (ȳ),)
275
+ @adjoint parent (x:: LinearAlgebra.Transpose ) = parent (x), ȳ -> (LinearAlgebra. Transpose (ȳ),)
314
276
315
277
function _kron (mat1:: AbstractMatrix ,mat2:: AbstractMatrix )
316
278
m1, n1 = size (mat1)
0 commit comments