Skip to content

Commit 12fbcc2

Browse files
authored
[NDTensors] Introduce LabelledNumbers and GradedAxesNext (#1351)
1 parent 638624f commit 12fbcc2

16 files changed

+592
-0
lines changed

NDTensors/src/imports.jl

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ for lib in [
3636
:BroadcastMapConversion,
3737
:RankFactorization,
3838
:Sectors,
39+
:LabelledNumbers,
40+
:GradedAxesNext,
3941
:GradedAxes,
4042
:TensorAlgebra,
4143
:SparseArrayInterface,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module GradedAxesNext
2+
include("gradedunitrange.jl")
3+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
using BlockArrays:
2+
BlockArrays,
3+
Block,
4+
BlockedUnitRange,
5+
BlockRange,
6+
BlockVector,
7+
blockedrange,
8+
BlockIndexRange,
9+
blockfirsts,
10+
blocklasts,
11+
blocklengths,
12+
findblock,
13+
findblockindex,
14+
mortar
15+
using ..LabelledNumbers: LabelledNumbers, LabelledInteger, label, labelled, unlabel
16+
17+
# Custom `BlockedUnitRange` constructor that takes a unit range
18+
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
19+
function blockedunitrange(a::AbstractUnitRange, blocklengths)
20+
blocklengths_shifted = copy(blocklengths)
21+
blocklengths_shifted[1] += (first(a) - 1)
22+
blocklasts = cumsum(blocklengths_shifted)
23+
return BlockArrays._BlockedUnitRange(first(a), blocklasts)
24+
end
25+
26+
# Circumvents issue in `findblock` that assumes the `BlockedUnitRange`
27+
# starts at 1.
28+
# TODO: Raise an issue with `BlockArrays`.
29+
function blockedunitrange_findblock(a::BlockedUnitRange, index::Integer)
30+
@boundscheck index in 1:length(a) || throw(BoundsError(a, index))
31+
return @inbounds findblock(a, index + first(a) - 1)
32+
end
33+
34+
# Circumvents issue in `findblockindex` that assumes the `BlockedUnitRange`
35+
# starts at 1.
36+
# TODO: Raise an issue with `BlockArrays`.
37+
function blockedunitrange_findblockindex(a::BlockedUnitRange, index::Integer)
38+
@boundscheck index in 1:length(a) || throw(BoundsError())
39+
return @inbounds findblockindex(a, index + first(a) - 1)
40+
end
41+
42+
const GradedUnitRange{BlockLasts<:Vector{<:LabelledInteger}} = BlockedUnitRange{BlockLasts}
43+
44+
function gradedrange(lblocklengths::AbstractVector{<:LabelledInteger})
45+
brange = blockedrange(unlabel.(lblocklengths))
46+
lblocklasts = labelled.(blocklasts(brange), label.(lblocklengths))
47+
# TODO: `first` is forced to be `Int` in `BlockArrays.BlockedUnitRange`,
48+
# so this doesn't do anything right now. Make a PR to generalize it.
49+
firstlength = first(lblocklengths)
50+
lfirst = oneunit(firstlength)
51+
return BlockArrays._BlockedUnitRange(lfirst, lblocklasts)
52+
end
53+
54+
Base.last(a::GradedUnitRange) = isempty(a.lasts) ? first(a) - 1 : last(a.lasts)
55+
56+
function gradedrange(lblocklengths::AbstractVector{<:Pair{<:Any,<:Integer}})
57+
return gradedrange(labelled.(last.(lblocklengths), first.(lblocklengths)))
58+
end
59+
60+
function labelled_blocks(a::BlockedUnitRange, labels)
61+
return BlockArrays._BlockedUnitRange(a.first, labelled.(a.lasts, labels))
62+
end
63+
64+
function BlockArrays.findblock(a::GradedUnitRange, index::Integer)
65+
return blockedunitrange_findblock(unlabel_blocks(a), index)
66+
end
67+
68+
function blockedunitrange_findblock(a::GradedUnitRange, index::Integer)
69+
return blockedunitrange_findblock(unlabel_blocks(a), index)
70+
end
71+
72+
function blockedunitrange_findblockindex(a::GradedUnitRange, index::Integer)
73+
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
74+
end
75+
76+
function BlockArrays.findblockindex(a::GradedUnitRange, index::Integer)
77+
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
78+
end
79+
80+
## Block label interface
81+
82+
# Internal function
83+
function get_label(a::BlockedUnitRange, index::Block{1})
84+
return label(blocklasts(a)[Int(index)])
85+
end
86+
87+
# Internal function
88+
function get_label(a::BlockedUnitRange, index::Integer)
89+
return get_label(a, blockedunitrange_findblock(a, index))
90+
end
91+
92+
function blocklabels(a::BlockVector)
93+
return map(BlockRange(a)) do block
94+
return label(@view(a[block]))
95+
end
96+
end
97+
98+
function blocklabels(a::BlockedUnitRange)
99+
# Using `a.lasts` here since that is what is stored
100+
# inside of `BlockedUnitRange`, maybe change that.
101+
# For example, it could be something like:
102+
#
103+
# map(BlockRange(a)) do block
104+
# return label(@view(a[block]))
105+
# end
106+
#
107+
return label.(a.lasts)
108+
end
109+
110+
# TODO: This relies on internals of `BlockArrays`, maybe redesign
111+
# to try to avoid that.
112+
# TODO: Define `set_grades`, `set_sector_labels`, `set_labels`.
113+
function unlabel_blocks(a::BlockedUnitRange)
114+
return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts))
115+
end
116+
117+
## BlockedUnitRage interface
118+
119+
function Base.axes(ga::GradedUnitRange)
120+
return map(axes(unlabel_blocks(ga))) do a
121+
return labelled_blocks(a, blocklabels(ga))
122+
end
123+
end
124+
125+
function BlockArrays.blockfirsts(a::GradedUnitRange)
126+
return labelled.(blockfirsts(unlabel_blocks(a)), blocklabels(a))
127+
end
128+
129+
function BlockArrays.blocklasts(a::GradedUnitRange)
130+
return labelled.(blocklasts(unlabel_blocks(a)), blocklabels(a))
131+
end
132+
133+
function BlockArrays.blocklengths(a::GradedUnitRange)
134+
return labelled.(blocklengths(unlabel_blocks(a)), blocklabels(a))
135+
end
136+
137+
function Base.first(a::GradedUnitRange)
138+
return labelled(first(unlabel_blocks(a)), label(a[Block(1)]))
139+
end
140+
141+
function firstblockindices(a::GradedUnitRange)
142+
return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a))
143+
end
144+
145+
function blockedunitrange_getindex(a::GradedUnitRange, index)
146+
# This uses `blocklasts` since that is what is stored
147+
# in `BlockedUnitRange`, maybe abstract that away.
148+
return labelled(unlabel_blocks(a)[index], get_label(a, index))
149+
end
150+
151+
# Like `a[indices]` but preserves block structure.
152+
using BlockArrays: block, blockindex
153+
function blockedunitrange_getindices(
154+
a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer}
155+
)
156+
first_blockindex = blockedunitrange_findblockindex(a, first(indices))
157+
last_blockindex = blockedunitrange_findblockindex(a, last(indices))
158+
first_block = block(first_blockindex)
159+
last_block = block(last_blockindex)
160+
blocklengths = if first_block == last_block
161+
[length(indices)]
162+
else
163+
map(first_block:last_block) do block
164+
if block == first_block
165+
return length(a[first_block]) - blockindex(first_blockindex) + 1
166+
end
167+
if block == last_block
168+
return blockindex(last_blockindex)
169+
end
170+
return length(a[block])
171+
end
172+
end
173+
return blockedunitrange(indices .+ (first(a) - 1), blocklengths)
174+
end
175+
176+
function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockIndexRange)
177+
return a[block(indices)][only(indices.indices)]
178+
end
179+
180+
function blockedunitrange_getindices(a::BlockedUnitRange, indices::Vector{<:Integer})
181+
return map(index -> a[index], indices)
182+
end
183+
184+
function blockedunitrange_getindices(
185+
a::BlockedUnitRange, indices::Vector{<:Union{Block{1},BlockIndexRange{1}}}
186+
)
187+
return mortar(map(index -> a[index], indices))
188+
end
189+
190+
function blockedunitrange_getindices(a::BlockedUnitRange, indices)
191+
return error("Not implemented.")
192+
end
193+
194+
# The blocks of the corresponding slice.
195+
_blocks(a::AbstractUnitRange, indices) = error("Not implemented")
196+
function _blocks(a::AbstractUnitRange, indices::AbstractUnitRange)
197+
return findblock(a, first(indices)):findblock(a, last(indices))
198+
end
199+
function _blocks(a::AbstractUnitRange, indices::BlockRange)
200+
return indices
201+
end
202+
203+
# The block labels of the corresponding slice.
204+
function blocklabels(a::AbstractUnitRange, indices)
205+
return map(_blocks(a, indices)) do block
206+
return label(a[block])
207+
end
208+
end
209+
210+
function blockedunitrange_getindices(
211+
ga::GradedUnitRange, indices::AbstractUnitRange{<:Integer}
212+
)
213+
a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices)
214+
return labelled_blocks(a_indices, blocklabels(ga, indices))
215+
end
216+
217+
function blockedunitrange_getindices(ga::GradedUnitRange, indices::BlockRange)
218+
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
219+
end
220+
221+
function Base.getindex(a::GradedUnitRange, index::Integer)
222+
return blockedunitrange_getindex(a, index)
223+
end
224+
225+
function Base.getindex(a::GradedUnitRange, index::Block{1})
226+
return blockedunitrange_getindex(a, index)
227+
end
228+
229+
function Base.getindex(a::GradedUnitRange, indices::BlockIndexRange)
230+
return blockedunitrange_getindices(a, indices)
231+
end
232+
233+
function Base.getindex(
234+
a::GradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}
235+
)
236+
return blockedunitrange_getindices(a, indices)
237+
end
238+
239+
function Base.getindex(a::GradedUnitRange, indices)
240+
return blockedunitrange_getindices(a, indices)
241+
end
242+
243+
function Base.getindex(a::GradedUnitRange, indices::AbstractUnitRange{<:Integer})
244+
return blockedunitrange_getindices(a, indices)
245+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
3+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
@eval module $(gensym())
2+
using BlockArrays:
3+
Block, BlockVector, blockedrange, blockfirsts, blocklasts, blocklength, blocklengths
4+
using NDTensors.GradedAxesNext: GradedUnitRange, blocklabels, gradedrange
5+
using NDTensors.LabelledNumbers: LabelledUnitRange, label, unlabel
6+
using Test: @test, @test_broken, @testset
7+
@testset "GradedAxes" begin
8+
a = gradedrange(["x" => 2, "y" => 3])
9+
@test a isa GradedUnitRange
10+
@test length(a) == 5
11+
@test a[Block(2)] == 3:5
12+
@test label(a[Block(2)]) == "y"
13+
@test a[Block(2)] isa LabelledUnitRange
14+
@test a[4] == 4
15+
@test label(a[4]) == "y"
16+
@test unlabel(a[4]) == 4
17+
@test blocklengths(a) == [2, 3]
18+
@test blocklabels(a) == ["x", "y"]
19+
@test label.(blocklengths(a)) == ["x", "y"]
20+
@test blockfirsts(a) == [1, 3]
21+
@test label.(blockfirsts(a)) == ["x", "y"]
22+
@test first(a) == 1
23+
@test label(first(a)) == "x"
24+
@test blocklasts(a) == [2, 5]
25+
@test label.(blocklasts(a)) == ["x", "y"]
26+
@test last(a) == 5
27+
@test label(last(a)) == "y"
28+
@test a[Block(2)] == 3:5
29+
@test label(a[Block(2)]) == "y"
30+
@test length(a[Block(2)]) == 3
31+
@test blocklengths(only(axes(a))) == blocklengths(a)
32+
@test blocklabels(only(axes(a))) == blocklabels(a)
33+
34+
# Slicing operations
35+
x = gradedrange(["x" => 2, "y" => 3])
36+
a = x[2:4]
37+
@test a isa GradedUnitRange
38+
@test length(a) == 3
39+
@test blocklength(a) == 2
40+
@test a[Block(1)] == 2:2
41+
@test label(a[Block(1)]) == "x"
42+
@test a[Block(2)] == 3:4
43+
@test label(a[Block(2)]) == "y"
44+
@test isone(first(only(axes(a))))
45+
@test length(only(axes(a))) == length(a)
46+
@test blocklengths(only(axes(a))) == blocklengths(a)
47+
48+
x = gradedrange(["x" => 2, "y" => 3])
49+
a = x[3:4]
50+
@test a isa GradedUnitRange
51+
@test length(a) == 2
52+
@test blocklength(a) == 1
53+
@test a[Block(1)] == 3:4
54+
@test label(a[Block(1)]) == "y"
55+
56+
x = gradedrange(["x" => 2, "y" => 3])
57+
a = x[2:4][1:2]
58+
@test a isa GradedUnitRange
59+
@test length(a) == 2
60+
@test blocklength(a) == 2
61+
@test a[Block(1)] == 2:2
62+
@test label(a[Block(1)]) == "x"
63+
@test a[Block(2)] == 3:3
64+
@test label(a[Block(2)]) == "y"
65+
66+
x = gradedrange(["x" => 2, "y" => 3])
67+
a = x[Block(2)[2:3]]
68+
@test a isa LabelledUnitRange
69+
@test length(a) == 2
70+
@test a == 4:5
71+
@test label(a) == "y"
72+
73+
x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
74+
a = x[Block(2):Block(3)]
75+
@test a isa GradedUnitRange
76+
@test length(a) == 7
77+
@test blocklength(a) == 2
78+
@test blocklengths(a) == [3, 4]
79+
@test blocklabels(a) == ["y", "z"]
80+
@test a[Block(1)] == 3:5
81+
@test a[Block(2)] == 6:9
82+
83+
x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
84+
a = x[[Block(3), Block(2)]]
85+
@test a isa BlockVector
86+
@test length(a) == 7
87+
@test blocklength(a) == 2
88+
# TODO: `BlockArrays` doesn't define `blocklengths`
89+
# for `BlockVector`, should it?
90+
@test_broken blocklengths(a) == [4, 3]
91+
@test blocklabels(a) == ["z", "y"]
92+
@test a[Block(1)] == 6:9
93+
@test a[Block(2)] == 3:5
94+
95+
x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
96+
a = x[[Block(3)[2:3], Block(2)[2:3]]]
97+
@test a isa BlockVector
98+
@test length(a) == 4
99+
@test blocklength(a) == 2
100+
# TODO: `BlockArrays` doesn't define `blocklengths`
101+
# for `BlockVector`, should it?
102+
@test_broken blocklengths(a) == [2, 2]
103+
@test blocklabels(a) == ["z", "y"]
104+
@test a[Block(1)] == 7:8
105+
@test a[Block(2)] == 4:5
106+
end
107+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module LabelledNumbers
2+
include("labelled_interface.jl")
3+
include("labellednumber.jl")
4+
include("labelledinteger.jl")
5+
include("labelledarray.jl")
6+
include("labelledunitrange.jl")
7+
end

0 commit comments

Comments
 (0)