|
| 1 | +using BlockArrays: BlockedUnitRange |
| 2 | + |
| 3 | +# TODO: Decide what to do about `dual`. Should there just |
| 4 | +# be a version in `Sectors`? |
| 5 | +## function dual end |
| 6 | + |
| 7 | +# Represents the range `1:1` or `Base.OneTo(1)`. |
| 8 | +struct OneToOne{T} <: AbstractUnitRange{T} end |
| 9 | +OneToOne() = OneToOne{Bool}() |
| 10 | +Base.first(a::OneToOne) = one(eltype(a)) |
| 11 | +Base.last(a::OneToOne) = one(eltype(a)) |
| 12 | + |
| 13 | +# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl |
| 14 | +# https://en.wikipedia.org/wiki/Tensor_product |
| 15 | +# https://github.com/KeitaNakamura/Tensorial.jl |
| 16 | +function tensor_product( |
| 17 | + a1::AbstractUnitRange, |
| 18 | + a2::AbstractUnitRange, |
| 19 | + a3::AbstractUnitRange, |
| 20 | + a_rest::Vararg{AbstractUnitRange}, |
| 21 | +) |
| 22 | + return foldl(tensor_product, (a1, a2, a3, a_rest...)) |
| 23 | +end |
| 24 | + |
| 25 | +function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange) |
| 26 | + return error("Not implemented yet.") |
| 27 | +end |
| 28 | + |
| 29 | +function tensor_product(a1::Base.OneTo, a2::Base.OneTo) |
| 30 | + return Base.OneTo(length(a1) * length(a2)) |
| 31 | +end |
| 32 | + |
| 33 | +function tensor_product(a1::OneToOne, a2::AbstractUnitRange) |
| 34 | + return a2 |
| 35 | +end |
| 36 | + |
| 37 | +function tensor_product(a1::AbstractUnitRange, a2::OneToOne) |
| 38 | + return a1 |
| 39 | +end |
| 40 | + |
| 41 | +function tensor_product(a1::OneToOne, a2::OneToOne) |
| 42 | + return OneToOne() |
| 43 | +end |
| 44 | + |
| 45 | +function fuse_labels(x, y) |
| 46 | + return error( |
| 47 | + "`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`." |
| 48 | + ) |
| 49 | +end |
| 50 | + |
| 51 | +function fuse_blocklengths(x::Integer, y::Integer) |
| 52 | + return x * y |
| 53 | +end |
| 54 | + |
| 55 | +using ..LabelledNumbers: LabelledInteger, label, labelled, unlabel |
| 56 | +function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger) |
| 57 | + return labelled(unlabel(x) * unlabel(y), fuse_labels(label(x), label(y))) |
| 58 | +end |
| 59 | + |
| 60 | +using BlockArrays: blockedrange, blocks |
| 61 | +function tensor_product(a1::BlockedUnitRange, a2::BlockedUnitRange) |
| 62 | + blocklengths = map(vec(collect(Iterators.product(blocks(a1), blocks(a2))))) do x |
| 63 | + return mapreduce(length, fuse_blocklengths, x) |
| 64 | + end |
| 65 | + return blockedrange(blocklengths) |
| 66 | +end |
| 67 | + |
| 68 | +function blocksortperm(a::BlockedUnitRange) |
| 69 | + # TODO: Figure out how to deal with dual sectors. |
| 70 | + # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. |
| 71 | + ## return Block.(sortperm(nondual_sectors(a); rev=isdual(a))) |
| 72 | + return Block.(sortperm(blocklabels(a))) |
| 73 | +end |
| 74 | + |
| 75 | +using BlockArrays: Block, BlockVector |
| 76 | +using SplitApplyCombine: groupcount |
| 77 | +# Get the permutation for sorting, then group by common elements. |
| 78 | +# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] |
| 79 | +function groupsortperm(v; kwargs...) |
| 80 | + perm = sortperm(v; kwargs...) |
| 81 | + v_sorted = @view v[perm] |
| 82 | + group_lengths = collect(groupcount(identity, v_sorted)) |
| 83 | + return BlockVector(perm, group_lengths) |
| 84 | +end |
| 85 | + |
| 86 | +# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. |
| 87 | +# Get the permutation for sorting, then group by common elements. |
| 88 | +# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] |
| 89 | +function blockmergesortperm(a::BlockedUnitRange) |
| 90 | + # If it is dual, reverse the sorting so the sectors |
| 91 | + # end up sorted in the same way whether or not the space |
| 92 | + # is dual. |
| 93 | + # TODO: Figure out how to deal with dual sectors. |
| 94 | + # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. |
| 95 | + ## return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a))) |
| 96 | + return Block.(groupsortperm(blocklabels(a))) |
| 97 | +end |
| 98 | + |
| 99 | +# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. |
| 100 | +invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a))) |
| 101 | + |
| 102 | +# Used by `TensorAlgebra.fusedims` in `BlockSparseArraysGradedAxesExt`. |
| 103 | +function blockmergesortperm(a::GradedUnitRange) |
| 104 | + # If it is dual, reverse the sorting so the sectors |
| 105 | + # end up sorted in the same way whether or not the space |
| 106 | + # is dual. |
| 107 | + # TODO: Figure out how to deal with dual sectors. |
| 108 | + # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. |
| 109 | + return Block.(groupsortperm(blocklabels(a))) |
| 110 | +end |
0 commit comments