Skip to content

Propagate AxisArray copy / view down to taking copies / views of its axes as well #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ atvalue(x) = ExactValue(x)

const Values = AbstractArray{<:Value}

# Trait for describing whether to take views or copies
abstract type CopyStyle end
struct Copy <: CopyStyle end
struct View <: CopyStyle end

# For throwing a BoundsError with a Value index, we need to define the following
# (note that we could inherit them for free, were Value <: Number)
Base.iterate(x::Value, state = false) = state ? nothing : (x, true)
Expand All @@ -43,21 +48,22 @@ Base.IndexStyle(::Type{AxisArray{T,N,D,Ax}}) where {T,N,D,Ax} = IndexStyle(D)
Base.eachindex(A::AxisArray) = eachindex(A.data)

"""
reaxis(A::AxisArray, I...)
reaxis(A::AxisArray, copy::Val, I...)

This internal function determines the new set of axes that are constructed upon
indexing with I.
indexing with I. The `copy` argument determines whether views or copies of the
axes are taken.
"""
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::Idx...) = _reaxis(make_axes_match(axes(A), I), copy, I)
# Linear indexing
reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Int}) = _new_axes(A.axes[1], I)
reaxis(A::AxisArray, I::AbstractArray{Int}) = default_axes(I)
reaxis(A::AxisArray{<:Any,1}, I::Real) = ()
reaxis(A::AxisArray, I::Real) = ()
reaxis(A::AxisArray{<:Any,1}, I::Colon) = _new_axes(A.axes[1], Base.axes(A, 1))
reaxis(A::AxisArray, I::Colon) = default_axes(Base.OneTo(length(A)))
reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Bool}) = _new_axes(A.axes[1], findall(I))
reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I))
reaxis(A::AxisArray{<:Any,1}, copy::Type{<:CopyStyle}, I::AbstractArray{Int}) = _new_axes(A.axes[1], copy, I)
reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::AbstractArray{Int}) = default_axes(I)
reaxis(A::AxisArray{<:Any,1}, copy::Type{<:CopyStyle}, I::Real) = ()
reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::Real) = ()
reaxis(A::AxisArray{<:Any,1}, copy::Type{<:CopyStyle}, I::Colon) = _new_axes(A.axes[1], copy, Base.axes(A, 1))
reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::Colon) = default_axes(Base.OneTo(length(A)))
reaxis(A::AxisArray{<:Any,1},copy::Type{<:CopyStyle}, I::AbstractArray{Bool}) = _new_axes(A.axes[1], copy, findall(I))
reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::AbstractArray{Bool}) = default_axes(findall(I))

# Ensure the number of axes matches the number of indexing dimensions
@inline function make_axes_match(axs, idxs)
Expand All @@ -66,28 +72,35 @@ reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I))
end

# Now we can reaxis without worrying about mismatched axes/indices
@inline _reaxis(axs::Tuple{}, idxs::Tuple{}) = ()
@inline _reaxis(axs::Tuple{}, copy::Type{<:CopyStyle}, idxs::Tuple{}) = ()
# Scalars are dropped
const ScalarIndex = Union{Real, AbstractArray{<:Any, 0}}
@inline _reaxis(axs::Tuple, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), tail(idxs))
@inline _reaxis(axs::Tuple, copy::Type{<:CopyStyle}, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), copy, tail(idxs))
# Colon passes straight through
@inline _reaxis(axs::Tuple, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), tail(idxs))...)
@inline _reaxis(axs::Tuple, copy::Type{<:CopyStyle}, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), copy, tail(idxs))...)
# But arrays can add or change dimensions and accompanying axis names
@inline _reaxis(axs::Tuple, idxs::Tuple{AbstractArray, Vararg{Any}}) =
(_new_axes(axs[1], idxs[1])..., _reaxis(tail(axs), tail(idxs))...)
@inline _reaxis(axs::Tuple, copy::Type{<:CopyStyle}, idxs::Tuple{AbstractArray, Vararg{Any}}) =
(_new_axes(axs[1], copy, idxs[1])..., _reaxis(tail(axs), copy, tail(idxs))...)

# Vectors simply create new axes with the same name; just subsetted by their value
@inline _new_axes(ax::Axis{name}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),)
@inline _new_axes(ax::Axis{name}, ::Type{Copy}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),)
@inline _new_axes(ax::Axis{name}, ::Type{View}, idx::AbstractVector) where {name} = (Axis{name}(view(ax.val, idx)),)

# Arrays create multiple axes with _N appended to the axis name containing their indices
@generated function _new_axes(ax::Axis{name}, idx::AbstractArray{<:Any,N}) where {name,N}
_new_axes(ax::Axis{name}, ::Type{Copy}, idx::AbstractArray{<:Any, N}) where {name, N} = _new_axes(ax, idx)
_new_axes(ax::Axis{name}, ::Type{View}, idx::AbstractArray{<:Any, N}) where {name, N} = _new_axes(ax, idx)
@generated function _new_axes(ax::Axis{name}, idx::AbstractArray{<:Any, N}) where {name, N}
newaxes = Expr(:tuple)
for i=1:N
push!(newaxes.args, :($(Axis{Symbol(name, "_", i)})(Base.axes(idx, $i))))
end
newaxes
end

# And indexing with an AxisArray joins the name and overrides the values
@generated function _new_axes(ax::Axis{name}, idx::AxisArray{<:Any, N}) where {name,N}
_new_axes(ax::Axis{name}, ::Type{Copy}, idx::AxisArray{<:Any, N}) where {name, N} = _new_axes(ax, idx)
_new_axes(ax::Axis{name}, ::Type{View}, idx::AxisArray{<:Any, N}) where {name, N} = _new_axes(ax, idx)
@generated function _new_axes(ax::Axis{name}, idx::AxisArray{<:Any, N}) where {name, N}
newaxes = Expr(:tuple)
idxnames = axisnames(idx)
for i=1:N
Expand All @@ -97,19 +110,19 @@ end
end

@propagate_inbounds function Base.getindex(A::AxisArray, idxs::Idx...)
AxisArray(A.data[idxs...], reaxis(A, idxs...))
AxisArray(A.data[idxs...], reaxis(A, Copy, idxs...))
end

# To resolve ambiguities, we need several definitions
using Base: AbstractCartesianIndex
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, View, idxs...))

# Setindex is so much simpler. Just assign it to the data:
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)

# Logical indexing
@propagate_inbounds function Base.getindex(A::AxisArray, idx::AbstractArray{Bool})
AxisArray(A.data[idx], reaxis(A, idx))
AxisArray(A.data[idx], reaxis(A, Copy, idx))
end
@propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)

Expand Down