-
Notifications
You must be signed in to change notification settings - Fork 196
Description
Motivation
Hej,
While working on the side on LightConvex, I've been experimenting with high-level representation of factorized matrices using derived-types. Basically, a derived-type encapsulating all of the different factors and a bunch of functions/subroutines for standard blas-2 and blas-3 operations. Here is an example with the QR factorization.
!> Matrix Factorization.
type, abstract :: AbstractMatrixFactorization
end type AbstractMatrixFactorization
!> QR factorization.
type, extends(AbstractMatrixFactorization), public :: qr_type
private
real(dp), allocatable :: data(:, :)
!! Lapack storage of the QR factorization.
real(dp), allocatable :: t(:)
!! Data structuted used to represent Q in geqr.
integer(ilp) :: tsize
!! Dimension of the array T.
end type
pure type(qr_type) module function qrfact(A) result(F)
implicit none(external)
real(dp), intent(in) :: A(:, :)
!! Matrix to be factorized.
end function qrfact
interface famv
pure module subroutine qrmv(A, x, y, trans)
implicit none(external)
type(qr_type), intent(in) :: A
real(dp), intent(in) :: x(:)
real(dp), intent(out), target, contiguous :: y(:)
character(len=*), intent(in), optional :: trans
end subroutine qrmv
end interface famvand the procedures
module procedure qrfact
integer(ilp) :: m, n, lda, lwork, info, tsize
real(dp), allocatable :: t(:), work(:)
!> Matrix dimension.
m = size(A, 1); n = size(A, 2); lda = m
!> Initialize derived-type.
F%data = A
!> Workspace query.
lwork = -1; tsize = -1
allocate (work(1), source=0.0_dp); allocate (t(1), source=0.0_dp)
call geqr(m, n, F%data, m, t, tsize, work, lwork, info)
!> QR factorization.
lwork = work(1); deallocate (work); allocate (work(lwork), source=0.0_dp)
tsize = t(1); deallocate (t); allocate (t(tsize), source=0.0_dp)
call geqr(m, n, F%data, m, t, tsize, work, lwork, info)
!> Store additional data in derived type.
F%t = t; F%tsize = tsize
end procedure qrfact
module procedure qrmv
character(len=1), parameter :: side = "L"
integer(ilp), parameter :: nc = 1
character(len=1) :: trans_
integer(ilp) :: ma, na, mc
integer(ilp) :: lda, ldc, lwork, info
real(dp), allocatable :: work(:)
real(dp), pointer :: ymat(:, :)
! By default A is not transposed.
trans_ = optval(trans, "N")
ma = size(A%data, 1); na = size(A%data, 2); mc = size(y)
select case (trans_)
case ("N")
y(:na) = x(:na)
!> y = R @ x
call trmv("u", trans_, "n", na, A%data, ma, y, 1)
!----- y = Q @ (R @ x) -----
!> Pointer trick.
ymat(1:na, 1:1) => y(:na)
!> Workspace query.
lwork = -1; allocate (work(1), source=0.0_dp)
call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, ymat, mc, work, lwork, info)
!> Actual matrix-vector product.
lwork = work(1); deallocate (work); allocate (work(lwork), source=0.0_dp)
call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, ymat, mc, work, lwork, info)
case ("T")
block
real(dp), allocatable, target :: x_tmp(:)
real(dp), pointer :: xmat(:, :)
!----- y = Q.T @ x -----
!> Pointer trick.
x_tmp = x; xmat(1:ma, 1:1) => x_tmp
!> Workspace query.
lwork = -1; allocate (work(1), source=0.0_dp)
call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, xmat, ma, work, lwork, info)
!> Actual matrix-vector product.
lwork = work(1); deallocate (work); allocate (work(lwork), source=0.0_dp)
call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, xmat, ma, work, lwork, info)
!----- y = R.T @ (Q.T @ x) -----
y = x_tmp(:na)
call trmv("u", trans_, "n", na, A%data, ma, y, 1)
end block
end select
end procedureAdditionally, type-bound procedures could also be defined to extract the different factors, e.g.
! Factorize matrix.
F = qrfact(A)
! Extract the Q factor.
Q = F%Q()
! Extract the R matrix.
R = F%R()Note that I'am not storing the actual Q matrix inside the derived-type as lapack has its own efficient format for QR and specialized subroutines leveraging this format for matrix-vector and matrix-matrix operations.
Prior Art
Julia defines similar derived-types (struct) to store matrix factorizations. As far as I know, there are no equivalent in Python. I do not know enough of the C and C++ ecosystems to have meaningful references here.