Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
with:
Expand Down
15 changes: 9 additions & 6 deletions src/HssMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,28 @@ module HssMatrices
noversampling::Int
stepsize::Int
verbose::Bool
multithreaded::Bool
end

# set default values
atol = 1e-9
rtol = 1e-9
leafsize = 64
noversampling = 10
stepsize = 20
verbose = false
multithreaded = false

# struct to pass around with options
function HssOptions(::Type{T}; args...) where T
opts = HssOptions(atol, rtol, leafsize, noversampling, stepsize, verbose)
function HssOptions(::Type{T}; args...) where {T}
opts = HssOptions(atol, rtol, leafsize, noversampling, stepsize, verbose, multithreaded)
for (key, value) in args
setfield!(opts, key, value)
end
opts
end
HssOptions(; args...) = HssOptions(Float64; args...)

function copy(opts::HssOptions; args...)
opts_ = HssOptions()
for field in fieldnames(typeof(opts))
Expand All @@ -86,7 +88,7 @@ module HssMatrices
end
opts_
end

function chkopts!(opts::HssOptions)
opts.atol ≥ 0. || throw(ArgumentError("atol"))
opts.rtol ≥ 0. || throw(ArgumentError("rtol"))
Expand All @@ -104,9 +106,10 @@ module HssMatrices
global noversampling = opts.noversampling
global stepsize = opts.stepsize
global verbose = opts.verbose
global multithreaded = opts.multithreaded
return opts
end

include("recursiontools.jl")
include("binarytree.jl")
include("clustertree.jl")
include("linearmap.jl")
Expand Down
44 changes: 26 additions & 18 deletions src/compression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ function _compress!(A::Matrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, rcl::Cluster
R, Brow = _compress_block!(Brow, atol, rtol)
R1 = R[1:rm1, :]
R2 = R[rm1+1:end, :]

X = copy(Bcol')
W, Bcol = _compress_block!(copy(Bcol'), atol, rtol);
Bcol = copy(Bcol')
W1 = W[1:rn1, :]
W2 = W[rn1+1:end, :]

hssA = HssMatrix(A11, A22, B12, B21, R1, W1, R2, W2, rootnode)
else
hssA = HssMatrix(A11, A22, B12, B21, rootnode)
Expand All @@ -148,7 +148,9 @@ julia> hssA = recompress!(hssA, atol=1e-3, rtol=1e-3)
function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args...) where T
opts = copy(opts; args...)
chkopts!(opts)
rtol = opts.rtol; atol = opts.atol;
rtol = opts.rtol
atol = opts.atol
multithreaded = opts.multithreaded

if isleaf(hssA); return hssA; end

Expand Down Expand Up @@ -192,17 +194,18 @@ function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args...
end

# call recompression recursively
context = RecursionTools.RecursionContext(multithreaded)
if isbranch(hssA.A11)
_recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol)
_recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol, context)
end
if isbranch(hssA.A22)
_recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol)
_recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol, context)
end

return hssA
end

function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol) where T
function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol, context) where {T}
# compress B12
Brow1 = [hssA.B12 hssA.R1*Brow]
Bcol2 = [hssA.B12' hssA.W2*Bcol]
Expand Down Expand Up @@ -245,20 +248,25 @@ function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol
end

# call recompression recursively
if isbranch(hssA.A11)
Brow1 = hcat(hssA.B12, S2[:,rn2+1:end])
Bcol1 = hcat(hssA.B21', T2[:,rm2+1:end])
_recompress!(hssA.A11, Brow1, Bcol1, atol, rtol)
end
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_recursive_compression_A11, (hssA, S2, T2, rn2, rm2, atol, rtol, newContext), context)
if isbranch(hssA.A22)
Brow2 = hcat(hssA.B21, S1[:,rn1+1:end])
Bcol2 = hcat(hssA.B12', T1[:,rm1+1:end])
_recompress!(hssA.A22, Brow2, Bcol2, atol, rtol)
Brow2 = hcat(hssA.B21, S1[:, rn1+1:end])
Bcol2 = hcat(hssA.B12', T1[:, rm1+1:end])
_recompress!(hssA.A22, Brow2, Bcol2, atol, rtol, newContext)
end

RecursionTools.wait(task)
return hssA
end

function _recursive_compression_A11(hssA, S2, T2, rn2, rm2, atol, rtol, context)
if isbranch(hssA.A11)
Brow1 = hcat(hssA.B12, S2[:, rn2+1:end])
Bcol1 = hcat(hssA.B21', T2[:, rm2+1:end])
_recompress!(hssA.A11, Brow1, Bcol1, atol, rtol, context)
end
end

"""
randcompress(A, rcl, ccl, kest; args...)

Expand Down Expand Up @@ -379,7 +387,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr
m1, n1 = hssA.sz1; m2, n2 = hssA.sz2
hssA.A11, Scol1, Srow1, Ωcol1, Ωrow1, Jcol1, Jrow1, U1, V1 = _randcompress!(hssA.A11, A, Scol[1:m1, :], Srow[1:n1, :], Ωcol[1:n1, :], Ωrow[1:m1, :], ro, co, atol, rtol)
hssA.A22, Scol2, Srow2, Ωcol2, Ωrow2, Jcol2, Jrow2, U2, V2 = _randcompress!(hssA.A22, A, Scol[m1+1:end, :], Srow[n1+1:end, :], Ωcol[n1+1:end, :], Ωrow[m1+1:end, :], ro+m1, co+n1, atol, rtol)

# update the sampling matrix based on the extracted generators
Ωcol2 = V2' * Ωcol2
Ωcol1 = V1' * Ωcol1
Expand Down Expand Up @@ -416,7 +424,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr
Scol = Scol[Jcolloc, :]
Jcol = Jcol[Jcolloc]
U = [hssA.R1; hssA.R2]

# take care of the rows
Xrow, Jrowloc = _interpolate(Srow', atol, rtol)
hssA.W1 = Xrow[:, 1:size(Srow1, 1)]'
Expand Down Expand Up @@ -447,7 +455,7 @@ end
function hss_blkdiag(A::AbstractMatOrLinOp{T}, rcl::ClusterTree, ccl::ClusterTree; rootnode=true) where T
m = length(rcl.data); n = length(ccl.data)
if isleaf(rcl) # only check row cluster as we have already checked cluster equality

D = convert(Matrix{T}, A[rcl.data, ccl.data])
if rootnode
hssA = HssMatrix(D)
Expand Down
49 changes: 30 additions & 19 deletions src/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,21 @@ end
#offdiag(hssA::HssNode, ::Val{:upper}) = generators(hssA.A11)[1]*hssA.B12*generators(hssA.A11)[2]'

## orthogonalize generators
function orthonormalize_generators!(hssA::HssMatrix{T}) where T
function orthonormalize_generators!(hssA::HssMatrix{T}; multithreaded::Bool = HssOptions().multithreaded) where T
return orthonormalize_generators!(hssA, RecursionTools.RecursionContext(multithreaded) )
end
function orthonormalize_generators!(hssA::HssMatrix{T}, context::RecursionTools.RecursionContext) where T
if isleaf(hssA)
U1 = pqrfact!(hssA.A11.U, sketch=:none); hssA.U = Matrix(U1.Q)
V1 = pqrfact!(hssA.A11.V, sketch=:none); hssA.V = Matrix(V1.Q)
else
if isleaf(hssA.A11)
U1 = pqrfact!(hssA.A11.U, sketch=:none); hssA.A11.U = Matrix(U1.Q)
V1 = pqrfact!(hssA.A11.V, sketch=:none); hssA.A11.V = Matrix(V1.Q)
else
hssA.A11 = orthonormalize_generators!(hssA.A11)
U1 = pqrfact!([hssA.A11.R1; hssA.A11.R2], sketch=:none)
V1 = pqrfact!([hssA.A11.W1; hssA.A11.W2], sketch=:none)
rm1 = size(hssA.A11.R1, 1)
R = Matrix(U1.Q)
hssA.A11.R1 = R[1:rm1,:]
hssA.A11.R2 = R[rm1+1:end,:]
rn1 = size(hssA.A11.W1, 1)
W = Matrix(V1.Q)
hssA.A11.W1 = W[1:rn1,:]
hssA.A11.W2 = W[rn1+1:end,:]
end

newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_orthogonalise_A11!, (hssA,newContext), context)
if isleaf(hssA.A22)
U2 = pqrfact!(hssA.A22.U, sketch=:none); hssA.A22.U = Matrix(U2.Q)
V2 = pqrfact!(hssA.A22.V, sketch=:none); hssA.A22.V = Matrix(V2.Q)
else
hssA.A22 = orthonormalize_generators!(hssA.A22)
hssA.A22 = orthonormalize_generators!(hssA.A22, newContext)
U2 = pqrfact!([hssA.A22.R1; hssA.A22.R2], sketch=:none)
V2 = pqrfact!([hssA.A22.W1; hssA.A22.W2], sketch=:none)
rm1 = size(hssA.A22.R1, 1)
Expand All @@ -99,6 +87,7 @@ function orthonormalize_generators!(hssA::HssMatrix{T}) where T
hssA.A22.W1 = W[1:rn1,:]
hssA.A22.W2 = W[rn1+1:end,:]
end
U1, V1 = RecursionTools.fetch(task)
ipU1 = invperm(U1.p); ipV1 = invperm(V1.p)
ipU2 = invperm(U2.p); ipV2 = invperm(V2.p)

Expand All @@ -111,4 +100,26 @@ function orthonormalize_generators!(hssA::HssMatrix{T}) where T
hssA.W2 = V2.R[:, ipV2]*hssA.W2
end
return hssA
end

function _orthogonalise_A11!(hssA, context)
if isleaf(hssA.A11)
U1 = pqrfact!(hssA.A11.U, sketch=:none)
hssA.A11.U = Matrix(U1.Q)
V1 = pqrfact!(hssA.A11.V, sketch=:none)
hssA.A11.V = Matrix(V1.Q)
else
hssA.A11 = orthonormalize_generators!(hssA.A11, context)
U1 = pqrfact!([hssA.A11.R1; hssA.A11.R2], sketch=:none)
V1 = pqrfact!([hssA.A11.W1; hssA.A11.W2], sketch=:none)
rm1 = size(hssA.A11.R1, 1)
R = Matrix(U1.Q)
hssA.A11.R1 = R[1:rm1, :]
hssA.A11.R2 = R[rm1+1:end, :]
rn1 = size(hssA.A11.W1, 1)
W = Matrix(V1.Q)
hssA.A11.W1 = W[1:rn1, :]
hssA.A11.W2 = W[rn1+1:end, :]
end
return U1, V1
end
8 changes: 6 additions & 2 deletions src/hssmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,18 @@ for op in (:+,:-)
$op(a::Bool, hssA::HssMatrix{Bool}) = error("Not callable")
$op(L::HssMatrix{Bool}, a::Bool) = error("Not callable")

function $op(hssA::HssMatrix, hssB::HssMatrix)
function $op(hssA::HssMatrix, hssB::HssMatrix,
context::RecursionTools.RecursionContext = RecursionTools.RecursionContext(HssOptions().multithreaded))
size(hssA) == size(hssB) || throw(DimensionMismatch("A has dimensions $(size(hssA)) but B has dimensions $(size(hssB))"))
if isleaf(hssA) && isleaf(hssB)
return HssMatrix($op(hssA.D, hssB.D), [hssA.U hssB.U], [hssA.V hssB.V], isroot(hssA) && isroot(hssB))
elseif isbranch(hssA) && isbranch(hssB)
hssA.sz1 == hssB.sz1 || throw(DimensionMismatch("A11 has dimensions $(hssA.sz1) but B11 has dimensions $(hssB.sz1)"))
hssA.sz2 == hssB.sz2 || throw(DimensionMismatch("A22 has dimensions $(hssA.sz2) but B22 has dimensions $(hssA.sz2)"))
return HssMatrix($op(hssA.A11, hssB.A11), $op(hssA.A22, hssB.A22), blkdiagm(hssA.B12, $op(hssB.B12)), blkdiagm(hssA.B21, $op(hssB.B21)),blkdiagm(hssA.R1, hssB.R1), blkdiagm(hssA.W1, hssB.W1), blkdiagm(hssA.R2, hssB.R2), blkdiagm(hssA.W2, hssB.W2), isroot(hssA) && isroot(hssB))
newContext = RecursionTools.updateAfterSpawn(context)
taskA11 = RecursionTools.spawn($op, (hssA.A11, hssB.A11, newContext), context)
taskA22 = RecursionTools.spawn($op, (hssA.A22, hssB.A22, newContext), context)
return HssMatrix(RecursionTools.fetch(taskA11), RecursionTools.fetch(taskA22), blkdiagm(hssA.B12, $op(hssB.B12)), blkdiagm(hssA.B21, $op(hssB.B21)),blkdiagm(hssA.R1, hssB.R1), blkdiagm(hssA.W1, hssB.W1), blkdiagm(hssA.R2, hssB.R2), blkdiagm(hssA.W2, hssB.W2), isroot(hssA) && isroot(hssB))
else
throw(ArgumentError("Cannot add leaves and branches."))
end
Expand Down
51 changes: 32 additions & 19 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,37 @@
*(hssA::HssMatrix, x::AbstractVector) = reshape(hssA * reshape(x, length(x), 1), length(x))

# implement low-level mul! routines
function mul!(C::AbstractMatrix, hssA::HssMatrix, B::AbstractMatrix, α::Number, β::Number)
function mul!(C::AbstractMatrix, hssA::HssMatrix, B::AbstractMatrix, α::Number, β::Number; multithreaded::Bool = HssOptions().multithreaded)
size(hssA,2) == size(B,1) || throw(DimensionMismatch("First dimension of B does not match second dimension of A. Expected $(size(A, 2)), got $(size(B, 1))"))
size(C) == (size(hssA,1), size(B,2)) || throw(DimensionMismatch("Dimensions of C don't match up with A and B."))
if isleaf(hssA)
return mul!(C, hssA.D, B, α, β)
else
context = RecursionTools.RecursionContext(multithreaded)
hssA = rooted(hssA) # this is to avoid top-generators getting in the way if this is called on a sub-block of the HSS matrix
Z = _matmatup(hssA, B) # saves intermediate steps of multiplication in a binary tree structure
return _matmatdown!(C, hssA, B, Z, nothing, α, β)
Z = _matmatup(hssA, B, context) # saves intermediate steps of multiplication in a binary tree structure
return _matmatdown!(C, hssA, B, Z, nothing, α, β, context)
end
end

## auxiliary functions for the fast multiplication algorithm
# post-ordered step of mat-vec
function _matmatup(hssA::HssMatrix, B::AbstractMatrix)
function _matmatup(hssA::HssMatrix, B::AbstractMatrix, context)
if isleaf(hssA)
return BinaryNode(hssA.V' * B)
else
n1 = hssA.sz1[2]
Z1 = _matmatup(hssA.A11, B[1:n1,:])
Z2 = _matmatup(hssA.A22, B[n1+1:end,:])
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn( _matmatup, (hssA.A11, B[1:n1,:],newContext), context)
Z2 = _matmatup(hssA.A22, B[n1+1:end,:],newContext)
Z1 = RecursionTools.fetch(task)
Z = BinaryNode(hssA.W1'*Z1.data .+ hssA.W2'*Z2.data, Z1, Z2)
return Z
end
end

function _matmatdown!(C::AbstractMatrix{T}, hssA::HssMatrix{T}, B::AbstractMatrix{T}, Z::BinaryNode{AT}, F::Union{AbstractMatrix{T}, Nothing}, α::Number, β::Number) where {T, AT<:AbstractArray{T}}
function _matmatdown!(C::AbstractMatrix{T}, hssA::HssMatrix{T}, B::AbstractMatrix{T}, Z::BinaryNode{AT}, F::Union{AbstractMatrix{T}, Nothing},
α::Number, β::Number, context) where {T, AT<:AbstractArray{T}}
if isleaf(hssA)
mul!(C, hssA.D, B , α, β)
if !isnothing(F); mul!(C, hssA.U, F , α, 1.); end
Expand All @@ -55,45 +59,52 @@ function _matmatdown!(C::AbstractMatrix{T}, hssA::HssMatrix{T}, B::AbstractMatri
F1 = hssA.B12 * Z.right.data
F2 = hssA.B21 * Z.left.data
end
_matmatdown!(@view(C[1:m1,:]), hssA.A11, @view(B[1:n1,:]), Z.left, F1, α, β)
_matmatdown!(@view(C[m1+1:end,:]), hssA.A22, @view(B[n1+1:end,:]), Z.right, F2, α, β)
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_matmatdown!, (@view(C[1:m1,:]), hssA.A11, @view(B[1:n1,:]), Z.left, F1, α, β, newContext),context)
_matmatdown!(@view(C[m1+1:end,:]), hssA.A22, @view(B[n1+1:end,:]), Z.right, F2, α, β, context)
RecursionTools.wait(task)
return C
end
end

## multiplication of two HSS matrices
function *(hssA::HssMatrix, hssB::HssMatrix)
function *(hssA::HssMatrix, hssB::HssMatrix, multithreaded::Bool = HssOptions().multithreaded)
if isleaf(hssA) && isleaf(hssA)
hssC = HssMatrix(hssA.D*hssB.D, hssA.U, hssB.V, true)
elseif isbranch(hssA) && isbranch(hssA)
hssA = rooted(hssA); hssB = rooted(hssB)
# implememnt cluster equality checks
#if cluster(hssA,2) != cluster(hssB,1); throw(DimensionMismatch("clusters of hssA and hssB must be matching")) end
Z = _matmatup(hssA, hssB) # saves intermediate steps of multiplication in a binary tree structure
context = RecursionTools.RecursionContext(multithreaded)
Z = _matmatup(hssA, hssB, context ) # saves intermediate steps of multiplication in a binary tree structure
F1 = hssA.B12 * Z.right.data * hssB.B21
F2 = hssA.B21 * Z.left.data * hssB.B12
B12 = blkdiagm(hssA.B12, hssB.B12)
B21 = blkdiagm(hssA.B21, hssB.B21)
A11 = _matmatdown!(hssA.A11, hssB.A11, Z.left, F1)
A22 = _matmatdown!(hssA.A22, hssB.A22, Z.right, F2)
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_matmatdown!,(hssA.A22, hssB.A22, Z.right, F2, newContext),context)
A11 = _matmatdown!(hssA.A11, hssB.A11, Z.left, F1, newContext)
A22 = RecursionTools.fetch(task)
hssC = HssMatrix(A11, A22, B12, B21, true)
else
error("Clusters don't seem to match")
end
return hssC
end

function _matmatup(hssA::HssMatrix{T}, hssB::HssMatrix{T}) where T<:Number
function _matmatup(hssA::HssMatrix{T}, hssB::HssMatrix{T}, context) where T<:Number
if isleaf(hssA) & isleaf(hssB)
return BinaryNode{Matrix{T}}(hssA.V' * hssB.U)
elseif isbranch(hssA) && isbranch(hssB)
Z1 = _matmatup(hssA.A11, hssB.A11)
Z2 = _matmatup(hssA.A22, hssB.A22)
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_matmatup, (hssA.A22, hssB.A22, newContext), context)
Z1 = _matmatup(hssA.A11, hssB.A11, context)
Z2 = fetch(task)
return BinaryNode(hssA.W1' * Z1.data * hssB.R1 + hssA.W2' * Z2.data * hssB.R2, Z1, Z2)
end
end

function _matmatdown!(hssA::HssMatrix{T}, hssB::HssMatrix{T}, Z::BinaryNode{Matrix{T}}, F::Matrix{T}) where T
function _matmatdown!(hssA::HssMatrix{T}, hssB::HssMatrix{T}, Z::BinaryNode{Matrix{T}}, F::Matrix{T}, context) where T
if isleaf(hssA) & isleaf(hssB)
D = hssA.D * hssB.D + hssA.U * F * hssB.V'
U = [hssA.U hssA.D * hssB.U]
Expand All @@ -109,8 +120,10 @@ function _matmatdown!(hssA::HssMatrix{T}, hssB::HssMatrix{T}, Z::BinaryNode{Matr
W1 = [hssA.W1 zeros(T,size(hssA.W1,1), size(hssB.W1,2)); hssB.B21' * Z.right.data' * hssA.W2 hssB.W1];
R2 = [hssA.R2 hssA.B21 * Z.left.data * hssB.R1; zeros(T,size(hssB.R2,1), size(hssA.R2,2)) hssB.R2];
W2 = [hssA.W2 zeros(T,size(hssA.W2,1), size(hssB.W2,2)); hssB.B12' * Z.left.data' * hssA.W1 hssB.W2];
A11 = _matmatdown!(hssA.A11, hssB.A11, Z.left, F1)
A22 = _matmatdown!(hssA.A22, hssB.A22, Z.right, F2)
newContext = RecursionTools.updateAfterSpawn(context)
task = RecursionTools.spawn(_matmatdown!,(hssA.A11, hssB.A11, Z.left, F1, newContext),context)
A22 = _matmatdown!(hssA.A22, hssB.A22, Z.right, F2, newContext)
A11 = RecursionTools.fetch(task)
return HssMatrix(A11, A22, B12, B21, R1, W1, R2, W2, false)
else
error("Clusters don't seem to match")
Expand Down
Loading
Loading