From 818c3cdbf8fede9c1ae8c011d56cb885cf5879e6 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 9 Jan 2025 17:45:15 +0000 Subject: [PATCH] krontrav with Eye (#141) * krontrav with Eye * add tests * Update test_lazybandedinf.jl --- ext/LazyBandedMatricesInfiniteArraysExt.jl | 5 ++++- src/LazyBandedMatrices.jl | 4 ++-- src/blockkron.jl | 13 ++++++++----- test/test_blockkron.jl | 5 +++++ test/test_lazybandedinf.jl | 4 +++- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/ext/LazyBandedMatricesInfiniteArraysExt.jl b/ext/LazyBandedMatricesInfiniteArraysExt.jl index 66aca9b..ddbf1e7 100644 --- a/ext/LazyBandedMatricesInfiniteArraysExt.jl +++ b/ext/LazyBandedMatricesInfiniteArraysExt.jl @@ -4,7 +4,7 @@ using LazyBandedMatrices.BlockArrays using LazyBandedMatrices.ArrayLayouts import Base: BroadcastStyle, copy, OneTo, oneto -import LazyBandedMatrices: _krontrav_axes, _block_interlace_axes, _broadcast_sub_arguments, AbstractLazyBandedBlockBandedLayout, KronTravBandedBlockBandedLayout, krontravargs, DiagTravLayout +import LazyBandedMatrices: _krontrav_axes, _block_interlace_axes, _broadcast_sub_arguments, AbstractLazyBandedBlockBandedLayout, KronTravBandedBlockBandedLayout, krontravargs, DiagTravLayout, krontrav_materialize_layout, krontrav import InfiniteArrays: InfFill, TridiagonalToeplitzLayout, BidiagonalToeplitzLayout, LazyArrayStyle, OneToInf import LazyBandedMatrices.ArrayLayouts: MemoryLayout, sublayout, RangeCumsum, Mul import LazyBandedMatrices.BlockArrays: sizes_from_blocks, BlockedOneTo, BlockSlice1, BlockSlice @@ -58,4 +58,7 @@ _block_interlace_axes(nbc::Int, ax::NTuple{2,BlockedOneTo{Int,OneToInf{Int}}}... copy(M::Mul{InfKronTravBandedBlockBandedLayout, Lay}) where Lay<:DiagTravLayout{<:AbstractPaddedLayout} = copy(Mul{KronTravBandedBlockBandedLayout, Lay}(M.A, M.B)) +krontrav_materialize_layout(::InfKronTravBandedBlockBandedLayout, K) = K + + end \ No newline at end of file diff --git a/src/LazyBandedMatrices.jl b/src/LazyBandedMatrices.jl index d38908d..ae6d15d 100644 --- a/src/LazyBandedMatrices.jl +++ b/src/LazyBandedMatrices.jl @@ -1,6 +1,6 @@ module LazyBandedMatrices using ArrayLayouts: symmetriclayout -using BandedMatrices, BlockBandedMatrices, BlockArrays, LazyArrays, +using BandedMatrices, BlockBandedMatrices, BlockArrays, LazyArrays, FillArrays, ArrayLayouts, MatrixFactorizations, Base, StaticArrays, LinearAlgebra import Base: -, +, *, /, \, ==, AbstractMatrix, Matrix, Array, size, conj, real, imag, copy, copymutable, @@ -22,7 +22,7 @@ import BandedMatrices: AbstractBandedMatrix, BandedStyle, bandwidths, isbanded import BlockBandedMatrices: AbstractBlockBandedLayout, AbstractBandedBlockBandedLayout, BlockRange1, Block1, blockbandwidths, subblockbandwidths, BlockBandedStyle, BandedBlockBandedStyle, isblockbanded, isbandedblockbanded import BlockArrays: BlockSlices, BlockSlice1, BlockSlice, blockvec, AbstractBlockLayout, blockcolsupport, blockrowsupport, BlockLayout, block, blockindex, viewblock, AbstractBlockedUnitRange - +import FillArrays: SquareEye const LazyArraysBlockBandedMatricesExt = Base.get_extension(LazyArrays, :LazyArraysBlockBandedMatricesExt) diff --git a/src/blockkron.jl b/src/blockkron.jl index 95c27cb..4c511d3 100644 --- a/src/blockkron.jl +++ b/src/blockkron.jl @@ -391,11 +391,14 @@ function _krontrav_mul_diagtrav((A,B,C), X::AbstractArray{<:Any,3}, ::Type{T}) w DiagTrav(Y) end -kron_materialize_layout(_, K) = BlockedArray(K) -kron_materialize_layout(::AbstractBandedBlockBandedLayout, K) = BandedBlockBandedMatrix(K) -kron_materialize(K) = kron_materialize_layout(MemoryLayout(K), K) -krontrav(A...) = kron_materialize(KronTrav(A...)) - +krontrav_materialize_layout(_, K) = BlockedArray(K) +krontrav_materialize_layout(::AbstractBandedBlockBandedLayout, K) = BandedBlockBandedMatrix(K) +krontrav_materialize(K) = krontrav_materialize_layout(MemoryLayout(K), K) +krontrav(A...) = krontrav_materialize(KronTrav(A...)) +function krontrav(a::SquareEye{T}, b::SquareEye{V}) where {T,V} + size(a) == size(b) || throw(ArgumentError("size must match")) + SquareEye{promote_type(T,V)}((blockedrange(oneto(size(a,1))),)) +end # C = α*B*X*A' + β*C \ No newline at end of file diff --git a/test/test_blockkron.jl b/test/test_blockkron.jl index eda2b8d..3e589ea 100644 --- a/test/test_blockkron.jl +++ b/test/test_blockkron.jl @@ -353,6 +353,11 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data) Δ = BandedMatrix(1 => Ones(n-1), 0 => Fill(-2,n), -1 => Ones(n-1)) @test krontrav(Δ,Eye(n)) == KronTrav(Δ, Eye(n)) @test krontrav(Δ,Eye(n)) isa BandedBlockBandedMatrix + + @test krontrav(Eye(4), Eye(4)) isa Eye + @test krontrav(Eye(4), Eye(4)) == KronTrav(Eye(4), Eye(4)) + + @test_throws ArgumentError krontrav(Eye(4), Eye(5)) end end diff --git a/test/test_lazybandedinf.jl b/test/test_lazybandedinf.jl index 9f2da32..f3e0202 100644 --- a/test/test_lazybandedinf.jl +++ b/test/test_lazybandedinf.jl @@ -3,7 +3,7 @@ using InfiniteArrays: TridiagonalToeplitzLayout, BidiagonalToeplitzLayout, Tridi using Base: oneto using BlockArrays: blockcolsupport using LazyArrays: arguments, simplifiable -using LazyBandedMatrices: BroadcastBandedBlockBandedLayout +using LazyBandedMatrices: BroadcastBandedBlockBandedLayout, krontrav const InfiniteArraysBlockArraysExt = Base.get_extension(InfiniteArrays, :InfiniteArraysBlockArraysExt) const LazyBandedMatricesInfiniteArraysExt = Base.get_extension(LazyBandedMatrices, :LazyBandedMatricesInfiniteArraysExt) @@ -195,6 +195,8 @@ const InfKronTravBandedBlockBandedLayout = LazyBandedMatricesInfiniteArraysExt.I @test subblockbandwidths(A + B) == (1, 1) @test subblockbandwidths(2A) == (1, 1) @test subblockbandwidths(2 * (A + B)) == (1, 1) + + @test krontrav(Eye(∞), Eye(∞)) isa Eye end @testset "BlockTridiagonal" begin