Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient definitions & supertypes #168

Merged
merged 1 commit into from
Sep 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,28 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
A.x,literal_ArrayPartition_x_adjoint
end

#=

# Define a new species of projection operator for this type:
ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}()

# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
(::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
# Gradient from broadcasting will be another AbstractArray
(::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx

But this may not be necessary?

=#


# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
# definition first, and finds its own before finding those.

ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
(Δ′,nothing)
(VectorOfArray(Δ′),nothing)
end
VA[i],AbstractVectorOfArray_getindex_adjoint
end
Expand All @@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = zero(VA)
Δ′[i,j...] = Δ
(Δ′, i,map(_ -> nothing, j)...)
@show Δ′
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@show Δ′

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, very WIP!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no worries, still worth the downstream tests

# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's a bug

(Δ′, nothing, map(_ -> nothing, j)...)
# (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would it not be this one?

end
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
end

ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
function ArrayPartition_adjoint(_y)
y = Array(_y)
Expand Down