Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve stack-overflow in Diagonal*Bidiagonal and (Sym)Tridiagonal #242

Merged
merged 8 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArrayLayouts"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
authors = ["Sheehan Olver <[email protected]>"]
version = "1.10.3"
version = "1.10.4"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
36 changes: 30 additions & 6 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,36 @@


## bi/tridiagonal copy
copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout}) = convert(Bidiagonal, M.A) * M.B
copy(M::Lmul{<:DiagonalLayout,<:BidiagonalLayout}) = M.A * convert(Bidiagonal, M.B)
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout}) = convert(Tridiagonal, M.A) * M.B
copy(M::Lmul{<:DiagonalLayout,<:TridiagonalLayout}) = M.A * convert(Tridiagonal, M.B)
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout}) = convert(SymTridiagonal, M.A) * M.B
copy(M::Lmul{<:DiagonalLayout,<:SymTridiagonalLayout}) = M.A * convert(SymTridiagonal, M.B)
# hack around the fact that a SymTridiagonal isn't fully mutable
_similar(A) = similar(A)
_similar(A::SymTridiagonal) = similar(Tridiagonal(A.ev, A.dv, A.ev))
_copy_diag(M::T, ::T) where {T<:Rmul} = copyto!(_similar(M.A), M)
_copy_diag(M::T, ::T) where {T<:Lmul} = copyto!(_similar(M.B), M)
_copy_diag(M, _) = copy(M)
function copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout})

Check warning on line 66 in src/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/diagonal.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
A = convert(Bidiagonal, M.A)
_copy_diag(Rmul(A, M.B), M)
end
function copy(M::Lmul{<:DiagonalLayout,<:BidiagonalLayout})

Check warning on line 70 in src/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/diagonal.jl#L70

Added line #L70 was not covered by tests
B = convert(Bidiagonal, M.B)
_copy_diag(Lmul(M.A, B), M)
end
function copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout})

Check warning on line 74 in src/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/diagonal.jl#L74

Added line #L74 was not covered by tests
A = convert(Tridiagonal, M.A)
_copy_diag(Rmul(A, M.B), M)
end
function copy(M::Lmul{<:DiagonalLayout,<:TridiagonalLayout})

Check warning on line 78 in src/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/diagonal.jl#L78

Added line #L78 was not covered by tests
B = convert(Tridiagonal, M.B)
_copy_diag(Lmul(M.A, B), M)
end
function copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout})

Check warning on line 82 in src/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/diagonal.jl#L82

Added line #L82 was not covered by tests
A = convert(SymTridiagonal, M.A)
_copy_diag(Rmul(A, M.B), M)
end
function copy(M::Lmul{<:DiagonalLayout,<:SymTridiagonalLayout})

Check warning on line 86 in src/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/diagonal.jl#L86

Added line #L86 was not covered by tests
B = convert(SymTridiagonal, M.B)
_copy_diag(Lmul(M.A, B), M)
end

copy(M::Lmul{DiagonalLayout{OnesLayout}}) = _copy_oftype(M.B, eltype(M))
copy(M::Lmul{DiagonalLayout{OnesLayout},<:DiagonalLayout}) = Diagonal(_copy_oftype(diagonaldata(M.B), eltype(M)))
Expand Down
83 changes: 54 additions & 29 deletions test/test_layoutarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,40 @@ module TestLayoutArray
using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays, Random
using ArrayLayouts: sub_materialize, MemoryLayout, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul, Mul
import ArrayLayouts: triangulardata
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal

struct MyMatrix <: LayoutMatrix{Float64}
A::Matrix{Float64}
struct MyMatrix{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
A::M
end

Base.getindex(A::MyMatrix, k::Int, j::Int) = A.A[k,j]
Base.setindex!(A::MyMatrix, v, k::Int, j::Int) = setindex!(A.A, v, k, j)
Base.size(A::MyMatrix) = size(A.A)
Base.strides(A::MyMatrix) = strides(A.A)
Base.elsize(::Type{MyMatrix}) = sizeof(Float64)
Base.cconvert(::Type{Ptr{Float64}}, A::MyMatrix) = A.A
Base.unsafe_convert(::Type{Ptr{Float64}}, A::MyMatrix) = Base.unsafe_convert(Ptr{Float64}, A.A)
MemoryLayout(::Type{MyMatrix}) = DenseColumnMajor()
Base.elsize(::Type{<:MyMatrix{T}}) where {T} = sizeof(T)
Base.cconvert(::Type{Ptr{T}}, A::MyMatrix{T}) where {T} = Base.cconvert(Ptr{T}, A.A)
Base.unsafe_convert(::Type{Ptr{T}}, A::MyMatrix{T}) where {T} = Base.unsafe_convert(Ptr{T}, A.A)
MemoryLayout(::Type{MyMatrix{T,M}}) where {T,M} = MemoryLayout(M)
Base.copy(A::MyMatrix) = MyMatrix(copy(A.A))
ArrayLayouts.bidiagonaluplo(M::MyMatrix) = ArrayLayouts.bidiagonaluplo(M.A)
for MT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal)
@eval $MT(M::MyMatrix) = $MT(M.A)
end

struct MyVector{T} <: LayoutVector{T}
A::Vector{T}
struct MyVector{T,V<:AbstractVector{T}} <: LayoutVector{T}
A::V
end

MyVector(M::MyVector) = MyVector(M.A)
Base.getindex(A::MyVector, k::Int) = A.A[k]
Base.setindex!(A::MyVector, v, k::Int) = setindex!(A.A, v, k)
Base.size(A::MyVector) = size(A.A)
Base.strides(A::MyVector) = strides(A.A)
Base.elsize(::Type{MyVector}) = sizeof(Float64)
Base.cconvert(::Type{Ptr{T}}, A::MyVector{T}) where {T} = A.A
Base.elsize(::Type{<:MyVector{T}}) where {T} = sizeof(T)
Base.cconvert(::Type{Ptr{T}}, A::MyVector{T}) where {T} = Base.cconvert(Ptr{T}, A.A)
Base.unsafe_convert(::Type{Ptr{T}}, A::MyVector{T}) where T = Base.unsafe_convert(Ptr{T}, A.A)
MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
MemoryLayout(::Type{MyVector{T,V}}) where {T,V} = MemoryLayout(V)
Base.copy(A::MyVector) = MyVector(copy(A.A))

# These need to test dispatch reduces to ArrayLayouts.mul, etc.
@testset "LayoutArray" begin
Expand All @@ -42,7 +49,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
@test a[1:3] == a.A[1:3]
@test a[:] == a
@test (a')[1,:] == (a')[1,1:3] == a
@test sprint(show, "text/plain", a) == "3-element $MyVector{Float64}:\n 1.0\n 2.0\n 3.0"
@test sprint(show, "text/plain", a) == "$(summary(a)):\n 1.0\n 2.0\n 3.0"
@test B*a ≈ B*a.A
@test B'*a ≈ B'*a.A
@test transpose(B)*a ≈ transpose(B)*a.A
Expand Down Expand Up @@ -104,8 +111,8 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
@test_throws ErrorException qr!(A)
@test lu!(copy(A)).factors ≈ lu(A.A).factors
b = randn(5)
@test A \ b == A.A \ b == A.A \ MyVector(b) == ldiv!(lu(A.A), copy(MyVector(b)))
@test A \ b == ldiv!(lu(A), copy(MyVector(b))) == ldiv!(lu(A), copy(b))
@test A \ b == A.A \ b == A.A \ MyVector(b) == ldiv!(lu(A.A), copy(b))
@test A \ b == ldiv!(lu(A), copy(b))
@test lu(A).L == lu(A.A).L
@test lu(A).U == lu(A.A).U
@test lu(A).p == lu(A.A).p
Expand All @@ -120,7 +127,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
@test cholesky!(deepcopy(S), CRowMaximum()).U ≈ cholesky(Matrix(S), CRowMaximum()).U
@test cholesky(S) \ b ≈ cholesky(Matrix(S)) \ b ≈ cholesky(Matrix(S)) \ MyVector(b)
@test cholesky(S, CRowMaximum()) \ b ≈ cholesky(Matrix(S), CRowMaximum()) \ b
@test cholesky(S, CRowMaximum()) \ b ≈ ldiv!(cholesky(Matrix(S), CRowMaximum()), copy(MyVector(b)))
@test cholesky(S, CRowMaximum()) \ b ≈ ldiv!(cholesky(Matrix(S), CRowMaximum()), copy(b))
@test cholesky(S) \ b ≈ Matrix(S) \ b ≈ Symmetric(Matrix(S)) \ b
@test cholesky(S) \ b ≈ Symmetric(Matrix(S)) \ MyVector(b)
if VERSION >= v"1.9"
Expand All @@ -140,18 +147,9 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()

@testset "ldiv!" begin
c = MyVector(randn(5))
if VERSION < v"1.9"
@test_broken ldiv!(lu(A), MyVector(copy(c))) ≈ A \ c
else
@test ldiv!(lu(A), MyVector(copy(c))) ≈ A \ c
end
if VERSION < v"1.9" || VERSION >= v"1.10-"
@test_throws ErrorException ldiv!(qr(A), MyVector(copy(c)))
else
@test_throws MethodError ldiv!(qr(A), MyVector(copy(c)))
end
@test ldiv!(lu(A), MyVector(copy(c))) ≈ A \ c
@test_throws ErrorException ldiv!(eigen(randn(5,5)), c)
@test ArrayLayouts.ldiv!(svd(A.A), copy(c)) ≈ ArrayLayouts.ldiv!(similar(c), svd(A.A), c) ≈ A \ c
@test ArrayLayouts.ldiv!(svd(A.A), Vector(c)) ≈ ArrayLayouts.ldiv!(similar(c), svd(A.A), c) ≈ A \ c
if VERSION ≥ v"1.8"
@test ArrayLayouts.ldiv!(similar(c), transpose(lu(A.A)), copy(c)) ≈ A'\c
end
Expand Down Expand Up @@ -213,8 +211,8 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
@test B == Ones(5,5)*A + 2.0Bin
end

C = MyMatrix([1 2; 3 4])
@test sprint(show, "text/plain", C) == "2×2 $MyMatrix:\n 1.0 2.0\n 3.0 4.0"
C = MyMatrix(Float64[1 2; 3 4])
@test sprint(show, "text/plain", C) == "$(summary(C)):\n 1.0 2.0\n 3.0 4.0"

@testset "layoutldiv" begin
A = MyMatrix(randn(5,5))
Expand Down Expand Up @@ -337,6 +335,33 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
@test B̃\D ≈ B̃\Matrix(D)
@test D\D̃ ≈ D̃\D
@test B̃/D ≈ B̃/Matrix(D)

@testset "Diagonal * Bidiagonal/Tridiagonal with structured diags" begin
n = size(D,1)
B = Bidiagonal(map(MyVector, (rand(n), rand(n-1)))..., :U)
MB = MyMatrix(B)
S = SymTridiagonal(map(MyVector, (rand(n), rand(n-1)))...)
MS = MyMatrix(S)
T = Tridiagonal(map(MyVector, (rand(n-1), rand(n), rand(n-1)))...)
MT = MyMatrix(T)
DA, BA, SA, TA = map(Array, (D, B, S, T))
if VERSION >= v"1.11"
@test D * B ≈ DA * BA
@test B * D ≈ BA * DA
@test D * MB ≈ DA * BA
@test MB * D ≈ BA * DA
end
if VERSION >= v"1.12.0-DEV.824"
@test D * S ≈ DA * SA
@test D * MS ≈ DA * SA
@test D * T ≈ DA * TA
@test D * MT ≈ DA * TA
@test S * D ≈ SA * DA
@test MS * D ≈ SA * DA
@test T * D ≈ TA * DA
@test MT * D ≈ TA * DA
end
end
end

@testset "Adj/Trans" begin
Expand Down Expand Up @@ -400,7 +425,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()

@testset "dot" begin
a = MyVector(randn(5))
@test dot(a, Zeros(5)) ≡ dot(Zeros(5), a) 0.0
@test dot(a, Zeros(5)) ≡ dot(Zeros(5), a) == 0.0
end

@testset "layout_getindex scalar" begin
Expand Down
Loading