Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 106 additions & 103 deletions src/HssMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,117 +5,120 @@
__precompile__()
module HssMatrices

# dependencies, trying to keep this list to a minimum if possible
using LinearAlgebra
using SparseArrays # introduce custom constructors from sparse matrices
using AbstractTrees
using Plots # in the future, move to RecipesBase instead
using LowRankApprox # optional for now , to replace prrqr with an efficient implementation
# dependencies, trying to keep this list to a minimum if possible
using LinearAlgebra
using SparseArrays # introduce custom constructors from sparse matrices
using AbstractTrees
using Plots # in the future, move to RecipesBase instead
using LowRankApprox # optional for now , to replace prrqr with an efficient implementation

# load BLAS/LAPACK routines only used within the library# load efficient BLAS and LAPACK routines for factorizations
import LinearAlgebra.LAPACK.geqlf!
import LinearAlgebra.LAPACK.gelqf!
import LinearAlgebra.LAPACK.ormql!
import LinearAlgebra.LAPACK.ormlq!
import LinearAlgebra.BLAS.ger!
import LinearAlgebra.BLAS.trsm
import LinearAlgebra.ishermitian
# load BLAS/LAPACK routines only used within the library# load efficient BLAS and LAPACK routines for factorizations
import LinearAlgebra.LAPACK.geqlf!
import LinearAlgebra.LAPACK.gelqf!
import LinearAlgebra.LAPACK.ormql!
import LinearAlgebra.LAPACK.ormlq!
import LinearAlgebra.BLAS.ger!
import LinearAlgebra.BLAS.trsm
import LinearAlgebra.ishermitian

# using InvertedIndices, DataStructures
import Base: +, -, *, \, /, ^, ==, !=, copy, size, getindex, adjoint, transpose, convert, Matrix
import LinearAlgebra: ldiv!, rdiv!, mul!
# using InvertedIndices, DataStructures
import Base: +, -, *, \, /, ^, ==, !=, copy, size, getindex, adjoint, transpose, convert, Matrix
import LinearAlgebra: ldiv!, rdiv!, mul!

# HssMatrices.jl
export HssOptions
# binarytree.jl
export BinaryNode, leftchild, rightchild, isleaf, isbranch, depth, descendants, compatible, nleaves, prune_leaves!
# clustertree.jl
export bisection_cluster, cluster
# linearmap.jl
export LinearMap, HermitianLinearMap
# hssmatrix.jl
export HssMatrix, hss, isleaf, isbranch, ishss, hssrank, full, checkdims, prune_leaves!, blkdiagm
# compression.jl
export compress, randcompress, randcompress_adaptive, recompress!, hss_blkdiag
# generators.jl
export generators, orthonormalize_generators!
# matmul.jl
# ulvfactor.jl
export ulvfactsolve
# ulvdivide.jl
export ldiv!, rdiv!
# constructors.jl
export lowrank2hss
# visualization.jl
export plotranks, pcolor
# HssMatrices.jl
export HssOptions
# binarytree.jl
export BinaryNode, leftchild, rightchild, isleaf, isbranch, depth, descendants, compatible, nleaves, prune_leaves!
# clustertree.jl
export bisection_cluster, cluster
# linearmap.jl
export LinearMap, HermitianLinearMap
# hssmatrix.jl
export HssMatrix, hss, isleaf, isbranch, ishss, hssrank, full, checkdims, prune_leaves!, blkdiagm
# compression.jl
export compress, randcompress, randcompress_adaptive, recompress!, hss_blkdiag
# generators.jl
export generators, orthonormalize_generators!
# matmul.jl
# ulvfactor.jl
export ulvfactsolve
# ulvdivide.jl
export ldiv!, rdiv!
# constructors.jl
export lowrank2hss
# visualization.jl
export plotranks, pcolor

mutable struct HssOptions
atol::Float64
rtol::Float64
leafsize::Int
noversampling::Int
stepsize::Int
verbose::Bool
end

# set default values
atol = 1e-9
rtol = 1e-9
leafsize = 64
noversampling = 10
stepsize = 20
verbose = false
mutable struct HssOptions
atol::Float64
rtol::Float64
leafsize::Int
noversampling::Int
stepsize::Int
verbose::Bool
multithreaded::Bool
end

# set default values
atol = 1e-9
rtol = 1e-9
leafsize = 64
noversampling = 10
stepsize = 20
verbose = false
multithreaded = false

# struct to pass around with options
function HssOptions(::Type{T}; args...) where T
opts = HssOptions(atol, rtol, leafsize, noversampling, stepsize, verbose)
for (key, value) in args
setfield!(opts, key, value)
end
opts
# struct to pass around with options
function HssOptions(::Type{T}; args...) where {T}
opts = HssOptions(atol, rtol, leafsize, noversampling, stepsize, verbose, multithreaded)
for (key, value) in args
setfield!(opts, key, value)
end
HssOptions(; args...) = HssOptions(Float64; args...)

function copy(opts::HssOptions; args...)
opts_ = HssOptions()
for field in fieldnames(typeof(opts))
setfield!(opts_, field, getfield(opts, field))
end
for (key, value) in args
setfield!(opts_, key, value)
end
opts_
opts
end
HssOptions(; args...) = HssOptions(Float64; args...)

function copy(opts::HssOptions; args...)
opts_ = HssOptions()
for field in fieldnames(typeof(opts))
setfield!(opts_, field, getfield(opts, field))
end

function chkopts!(opts::HssOptions)
opts.atol ≥ 0. || throw(ArgumentError("atol"))
opts.rtol ≥ 0. || throw(ArgumentError("rtol"))
opts.leafsize ≥ 1 || throw(ArgumentError("leafsize"))
opts.stepsize ≥ 1 || throw(ArgumentError("stepsize"))
opts.noversampling ≥ 1 || throw(ArgumentError("noversampling"))
for (key, value) in args
setfield!(opts_, key, value)
end
opts_
end

function setopts(opts=HssOptions(Float64); args...)
opts = copy(opts; args...)
chkopts!(opts)
global atol = opts.atol
global rtol = opts.rtol
global leafsize = opts.leafsize
global noversampling = opts.noversampling
global stepsize = opts.stepsize
global verbose = opts.verbose
return opts
end
function chkopts!(opts::HssOptions)
opts.atol ≥ 0. || throw(ArgumentError("atol"))
opts.rtol ≥ 0. || throw(ArgumentError("rtol"))
opts.leafsize ≥ 1 || throw(ArgumentError("leafsize"))
opts.stepsize ≥ 1 || throw(ArgumentError("stepsize"))
opts.noversampling ≥ 1 || throw(ArgumentError("noversampling"))
end

include("binarytree.jl")
include("clustertree.jl")
include("linearmap.jl")
include("hssmatrix.jl")
include("compression.jl")
include("generators.jl")
include("matmul.jl")
include("ulvfactor.jl")
include("ulvdivide.jl")
include("constructors.jl")
include("visualization.jl")
function setopts(opts=HssOptions(Float64); args...)
opts = copy(opts; args...)
chkopts!(opts)
global atol = opts.atol
global rtol = opts.rtol
global leafsize = opts.leafsize
global noversampling = opts.noversampling
global stepsize = opts.stepsize
global verbose = opts.verbose
global multithreaded = opts.multithreaded
return opts
end
include("recursiontools.jl")
include("binarytree.jl")
include("clustertree.jl")
include("linearmap.jl")
include("hssmatrix.jl")
include("compression.jl")
include("generators.jl")
include("matmul.jl")
include("ulvfactor.jl")
include("ulvdivide.jl")
include("constructors.jl")
include("visualization.jl")
end
48 changes: 28 additions & 20 deletions src/compression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ function _compress!(A::Matrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, rcl::Cluster
R, Brow = _compress_block!(Brow, atol, rtol)
R1 = R[1:rm1, :]
R2 = R[rm1+1:end, :]

X = copy(Bcol')
W, Bcol = _compress_block!(copy(Bcol'), atol, rtol);
Bcol = copy(Bcol')
W1 = W[1:rn1, :]
W2 = W[rn1+1:end, :]

hssA = HssMatrix(A11, A22, B12, B21, R1, W1, R2, W2, rootnode)
else
hssA = HssMatrix(A11, A22, B12, B21, rootnode)
Expand All @@ -148,7 +148,9 @@ julia> hssA = recompress!(hssA, atol=1e-3, rtol=1e-3)
function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args...) where T
opts = copy(opts; args...)
chkopts!(opts)
rtol = opts.rtol; atol = opts.atol;
rtol = opts.rtol
atol = opts.atol
multithreaded = opts.multithreaded

if isleaf(hssA); return hssA; end

Expand Down Expand Up @@ -192,17 +194,18 @@ function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args...
end

# call recompression recursively
context = RecursionTools.RecursionContext(multithreaded)
if isbranch(hssA.A11)
_recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol)
_recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol, context)
end
if isbranch(hssA.A22)
_recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol)
_recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol, context)
end

return hssA
end

function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol) where T
function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol, context) where {T}
# compress B12
Brow1 = [hssA.B12 hssA.R1*Brow]
Bcol2 = [hssA.B12' hssA.W2*Bcol]
Expand Down Expand Up @@ -245,20 +248,25 @@ function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol
end

# call recompression recursively
if isbranch(hssA.A11)
Brow1 = hcat(hssA.B12, S2[:,rn2+1:end])
Bcol1 = hcat(hssA.B21', T2[:,rm2+1:end])
_recompress!(hssA.A11, Brow1, Bcol1, atol, rtol)
end
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_recursive_compression_A11, (hssA, S2, T2, rn2, rm2, atol, rtol, newContext), context)
if isbranch(hssA.A22)
Brow2 = hcat(hssA.B21, S1[:,rn1+1:end])
Bcol2 = hcat(hssA.B12', T1[:,rm1+1:end])
_recompress!(hssA.A22, Brow2, Bcol2, atol, rtol)
Brow2 = hcat(hssA.B21, S1[:, rn1+1:end])
Bcol2 = hcat(hssA.B12', T1[:, rm1+1:end])
_recompress!(hssA.A22, Brow2, Bcol2, atol, rtol, newContext)
end

RecursionTools.wait(task)
return hssA
end

function _recursive_compression_A11(hssA, S2, T2, rn2, rm2, atol, rtol, context)
if isbranch(hssA.A11)
Brow1 = hcat(hssA.B12, S2[:, rn2+1:end])
Bcol1 = hcat(hssA.B21', T2[:, rm2+1:end])
_recompress!(hssA.A11, Brow1, Bcol1, atol, rtol, context)
end
end

"""
randcompress(A, rcl, ccl, kest; args...)

Expand All @@ -285,8 +293,8 @@ function randcompress(A::AbstractMatOrLinOp{T}, rcl::ClusterTree, ccl::ClusterTr

# compute initial sampling
k = kest; r = opts.noversampling;
Ωcol = randn(n, k+r)
Ωrow = randn(m, k+r)
Ωcol = randn(T,n, k+r)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the inclusion of T here correct? I am not sure how the sampling matrix in randomized SVD needs to be for complex numbers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I will check that.

Ωrow = randn(T,m, k+r)
Scol = A*Ωcol # this should invoke the magic of the linearoperator.jl type
Srow = A'*Ωrow
hssA, _, _, _, _, _, _, _, _ = _randcompress!(hssA, A, Scol, Srow, Ωcol, Ωrow, 0, 0, opts.atol, opts.rtol; rootnode=true)
Expand Down Expand Up @@ -379,7 +387,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr
m1, n1 = hssA.sz1; m2, n2 = hssA.sz2
hssA.A11, Scol1, Srow1, Ωcol1, Ωrow1, Jcol1, Jrow1, U1, V1 = _randcompress!(hssA.A11, A, Scol[1:m1, :], Srow[1:n1, :], Ωcol[1:n1, :], Ωrow[1:m1, :], ro, co, atol, rtol)
hssA.A22, Scol2, Srow2, Ωcol2, Ωrow2, Jcol2, Jrow2, U2, V2 = _randcompress!(hssA.A22, A, Scol[m1+1:end, :], Srow[n1+1:end, :], Ωcol[n1+1:end, :], Ωrow[m1+1:end, :], ro+m1, co+n1, atol, rtol)

# update the sampling matrix based on the extracted generators
Ωcol2 = V2' * Ωcol2
Ωcol1 = V1' * Ωcol1
Expand Down Expand Up @@ -416,7 +424,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr
Scol = Scol[Jcolloc, :]
Jcol = Jcol[Jcolloc]
U = [hssA.R1; hssA.R2]

# take care of the rows
Xrow, Jrowloc = _interpolate(Srow', atol, rtol)
hssA.W1 = Xrow[:, 1:size(Srow1, 1)]'
Expand Down Expand Up @@ -447,7 +455,7 @@ end
function hss_blkdiag(A::AbstractMatOrLinOp{T}, rcl::ClusterTree, ccl::ClusterTree; rootnode=true) where T
m = length(rcl.data); n = length(ccl.data)
if isleaf(rcl) # only check row cluster as we have already checked cluster equality

D = convert(Matrix{T}, A[rcl.data, ccl.data])
if rootnode
hssA = HssMatrix(D)
Expand Down
Loading
Loading