From 284454ce37ac789d821fab006de6f975fa219b92 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 21 May 2024 17:07:01 +0200 Subject: [PATCH] Define matrix solves for `AbstractQ` (#54531) --- stdlib/LinearAlgebra/src/abstractq.jl | 43 ++++++++++++++++++++++++++ stdlib/LinearAlgebra/test/abstractq.jl | 26 +++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/abstractq.jl b/stdlib/LinearAlgebra/src/abstractq.jl index bf4064c907a2d..69561ef664c03 100644 --- a/stdlib/LinearAlgebra/src/abstractq.jl +++ b/stdlib/LinearAlgebra/src/abstractq.jl @@ -235,6 +235,7 @@ end ### division \(Q::AbstractQ, A::AbstractVecOrMat) = Q'*A /(A::AbstractVecOrMat, Q::AbstractQ) = A*Q' +/(Q::AbstractQ, A::AbstractVecOrMat) = Matrix(Q) / A ldiv!(Q::AbstractQ, A::AbstractVecOrMat) = lmul!(Q', A) ldiv!(C::AbstractVecOrMat, Q::AbstractQ, A::AbstractVecOrMat) = mul!(C, Q', A) rdiv!(A::AbstractVecOrMat, Q::AbstractQ) = rmul!(A, Q') @@ -522,6 +523,27 @@ rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, Q::HessenbergQ{T}) where {T} = lmul!(Q lmul!(adjQ::AdjointQ{<:Any,<:HessenbergQ{T}}, X::Adjoint{T,<:StridedVecOrMat{T}}) where {T} = rmul!(X', adjQ')' rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, adjQ::AdjointQ{<:Any,<:HessenbergQ{T}}) where {T} = lmul!(adjQ', X')' +# division by a matrix +function /(Q::Union{QRPackedQ,QRCompactWYQ,HessenbergQ}, B::AbstractVecOrMat) + size(B, 2) in size(Q.factors) || + throw(DimensionMismatch(lazy"second dimension of B, $(size(B,2)), must equal one of the dimensions of Q, $(size(Q.factors))")) + if size(B, 2) == size(Q.factors, 2) + return Matrix(Q) / B + else + return collect(Q) / B + end +end +function \(A::AbstractVecOrMat, adjQ::AdjointQ{<:Any,<:Union{QRPackedQ,QRCompactWYQ,HessenbergQ}}) + Q = adjQ.Q + size(A, 1) in size(Q.factors) || + throw(DimensionMismatch(lazy"first dimension of A, $(size(A,1)), must equal one of the dimensions of Q, $(size(Q.factors))")) + if size(A, 1) == size(Q.factors, 2) + return A \ Matrix(Q)' + else + return A \ collect(Q)' + end +end + # flexible left-multiplication (and adjoint right-multiplication) qsize_check(Q::Union{QRPackedQ,QRCompactWYQ,HessenbergQ}, B::AbstractVecOrMat) = size(B, 1) in size(Q.factors) || @@ -588,6 +610,27 @@ lmul!(adjA::AdjointQ{<:Any,<:LQPackedQ{T}}, B::StridedVecOrMat{T}) where {T<:Bla lmul!(adjA::AdjointQ{<:Any,<:LQPackedQ{T}}, B::StridedVecOrMat{T}) where {T<:BlasComplex} = (A = adjA.Q; LAPACK.ormlq!('L', 'C', A.factors, A.τ, B)) +# division by a matrix +function /(adjQ::AdjointQ{<:Any,<:LQPackedQ}, B::AbstractVecOrMat) + Q = adjQ.Q + size(B, 2) in size(Q.factors) || + throw(DimensionMismatch(lazy"second dimension of B, $(size(B,2)), must equal one of the dimensions of Q, $(size(Q.factors))")) + if size(B, 2) == size(Q.factors, 1) + return Matrix(Q)' / B + else + return collect(Q)' / B + end +end +function \(A::AbstractVecOrMat, Q::LQPackedQ) + size(A, 1) in size(Q.factors) || + throw(DimensionMismatch(lazy"first dimension of A, $(size(A,1)), must equal one of the dimensions of Q, $(size(Q.factors))")) + if size(A, 1) == size(Q.factors, 1) + return A \ Matrix(Q) + else + return A \ collect(Q) + end +end + # In LQ factorization, `Q` is expressed as the product of the adjoint of the # reflectors. Thus, `det` has to be conjugated. det(Q::LQPackedQ) = conj(_det_tau(Q.τ)) diff --git a/stdlib/LinearAlgebra/test/abstractq.jl b/stdlib/LinearAlgebra/test/abstractq.jl index 0eb88324e8c20..9ac522cf3917f 100644 --- a/stdlib/LinearAlgebra/test/abstractq.jl +++ b/stdlib/LinearAlgebra/test/abstractq.jl @@ -91,7 +91,7 @@ n = 5 @test Q * x ≈ Q.Q * x @test Q' * x ≈ Q.Q' * x end - A = rand(Float64, 5, 3) + A = randn(Float64, 5, 3) F = qr(A) Q = MyQ(F.Q) Prect = Matrix(F.Q) @@ -102,6 +102,30 @@ n = 5 @test Q ≈ Prect @test Q ≈ Psquare @test Q ≈ F.Q*I + # matrix division + q, r = F + R = randn(Float64, 5, 5) + @test q / r ≈ Matrix(q) / r + @test_throws DimensionMismatch MyQ(q) / r # doesn't have size flexibility + @test q / R ≈ collect(q) / R + @test copy(r') \ q' ≈ (q / r)' + @test_throws DimensionMismatch copy(r') \ MyQ(q') + @test r \ q' ≈ r \ Matrix(q)' + @test R \ q' ≈ R \ MyQ(q') ≈ R \ collect(q') + @test R \ q ≈ R \ MyQ(q) ≈ R \ collect(q) + B = copy(A') + G = lq(B) + l, q = G + L = R + @test l \ q ≈ l \ Matrix(q) + @test_throws DimensionMismatch l \ MyQ(q) + @test L \ q ≈ L \ collect(q) + @test q' / copy(l') ≈ (l \ q)' + @test_throws DimensionMismatch MyQ(q') / copy(l') + @test q' / l ≈ Matrix(q)' / l + @test q' / L ≈ MyQ(q') / L ≈ collect(q)' / L + @test q / L ≈ Matrix(q) / L + @test MyQ(q) / L ≈ collect(q) / L end end # module