-
-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient definitions & supertypes #168
Conversation
@show Δ′ | ||
# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug? | ||
(Δ′, nothing, map(_ -> nothing, j)...) | ||
# (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would it not be this one?
@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A | |||
function AbstractVectorOfArray_getindex_adjoint(Δ) | |||
Δ′ = zero(VA) | |||
Δ′[i,j...] = Δ | |||
(Δ′, i,map(_ -> nothing, j)...) | |||
@show Δ′ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@show Δ′ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, very WIP!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh no worries, still worth the downstream tests
Thanks! Running downstream tests to check it out. |
@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A | |||
function AbstractVectorOfArray_getindex_adjoint(Δ) | |||
Δ′ = zero(VA) | |||
Δ′[i,j...] = Δ | |||
(Δ′, i,map(_ -> nothing, j)...) | |||
@show Δ′ | |||
# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's a bug
This is inspired by SciML/SciMLSensitivity.jl#493 (comment) . I wrote many examples there, but the core surprise was that a
VectorOfArrays{T} <: AbstractArray{T,3}
has a gradientVector{Matrix{T}}
which is not a 3-array. If I understood correctly, it would desirable that the result be another VectorOfArray, i.e. somethingAbstractArray{T,3}
which iterates like Vector{Matrix}.One way to do that would be using ProjectTo. But in fact I believe the Vector{Matrix} is being explicitly generated by a rule defined here, so perhaps it is simpler just to alter that rule to wrap this?
RFC, WIP, needs tests... and needs thought about other cases:
VectorOfArray(collect(eachslice(Δ)))
, to return a comparable object, i.e. one which broadcasts one way and iterates a different way. This would needProjectTo
. It would also apply to things likegradient(va -> sum(abs, [1 2; 3 4] * va), VectorOfArray([[1,2], [3,4]]))
.va[:,1]
expected to work? What should it return? The other way around should I think do this: