Skip to content

Commit

Permalink
Define matrix solves for AbstractQ (JuliaLang#54531)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored and lazarusA committed Jul 12, 2024
1 parent 3bc4850 commit 284454c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
43 changes: 43 additions & 0 deletions stdlib/LinearAlgebra/src/abstractq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ end
### division
\(Q::AbstractQ, A::AbstractVecOrMat) = Q'*A
/(A::AbstractVecOrMat, Q::AbstractQ) = A*Q'
/(Q::AbstractQ, A::AbstractVecOrMat) = Matrix(Q) / A
ldiv!(Q::AbstractQ, A::AbstractVecOrMat) = lmul!(Q', A)
ldiv!(C::AbstractVecOrMat, Q::AbstractQ, A::AbstractVecOrMat) = mul!(C, Q', A)
rdiv!(A::AbstractVecOrMat, Q::AbstractQ) = rmul!(A, Q')
Expand Down Expand Up @@ -522,6 +523,27 @@ rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, Q::HessenbergQ{T}) where {T} = lmul!(Q
lmul!(adjQ::AdjointQ{<:Any,<:HessenbergQ{T}}, X::Adjoint{T,<:StridedVecOrMat{T}}) where {T} = rmul!(X', adjQ')'
rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, adjQ::AdjointQ{<:Any,<:HessenbergQ{T}}) where {T} = lmul!(adjQ', X')'

# division by a matrix
function /(Q::Union{QRPackedQ,QRCompactWYQ,HessenbergQ}, B::AbstractVecOrMat)
size(B, 2) in size(Q.factors) ||
throw(DimensionMismatch(lazy"second dimension of B, $(size(B,2)), must equal one of the dimensions of Q, $(size(Q.factors))"))
if size(B, 2) == size(Q.factors, 2)
return Matrix(Q) / B
else
return collect(Q) / B
end
end
function \(A::AbstractVecOrMat, adjQ::AdjointQ{<:Any,<:Union{QRPackedQ,QRCompactWYQ,HessenbergQ}})
Q = adjQ.Q
size(A, 1) in size(Q.factors) ||
throw(DimensionMismatch(lazy"first dimension of A, $(size(A,1)), must equal one of the dimensions of Q, $(size(Q.factors))"))
if size(A, 1) == size(Q.factors, 2)
return A \ Matrix(Q)'
else
return A \ collect(Q)'
end
end

# flexible left-multiplication (and adjoint right-multiplication)
qsize_check(Q::Union{QRPackedQ,QRCompactWYQ,HessenbergQ}, B::AbstractVecOrMat) =
size(B, 1) in size(Q.factors) ||
Expand Down Expand Up @@ -588,6 +610,27 @@ lmul!(adjA::AdjointQ{<:Any,<:LQPackedQ{T}}, B::StridedVecOrMat{T}) where {T<:Bla
lmul!(adjA::AdjointQ{<:Any,<:LQPackedQ{T}}, B::StridedVecOrMat{T}) where {T<:BlasComplex} =
(A = adjA.Q; LAPACK.ormlq!('L', 'C', A.factors, A.τ, B))

# division by a matrix
function /(adjQ::AdjointQ{<:Any,<:LQPackedQ}, B::AbstractVecOrMat)
Q = adjQ.Q
size(B, 2) in size(Q.factors) ||
throw(DimensionMismatch(lazy"second dimension of B, $(size(B,2)), must equal one of the dimensions of Q, $(size(Q.factors))"))
if size(B, 2) == size(Q.factors, 1)
return Matrix(Q)' / B
else
return collect(Q)' / B
end
end
function \(A::AbstractVecOrMat, Q::LQPackedQ)
size(A, 1) in size(Q.factors) ||
throw(DimensionMismatch(lazy"first dimension of A, $(size(A,1)), must equal one of the dimensions of Q, $(size(Q.factors))"))
if size(A, 1) == size(Q.factors, 1)
return A \ Matrix(Q)
else
return A \ collect(Q)
end
end

# In LQ factorization, `Q` is expressed as the product of the adjoint of the
# reflectors. Thus, `det` has to be conjugated.
det(Q::LQPackedQ) = conj(_det_tau(Q.τ))
26 changes: 25 additions & 1 deletion stdlib/LinearAlgebra/test/abstractq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ n = 5
@test Q * x Q.Q * x
@test Q' * x Q.Q' * x
end
A = rand(Float64, 5, 3)
A = randn(Float64, 5, 3)
F = qr(A)
Q = MyQ(F.Q)
Prect = Matrix(F.Q)
Expand All @@ -102,6 +102,30 @@ n = 5
@test Q Prect
@test Q Psquare
@test Q F.Q*I
# matrix division
q, r = F
R = randn(Float64, 5, 5)
@test q / r Matrix(q) / r
@test_throws DimensionMismatch MyQ(q) / r # doesn't have size flexibility
@test q / R collect(q) / R
@test copy(r') \ q' (q / r)'
@test_throws DimensionMismatch copy(r') \ MyQ(q')
@test r \ q' r \ Matrix(q)'
@test R \ q' R \ MyQ(q') R \ collect(q')
@test R \ q R \ MyQ(q) R \ collect(q)
B = copy(A')
G = lq(B)
l, q = G
L = R
@test l \ q l \ Matrix(q)
@test_throws DimensionMismatch l \ MyQ(q)
@test L \ q L \ collect(q)
@test q' / copy(l') (l \ q)'
@test_throws DimensionMismatch MyQ(q') / copy(l')
@test q' / l Matrix(q)' / l
@test q' / L MyQ(q') / L collect(q)' / L
@test q / L Matrix(q) / L
@test MyQ(q) / L collect(q) / L
end

end # module

0 comments on commit 284454c

Please sign in to comment.