Skip to content

define Base.permutedims! #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.3"
version = "0.5.4"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
13 changes: 11 additions & 2 deletions src/fusiontensor/base_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ Base.imag(ft::FusionTensor) = set_data_matrix(ft, imag(data_matrix(ft)))

Base.permutedims(ft::FusionTensor, args...) = fusiontensor_permutedims(ft, args...)

function Base.permutedims!(ftdst::FusionTensor, ftsrc::FusionTensor, args...)
return fusiontensor_permutedims!(ftdst, ftsrc, args...)
end

Base.real(ft::FusionTensor{<:Real}) = ft # same object
Base.real(ft::FusionTensor) = set_data_matrix(ft, real(data_matrix(ft)))

Expand All @@ -103,13 +107,18 @@ end
function Base.similar(::FusionTensor, ::Type, ::Tuple{})
throw(MethodError(similar, (Tuple{},)))
end

function Base.similar(
ft::FusionTensor, ::Type{T}, new_axes::Tuple{<:Tuple,<:Tuple}
) where {T}
return similar(ft, T, tuplemortar(new_axes))
end
function Base.similar(::FusionTensor, ::Type{T}, new_axes::BlockedTuple{2}) where {T}
function Base.similar(ft::FusionTensor, ::Type{T}, new_axes::BlockedTuple{2}) where {T}
return similar(ft, T, FusionTensorAxes(new_axes))
end
function Base.similar(ft::FusionTensor, new_axes::FusionTensorAxes)
return similar(ft, eltype(ft), new_axes)
end
function Base.similar(::FusionTensor, ::Type{T}, new_axes::FusionTensorAxes) where {T}
return FusionTensor{T}(undef, new_axes)
end

Expand Down
2 changes: 2 additions & 0 deletions src/fusiontensor/fusiontensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ function GradedArrays.sector_type(::Type{FT}) where {FT<:FusionTensor}
return sector_type(type_parameters(FT, 3))
end

SymmetryStyle(::Type{FT}) where {FT<:FusionTensor} = SymmetryStyle(sector_type(FT))

# ============================== FusionTensor interface ==================================

# misc access
Expand Down
5 changes: 5 additions & 0 deletions src/fusiontensor/fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using GradedArrays:
GradedArrays,
AbstractGradedUnitRange,
AbstractSector,
SymmetryStyle,
TrivialSector,
dual,
sector_type,
Expand Down Expand Up @@ -110,6 +111,10 @@ function GradedArrays.sector_type(::Type{FTA}) where {BT,FTA<:FusionTensorAxes{B
return sector_type(type_parameters(type_parameters(BT, 3), 1))
end

function GradedArrays.SymmetryStyle(::Type{FTA}) where {FTA<:FusionTensorAxes}
return SymmetryStyle(sector_type(FTA))
end

function GradedArrays.checkspaces(
::Type{Bool}, left::FusionTensorAxes, right::FusionTensorAxes
)
Expand Down
83 changes: 52 additions & 31 deletions src/permutedims/permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,75 @@
using BlockArrays: blocklengths
using Strided: Strided, @strided

using TensorAlgebra: BlockedPermutation, permmortar, blockpermute
using GradedArrays: AbelianStyle, NotAbelianStyle, SymmetryStyle, checkspaces
using TensorAlgebra: AbstractBlockPermutation, permmortar

function naive_permutedims(ft, biperm::BlockedPermutation{2})
@assert ndims(ft) == length(biperm)

# naive permute: cast to dense, permutedims, cast to FusionTensor
arr = Array(ft)
permuted_arr = permutedims(arr, Tuple(biperm))
permuted = to_fusiontensor(permuted_arr, blocks(axes(ft)[biperm])...)
return permuted
# permutedims with 1 tuple of 2 separate tuples
function fusiontensor_permutedims(ft, new_leg_dims::Tuple{Tuple,Tuple})
return fusiontensor_permutedims(ft, new_leg_dims...)
end

# permutedims with 1 tuple of 2 separate tuples
function fusiontensor_permutedims(ft, new_leg_indices::Tuple{Tuple,Tuple})
return fusiontensor_permutedims(ft, new_leg_indices...)
function fusiontensor_permutedims!(ftdst, ftsrc, new_leg_dims::Tuple{Tuple,Tuple})
return fusiontensor_permutedims!(ftdst, ftsrc, new_leg_dims...)
end

# permutedims with 2 separate tuples
function fusiontensor_permutedims(
ft, new_codomain_indices::Tuple, new_domain_indices::Tuple
)
biperm = permmortar((new_codomain_indices, new_domain_indices))
function fusiontensor_permutedims(ft, new_codomain_dims::Tuple, new_domain_dims::Tuple)
biperm = permmortar((new_codomain_dims, new_domain_dims))
return fusiontensor_permutedims(ft, biperm)
end

function fusiontensor_permutedims(ft, biperm::BlockedPermutation{2})
function fusiontensor_permutedims!(
ftdst, ftsrc, new_codomain_dims::Tuple, new_domain_dims::Tuple
)
biperm = permmortar((new_codomain_dims, new_domain_dims))
return fusiontensor_permutedims!(ftdst, ftsrc, biperm)
end

# permutedims with BlockedPermutation
function fusiontensor_permutedims(ft, biperm::AbstractBlockPermutation{2})
ndims(ft) == length(biperm) || throw(ArgumentError("Invalid permutation length"))
ftdst = similar(ft, axes(ft)[biperm])
fusiontensor_permutedims!(ftdst, ft, biperm)
return ftdst
end

function fusiontensor_permutedims!(ftdst, ftsrc, biperm::AbstractBlockPermutation{2})
ndims(ftsrc) == length(biperm) || throw(ArgumentError("Invalid permutation length"))
blocklengths(axes(ftdst)) == blocklengths(biperm) ||
throw(ArgumentError("Destination tensor does not match bipermutation"))
checkspaces(axes(ftdst), axes(ftsrc)[biperm])

# early return for identity operation. Do not copy. Also handle tricky 0-dim case.
if ndims_codomain(ft) == first(blocklengths(biperm)) # compile time
if Tuple(biperm) == ntuple(identity, ndims(ft))
return ft
# early return for identity operation. Also handle tricky 0-dim case.
if ndims_codomain(ftdst) == ndims_codomain(ftsrc) # compile time
if Tuple(biperm) == ntuple(identity, ndims(ftdst))
copy!(data_matrix(ftdst), data_matrix(ftsrc))
return ftdst
end
end
return permute_data!(SymmetryStyle(ftdst), ftdst, ftsrc, Tuple(biperm))
end

new_ft = FusionTensor{eltype(ft)}(undef, axes(ft)[biperm])
fusiontensor_permutedims!(new_ft, ft, Tuple(biperm))
return new_ft
# =============================== Internal =============================================
function permute_data!(::AbelianStyle, ftdst, ftsrc, flatperm)
# abelian case: all unitary blocks are 1x1 identity matrices
# compute_unitary is only called to get block positions
unitary = compute_unitary(ftdst, ftsrc, flatperm)
for ((old_trees, new_trees), _) in unitary
new_block = view(ftdst, new_trees...)
old_block = view(ftsrc, old_trees...)
@strided new_block .= permutedims(old_block, flatperm)
end
return ftdst
end

function fusiontensor_permutedims!(
new_ft::FusionTensor{T,N}, old_ft::FusionTensor{T,N}, flatperm::NTuple{N,Integer}
) where {T,N}
foreach(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(new_ft)))
unitary = compute_unitary(new_ft, old_ft, flatperm)
function permute_data!(::NotAbelianStyle, ftdst, ftsrc, flatperm)
foreach(m -> fill!(m, zero(eltype(ftdst))), eachstoredblock(data_matrix(ftdst)))
unitary = compute_unitary(ftdst, ftsrc, flatperm)
for ((old_trees, new_trees), coeff) in unitary
new_block = view(new_ft, new_trees...)
old_block = view(old_ft, old_trees...)
new_block = view(ftdst, new_trees...)
old_block = view(ftsrc, old_trees...)
@strided new_block .+= coeff .* permutedims(old_block, flatperm)
end
return ftdst
end
55 changes: 50 additions & 5 deletions test/test_permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using Test: @test, @testset, @test_broken, @test_throws
using BlockArrays: blocks

using FusionTensors:
FusionTensor,
FusionTensorAxes,
data_matrix,
codomain_axis,
domain_axis,
naive_permutedims,
ndims_domain,
ndims_codomain,
to_fusiontensor
Expand All @@ -15,27 +15,47 @@ using TensorAlgebra: permmortar, tuplemortar

include("setup.jl")

function naive_permutedims(ft, biperm)
@assert ndims(ft) == length(biperm)

# naive permute: cast to dense, permutedims, cast to FusionTensor
arr = Array(ft)
permuted_arr = permutedims(arr, Tuple(biperm))
permuted = to_fusiontensor(permuted_arr, blocks(axes(ft)[biperm])...)
return permuted
end

@testset "Abelian permutedims" begin
@testset "dummy" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
ftaxes1 = FusionTensorAxes((g1, g2), (dual(g3), dual(g4)))

for elt in (Float64, ComplexF64)
ft1 = FusionTensor{elt}(undef, (g1, g2), dual.((g3, g4)))
ft1 = randn(elt, ftaxes1)
@test isnothing(check_sanity(ft1))

# test permutedims interface
ft2 = permutedims(ft1, (1, 2), (3, 4)) # trivial with 2 tuples
@test ft2 === ft1 # same object
@test ft2 ≈ ft1
@test ft2 !== ft1
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
@test data_matrix(ft2) == data_matrix(ft1) # check copy

ft2 = permutedims(ft1, ((1, 2), (3, 4))) # trivial with tuple of 2 tuples
@test ft2 === ft1 # same object
@test ft2 ≈ ft1
@test ft2 !== ft1
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
@test data_matrix(ft2) == data_matrix(ft1) # check copy

biperm = permmortar(((1, 2), (3, 4)))
ft2 = permutedims(ft1, biperm) # trivial with biperm
@test ft2 === ft1 # same object
@test ft2 ≈ ft1
@test ft2 !== ft1
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
@test data_matrix(ft2) == data_matrix(ft1) # check copy

ft3 = permutedims(ft1, (4,), (1, 2, 3))
@test ft3 !== ft1
Expand All @@ -49,8 +69,33 @@ include("setup.jl")
@test space_isequal(domain_axis(ft1), domain_axis(ft4))
@test ft4 ≈ ft1

# test permutedims! interface
ft2 = randn(elt, axes(ft1))
permutedims!(ft2, ft1, (1, 2), (3, 4))
@test ft2 ≈ ft1
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
@test data_matrix(ft2) == data_matrix(ft1) # check copy

ft2 = randn(elt, axes(ft1))
permutedims!(ft2, ft1, ((1, 2), (3, 4)))
@test ft2 ≈ ft1
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
@test data_matrix(ft2) == data_matrix(ft1) # check copy

ft2 = randn(elt, axes(ft1))
permutedims!(ft2, ft1, biperm)
@test ft2 ≈ ft1
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
@test data_matrix(ft2) == data_matrix(ft1) # check copy

# test clean errors
ft2 = randn(elt, axes(ft1))
@test_throws MethodError permutedims(ft1, (2, 3, 4, 1))
@test_throws ArgumentError permutedims(ft1, (2, 3), (5, 4, 1))
@test_throws MethodError permutedims!(ft2, ft1, (2, 3, 4, 1))
@test_throws ArgumentError permutedims!(ft2, ft1, (2, 3), (5, 4, 1))
@test_throws ArgumentError permutedims!(ft2, ft1, (1, 2, 3), (4,))
@test_throws ArgumentError permutedims!(ft2, ft1, (1, 2), (4, 3))
end
end

Expand Down
Loading