Skip to content

Commit

Permalink
feat: metadata_getindex for AbstractMillNodes
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmandlik committed Nov 26, 2024
1 parent 29e4578 commit 04f586b
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/src/manual/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Base.show(io::IO, n::PathNode) = print(io, "PathNode ($(numobs(n)) obs)")
Base.ndims(::PathNode) = Colon()
numobs(n::PathNode) = length(n.data)
catobs(ns::PathNode) = PathNode(vcat(data.(ns)...), catobs(metadata.(as)...))
Base.getindex(n::PathNode, i::VecOrRange{<:Int}) = PathNode(n.data[i], Mill.metadata_getindex(n.metadata, i))
Base.getindex(n::PathNode, i::VecOrRange{<:Int}) = PathNode(n.data[i], Mill.metadata_getindex(n, i))
NodeType(::Type{<:PathNode}) = LeafNode()
nothing # hide
```
Expand Down
2 changes: 1 addition & 1 deletion src/Mill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ include("datanodes/datanode.jl")
export AbstractMillNode, AbstractProductNode, AbstractBagNode
export ArrayNode, BagNode, WeightedBagNode, ProductNode, LazyNode
export numobs, getobs, catobs, removeinstances, dropmeta
@compat public data, metadata, mapdata, unpack2mill
@compat public data, metadata, metadata_getindex, mapdata, unpack2mill

include("special_arrays/special_arrays.jl")
export MaybeHotVector, MaybeHotMatrix, maybehot, maybehotbatch, maybecold
Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/arraynode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Base.reduce(::typeof(hcat), as::Vector{<:ArrayNode})
end

function Base.getindex(x::ArrayNode, i::VecOrRange{<:Integer})
ArrayNode(x.data[:, i], metadata_getindex(x.metadata, i))
ArrayNode(x.data[:, i], metadata_getindex(x, i))
end

_arraynode(m) = ArrayNode(m)
Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/bagnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dropmeta(x::BagNode) = BagNode(dropmeta(x.data), x.bags)
function Base.getindex(x::BagNode, i::VecOrRange{<:Integer})
nb, ii = remapbags(x.bags, i)
emptyismissing() && isempty(ii) && return(BagNode(missing, nb, nothing))
BagNode(x.data[ii], nb, metadata_getindex(x.metadata, i))
BagNode(x.data[ii], nb, metadata_getindex(x, i))
end

function Base.reduce(::typeof(catobs), as::Vector{<:BagNode})
Expand Down
1 change: 1 addition & 0 deletions src/datanodes/datanode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ julia> Mill.metadata_getindex([1 2 3; 4 5 6], [1, 3])
See also: [`Mill.metadata`](@ref), [`Mill.dropmeta`](@ref).
"""
metadata_getindex(x, _) = x
metadata_getindex(x::AbstractMillNode, i) = metadata_getindex(metadata(x), i)
@generated function metadata_getindex(x::AbstractArray{T, U}, i) where {T, U}
U == 1 ? :(getindex(x, i)) : :(getindex(x, :, i, $(Colon() for _ in 3:U)...))
end
Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/lazynode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function Base.reduce(::typeof(catobs), as::Vector{<: LazyNode{N}}) where N
end

function Base.getindex(x::LazyNode{N}, i::VecOrRange{<:Int}) where N
LazyNode{N}(x.data[i], metadata_getindex(x.metadata, i))
LazyNode{N}(x.data[i], metadata_getindex(x, i))
end

Base.hash(n::LazyNode, h::UInt) = hash((n.data), h)
Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/productnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function Base.reduce(::typeof(catobs), as::Vector{<:ProductNode})
end

function Base.getindex(x::ProductNode, i::VecOrRange{<:Integer})
ProductNode(map(v -> v[i], x.data), metadata_getindex(x.metadata, i))
ProductNode(map(v -> v[i], x.data), metadata_getindex(x, i))
end

Base.hash(n::ProductNode, h::UInt) = hash((n.data, n.metadata), h)
Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/weighted_bagnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dropmeta(x::WeightedBagNode) = WeightedBagNode(dropmeta(x.data), x.bags, x.weigh
function Base.getindex(x::WeightedBagNode{T, B, W}, i::VecOrRange{<:Integer}) where {T, B, W}
nb, ii = remapbags(x.bags, i)
emptyismissing() && isempty(ii) && return (WeightedBagNode(missing, nb, W[], nothing))
WeightedBagNode(x.data[ii], nb, x.weights[ii], metadata_getindex(x.metadata, i))
WeightedBagNode(x.data[ii], nb, x.weights[ii], metadata_getindex(x, i))
end

function Base.reduce(::typeof(catobs), as::Vector{<:WeightedBagNode})
Expand Down

0 comments on commit 04f586b

Please sign in to comment.