diff --git a/src/indexing.jl b/src/indexing.jl index 0db70b6..6e2961b 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -20,6 +20,11 @@ atvalue(x) = ExactValue(x) const Values = AbstractArray{<:Value} +# Trait for describing whether to take views or copies +abstract type CopyStyle end +struct Copy <: CopyStyle end +struct View <: CopyStyle end + # For throwing a BoundsError with a Value index, we need to define the following # (note that we could inherit them for free, were Value <: Number) Base.iterate(x::Value, state = false) = state ? nothing : (x, true) @@ -43,21 +48,22 @@ Base.IndexStyle(::Type{AxisArray{T,N,D,Ax}}) where {T,N,D,Ax} = IndexStyle(D) Base.eachindex(A::AxisArray) = eachindex(A.data) """ - reaxis(A::AxisArray, I...) + reaxis(A::AxisArray, copy::Val, I...) This internal function determines the new set of axes that are constructed upon -indexing with I. +indexing with I. The `copy` argument determines whether views or copies of the +axes are taken. """ -reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I) +reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::Idx...) = _reaxis(make_axes_match(axes(A), I), copy, I) # Linear indexing -reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Int}) = _new_axes(A.axes[1], I) -reaxis(A::AxisArray, I::AbstractArray{Int}) = default_axes(I) -reaxis(A::AxisArray{<:Any,1}, I::Real) = () -reaxis(A::AxisArray, I::Real) = () -reaxis(A::AxisArray{<:Any,1}, I::Colon) = _new_axes(A.axes[1], Base.axes(A, 1)) -reaxis(A::AxisArray, I::Colon) = default_axes(Base.OneTo(length(A))) -reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Bool}) = _new_axes(A.axes[1], findall(I)) -reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I)) +reaxis(A::AxisArray{<:Any,1}, copy::Type{<:CopyStyle}, I::AbstractArray{Int}) = _new_axes(A.axes[1], copy, I) +reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::AbstractArray{Int}) = default_axes(I) +reaxis(A::AxisArray{<:Any,1}, copy::Type{<:CopyStyle}, I::Real) = () +reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::Real) = () +reaxis(A::AxisArray{<:Any,1}, copy::Type{<:CopyStyle}, I::Colon) = _new_axes(A.axes[1], copy, Base.axes(A, 1)) +reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::Colon) = default_axes(Base.OneTo(length(A))) +reaxis(A::AxisArray{<:Any,1},copy::Type{<:CopyStyle}, I::AbstractArray{Bool}) = _new_axes(A.axes[1], copy, findall(I)) +reaxis(A::AxisArray, copy::Type{<:CopyStyle}, I::AbstractArray{Bool}) = default_axes(findall(I)) # Ensure the number of axes matches the number of indexing dimensions @inline function make_axes_match(axs, idxs) @@ -66,28 +72,35 @@ reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I)) end # Now we can reaxis without worrying about mismatched axes/indices -@inline _reaxis(axs::Tuple{}, idxs::Tuple{}) = () +@inline _reaxis(axs::Tuple{}, copy::Type{<:CopyStyle}, idxs::Tuple{}) = () # Scalars are dropped const ScalarIndex = Union{Real, AbstractArray{<:Any, 0}} -@inline _reaxis(axs::Tuple, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), tail(idxs)) +@inline _reaxis(axs::Tuple, copy::Type{<:CopyStyle}, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), copy, tail(idxs)) # Colon passes straight through -@inline _reaxis(axs::Tuple, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), tail(idxs))...) +@inline _reaxis(axs::Tuple, copy::Type{<:CopyStyle}, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), copy, tail(idxs))...) # But arrays can add or change dimensions and accompanying axis names -@inline _reaxis(axs::Tuple, idxs::Tuple{AbstractArray, Vararg{Any}}) = - (_new_axes(axs[1], idxs[1])..., _reaxis(tail(axs), tail(idxs))...) +@inline _reaxis(axs::Tuple, copy::Type{<:CopyStyle}, idxs::Tuple{AbstractArray, Vararg{Any}}) = + (_new_axes(axs[1], copy, idxs[1])..., _reaxis(tail(axs), copy, tail(idxs))...) # Vectors simply create new axes with the same name; just subsetted by their value -@inline _new_axes(ax::Axis{name}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),) +@inline _new_axes(ax::Axis{name}, ::Type{Copy}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),) +@inline _new_axes(ax::Axis{name}, ::Type{View}, idx::AbstractVector) where {name} = (Axis{name}(view(ax.val, idx)),) + # Arrays create multiple axes with _N appended to the axis name containing their indices -@generated function _new_axes(ax::Axis{name}, idx::AbstractArray{<:Any,N}) where {name,N} +_new_axes(ax::Axis{name}, ::Type{Copy}, idx::AbstractArray{<:Any, N}) where {name, N} = _new_axes(ax, idx) +_new_axes(ax::Axis{name}, ::Type{View}, idx::AbstractArray{<:Any, N}) where {name, N} = _new_axes(ax, idx) +@generated function _new_axes(ax::Axis{name}, idx::AbstractArray{<:Any, N}) where {name, N} newaxes = Expr(:tuple) for i=1:N push!(newaxes.args, :($(Axis{Symbol(name, "_", i)})(Base.axes(idx, $i)))) end newaxes end + # And indexing with an AxisArray joins the name and overrides the values -@generated function _new_axes(ax::Axis{name}, idx::AxisArray{<:Any, N}) where {name,N} +_new_axes(ax::Axis{name}, ::Type{Copy}, idx::AxisArray{<:Any, N}) where {name, N} = _new_axes(ax, idx) +_new_axes(ax::Axis{name}, ::Type{View}, idx::AxisArray{<:Any, N}) where {name, N} = _new_axes(ax, idx) +@generated function _new_axes(ax::Axis{name}, idx::AxisArray{<:Any, N}) where {name, N} newaxes = Expr(:tuple) idxnames = axisnames(idx) for i=1:N @@ -97,19 +110,19 @@ end end @propagate_inbounds function Base.getindex(A::AxisArray, idxs::Idx...) - AxisArray(A.data[idxs...], reaxis(A, idxs...)) + AxisArray(A.data[idxs...], reaxis(A, Copy, idxs...)) end # To resolve ambiguities, we need several definitions using Base: AbstractCartesianIndex -@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...)) +@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, View, idxs...)) # Setindex is so much simpler. Just assign it to the data: @propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v) # Logical indexing @propagate_inbounds function Base.getindex(A::AxisArray, idx::AbstractArray{Bool}) - AxisArray(A.data[idx], reaxis(A, idx)) + AxisArray(A.data[idx], reaxis(A, Copy, idx)) end @propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)