Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking: Implement AbstractArray interface for AbstractDimStack #871

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ metadata(A::AbstractDimArray) = A.metadata
layerdims(A::AbstractDimArray) = basedims(A)

@inline rebuildsliced(A::AbstractBasicDimArray, args...) = rebuildsliced(getindex, A, args...)
@inline function rebuildsliced(f::Function, A::AbstractBasicDimArray, data::AbstractArray, I::Tuple, name=name(A))
@inline function rebuildsliced(f::Function, A::AbstractBasicDimArray, data, I, name=name(A))
I1 = to_indices(A, I)
rebuild(A, data, slicedims(f, A, I1)..., name)
end
Expand Down
2 changes: 1 addition & 1 deletion src/array/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,4 @@ Base.@assume_effects :foldable @inline _simplify_dim_indices() = ()
@propagate_inbounds Base.maybeview(A::AbstractDimArray, args...; kw...) =
view(A, args...; kw...)
@propagate_inbounds Base.maybeview(A::AbstractDimArray, args::Vararg{Union{Number,Base.AbstractCartesianIndex}}; kw...) =
view(A, args...; kw...)
view(A, args...; kw...)
78 changes: 14 additions & 64 deletions src/stack/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,6 @@ Base.@assume_effects :effect_free @propagate_inbounds function Base.getindex(s::
# Use dimensional indexing
Base.getindex(s, rebuild(only(dims(s)), i))
end
Base.@assume_effects :effect_free @propagate_inbounds function Base.getindex(
s::AbstractDimStack{<:Any,T}, i::Union{AbstractArray,Colon}
) where {T}
ls = _maybe_extented_layers(s)
inds = to_indices(first(ls), (i,))[1]
out = similar(inds, T)
for (i, ind) in enumerate(inds)
out[i] = T(map(v -> v[ind], ls))
end
return out
end
@propagate_inbounds function Base.getindex(s::AbstractDimStack{<:Any,<:Any,N}, i::Integer) where N
if N == 1 && hassamedims(s)
# This is a few ns faster when possible
Expand All @@ -52,37 +41,14 @@ end
@propagate_inbounds function Base.view(s::AbstractVectorDimStack, i::Union{AbstractVector{<:Integer},Colon,Integer})
Base.view(s, DimIndices(s)[i])
end
@propagate_inbounds function Base.view(s::AbstractDimStack, i::Union{AbstractArray{<:Integer},Colon,Integer})
# Pretend the stack is an AbstractArray so `SubArray` accepts it.
Base.view(OpaqueArray(s), i)
end

for f in (:getindex, :view, :dotview)
_dim_f = Symbol(:_dim_, f)
@eval begin
@propagate_inbounds function Base.$f(s::AbstractDimStack, i)
Base.$f(s, to_indices(CartesianIndices(s), Lookups._construct_types(i))...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i::Union{SelectorOrInterval,Extents.Extent})
Base.$f(s, dims2indices(s, i)...)
end
@propagate_inbounds function Base.$f(s::AbstractVectorDimStack, i::Union{CartesianIndices,CartesianIndex})
I = to_indices(CartesianIndices(s), (i,))
Base.$f(s, I...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i::Union{CartesianIndices,CartesianIndex})
I = to_indices(CartesianIndices(s), (i,))
Base.$f(s, I...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i1, i2, Is...)
I = to_indices(CartesianIndices(s), Lookups._construct_types(i1, i2, Is...))
# Check we have the right number of dimensions
if length(dims(s)) > length(I)
@propagate_inbounds function $_dim_f(
A::AbstractDimStack, a1::Union{Dimension,DimensionIndsArrays}, args::Union{Dimension,DimensionIndsArrays}...
)
return merge_and_index(Base.$f, A, (a1, args...))
end
throw(BoundsError(dims(s), I))
elseif length(dims(s)) < length(I)
# Allow trailing ones
Expand All @@ -95,12 +61,10 @@ for f in (:getindex, :view, :dotview)
# Convert to Dimension wrappers to handle mixed size layers
Base.$f(s, map(rebuild, dims(s), I)...)
end
@propagate_inbounds function Base.$f(
s::AbstractDimStack, D::DimensionalIndices...; kw...
)
$_dim_f(s, _simplify_dim_indices(D..., kw2dims(values(kw))...)...)
end
# Ambiguities
@propagate_inbounds Base.$f(A::AbstractDimStack, d1::DimensionalIndices, d2::DimensionalIndices, D::DimensionalIndices...; kw...) =
$_dim_f(A, _simplify_dim_indices(d1, d2, D..., kw2dims(values(kw))...)...)

@propagate_inbounds function Base.$f(s::DimensionalData.AbstractVectorDimStack,
i::Union{AbstractVector{<:DimensionalData.Dimensions.Dimension},
AbstractVector{<:Tuple{DimensionalData.Dimensions.Dimension, Vararg{DimensionalData.Dimensions.Dimension}}},
Expand All @@ -109,12 +73,6 @@ for f in (:getindex, :view, :dotview)
$_dim_f(s, _simplify_dim_indices(i)...)
end


@propagate_inbounds function $_dim_f(
A::AbstractDimStack, a1::Union{Dimension,DimensionIndsArrays}, args::Union{Dimension,DimensionIndsArrays}...
)
return merge_and_index(Base.$f, A, (a1, args...))
end
# Handle zero-argument getindex, this will error unless all layers are zero dimensional
@propagate_inbounds function $_dim_f(s::AbstractDimStack)
map(Base.$f, data(s))
Expand Down Expand Up @@ -151,25 +109,15 @@ end
#### setindex ####
@propagate_inbounds Base.setindex!(s::AbstractDimStack, xs, I...; kw...) =
map((A, x) -> setindex!(A, x, I...; kw...), layers(s), xs)
@propagate_inbounds Base.setindex!(s::AbstractDimStack, xs::NamedTuple, i::Integer; kw...) =
hassamedims(s) ? _map_setindex!(s, xs, i; kw...) : _setindex_mixed!(s, xs, i; kw...)
@propagate_inbounds Base.setindex!(s::AbstractDimStack, xs::NamedTuple, i::Colon; kw...) =
hassamedims(s) ? _map_setindex!(s, xs, i; kw...) : _setindex_mixed!(s, xs, i; kw...)
@propagate_inbounds Base.setindex!(s::AbstractDimStack, xs::NamedTuple, i::AbstractArray; kw...) =
hassamedims(s) ? _map_setindex!(s, xs, i; kw...) : _setindex_mixed!(s, xs, i; kw...)

@propagate_inbounds function Base.setindex!(
s::AbstractDimStack, xs::NamedTuple, I...; kw...
)
map((A, x) -> setindex!(A, x, I...; kw...), layers(s), xs)
end
@propagate_inbounds Base.setindex!(s::AbstractDimStack, xs::NamedTuple, i...; kw...) =
hassamedims(s) ? _map_setindex!(s, xs, i...; kw...) : _setindex_mixed!(s, xs, i...; kw...)

_map_setindex!(s, xs, i; kw...) = map((A, x) -> setindex!(A, x, i...; kw...), layers(s), xs)

_setindex_mixed!(s::AbstractDimStack, x, i::AbstractArray) =
map(A -> setindex!(A, x, DimIndices(dims(s))[i]), layers(s))
_setindex_mixed!(s::AbstractDimStack, i::Integer) =
map(A -> setindex!(A, x, DimIndices(dims(s))[i]), layers(s))
_setindex_mixed!(s::AbstractDimStack, x, i...; kw...) =
map(layers(s), x) do A, x
A[dims(DimIndices(s)[i...], dims(A))] = x
end
function _setindex_mixed!(s::AbstractDimStack, x, i::Colon)
map(DimIndices(dims(s))) do D
map(A -> setindex!(A, D), x, layers(s))
Expand All @@ -178,9 +126,6 @@ end

@noinline _keysmismatch(K1, K2) = throw(ArgumentError("NamedTuple keys $K2 do not mach stack keys $K1"))

# For @views macro to work with keywords
Base.maybeview(A::AbstractDimStack, args...; kw...) = view(A, args...; kw...)

function merge_and_index(f, s::AbstractDimStack, ds)
ds, inds_arrays = _separate_dims_arrays(_simplify_dim_indices(ds...)...)
# No arrays here, so abort (dispatch is tricky...)
Expand All @@ -198,3 +143,8 @@ function merge_and_index(f, s::AbstractDimStack, ds)
end
return rebuild_from_arrays(s, newlayers)
end

@propagate_inbounds Base.maybeview(A::AbstractDimStack, args...; kw...) =
view(A, args...; kw...)
@propagate_inbounds Base.maybeview(A::AbstractDimStack, args::Vararg{Union{Number,Base.AbstractCartesianIndex}}; kw...) =
view(A, args...; kw...)
15 changes: 5 additions & 10 deletions src/stack/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ end

# Methods with an argument that return a DimStack
for fname in (:rotl90, :rotr90, :rot180)
@eval (Base.$fname)(s::AbstractDimStack, args...) =
maplayers(A -> (Base.$fname)(A, args...), s)
@eval (Base.$fname)(s::AbstractDimStack) =
maplayers(A -> (Base.$fname)(A), s)
@eval (Base.$fname)(s::AbstractDimStack, k::Integer) =
maplayers(A -> (Base.$fname)(A, k), s)
end
for fname in (:PermutedDimsArray, :permutedims)
@eval function (Base.$fname)(s::AbstractDimStack, perm)
Expand Down Expand Up @@ -187,11 +189,4 @@ for fname in (:one, :oneunit, :zero, :copy)
end
end

Base.reverse(s::AbstractDimStack; dims=:) = maplayers(A -> reverse(A; dims=dims), s)

# Random
Random.Sampler(RNG::Type{<:AbstractRNG}, st::AbstractDimStack, n::Random.Repetition) =
Random.SamplerSimple(st, Random.Sampler(RNG, DimIndices(st), n))

Random.rand(rng::AbstractRNG, sp::Random.SamplerSimple{<:AbstractDimStack,<:Random.Sampler}) =
@inbounds return sp[][rand(rng, sp.data)...]
Base.reverse(s::AbstractDimStack; dims=:) = maplayers(A -> reverse(A; dims=dims), s)
2 changes: 1 addition & 1 deletion src/stack/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function show_main(io, mime, stack::AbstractDimStack)
_, blockwidth = print_metadata_block(io, mime, metadata(stack); displaywidth, blockwidth=min(displaywidth-2, blockwidth))
end

function show_after(io, mime, stack::AbstractDimStack)
function show_after(io::IO, mime, stack::AbstractDimStack)
blockwidth = get(io, :blockwidth, 0)
print_block_close(io, blockwidth)
end
Expand Down
80 changes: 40 additions & 40 deletions src/stack/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,19 @@ To extend `AbstractDimStack`, implement argument and keyword version of

The constructor of an `AbstractDimStack` must accept a `NamedTuple`.
"""
abstract type AbstractDimStack{K,T,N,L} end
abstract type AbstractDimStack{K,T<:NamedTuple,N,L,D} <: AbstractBasicDimArray{T,N,D} end
const AbstractVectorDimStack = AbstractDimStack{K,T,1} where {K,T}
const AbstractMatrixDimStack = AbstractDimStack{K,T,2} where {K,T}

(::Type{T})(st::AbstractDimStack) where T<:AbstractDimArray =
(::Type{T})(st::AbstractDimStack) where {T<:AbstractDimArray} =
T([st[D] for D in DimIndices(st)]; dims=dims(st), metadata=metadata(st))
DimArray(st::AbstractDimStack) =
DimArray(collect(st), dims(st); metadata=metadata(st))
DimMatrix(st::AbstractMatrixDimStack) =
DimArray(collect(st), dims(st); metadata=metadata(st))
DimVector(st::AbstractVectorDimStack) =
DimArray(collect(st), dims(st); metadata=metadata(st))


data(s::AbstractDimStack) = getfield(s, :data)
dims(s::AbstractDimStack) = getfield(s, :dims)
Expand Down Expand Up @@ -145,31 +152,12 @@ Base.:(==)(s1::AbstractDimStack, s2::AbstractDimStack) =
data(s1) == data(s2) && dims(s1) == dims(s2) && layerdims(s1) == layerdims(s2)
Base.read(s::AbstractDimStack) = maplayers(read, s)

# Array-like
Base.size(s::AbstractDimStack) = map(length, dims(s))
Base.size(s::AbstractDimStack, dims::DimOrDimType) = size(s, dimnum(s, dims))
Base.size(s::AbstractDimStack, dims::Integer) = size(s)[dims]
Base.length(s::AbstractDimStack) = prod(size(s))
Base.axes(s::AbstractDimStack) = map(first ∘ axes, dims(s))
Base.axes(s::AbstractDimStack, dims::DimOrDimType) = axes(s, dimnum(s, dims))
Base.axes(s::AbstractDimStack, dims::Integer) = axes(s)[dims]
Base.similar(s::AbstractDimStack, args...) = maplayers(A -> similar(A, args...), s)
Base.eltype(::AbstractDimStack{<:Any,T}) where T = T
Base.ndims(::AbstractDimStack{<:Any,<:Any,N}) where N = N
Base.CartesianIndices(s::AbstractDimStack) = CartesianIndices(dims(s))
Base.LinearIndices(s::AbstractDimStack) =
LinearIndices(CartesianIndices(map(l -> axes(l, 1), lookup(s))))
Base.IteratorSize(::AbstractDimStack{<:Any,<:Any,N}) where N = Base.HasShape{N}()
function Base.eachindex(s::AbstractDimStack)
li = LinearIndices(s)
first(li):last(li)
end
Base.firstindex(s::AbstractDimStack) = first(LinearIndices(s))
Base.lastindex(s::AbstractDimStack) = last(LinearIndices(s))
Base.first(s::AbstractDimStack) = s[firstindex((s))]
Base.last(s::AbstractDimStack) = s[lastindex(LinearIndices(s))]
Base.copy(s::AbstractDimStack) = modify(copy, s)
# all of methods.jl is also Array-like...
# Array interface methods
Base.IndexStyle(A::AbstractDimStack) = Base.IndexStyle(first(layers(A)))

Base.similar(s::AbstractDimStack; kw...) = maplayers(A -> similar(A; kw...), s)
Base.similar(s::AbstractDimStack, ::Type{T}; kw...) where T = maplayers(A -> similar(A, T; kw...), s)
Base.similar(s::AbstractDimStack, ::Type{T}, dt::DimTuple; kw...) where T = maplayers(A -> similar(A, T, dt; kw...), s)

# NamedTuple-like
@assume_effects :foldable Base.getproperty(s::AbstractDimStack, x::Symbol) = s[x]
Expand All @@ -178,24 +166,17 @@ Base.values(s::AbstractDimStack) = _values_gen(s)
@generated function _values_gen(s::AbstractDimStack{K}) where K
Expr(:tuple, map(k -> :(s[$(QuoteNode(k))]), K)...)
end
Base.checkbounds(s::AbstractDimStack, I...) = checkbounds(CartesianIndices(s), I...)
Base.checkbounds(T::Type, s::AbstractDimStack, I...) = checkbounds(T, CartesianIndices(s), I...)

@inline Base.keys(s::AbstractDimStack{K}) where K = K
@inline Base.propertynames(s::AbstractDimStack{K}) where K = K
@inline Base.setindex(s::AbstractDimStack, val::AbstractBasicDimArray, name::Symbol) =
rebuild_from_arrays(s, Base.setindex(layers(s), val, name))
Base.NamedTuple(s::AbstractDimStack) = NamedTuple(layers(s))
Base.collect(st::AbstractDimStack) = parent([st[D] for D in DimIndices(st)])
Base.Array(st::AbstractDimStack) = collect(st)
Base.vec(st::AbstractDimStack) = vec(collect(st))
Base.get(st::AbstractDimStack, k::Symbol, default) =
haskey(st, k) ? st[k] : default
Base.get(f::Base.Callable, st::AbstractDimStack, k::Symbol) =
haskey(st, k) ? st[k] : f()
@propagate_inbounds Base.iterate(st::AbstractDimStack) = iterate(st, 1)
@propagate_inbounds Base.iterate(st::AbstractDimStack, i) =
i > length(st) ? nothing : (st[DimIndices(st)[i]], i + 1)

# `merge` for AbstractDimStack and NamedTuple.
# One of the first three arguments must be an AbstractDimStack for dispatch to work.
Expand Down Expand Up @@ -231,14 +212,30 @@ Base.map(f, ::Union{AbstractDimStack,NamedTuple}, xs::Union{AbstractDimStack,Nam
maplayers(f, s::AbstractDimStack) =
_maybestack(s, unrolled_map(f, values(s)))
function maplayers(
f, x1::Union{AbstractDimStack,NamedTuple}, xs::Union{AbstractDimStack,NamedTuple}...
f, x1::Union{AbstractBasicDimArray,NamedTuple}, xs::Union{AbstractBasicDimArray,NamedTuple}...
)
stacks = (x1, xs...)
_check_same_names(stacks...)
vals = map(f, map(values, stacks)...)
return _maybestack(_firststack(stacks...), vals)
xs = (x1, xs...)
firststack = _firststack(xs...)
if isnothing(firststack)
if all(map(x -> x isa AbstractArray, xs))
# all arguments are arrays, but none are stacks, just apply the function
f(xs...)
else
# Arguments are some mix of arrays and NamedTuple
throw(ArgumentError("Cannot apply maplayers to NamedTuple and AbstractBasicDimArray"))
end
else
# There is at least one stack, we apply layer-wise
_check_same_names(xs...)
l = length(values(firststack))
vals = map(f, map(s -> _values_or_tuple(s, l), xs)...)
return _maybestack(firststack, vals)
end
end

_values_or_tuple(x::Union{AbstractDimStack, NamedTuple}, l) = values(x)
_values_or_tuple(x::Union{AbstractBasicDimArray}, l) = Tuple(Iterators.repeated(x, l))

# Other interfaces

Extents.extent(A::AbstractDimStack, args...) = Extents.extent(dims(A), args...)
Expand Down Expand Up @@ -287,6 +284,9 @@ _check_same_names(::Union{AbstractDimStack{names},NamedTuple{names}},
::Union{AbstractDimStack{names},NamedTuple{names}}...) where {names} = nothing
_check_same_names(::Union{AbstractDimStack,NamedTuple}, ::Union{AbstractDimStack,NamedTuple}...) =
throw(ArgumentError("Named tuple names do not match."))
_check_same_names(xs::Union{AbstractDimStack,NamedTuple,AbstractBasicDimArray}...) =
_check_same_names((x for x in xs if x isa Union{AbstractDimStack,NamedTuple})...)


_firststack(s::AbstractDimStack, args...) = s
_firststack(arg1, args...) = _firststack(args...)
Expand Down Expand Up @@ -390,7 +390,7 @@ julia> s[X(At(:a))] isa DimStack
true
```
"""
struct DimStack{K,T,N,L,D<:Tuple,R<:Tuple,LD,M,LM} <: AbstractDimStack{K,T,N,L}
struct DimStack{K,T<:NamedTuple,N,L,D<:Tuple,R<:Tuple,LD,M,LM} <: AbstractDimStack{K,T,N,L,D}
data::L
dims::D
refdims::R
Expand Down
42 changes: 21 additions & 21 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ modify(CuArray, A)
This also works for all the data layers in a `DimStack`.
"""
function modify end
modify(f, s::AbstractDimStack) = maplayers(a -> modify(f, a), s)
# Stack optimisation to avoid compilation to build all the `AbstractDimArray`
# layers, and instead just modify the parent data directly.
modify(f, s::AbstractDimStack{<:Any,<:Any,<:NamedTuple}) =
Expand Down Expand Up @@ -121,19 +120,16 @@ all passed in arrays in the order in which they are found.
This is like broadcasting over every slice of `A` if it is
sliced by the dimensions of `B`.
"""
function broadcast_dims(f, As::AbstractBasicDimArray...)
dims = combinedims(As...)
T = Base.Broadcast.combine_eltypes(f, As)
broadcast_dims!(f, similar(first(As), T, dims), As...)
end

function broadcast_dims(f, As::Union{AbstractDimStack,AbstractBasicDimArray}...)
st = _firststack(As...)
nts = _as_extended_nts(NamedTuple(st), As...)
layers = map(keys(st)) do name
broadcast_dims(f, map(nt -> nt[name], nts)...)
function broadcast_dims(f, As::AbstractBasicDimArray...; bylayer=false)
if bylayer
maplayers((As...) -> broadcast_dims(f, As...), As...)
else
A1 = first(As)
_A1 = A1 isa AbstractDimStack ? first(layers(A1)) : A1
dims = combinedims(As...)
T = Base.Broadcast.combine_eltypes(f, As)
broadcast_dims!(f, similar(_A1, T, dims), As...)
end
rebuild_from_arrays(st, layers)
end

"""
Expand All @@ -150,15 +146,19 @@ which they are found.
- `dest`: `AbstractDimArray` to update.
- `sources`: `AbstractDimArrays` to broadcast over with `f`.
"""
function broadcast_dims!(f, dest::AbstractDimArray{<:Any,N}, As::AbstractBasicDimArray...) where {N}
As = map(As) do A
isempty(otherdims(A, dims(dest))) || throw(DimensionMismatch("Cannot broadcast over dimensions not in the dest array"))
# comparedims(dest, dims(A, dims(dest)))
# Lazily permute B dims to match the order in A, if required
_maybe_lazy_permute(A, dims(dest))
function broadcast_dims!(f, dest::Union{AbstractDimArray{<:Any,N}, AbstractDimStack{<:Any,<:Any,N}}, As::AbstractBasicDimArray...; bylayer=false) where {N}
if bylayer
maplayers((dest,As) -> broadcast_dims!(f, As, dest), As, dest)
else
As = map(As) do A
isempty(otherdims(A, dims(dest))) || throw(DimensionMismatch("Cannot broadcast over dimensions not in the dest array"))
# comparedims(dest, dims(A, dims(dest)))
# Lazily permute B dims to match the order in A, if required
_maybe_lazy_permute(A, dims(dest))
end
od = map(A -> otherdims(dest, dims(A)), As)
return _broadcast_dims_inner!(f, dest, As, od)
end
od = map(A -> otherdims(dest, dims(A)), As)
return _broadcast_dims_inner!(f, dest, As, od)
end

# Function barrier
Expand Down
Loading
Loading