Skip to content

Derived type for matrix factorizations #1040

@loiseaujc

Description

@loiseaujc

Motivation

Hej,

While working on the side on LightConvex, I've been experimenting with high-level representation of factorized matrices using derived-types. Basically, a derived-type encapsulating all of the different factors and a bunch of functions/subroutines for standard blas-2 and blas-3 operations. Here is an example with the QR factorization.

   !> Matrix Factorization.
   type, abstract :: AbstractMatrixFactorization
   end type AbstractMatrixFactorization

   !> QR factorization.
   type, extends(AbstractMatrixFactorization), public :: qr_type
      private
      real(dp), allocatable :: data(:, :)
      !! Lapack storage of the QR factorization.
      real(dp), allocatable :: t(:)
      !! Data structuted used to represent Q in geqr.
      integer(ilp) :: tsize
      !! Dimension of the array T.
   end type

   pure type(qr_type) module function qrfact(A) result(F)
      implicit none(external)
      real(dp), intent(in) :: A(:, :)
      !! Matrix to be factorized.
   end function qrfact

   interface famv
      pure module subroutine qrmv(A, x, y, trans)
         implicit none(external)
         type(qr_type), intent(in) :: A
         real(dp), intent(in) :: x(:)
         real(dp), intent(out), target, contiguous :: y(:)
         character(len=*), intent(in), optional :: trans
      end subroutine qrmv
   end interface famv

and the procedures

   module procedure qrfact
   integer(ilp) :: m, n, lda, lwork, info, tsize
   real(dp), allocatable :: t(:), work(:)

   !> Matrix dimension.
   m = size(A, 1); n = size(A, 2); lda = m

   !> Initialize derived-type.
   F%data = A

   !> Workspace query.
   lwork = -1; tsize = -1
   allocate (work(1), source=0.0_dp); allocate (t(1), source=0.0_dp)
   call geqr(m, n, F%data, m, t, tsize, work, lwork, info)

   !> QR factorization.
   lwork = work(1); deallocate (work); allocate (work(lwork), source=0.0_dp)
   tsize = t(1); deallocate (t); allocate (t(tsize), source=0.0_dp)
   call geqr(m, n, F%data, m, t, tsize, work, lwork, info)

   !> Store additional data in derived type.
   F%t = t; F%tsize = tsize
   end procedure qrfact

   module procedure qrmv
   character(len=1), parameter :: side = "L"
   integer(ilp), parameter :: nc = 1
   character(len=1) :: trans_
   integer(ilp) :: ma, na, mc
   integer(ilp) :: lda, ldc, lwork, info
   real(dp), allocatable :: work(:)
   real(dp), pointer :: ymat(:, :)

   ! By default A is not transposed.
   trans_ = optval(trans, "N")

   ma = size(A%data, 1); na = size(A%data, 2); mc = size(y)

   select case (trans_)
   case ("N")
      y(:na) = x(:na)
      !> y = R @ x
      call trmv("u", trans_, "n", na, A%data, ma, y, 1)
      !----- y = Q @ (R @ x) -----
      !> Pointer trick.
      ymat(1:na, 1:1) => y(:na)
      !> Workspace query.
      lwork = -1; allocate (work(1), source=0.0_dp)
      call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, ymat, mc, work, lwork, info)
      !> Actual matrix-vector product.
      lwork = work(1); deallocate (work); allocate (work(lwork), source=0.0_dp)
      call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, ymat, mc, work, lwork, info)
   case ("T")
      block
         real(dp), allocatable, target :: x_tmp(:)
         real(dp), pointer :: xmat(:, :)
         !----- y = Q.T @ x -----
         !> Pointer trick.
         x_tmp = x; xmat(1:ma, 1:1) => x_tmp
         !> Workspace query.
         lwork = -1; allocate (work(1), source=0.0_dp)
         call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, xmat, ma, work, lwork, info)
         !> Actual matrix-vector product.
         lwork = work(1); deallocate (work); allocate (work(lwork), source=0.0_dp)
         call gemqr(side, trans_, ma, nc, ma, A%data, ma, A%t, A%tsize, xmat, ma, work, lwork, info)
         !----- y = R.T @ (Q.T @ x) -----
         y = x_tmp(:na)
         call trmv("u", trans_, "n", na, A%data, ma, y, 1)
      end block
   end select
   end procedure

Additionally, type-bound procedures could also be defined to extract the different factors, e.g.

! Factorize matrix.
F = qrfact(A)
! Extract the Q factor.
Q = F%Q()
! Extract the R matrix.
R = F%R()

Note that I'am not storing the actual Q matrix inside the derived-type as lapack has its own efficient format for QR and specialized subroutines leveraging this format for matrix-vector and matrix-matrix operations.

Prior Art

Julia defines similar derived-types (struct) to store matrix factorizations. As far as I know, there are no equivalent in Python. I do not know enough of the C and C++ ecosystems to have meaningful references here.

Additional Information

ping : @perazz, @jvdp1, @jalvesz

Metadata

Metadata

Assignees

No one assigned

    Labels

    ideaProposition of an idea and opening an issue to discuss it

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions