Skip to content

Commit 70052f1

Browse files
authored
[GradedAxes] Replace GradedAxes with GradedAxesNext (#1355)
1 parent 12fbcc2 commit 70052f1

29 files changed

+557
-754
lines changed

NDTensors/src/imports.jl

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ for lib in [
3737
:RankFactorization,
3838
:Sectors,
3939
:LabelledNumbers,
40-
:GradedAxesNext,
4140
:GradedAxes,
4241
:TensorAlgebra,
4342
:SparseArrayInterface,

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl

+3-10
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,18 @@ module BlockSparseArraysGradedAxesExt
22
using BlockArrays: AbstractBlockVector, Block, BlockedUnitRange
33
using ..BlockSparseArrays: BlockSparseArrays, block_merge
44
using ...GradedAxes:
5-
AbstractGradedUnitRange,
6-
OneToOne,
7-
blockmergesortperm,
8-
blocksortperm,
9-
invblockperm,
10-
tensor_product
5+
GradedUnitRange, OneToOne, blockmergesortperm, blocksortperm, invblockperm, tensor_product
116
using ...TensorAlgebra:
127
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
138

149
# TODO: Make a `ReduceWhile` library.
1510
include("reducewhile.jl")
1611

17-
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
12+
TensorAlgebra.FusionStyle(::GradedUnitRange) = SectorFusion()
1813

1914
# TODO: Need to implement this! Will require implementing
2015
# `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
21-
function BlockSparseArrays.block_merge(
22-
a::AbstractGradedUnitRange, blockmerger::BlockedUnitRange
23-
)
16+
function BlockSparseArrays.block_merge(a::GradedUnitRange, blockmerger::BlockedUnitRange)
2417
return a
2518
end
2619

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ end
2222

2323
# TODO: Implement as `copy(@view a[I...])`, which is then implemented
2424
# through `ArrayLayouts.sub_materialize`.
25+
using ..SparseArrayInterface: set_getindex_zero_function
2526
function blocksparse_getindex(
2627
a::AbstractArray{<:Any,N}, I::Vararg{AbstractVector{<:Block{1}},N}
2728
) where {N}
@@ -30,8 +31,9 @@ function blocksparse_getindex(
3031
CI = map(i -> Int.(i), I)
3132
subblocks_a = blocks_a[CI...]
3233
subaxes = ntuple(ndims(a)) do i
33-
return axes(a, i)[I[i]]
34+
return only(axes(axes(a, i)[I[i]]))
3435
end
36+
subblocks_a = set_getindex_zero_function(subblocks_a, BlockZero(subaxes))
3537
return typeof(a)(subblocks_a, subaxes)
3638
end
3739

NDTensors/src/lib/GradedAxes/Project.toml

-2
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[deps]
2-
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module GradedAxesSectorsExt
22
using ..GradedAxes: GradedAxes
3-
using ...Sectors: Sectors, AbstractCategory, , dual
3+
using ...Sectors: Sectors, AbstractCategory, # , dual
44

5-
GradedAxes.fuse(c1::AbstractCategory, c2::AbstractCategory) = only(c1 c2)
5+
GradedAxes.fuse_labels(c1::AbstractCategory, c2::AbstractCategory) = only(c1 c2)
66

7-
GradedAxes.dual(c::AbstractCategory) = dual(c)
7+
# TODO: Decide the fate of `dual`.
8+
## GradedAxes.dual(c::AbstractCategory) = dual(c)
89
end

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
@eval module $(gensym())
2-
using NDTensors.GradedAxes: dual, fuse
2+
using NDTensors.GradedAxes: dual, fuse_labels
33
using NDTensors.Sectors: U1, Z
44
using Test: @test, @testset
55

66
@testset "GradedAxesSectorsExt" begin
7-
@test fuse(U1(1), U1(2)) == U1(3)
7+
@test fuse_labels(U1(1), U1(2)) == U1(3)
88
@test dual(U1(2)) == U1(-2)
99

10-
@test fuse(Z{2}(1), Z{2}(1)) == Z{2}(0)
11-
@test fuse(Z{2}(0), Z{2}(1)) == Z{2}(1)
10+
@test fuse_labels(Z{2}(1), Z{2}(1)) == Z{2}(0)
11+
@test fuse_labels(Z{2}(0), Z{2}(1)) == Z{2}(1)
1212
@test dual(Z{2}(1)) == Z{2}(1)
1313
@test dual(Z{2}(0)) == Z{2}(0)
1414
end
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
module GradedAxes
2-
include("groupsortperm.jl")
3-
include("tensor_product.jl")
4-
include("abstractgradedunitrange.jl")
52
include("gradedunitrange.jl")
3+
include("fusion.jl")
64
include("../ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl")
75
end

NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl

-150
This file was deleted.
+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
using BlockArrays: BlockedUnitRange
2+
3+
# TODO: Decide what to do about `dual`. Should there just
4+
# be a version in `Sectors`?
5+
## function dual end
6+
7+
# Represents the range `1:1` or `Base.OneTo(1)`.
8+
struct OneToOne{T} <: AbstractUnitRange{T} end
9+
OneToOne() = OneToOne{Bool}()
10+
Base.first(a::OneToOne) = one(eltype(a))
11+
Base.last(a::OneToOne) = one(eltype(a))
12+
13+
# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
14+
# https://en.wikipedia.org/wiki/Tensor_product
15+
# https://github.com/KeitaNakamura/Tensorial.jl
16+
function tensor_product(
17+
a1::AbstractUnitRange,
18+
a2::AbstractUnitRange,
19+
a3::AbstractUnitRange,
20+
a_rest::Vararg{AbstractUnitRange},
21+
)
22+
return foldl(tensor_product, (a1, a2, a3, a_rest...))
23+
end
24+
25+
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
26+
return error("Not implemented yet.")
27+
end
28+
29+
function tensor_product(a1::Base.OneTo, a2::Base.OneTo)
30+
return Base.OneTo(length(a1) * length(a2))
31+
end
32+
33+
function tensor_product(a1::OneToOne, a2::AbstractUnitRange)
34+
return a2
35+
end
36+
37+
function tensor_product(a1::AbstractUnitRange, a2::OneToOne)
38+
return a1
39+
end
40+
41+
function tensor_product(a1::OneToOne, a2::OneToOne)
42+
return OneToOne()
43+
end
44+
45+
function fuse_labels(x, y)
46+
return error(
47+
"`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`."
48+
)
49+
end
50+
51+
function fuse_blocklengths(x::Integer, y::Integer)
52+
return x * y
53+
end
54+
55+
using ..LabelledNumbers: LabelledInteger, label, labelled, unlabel
56+
function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger)
57+
return labelled(unlabel(x) * unlabel(y), fuse_labels(label(x), label(y)))
58+
end
59+
60+
using BlockArrays: blockedrange, blocks
61+
function tensor_product(a1::BlockedUnitRange, a2::BlockedUnitRange)
62+
blocklengths = map(vec(collect(Iterators.product(blocks(a1), blocks(a2))))) do x
63+
return mapreduce(length, fuse_blocklengths, x)
64+
end
65+
return blockedrange(blocklengths)
66+
end
67+
68+
function blocksortperm(a::BlockedUnitRange)
69+
# TODO: Figure out how to deal with dual sectors.
70+
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
71+
## return Block.(sortperm(nondual_sectors(a); rev=isdual(a)))
72+
return Block.(sortperm(blocklabels(a)))
73+
end
74+
75+
using BlockArrays: Block, BlockVector
76+
using SplitApplyCombine: groupcount
77+
# Get the permutation for sorting, then group by common elements.
78+
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
79+
function groupsortperm(v; kwargs...)
80+
perm = sortperm(v; kwargs...)
81+
v_sorted = @view v[perm]
82+
group_lengths = collect(groupcount(identity, v_sorted))
83+
return BlockVector(perm, group_lengths)
84+
end
85+
86+
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
87+
# Get the permutation for sorting, then group by common elements.
88+
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
89+
function blockmergesortperm(a::BlockedUnitRange)
90+
# If it is dual, reverse the sorting so the sectors
91+
# end up sorted in the same way whether or not the space
92+
# is dual.
93+
# TODO: Figure out how to deal with dual sectors.
94+
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
95+
## return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a)))
96+
return Block.(groupsortperm(blocklabels(a)))
97+
end
98+
99+
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
100+
invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a)))
101+
102+
# Used by `TensorAlgebra.fusedims` in `BlockSparseArraysGradedAxesExt`.
103+
function blockmergesortperm(a::GradedUnitRange)
104+
# If it is dual, reverse the sorting so the sectors
105+
# end up sorted in the same way whether or not the space
106+
# is dual.
107+
# TODO: Figure out how to deal with dual sectors.
108+
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
109+
return Block.(groupsortperm(blocklabels(a)))
110+
end

0 commit comments

Comments
 (0)