diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d278723..bc9e5a4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: diff --git a/src/HssMatrices.jl b/src/HssMatrices.jl index e235dbd..116a56d 100644 --- a/src/HssMatrices.jl +++ b/src/HssMatrices.jl @@ -56,8 +56,9 @@ module HssMatrices noversampling::Int stepsize::Int verbose::Bool + multithreaded::Bool end - + # set default values atol = 1e-9 rtol = 1e-9 @@ -65,17 +66,18 @@ module HssMatrices noversampling = 10 stepsize = 20 verbose = false + multithreaded = false # struct to pass around with options - function HssOptions(::Type{T}; args...) where T - opts = HssOptions(atol, rtol, leafsize, noversampling, stepsize, verbose) + function HssOptions(::Type{T}; args...) where {T} + opts = HssOptions(atol, rtol, leafsize, noversampling, stepsize, verbose, multithreaded) for (key, value) in args setfield!(opts, key, value) end opts end HssOptions(; args...) = HssOptions(Float64; args...) - + function copy(opts::HssOptions; args...) opts_ = HssOptions() for field in fieldnames(typeof(opts)) @@ -86,7 +88,7 @@ module HssMatrices end opts_ end - + function chkopts!(opts::HssOptions) opts.atol ≥ 0. || throw(ArgumentError("atol")) opts.rtol ≥ 0. || throw(ArgumentError("rtol")) @@ -104,9 +106,10 @@ module HssMatrices global noversampling = opts.noversampling global stepsize = opts.stepsize global verbose = opts.verbose + global multithreaded = opts.multithreaded return opts end - + include("recursiontools.jl") include("binarytree.jl") include("clustertree.jl") include("linearmap.jl") diff --git a/src/compression.jl b/src/compression.jl index 0723d7d..51324ad 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -118,13 +118,13 @@ function _compress!(A::Matrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, rcl::Cluster R, Brow = _compress_block!(Brow, atol, rtol) R1 = R[1:rm1, :] R2 = R[rm1+1:end, :] - + X = copy(Bcol') W, Bcol = _compress_block!(copy(Bcol'), atol, rtol); Bcol = copy(Bcol') W1 = W[1:rn1, :] W2 = W[rn1+1:end, :] - + hssA = HssMatrix(A11, A22, B12, B21, R1, W1, R2, W2, rootnode) else hssA = HssMatrix(A11, A22, B12, B21, rootnode) @@ -148,7 +148,9 @@ julia> hssA = recompress!(hssA, atol=1e-3, rtol=1e-3) function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args...) where T opts = copy(opts; args...) chkopts!(opts) - rtol = opts.rtol; atol = opts.atol; + rtol = opts.rtol + atol = opts.atol + multithreaded = opts.multithreaded if isleaf(hssA); return hssA; end @@ -192,17 +194,18 @@ function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args... end # call recompression recursively + context = RecursionTools.RecursionContext(multithreaded) if isbranch(hssA.A11) - _recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol) + _recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol, context) end if isbranch(hssA.A22) - _recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol) + _recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol, context) end return hssA end -function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol) where T +function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol, context) where {T} # compress B12 Brow1 = [hssA.B12 hssA.R1*Brow] Bcol2 = [hssA.B12' hssA.W2*Bcol] @@ -245,20 +248,25 @@ function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol end # call recompression recursively - if isbranch(hssA.A11) - Brow1 = hcat(hssA.B12, S2[:,rn2+1:end]) - Bcol1 = hcat(hssA.B21', T2[:,rm2+1:end]) - _recompress!(hssA.A11, Brow1, Bcol1, atol, rtol) - end + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn(_recursive_compression_A11, (hssA, S2, T2, rn2, rm2, atol, rtol, newContext), context) if isbranch(hssA.A22) - Brow2 = hcat(hssA.B21, S1[:,rn1+1:end]) - Bcol2 = hcat(hssA.B12', T1[:,rm1+1:end]) - _recompress!(hssA.A22, Brow2, Bcol2, atol, rtol) + Brow2 = hcat(hssA.B21, S1[:, rn1+1:end]) + Bcol2 = hcat(hssA.B12', T1[:, rm1+1:end]) + _recompress!(hssA.A22, Brow2, Bcol2, atol, rtol, newContext) end - + RecursionTools.wait(task) return hssA end +function _recursive_compression_A11(hssA, S2, T2, rn2, rm2, atol, rtol, context) + if isbranch(hssA.A11) + Brow1 = hcat(hssA.B12, S2[:, rn2+1:end]) + Bcol1 = hcat(hssA.B21', T2[:, rm2+1:end]) + _recompress!(hssA.A11, Brow1, Bcol1, atol, rtol, context) + end +end + """ randcompress(A, rcl, ccl, kest; args...) @@ -379,7 +387,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr m1, n1 = hssA.sz1; m2, n2 = hssA.sz2 hssA.A11, Scol1, Srow1, Ωcol1, Ωrow1, Jcol1, Jrow1, U1, V1 = _randcompress!(hssA.A11, A, Scol[1:m1, :], Srow[1:n1, :], Ωcol[1:n1, :], Ωrow[1:m1, :], ro, co, atol, rtol) hssA.A22, Scol2, Srow2, Ωcol2, Ωrow2, Jcol2, Jrow2, U2, V2 = _randcompress!(hssA.A22, A, Scol[m1+1:end, :], Srow[n1+1:end, :], Ωcol[n1+1:end, :], Ωrow[m1+1:end, :], ro+m1, co+n1, atol, rtol) - + # update the sampling matrix based on the extracted generators Ωcol2 = V2' * Ωcol2 Ωcol1 = V1' * Ωcol1 @@ -416,7 +424,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr Scol = Scol[Jcolloc, :] Jcol = Jcol[Jcolloc] U = [hssA.R1; hssA.R2] - + # take care of the rows Xrow, Jrowloc = _interpolate(Srow', atol, rtol) hssA.W1 = Xrow[:, 1:size(Srow1, 1)]' @@ -447,7 +455,7 @@ end function hss_blkdiag(A::AbstractMatOrLinOp{T}, rcl::ClusterTree, ccl::ClusterTree; rootnode=true) where T m = length(rcl.data); n = length(ccl.data) if isleaf(rcl) # only check row cluster as we have already checked cluster equality - + D = convert(Matrix{T}, A[rcl.data, ccl.data]) if rootnode hssA = HssMatrix(D) diff --git a/src/generators.jl b/src/generators.jl index 7b92f6c..66ee876 100644 --- a/src/generators.jl +++ b/src/generators.jl @@ -61,33 +61,21 @@ end #offdiag(hssA::HssNode, ::Val{:upper}) = generators(hssA.A11)[1]*hssA.B12*generators(hssA.A11)[2]' ## orthogonalize generators -function orthonormalize_generators!(hssA::HssMatrix{T}) where T +function orthonormalize_generators!(hssA::HssMatrix{T}; multithreaded::Bool = HssOptions().multithreaded) where T + return orthonormalize_generators!(hssA, RecursionTools.RecursionContext(multithreaded) ) +end +function orthonormalize_generators!(hssA::HssMatrix{T}, context::RecursionTools.RecursionContext) where T if isleaf(hssA) U1 = pqrfact!(hssA.A11.U, sketch=:none); hssA.U = Matrix(U1.Q) V1 = pqrfact!(hssA.A11.V, sketch=:none); hssA.V = Matrix(V1.Q) else - if isleaf(hssA.A11) - U1 = pqrfact!(hssA.A11.U, sketch=:none); hssA.A11.U = Matrix(U1.Q) - V1 = pqrfact!(hssA.A11.V, sketch=:none); hssA.A11.V = Matrix(V1.Q) - else - hssA.A11 = orthonormalize_generators!(hssA.A11) - U1 = pqrfact!([hssA.A11.R1; hssA.A11.R2], sketch=:none) - V1 = pqrfact!([hssA.A11.W1; hssA.A11.W2], sketch=:none) - rm1 = size(hssA.A11.R1, 1) - R = Matrix(U1.Q) - hssA.A11.R1 = R[1:rm1,:] - hssA.A11.R2 = R[rm1+1:end,:] - rn1 = size(hssA.A11.W1, 1) - W = Matrix(V1.Q) - hssA.A11.W1 = W[1:rn1,:] - hssA.A11.W2 = W[rn1+1:end,:] - end - + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn(_orthogonalise_A11!, (hssA,newContext), context) if isleaf(hssA.A22) U2 = pqrfact!(hssA.A22.U, sketch=:none); hssA.A22.U = Matrix(U2.Q) V2 = pqrfact!(hssA.A22.V, sketch=:none); hssA.A22.V = Matrix(V2.Q) else - hssA.A22 = orthonormalize_generators!(hssA.A22) + hssA.A22 = orthonormalize_generators!(hssA.A22, newContext) U2 = pqrfact!([hssA.A22.R1; hssA.A22.R2], sketch=:none) V2 = pqrfact!([hssA.A22.W1; hssA.A22.W2], sketch=:none) rm1 = size(hssA.A22.R1, 1) @@ -99,6 +87,7 @@ function orthonormalize_generators!(hssA::HssMatrix{T}) where T hssA.A22.W1 = W[1:rn1,:] hssA.A22.W2 = W[rn1+1:end,:] end + U1, V1 = RecursionTools.fetch(task) ipU1 = invperm(U1.p); ipV1 = invperm(V1.p) ipU2 = invperm(U2.p); ipV2 = invperm(V2.p) @@ -111,4 +100,26 @@ function orthonormalize_generators!(hssA::HssMatrix{T}) where T hssA.W2 = V2.R[:, ipV2]*hssA.W2 end return hssA +end + +function _orthogonalise_A11!(hssA, context) + if isleaf(hssA.A11) + U1 = pqrfact!(hssA.A11.U, sketch=:none) + hssA.A11.U = Matrix(U1.Q) + V1 = pqrfact!(hssA.A11.V, sketch=:none) + hssA.A11.V = Matrix(V1.Q) + else + hssA.A11 = orthonormalize_generators!(hssA.A11, context) + U1 = pqrfact!([hssA.A11.R1; hssA.A11.R2], sketch=:none) + V1 = pqrfact!([hssA.A11.W1; hssA.A11.W2], sketch=:none) + rm1 = size(hssA.A11.R1, 1) + R = Matrix(U1.Q) + hssA.A11.R1 = R[1:rm1, :] + hssA.A11.R2 = R[rm1+1:end, :] + rn1 = size(hssA.A11.W1, 1) + W = Matrix(V1.Q) + hssA.A11.W1 = W[1:rn1, :] + hssA.A11.W2 = W[rn1+1:end, :] + end + return U1, V1 end \ No newline at end of file diff --git a/src/hssmatrix.jl b/src/hssmatrix.jl index 7bed914..49a288d 100644 --- a/src/hssmatrix.jl +++ b/src/hssmatrix.jl @@ -215,14 +215,18 @@ for op in (:+,:-) $op(a::Bool, hssA::HssMatrix{Bool}) = error("Not callable") $op(L::HssMatrix{Bool}, a::Bool) = error("Not callable") - function $op(hssA::HssMatrix, hssB::HssMatrix) + function $op(hssA::HssMatrix, hssB::HssMatrix, + context::RecursionTools.RecursionContext = RecursionTools.RecursionContext(HssOptions().multithreaded)) size(hssA) == size(hssB) || throw(DimensionMismatch("A has dimensions $(size(hssA)) but B has dimensions $(size(hssB))")) if isleaf(hssA) && isleaf(hssB) return HssMatrix($op(hssA.D, hssB.D), [hssA.U hssB.U], [hssA.V hssB.V], isroot(hssA) && isroot(hssB)) elseif isbranch(hssA) && isbranch(hssB) hssA.sz1 == hssB.sz1 || throw(DimensionMismatch("A11 has dimensions $(hssA.sz1) but B11 has dimensions $(hssB.sz1)")) hssA.sz2 == hssB.sz2 || throw(DimensionMismatch("A22 has dimensions $(hssA.sz2) but B22 has dimensions $(hssA.sz2)")) - return HssMatrix($op(hssA.A11, hssB.A11), $op(hssA.A22, hssB.A22), blkdiagm(hssA.B12, $op(hssB.B12)), blkdiagm(hssA.B21, $op(hssB.B21)),blkdiagm(hssA.R1, hssB.R1), blkdiagm(hssA.W1, hssB.W1), blkdiagm(hssA.R2, hssB.R2), blkdiagm(hssA.W2, hssB.W2), isroot(hssA) && isroot(hssB)) + newContext = RecursionTools.updateAfterSpawn(context) + taskA11 = RecursionTools.spawn($op, (hssA.A11, hssB.A11, newContext), context) + taskA22 = RecursionTools.spawn($op, (hssA.A22, hssB.A22, newContext), context) + return HssMatrix(RecursionTools.fetch(taskA11), RecursionTools.fetch(taskA22), blkdiagm(hssA.B12, $op(hssB.B12)), blkdiagm(hssA.B21, $op(hssB.B21)),blkdiagm(hssA.R1, hssB.R1), blkdiagm(hssA.W1, hssB.W1), blkdiagm(hssA.R2, hssB.R2), blkdiagm(hssA.W2, hssB.W2), isroot(hssA) && isroot(hssB)) else throw(ArgumentError("Cannot add leaves and branches.")) end diff --git a/src/matmul.jl b/src/matmul.jl index 86c22a0..759ce54 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -15,33 +15,37 @@ *(hssA::HssMatrix, x::AbstractVector) = reshape(hssA * reshape(x, length(x), 1), length(x)) # implement low-level mul! routines -function mul!(C::AbstractMatrix, hssA::HssMatrix, B::AbstractMatrix, α::Number, β::Number) +function mul!(C::AbstractMatrix, hssA::HssMatrix, B::AbstractMatrix, α::Number, β::Number; multithreaded::Bool = HssOptions().multithreaded) size(hssA,2) == size(B,1) || throw(DimensionMismatch("First dimension of B does not match second dimension of A. Expected $(size(A, 2)), got $(size(B, 1))")) size(C) == (size(hssA,1), size(B,2)) || throw(DimensionMismatch("Dimensions of C don't match up with A and B.")) if isleaf(hssA) return mul!(C, hssA.D, B, α, β) else + context = RecursionTools.RecursionContext(multithreaded) hssA = rooted(hssA) # this is to avoid top-generators getting in the way if this is called on a sub-block of the HSS matrix - Z = _matmatup(hssA, B) # saves intermediate steps of multiplication in a binary tree structure - return _matmatdown!(C, hssA, B, Z, nothing, α, β) + Z = _matmatup(hssA, B, context) # saves intermediate steps of multiplication in a binary tree structure + return _matmatdown!(C, hssA, B, Z, nothing, α, β, context) end end ## auxiliary functions for the fast multiplication algorithm # post-ordered step of mat-vec -function _matmatup(hssA::HssMatrix, B::AbstractMatrix) +function _matmatup(hssA::HssMatrix, B::AbstractMatrix, context) if isleaf(hssA) return BinaryNode(hssA.V' * B) else n1 = hssA.sz1[2] - Z1 = _matmatup(hssA.A11, B[1:n1,:]) - Z2 = _matmatup(hssA.A22, B[n1+1:end,:]) + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn( _matmatup, (hssA.A11, B[1:n1,:],newContext), context) + Z2 = _matmatup(hssA.A22, B[n1+1:end,:],newContext) + Z1 = RecursionTools.fetch(task) Z = BinaryNode(hssA.W1'*Z1.data .+ hssA.W2'*Z2.data, Z1, Z2) return Z end end -function _matmatdown!(C::AbstractMatrix{T}, hssA::HssMatrix{T}, B::AbstractMatrix{T}, Z::BinaryNode{AT}, F::Union{AbstractMatrix{T}, Nothing}, α::Number, β::Number) where {T, AT<:AbstractArray{T}} +function _matmatdown!(C::AbstractMatrix{T}, hssA::HssMatrix{T}, B::AbstractMatrix{T}, Z::BinaryNode{AT}, F::Union{AbstractMatrix{T}, Nothing}, + α::Number, β::Number, context) where {T, AT<:AbstractArray{T}} if isleaf(hssA) mul!(C, hssA.D, B , α, β) if !isnothing(F); mul!(C, hssA.U, F , α, 1.); end @@ -55,27 +59,32 @@ function _matmatdown!(C::AbstractMatrix{T}, hssA::HssMatrix{T}, B::AbstractMatri F1 = hssA.B12 * Z.right.data F2 = hssA.B21 * Z.left.data end - _matmatdown!(@view(C[1:m1,:]), hssA.A11, @view(B[1:n1,:]), Z.left, F1, α, β) - _matmatdown!(@view(C[m1+1:end,:]), hssA.A22, @view(B[n1+1:end,:]), Z.right, F2, α, β) + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn(_matmatdown!, (@view(C[1:m1,:]), hssA.A11, @view(B[1:n1,:]), Z.left, F1, α, β, newContext),context) + _matmatdown!(@view(C[m1+1:end,:]), hssA.A22, @view(B[n1+1:end,:]), Z.right, F2, α, β, context) + RecursionTools.wait(task) return C end end ## multiplication of two HSS matrices -function *(hssA::HssMatrix, hssB::HssMatrix) +function *(hssA::HssMatrix, hssB::HssMatrix, multithreaded::Bool = HssOptions().multithreaded) if isleaf(hssA) && isleaf(hssA) hssC = HssMatrix(hssA.D*hssB.D, hssA.U, hssB.V, true) elseif isbranch(hssA) && isbranch(hssA) hssA = rooted(hssA); hssB = rooted(hssB) # implememnt cluster equality checks #if cluster(hssA,2) != cluster(hssB,1); throw(DimensionMismatch("clusters of hssA and hssB must be matching")) end - Z = _matmatup(hssA, hssB) # saves intermediate steps of multiplication in a binary tree structure + context = RecursionTools.RecursionContext(multithreaded) + Z = _matmatup(hssA, hssB, context ) # saves intermediate steps of multiplication in a binary tree structure F1 = hssA.B12 * Z.right.data * hssB.B21 F2 = hssA.B21 * Z.left.data * hssB.B12 B12 = blkdiagm(hssA.B12, hssB.B12) B21 = blkdiagm(hssA.B21, hssB.B21) - A11 = _matmatdown!(hssA.A11, hssB.A11, Z.left, F1) - A22 = _matmatdown!(hssA.A22, hssB.A22, Z.right, F2) + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn(_matmatdown!,(hssA.A22, hssB.A22, Z.right, F2, newContext),context) + A11 = _matmatdown!(hssA.A11, hssB.A11, Z.left, F1, newContext) + A22 = RecursionTools.fetch(task) hssC = HssMatrix(A11, A22, B12, B21, true) else error("Clusters don't seem to match") @@ -83,17 +92,19 @@ function *(hssA::HssMatrix, hssB::HssMatrix) return hssC end -function _matmatup(hssA::HssMatrix{T}, hssB::HssMatrix{T}) where T<:Number +function _matmatup(hssA::HssMatrix{T}, hssB::HssMatrix{T}, context) where T<:Number if isleaf(hssA) & isleaf(hssB) return BinaryNode{Matrix{T}}(hssA.V' * hssB.U) elseif isbranch(hssA) && isbranch(hssB) - Z1 = _matmatup(hssA.A11, hssB.A11) - Z2 = _matmatup(hssA.A22, hssB.A22) + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn(_matmatup, (hssA.A22, hssB.A22, newContext), context) + Z1 = _matmatup(hssA.A11, hssB.A11, context) + Z2 = fetch(task) return BinaryNode(hssA.W1' * Z1.data * hssB.R1 + hssA.W2' * Z2.data * hssB.R2, Z1, Z2) end end -function _matmatdown!(hssA::HssMatrix{T}, hssB::HssMatrix{T}, Z::BinaryNode{Matrix{T}}, F::Matrix{T}) where T +function _matmatdown!(hssA::HssMatrix{T}, hssB::HssMatrix{T}, Z::BinaryNode{Matrix{T}}, F::Matrix{T}, context) where T if isleaf(hssA) & isleaf(hssB) D = hssA.D * hssB.D + hssA.U * F * hssB.V' U = [hssA.U hssA.D * hssB.U] @@ -109,8 +120,10 @@ function _matmatdown!(hssA::HssMatrix{T}, hssB::HssMatrix{T}, Z::BinaryNode{Matr W1 = [hssA.W1 zeros(T,size(hssA.W1,1), size(hssB.W1,2)); hssB.B21' * Z.right.data' * hssA.W2 hssB.W1]; R2 = [hssA.R2 hssA.B21 * Z.left.data * hssB.R1; zeros(T,size(hssB.R2,1), size(hssA.R2,2)) hssB.R2]; W2 = [hssA.W2 zeros(T,size(hssA.W2,1), size(hssB.W2,2)); hssB.B12' * Z.left.data' * hssA.W1 hssB.W2]; - A11 = _matmatdown!(hssA.A11, hssB.A11, Z.left, F1) - A22 = _matmatdown!(hssA.A22, hssB.A22, Z.right, F2) + newContext = RecursionTools.updateAfterSpawn(context) + task = RecursionTools.spawn(_matmatdown!,(hssA.A11, hssB.A11, Z.left, F1, newContext),context) + A22 = _matmatdown!(hssA.A22, hssB.A22, Z.right, F2, newContext) + A11 = RecursionTools.fetch(task) return HssMatrix(A11, A22, B12, B21, R1, W1, R2, W2, false) else error("Clusters don't seem to match") diff --git a/src/recursiontools.jl b/src/recursiontools.jl new file mode 100644 index 0000000..9a23de6 --- /dev/null +++ b/src/recursiontools.jl @@ -0,0 +1,72 @@ +module RecursionTools + +import Base.Threads: fetch +import Base.Threads: wait + +export spawn +export wait +export fetch + +abstract type AbstractRecursion end + +struct ThreadedRecursion{T} <: AbstractRecursion + task::T +end + +struct SimpleRecursion{T,G} <: AbstractRecursion + fun::T + args::G +end + +struct RecursionContext + maxSplittingDepth::Int + depth::Int +end + +function RecursionContext(multithreaded::Bool) + if multithreaded + maxSplittingDepth = Int(round(log2(Threads.nthreads())))+1 + else + maxSplittingDepth = 0 + end + return RecursionContext(maxSplittingDepth,0) +end + +function hasToSpawn(context::RecursionContext) + return context.depth= 0.001 ? 1 / (x - y) : 10000.0 + A = [K(x, y) for x = -1:0.001:1, y = -1:0.001:1] + b = randn(size(A, 2), 5) + + # test the simple implementation of cluster trees + m, n = size(A) + #lsz = 64; + lsz = 64 + rcl = bisection_cluster(1:m, leafsize=lsz) + ccl = bisection_cluster(1:n, leafsize=lsz) + + hssA = compress(A, rcl, ccl) + + # time conversion to dense + println("Benchmarking full...") + @btime full($hssA) + + # time access + println("Benchmarking getindex...") + ii = randperm(100) + jj = randperm(100) + @btime Aij = $hssA[$ii, $jj] + + # time compression + println("Benchmarking compression...") + @btime hssA = compress($A, $rcl, $ccl) + + # time compression + println("Benchmarking randomized compression...") + @btime hssA = randcompress_adaptive($A, $rcl, $ccl) + + println("Benchmarking re-compression...") + hssB = copy(hssA) + @btime hssA = recompress!($hssB; atol=1e-3, rtol=1e-3) + + println("Benchmarking proper...") + hssB = hssA + hssA + @btime orthonormalize_generators!($hssB) + + println("Benchmarking addition...") + @btime hssC = $hssA + $hssA + + # time matvec + println("Benchmarking matvec...") + x = randn(size(A, 2), 10) + @btime y = $hssA * $x + + + println("Benchmarking matrix products...") + @btime y = $hssA * $hssB + + # time ulvfactsolve + println("Benchmarking ulvfactsolve...") + b = randn(size(A, 2), 10) + @btime x = ulvfactsolve($hssA, $b) + + # time hssldivide + println("Benchmarking hssldivide...") + hssX = compress(1.0 * Matrix(I, n, n), ccl, ccl, atol=1e-3, rtol=1e-3) # this should probably be one constructor + @btime hssC = ldiv!(copy($hssA), $hssX) +end \ No newline at end of file diff --git a/test/benchmarks_large_matrices.jl b/test/benchmarks_large_matrices.jl new file mode 100644 index 0000000..aec0d41 --- /dev/null +++ b/test/benchmarks_large_matrices.jl @@ -0,0 +1,125 @@ +#using Revise +using HssMatrices +using LinearAlgebra +using BenchmarkTools +using Random +using DataFrames +using SparseArrays +using Plots +Random.seed!(123) + +function benchmark_full(A) + return @benchmark full($A) +end + +function benchmark_getindex(A) + ii = randperm(100) + jj = randperm(100) + return @benchmark Aij = $A[$ii, $jj] +end + +function benchmark_randcompress(A, rcl, ccl) + return @benchmark randcompress_adaptive($A, $rcl, $ccl) +end + +function benchmark_recompress(hssA) + hssB = copy(hssA) + return @benchmark hssA = recompress!($hssB; atol=1e-3, rtol=1e-3) +end + +function benchmark_proper(hssA) + hssB = hssA + hssA + return @benchmark orthonormalize_generators!($hssB) +end + +function benchmark_addition(hssA) + return @benchmark hssC = $hssA + $hssA +end + +function benchmark_multiplication(hssA) + x = randn(size(hssA, 2), 10) + return @benchmark y = $hssA * $x +end + +function benchmark_ulvfactsolve(hssA) + b = randn(size(hssA, 2), 10) + return @benchmark x = ulvfactsolve($hssA, $b) +end + +function benchmark_hssldivide(hssA, ccl) + n = size(hssA, 1) + hssX = randcompress_adaptive(1.0 * SparseMatrixCSC{Float64}(I, n, n), ccl, ccl, atol=1e-3, rtol=1e-3) + return @benchmark hssC = ldiv!(copy($hssA), $hssX) +end + +function construct_test_matrix() # Define the test matrix and clusters + invA = spdiagm((k => [Float64(k) for i in 1:10000-abs(k)] for k = -10:10)...) + m, n = size(invA) + lsz = 64 + rcl = bisection_cluster(1:m, leafsize=lsz) + ccl = bisection_cluster(1:n, leafsize=lsz) + hssInvA = randcompress_adaptive(invA, rcl, ccl) + hssId = randcompress_adaptive(1.0 * SparseMatrixCSC{Float64}(I, n, n), ccl, ccl, atol=1e-6, rtol=1e-6) + hssA = ldiv!(copy(hssInvA), hssId) + return hssA, invA, rcl, ccl +end + +# Function to run all benchmarks +function run_benchmarks(multithreaded, blas_threads) + @show multithreaded + @show blas_threads + if multithreaded + @show Threads.nthreads() + end + BLAS.set_num_threads(blas_threads) + HssMatrices.setopts(multithreaded=multithreaded) + hssA, invA, rcl, ccl = construct_test_matrix() + + # Run benchmarks + results = Dict() + println("Running benchmark: getindex") + results[:getindex] = benchmark_getindex(hssA) + println("Running benchmark: randcompress") + results[:randcompress] = benchmark_randcompress(invA, rcl, ccl) + println("Running benchmark: recompress") + results[:recompress] = benchmark_recompress(hssA) + println("Running benchmark: proper") + results[:proper] = benchmark_proper(hssA) + println("Running benchmark: addition") + results[:addition] = benchmark_addition(hssA) + println("Running benchmark: multiplication") + results[:multiplication] = benchmark_multiplication(hssA) + println("Running benchmark: ulvfactsolve") + results[:ulvfactsolve] = benchmark_ulvfactsolve(hssA) + println("Running benchmark: hssldivide") + results[:hssldivide] = benchmark_hssldivide(hssA, ccl) + + return results +end + +#First evaluate the nominal single-threaded performance +reference_results = run_benchmarks(false, Sys.CPU_THREADS) +#Compute the performance for the multithreaded scenario using only 1 openBlas thread +multithreaded_results = run_benchmarks(true, 1) +speedup_results = Dict() +for eachKey in keys(reference_results) + speedup_results[eachKey] = median(reference_results[eachKey]).time / median(multithreaded_results[eachKey]).time +end +speedup_plot = bar( + string.(keys(speedup_results)), + collect(values(speedup_results)), + xlabel="Benchmark", + ylabel="Speedup", + title="Julia threads compared to OpenBLAS threads", + size=(800, 600); # Increase the size of the plot + legend=false +) +savefig("speedup_plot.png") + +hssA = construct_test_matrix()[1] +plt_right = plotranks(hssA) +savefig("hssranks_hssA.png") + + + + diff --git a/test/runtests.jl b/test/runtests.jl index 4b79c7e..7435a9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using Test, LinearAlgebra, HssMatrices +BLAS.set_num_threads(2) @testset "options" begin HssMatrices.setopts(atol=1e-6) @@ -11,70 +12,92 @@ using Test, LinearAlgebra, HssMatrices @test HssOptions().noversampling == 5 HssMatrices.setopts(stepsize=10) @test HssOptions().stepsize == 10 + HssMatrices.setopts(multithreaded = true) + @test HssOptions().multithreaded end -@testset for T in [Float32, Float64, ComplexF32,ComplexF64] - - # generate Cauchy matrix - K(x,y) = (x-y) > 0 ? 0.001/(x-y) : 2. - A = [ T(K(x,y)) for x=-1:0.001:1, y=-1:0.001:1]; - if T <: Complex - A = A + 1im .* A +@testset "recursion tools" begin + using HssMatrices.RecursionTools + function foo(x,y) + return x + y + end + for threaded in (false, true) + a,b = 2,3 + recursion = spawn(foo,(a,b), threaded) + @test fetch(recursion) == foo(a,b) end - m, n = size(A) - U = randn(T,n,3); V = randn(T,n,3) - # "safety" factor - c = 50. - tol = 1E-6 - HssMatrices.setopts(atol=tol) - HssMatrices.setopts(rtol=tol) +end - rcl = bisection_cluster(1:m) - ccl = bisection_cluster(1:n) +@testset for T in [Float32, Float64, ComplexF32, ComplexF64] + @testset for multithreaded = (false, true) + # generate Cauchy matrix + K(x,y) = abs(x-y) > 0 ? 0.001/(x-y) : 2. + n = 2000 + dt = 2/n + A = [ T(K(x,y)) for x = -1:dt:1, y = -1:dt:1]; + if T <: Complex + A = A + 1im .* A + end + m, n = size(A) + U = randn(T,n,3); V = randn(T,n,3) + # "safety" factor + c = 50. + tol = max(10*eps(real(T)), 1E-6) + HssMatrices.setopts(atol=tol) + HssMatrices.setopts(rtol=tol) + HssMatrices.setopts(multithreaded=multithreaded) + @test HssOptions().atol == tol + @test HssOptions().rtol == tol + @test HssOptions().multithreaded == multithreaded - @testset "compression" begin - hssA = compress(A, rcl, ccl); - @test norm(A - full(hssA))/norm(A) ≤ c*HssOptions().rtol || norm(A - full(hssA)) ≤ c*HssOptions().atol - hssB = randcompress_adaptive(A, rcl, ccl); rk = hssrank(hssB) - @test norm(A - full(hssB))/norm(A) ≤ c*HssOptions().rtol || norm(A - full(hssB)) ≤ c*HssOptions().atol - hssB = recompress!(hssB) - @test norm(A - full(hssB))/norm(A) ≤ c*HssOptions().rtol || norm(A - full(hssB)) ≤ c*HssOptions().atol - @test hssrank(hssB) ≤ rk - hssC = lowrank2hss(U, V, ccl, ccl) - @test norm(U*V' - full(hssC))/norm(U*V') ≤ c*tol - @test hssrank(hssC) == 3 - end; + rcl = bisection_cluster(1:m) + ccl = bisection_cluster(1:n) - @testset "random access" begin - I = rand(1:size(A,1), 10) - J = rand(1:size(A,1), 10) - testAccuracy(expected, result) = @test norm(expected - result)/norm(expected) ≤ c*HssOptions().rtol || - norm(expected - result) ≤ c*HssOptions().atol - testAccuracy(A[I,J], compress(A, rcl, ccl)[I,J]) - testAccuracy(A[I,:], compress(A, rcl, ccl)[I,:]) - testAccuracy(A[:,J], compress(A, rcl, ccl)[:,J]) - end; - @testset "arithmetic" begin - hssA = compress(A, rcl, ccl); - hssC = lowrank2hss(U, V, ccl, ccl) - @test norm(A' - full(hssA'))/norm(A) ≤ c*HssOptions().rtol || norm(A' - full(hssA')) ≤ c*HssOptions().atol - x = randn(T, n, 5); - @test norm(A*x - hssA*x)/norm(A*x) ≤ c*HssOptions().rtol || norm(A*x - hssA*x) ≤ c*HssOptions().atol - @test norm(full(hssA*hssC) - (A*U)*V')/norm((A*U)*V') ≤ c*HssOptions().rtol || norm(A*x - hssA*x) ≤ c*HssOptions().atol - rhs = randn(T,n, 5); x = hssA\rhs; x0 = A\rhs; - @test norm(x0 - x)/norm(x0) ≤ c*HssOptions().rtol || norm(x0 - x) ≤ c*HssOptions().atol - Id(i,j) = Matrix{T}(i.*ones(length(j))' .== ones(length(i)).*j') - IdOp = LinearMap{T}(n, n, (y,_,x) -> x, (y,_,x) -> x, (i,j) -> Id(i,j)) - hssI = randcompress(IdOp, ccl, ccl, 0) - @test norm(full(hssA*hssI) - full(hssA))/norm(full(hssA)) ≤ c*tol - Ainv = inv(A) - @test norm(Ainv - full(hssA\hssI))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssA\hssI)) ≤ c*HssOptions().atol - @test norm(Ainv - full(hssI/hssA))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssI/hssA)) ≤ c*HssOptions().atol - hssA.A11 = prune_leaves!(hssA.A11) - hssI.A11 = prune_leaves!(hssI.A11) - @test norm(Ainv - full(hssA\hssI))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssA\hssI)) ≤ c*HssOptions().atol - @test norm(Ainv - full(hssI/hssA))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssI/hssA)) ≤ c*HssOptions().atol + @testset "compression" begin + hssA = compress(A, rcl, ccl); + @test norm(A - full(hssA))/norm(A) ≤ c*HssOptions().rtol || norm(A - full(hssA)) ≤ c*HssOptions().atol + hssB = randcompress_adaptive(A, rcl, ccl); rk = hssrank(hssB) + @test norm(A - full(hssB))/norm(A) ≤ c*HssOptions().rtol || norm(A - full(hssB)) ≤ c*HssOptions().atol + hssB = recompress!(hssB) + @test norm(A - full(hssB))/norm(A) ≤ c*HssOptions().rtol || norm(A - full(hssB)) ≤ c*HssOptions().atol + @test hssrank(hssB) ≤ rk + hssC = lowrank2hss(U, V, ccl, ccl) + @test norm(U*V' - full(hssC))/norm(U*V') ≤ c*tol + @test hssrank(hssC) == 3 + end + + @testset "random access" begin + I = rand(1:size(A,1), 10) + J = rand(1:size(A,1), 10) + testAccuracy(expected, result) = @test norm(expected - result)/norm(expected) ≤ c*HssOptions().rtol || + norm(expected - result) ≤ c*HssOptions().atol + testAccuracy(A[I,J], compress(A, rcl, ccl)[I,J]) + testAccuracy(A[I,:], compress(A, rcl, ccl)[I,:]) + testAccuracy(A[:,J], compress(A, rcl, ccl)[:,J]) + end; + + @testset "arithmetic" begin + hssA = compress(A, rcl, ccl); + hssC = lowrank2hss(U, V, ccl, ccl) + @test norm(A' - full(hssA'))/norm(A) ≤ c*HssOptions().rtol || norm(A' - full(hssA')) ≤ c*HssOptions().atol + x = randn(T, n, 5); + @test norm(A*x - hssA*x)/norm(A*x) ≤ c*HssOptions().rtol || norm(A*x - hssA*x) ≤ c*HssOptions().atol + @test norm(full(hssA*hssC) - (A*U)*V')/norm((A*U)*V') ≤ c*HssOptions().rtol || norm(A*x - hssA*x) ≤ c*HssOptions().atol + rhs = randn(T,n, 5); x = hssA\rhs; x0 = A\rhs; + @test norm(x0 - x)/norm(x0) ≤ c*HssOptions().rtol || norm(x0 - x) ≤ c*HssOptions().atol + Id(i,j) = Matrix{T}(i.*ones(length(j))' .== ones(length(i)).*j') + IdOp = LinearMap{T}(n, n, (y,_,x) -> x, (y,_,x) -> x, (i,j) -> Id(i,j)) + hssI = randcompress(IdOp, ccl, ccl, 0) + @test norm(full(hssA*hssI) - full(hssA))/norm(full(hssA)) ≤ c*tol + Ainv = inv(A) + @test norm(Ainv - full(hssA\hssI))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssA\hssI)) ≤ c*HssOptions().atol + @test norm(Ainv - full(hssI/hssA))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssI/hssA)) ≤ c*HssOptions().atol + hssA.A11 = prune_leaves!(hssA.A11) + hssI.A11 = prune_leaves!(hssI.A11) + @test norm(Ainv - full(hssA\hssI))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssA\hssI)) ≤ c*HssOptions().atol + @test norm(Ainv - full(hssI/hssA))/norm(Ainv) ≤ c*HssOptions().rtol || norm(Ainv - full(hssI/hssA)) ≤ c*HssOptions().atol + end end -end \ No newline at end of file +end