Skip to content

How to implement vect #492

@willtebbutt

Description

@willtebbutt

#491 implements vect for two special cases, but neglects the general case because I couldn't figure out how to implement it in a type-stable manner for Numbers, or at all in general.

My attempt at an implementation for Numbers was

# Numbers need to be projected because they don't pass straight through the function.
# More generally, we would ideally project everything.
function rrule(::typeof(Base.vect), X::Vararg{Number, N}) where {N}
    l = length(X)
    projects = map(ProjectTo, X)
    function vect_pullback(ȳ)
        X̄ = ntuple(n -> projects[n](ȳ[n]), l)
        return (NoTangent(), X̄...)
    end
    return Base.vect(X...), vect_pullback
end

but it's type-unstable because the type of X\bar couldn't be inferred (Julia 1.6.2). I'm not entirely sure why it can't figure it out.

The Number specialisation of this function is interesting because it highlights that when an array is constructed in vect, the type of the elements changes from their original types some of the time but not others. For example, this isn't the case with subtypes of AbstractArray, which essentially pass through unchanged. I'm not sure how to do with this in general, because

  1. there's not a function to explicitly hook in to which implements the conversion
  2. ProjectTo isn't defined for all types (only Numbers and AbstractArrays AFAICT), so we can't rely on it in general.

Any thoughts on either of these issues?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions