-
Notifications
You must be signed in to change notification settings - Fork 93
Description
#491 implements vect
for two special cases, but neglects the general case because I couldn't figure out how to implement it in a type-stable manner for Number
s, or at all in general.
My attempt at an implementation for Number
s was
# Numbers need to be projected because they don't pass straight through the function.
# More generally, we would ideally project everything.
function rrule(::typeof(Base.vect), X::Vararg{Number, N}) where {N}
l = length(X)
projects = map(ProjectTo, X)
function vect_pullback(ȳ)
X̄ = ntuple(n -> projects[n](ȳ[n]), l)
return (NoTangent(), X̄...)
end
return Base.vect(X...), vect_pullback
end
but it's type-unstable because the type of X\bar
couldn't be inferred (Julia 1.6.2). I'm not entirely sure why it can't figure it out.
The Number
specialisation of this function is interesting because it highlights that when an array is constructed in vect
, the type of the elements changes from their original types some of the time but not others. For example, this isn't the case with subtypes of AbstractArray
, which essentially pass through unchanged. I'm not sure how to do with this in general, because
- there's not a function to explicitly hook in to which implements the conversion
ProjectTo
isn't defined for all types (onlyNumber
s andAbstractArray
s AFAICT), so we can't rely on it in general.
Any thoughts on either of these issues?