Skip to content

[NDTensors] Introduce LabelledNumbers and GradedAxesNext #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ for lib in [
:BroadcastMapConversion,
:RankFactorization,
:Sectors,
:LabelledNumbers,
:GradedAxesNext,
:GradedAxes,
:TensorAlgebra,
:SparseArrayInterface,
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
3 changes: 3 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/src/GradedAxesNext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module GradedAxesNext
include("gradedunitrange.jl")
end
245 changes: 245 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using BlockArrays:
BlockArrays,
Block,
BlockedUnitRange,
BlockRange,
BlockVector,
blockedrange,
BlockIndexRange,
blockfirsts,
blocklasts,
blocklengths,
findblock,
findblockindex,
mortar
using ..LabelledNumbers: LabelledNumbers, LabelledInteger, label, labelled, unlabel

# Custom `BlockedUnitRange` constructor that takes a unit range
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
function blockedunitrange(a::AbstractUnitRange, blocklengths)
blocklengths_shifted = copy(blocklengths)
blocklengths_shifted[1] += (first(a) - 1)
blocklasts = cumsum(blocklengths_shifted)
return BlockArrays._BlockedUnitRange(first(a), blocklasts)
end

# Circumvents issue in `findblock` that assumes the `BlockedUnitRange`
# starts at 1.
# TODO: Raise an issue with `BlockArrays`.
function blockedunitrange_findblock(a::BlockedUnitRange, index::Integer)
@boundscheck index in 1:length(a) || throw(BoundsError(a, index))
return @inbounds findblock(a, index + first(a) - 1)
end

# Circumvents issue in `findblockindex` that assumes the `BlockedUnitRange`
# starts at 1.
# TODO: Raise an issue with `BlockArrays`.
function blockedunitrange_findblockindex(a::BlockedUnitRange, index::Integer)
@boundscheck index in 1:length(a) || throw(BoundsError())
return @inbounds findblockindex(a, index + first(a) - 1)
end

const GradedUnitRange{BlockLasts<:Vector{<:LabelledInteger}} = BlockedUnitRange{BlockLasts}

function gradedrange(lblocklengths::AbstractVector{<:LabelledInteger})
brange = blockedrange(unlabel.(lblocklengths))
lblocklasts = labelled.(blocklasts(brange), label.(lblocklengths))
# TODO: `first` is forced to be `Int` in `BlockArrays.BlockedUnitRange`,
# so this doesn't do anything right now. Make a PR to generalize it.
firstlength = first(lblocklengths)
lfirst = oneunit(firstlength)
return BlockArrays._BlockedUnitRange(lfirst, lblocklasts)
end

Base.last(a::GradedUnitRange) = isempty(a.lasts) ? first(a) - 1 : last(a.lasts)

function gradedrange(lblocklengths::AbstractVector{<:Pair{<:Any,<:Integer}})
return gradedrange(labelled.(last.(lblocklengths), first.(lblocklengths)))
end

function labelled_blocks(a::BlockedUnitRange, labels)
return BlockArrays._BlockedUnitRange(a.first, labelled.(a.lasts, labels))
end

function BlockArrays.findblock(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblock(unlabel_blocks(a), index)
end

function blockedunitrange_findblock(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblock(unlabel_blocks(a), index)
end

function blockedunitrange_findblockindex(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
end

function BlockArrays.findblockindex(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
end

## Block label interface

# Internal function
function get_label(a::BlockedUnitRange, index::Block{1})
return label(blocklasts(a)[Int(index)])
end

# Internal function
function get_label(a::BlockedUnitRange, index::Integer)
return get_label(a, blockedunitrange_findblock(a, index))
end

function blocklabels(a::BlockVector)
return map(BlockRange(a)) do block
return label(@view(a[block]))
end
end

function blocklabels(a::BlockedUnitRange)
# Using `a.lasts` here since that is what is stored
# inside of `BlockedUnitRange`, maybe change that.
# For example, it could be something like:
#
# map(BlockRange(a)) do block
# return label(@view(a[block]))
# end
#
return label.(a.lasts)
end

# TODO: This relies on internals of `BlockArrays`, maybe redesign
# to try to avoid that.
# TODO: Define `set_grades`, `set_sector_labels`, `set_labels`.
function unlabel_blocks(a::BlockedUnitRange)
return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts))
end

## BlockedUnitRage interface

function Base.axes(ga::GradedUnitRange)
return map(axes(unlabel_blocks(ga))) do a
return labelled_blocks(a, blocklabels(ga))
end
end

function BlockArrays.blockfirsts(a::GradedUnitRange)
return labelled.(blockfirsts(unlabel_blocks(a)), blocklabels(a))
end

function BlockArrays.blocklasts(a::GradedUnitRange)
return labelled.(blocklasts(unlabel_blocks(a)), blocklabels(a))
end

function BlockArrays.blocklengths(a::GradedUnitRange)
return labelled.(blocklengths(unlabel_blocks(a)), blocklabels(a))
end

function Base.first(a::GradedUnitRange)
return labelled(first(unlabel_blocks(a)), label(a[Block(1)]))
end

function firstblockindices(a::GradedUnitRange)
return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a))
end

function blockedunitrange_getindex(a::GradedUnitRange, index)
# This uses `blocklasts` since that is what is stored
# in `BlockedUnitRange`, maybe abstract that away.
return labelled(unlabel_blocks(a)[index], get_label(a, index))
end

# Like `a[indices]` but preserves block structure.
using BlockArrays: block, blockindex
function blockedunitrange_getindices(
a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer}
)
first_blockindex = blockedunitrange_findblockindex(a, first(indices))
last_blockindex = blockedunitrange_findblockindex(a, last(indices))
first_block = block(first_blockindex)
last_block = block(last_blockindex)
blocklengths = if first_block == last_block
[length(indices)]
else
map(first_block:last_block) do block
if block == first_block
return length(a[first_block]) - blockindex(first_blockindex) + 1
end
if block == last_block
return blockindex(last_blockindex)
end
return length(a[block])
end
end
return blockedunitrange(indices .+ (first(a) - 1), blocklengths)
end

function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockIndexRange)
return a[block(indices)][only(indices.indices)]
end

function blockedunitrange_getindices(a::BlockedUnitRange, indices::Vector{<:Integer})
return map(index -> a[index], indices)
end

function blockedunitrange_getindices(
a::BlockedUnitRange, indices::Vector{<:Union{Block{1},BlockIndexRange{1}}}
)
return mortar(map(index -> a[index], indices))
end

function blockedunitrange_getindices(a::BlockedUnitRange, indices)
return error("Not implemented.")
end

# The blocks of the corresponding slice.
_blocks(a::AbstractUnitRange, indices) = error("Not implemented")
function _blocks(a::AbstractUnitRange, indices::AbstractUnitRange)
return findblock(a, first(indices)):findblock(a, last(indices))
end
function _blocks(a::AbstractUnitRange, indices::BlockRange)
return indices
end

# The block labels of the corresponding slice.
function blocklabels(a::AbstractUnitRange, indices)
return map(_blocks(a, indices)) do block
return label(a[block])
end
end

function blockedunitrange_getindices(
ga::GradedUnitRange, indices::AbstractUnitRange{<:Integer}
)
a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices)
return labelled_blocks(a_indices, blocklabels(ga, indices))
end

function blockedunitrange_getindices(ga::GradedUnitRange, indices::BlockRange)
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
end

function Base.getindex(a::GradedUnitRange, index::Integer)
return blockedunitrange_getindex(a, index)
end

function Base.getindex(a::GradedUnitRange, index::Block{1})
return blockedunitrange_getindex(a, index)
end

function Base.getindex(a::GradedUnitRange, indices::BlockIndexRange)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(
a::GradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}
)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(a::GradedUnitRange, indices)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(a::GradedUnitRange, indices::AbstractUnitRange{<:Integer})
return blockedunitrange_getindices(a, indices)
end
4 changes: 4 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
107 changes: 107 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
@eval module $(gensym())
using BlockArrays:
Block, BlockVector, blockedrange, blockfirsts, blocklasts, blocklength, blocklengths
using NDTensors.GradedAxesNext: GradedUnitRange, blocklabels, gradedrange
using NDTensors.LabelledNumbers: LabelledUnitRange, label, unlabel
using Test: @test, @test_broken, @testset
@testset "GradedAxes" begin
a = gradedrange(["x" => 2, "y" => 3])
@test a isa GradedUnitRange
@test length(a) == 5
@test a[Block(2)] == 3:5
@test label(a[Block(2)]) == "y"
@test a[Block(2)] isa LabelledUnitRange
@test a[4] == 4
@test label(a[4]) == "y"
@test unlabel(a[4]) == 4
@test blocklengths(a) == [2, 3]
@test blocklabels(a) == ["x", "y"]
@test label.(blocklengths(a)) == ["x", "y"]
@test blockfirsts(a) == [1, 3]
@test label.(blockfirsts(a)) == ["x", "y"]
@test first(a) == 1
@test label(first(a)) == "x"
@test blocklasts(a) == [2, 5]
@test label.(blocklasts(a)) == ["x", "y"]
@test last(a) == 5
@test label(last(a)) == "y"
@test a[Block(2)] == 3:5
@test label(a[Block(2)]) == "y"
@test length(a[Block(2)]) == 3
@test blocklengths(only(axes(a))) == blocklengths(a)
@test blocklabels(only(axes(a))) == blocklabels(a)

# Slicing operations
x = gradedrange(["x" => 2, "y" => 3])
a = x[2:4]
@test a isa GradedUnitRange
@test length(a) == 3
@test blocklength(a) == 2
@test a[Block(1)] == 2:2
@test label(a[Block(1)]) == "x"
@test a[Block(2)] == 3:4
@test label(a[Block(2)]) == "y"
@test isone(first(only(axes(a))))
@test length(only(axes(a))) == length(a)
@test blocklengths(only(axes(a))) == blocklengths(a)

x = gradedrange(["x" => 2, "y" => 3])
a = x[3:4]
@test a isa GradedUnitRange
@test length(a) == 2
@test blocklength(a) == 1
@test a[Block(1)] == 3:4
@test label(a[Block(1)]) == "y"

x = gradedrange(["x" => 2, "y" => 3])
a = x[2:4][1:2]
@test a isa GradedUnitRange
@test length(a) == 2
@test blocklength(a) == 2
@test a[Block(1)] == 2:2
@test label(a[Block(1)]) == "x"
@test a[Block(2)] == 3:3
@test label(a[Block(2)]) == "y"

x = gradedrange(["x" => 2, "y" => 3])
a = x[Block(2)[2:3]]
@test a isa LabelledUnitRange
@test length(a) == 2
@test a == 4:5
@test label(a) == "y"

x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
a = x[Block(2):Block(3)]
@test a isa GradedUnitRange
@test length(a) == 7
@test blocklength(a) == 2
@test blocklengths(a) == [3, 4]
@test blocklabels(a) == ["y", "z"]
@test a[Block(1)] == 3:5
@test a[Block(2)] == 6:9

x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
a = x[[Block(3), Block(2)]]
@test a isa BlockVector
@test length(a) == 7
@test blocklength(a) == 2
# TODO: `BlockArrays` doesn't define `blocklengths`
# for `BlockVector`, should it?
@test_broken blocklengths(a) == [4, 3]
@test blocklabels(a) == ["z", "y"]
@test a[Block(1)] == 6:9
@test a[Block(2)] == 3:5

x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
a = x[[Block(3)[2:3], Block(2)[2:3]]]
@test a isa BlockVector
@test length(a) == 4
@test blocklength(a) == 2
# TODO: `BlockArrays` doesn't define `blocklengths`
# for `BlockVector`, should it?
@test_broken blocklengths(a) == [2, 2]
@test blocklabels(a) == ["z", "y"]
@test a[Block(1)] == 7:8
@test a[Block(2)] == 4:5
end
end
2 changes: 2 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
7 changes: 7 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/LabelledNumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module LabelledNumbers
include("labelled_interface.jl")
include("labellednumber.jl")
include("labelledinteger.jl")
include("labelledarray.jl")
include("labelledunitrange.jl")
end
Loading