@@ -43,10 +43,28 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
43
43
A. x,literal_ArrayPartition_x_adjoint
44
44
end
45
45
46
+ #=
47
+
48
+ # Define a new species of projection operator for this type:
49
+ ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}()
50
+
51
+ # Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
52
+ (::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
53
+ # Gradient from broadcasting will be another AbstractArray
54
+ (::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
55
+
56
+ But this may not be necessary?
57
+
58
+ =#
59
+
60
+
61
+ # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
62
+ # definition first, and finds its own before finding those.
63
+
46
64
ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} )
47
65
function AbstractVectorOfArray_getindex_adjoint (Δ)
48
66
Δ′ = [ (i == j ? Δ : zero (x)) for (x,j) in zip (VA. u, 1 : length (VA))]
49
- (Δ′ ,nothing )
67
+ (VectorOfArray (Δ′) ,nothing )
50
68
end
51
69
VA[i],AbstractVectorOfArray_getindex_adjoint
52
70
end
@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
55
73
function AbstractVectorOfArray_getindex_adjoint (Δ)
56
74
Δ′ = zero (VA)
57
75
Δ′[i,j... ] = Δ
58
- (Δ′, i,map (_ -> nothing , j)... )
76
+ @show Δ′
77
+ # (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
78
+ (Δ′, nothing , map (_ -> nothing , j)... )
79
+ # (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
59
80
end
60
81
VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
61
82
end
62
-
63
83
ZygoteRules. @adjoint function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
64
84
function ArrayPartition_adjoint (_y)
65
85
y = Array (_y)
0 commit comments