-
Notifications
You must be signed in to change notification settings - Fork 5
Feature/limited multithreading #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
56d9715
9bd0506
f362297
50bcd94
3807377
ab9d49e
cbb8db0
26229c1
39fc7f6
c02e8c0
4b7ac1d
67f96be
2b9ec03
9179484
9ab79fa
db9f5ed
2b3f626
da19cbe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,13 +118,13 @@ function _compress!(A::Matrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, rcl::Cluster | |
| R, Brow = _compress_block!(Brow, atol, rtol) | ||
| R1 = R[1:rm1, :] | ||
| R2 = R[rm1+1:end, :] | ||
|
|
||
| X = copy(Bcol') | ||
| W, Bcol = _compress_block!(copy(Bcol'), atol, rtol); | ||
| Bcol = copy(Bcol') | ||
| W1 = W[1:rn1, :] | ||
| W2 = W[rn1+1:end, :] | ||
|
|
||
| hssA = HssMatrix(A11, A22, B12, B21, R1, W1, R2, W2, rootnode) | ||
| else | ||
| hssA = HssMatrix(A11, A22, B12, B21, rootnode) | ||
|
|
@@ -148,7 +148,9 @@ julia> hssA = recompress!(hssA, atol=1e-3, rtol=1e-3) | |
| function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args...) where T | ||
| opts = copy(opts; args...) | ||
| chkopts!(opts) | ||
| rtol = opts.rtol; atol = opts.atol; | ||
| rtol = opts.rtol | ||
| atol = opts.atol | ||
| multithreaded = opts.multithreaded | ||
|
|
||
| if isleaf(hssA); return hssA; end | ||
|
|
||
|
|
@@ -192,17 +194,18 @@ function recompress!(hssA::HssMatrix{T}, opts::HssOptions=HssOptions(T); args... | |
| end | ||
|
|
||
| # call recompression recursively | ||
| context = RecursionTools.RecursionContext(multithreaded) | ||
| if isbranch(hssA.A11) | ||
| _recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol) | ||
| _recompress!(hssA.A11, hssA.B12, copy(hssA.B21'), atol, rtol, context) | ||
| end | ||
| if isbranch(hssA.A22) | ||
| _recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol) | ||
| _recompress!(hssA.A22, hssA.B21, copy(hssA.B12'), atol, rtol, context) | ||
| end | ||
|
|
||
| return hssA | ||
| end | ||
|
|
||
| function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol) where T | ||
| function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol, rtol, context) where {T} | ||
| # compress B12 | ||
| Brow1 = [hssA.B12 hssA.R1*Brow] | ||
| Bcol2 = [hssA.B12' hssA.W2*Bcol] | ||
|
|
@@ -245,20 +248,25 @@ function _recompress!(hssA::HssMatrix{T}, Brow::Matrix{T}, Bcol::Matrix{T}, atol | |
| end | ||
|
|
||
| # call recompression recursively | ||
| if isbranch(hssA.A11) | ||
| Brow1 = hcat(hssA.B12, S2[:,rn2+1:end]) | ||
| Bcol1 = hcat(hssA.B21', T2[:,rm2+1:end]) | ||
| _recompress!(hssA.A11, Brow1, Bcol1, atol, rtol) | ||
| end | ||
| newContext = RecursionTools.updateAfterSpawn(context) | ||
| task = RecursionTools.spawn(_recursive_compression_A11, (hssA, S2, T2, rn2, rm2, atol, rtol, newContext), context) | ||
| if isbranch(hssA.A22) | ||
| Brow2 = hcat(hssA.B21, S1[:,rn1+1:end]) | ||
| Bcol2 = hcat(hssA.B12', T1[:,rm1+1:end]) | ||
| _recompress!(hssA.A22, Brow2, Bcol2, atol, rtol) | ||
| Brow2 = hcat(hssA.B21, S1[:, rn1+1:end]) | ||
| Bcol2 = hcat(hssA.B12', T1[:, rm1+1:end]) | ||
| _recompress!(hssA.A22, Brow2, Bcol2, atol, rtol, newContext) | ||
| end | ||
|
|
||
| RecursionTools.wait(task) | ||
| return hssA | ||
| end | ||
|
|
||
| function _recursive_compression_A11(hssA, S2, T2, rn2, rm2, atol, rtol, context) | ||
| if isbranch(hssA.A11) | ||
| Brow1 = hcat(hssA.B12, S2[:, rn2+1:end]) | ||
| Bcol1 = hcat(hssA.B21', T2[:, rm2+1:end]) | ||
| _recompress!(hssA.A11, Brow1, Bcol1, atol, rtol, context) | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| randcompress(A, rcl, ccl, kest; args...) | ||
|
|
||
|
|
@@ -285,8 +293,8 @@ function randcompress(A::AbstractMatOrLinOp{T}, rcl::ClusterTree, ccl::ClusterTr | |
|
|
||
| # compute initial sampling | ||
| k = kest; r = opts.noversampling; | ||
| Ωcol = randn(n, k+r) | ||
| Ωrow = randn(m, k+r) | ||
| Ωcol = randn(T,n, k+r) | ||
|
||
| Ωrow = randn(T,m, k+r) | ||
| Scol = A*Ωcol # this should invoke the magic of the linearoperator.jl type | ||
| Srow = A'*Ωrow | ||
| hssA, _, _, _, _, _, _, _, _ = _randcompress!(hssA, A, Scol, Srow, Ωcol, Ωrow, 0, 0, opts.atol, opts.rtol; rootnode=true) | ||
|
|
@@ -379,7 +387,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr | |
| m1, n1 = hssA.sz1; m2, n2 = hssA.sz2 | ||
| hssA.A11, Scol1, Srow1, Ωcol1, Ωrow1, Jcol1, Jrow1, U1, V1 = _randcompress!(hssA.A11, A, Scol[1:m1, :], Srow[1:n1, :], Ωcol[1:n1, :], Ωrow[1:m1, :], ro, co, atol, rtol) | ||
| hssA.A22, Scol2, Srow2, Ωcol2, Ωrow2, Jcol2, Jrow2, U2, V2 = _randcompress!(hssA.A22, A, Scol[m1+1:end, :], Srow[n1+1:end, :], Ωcol[n1+1:end, :], Ωrow[m1+1:end, :], ro+m1, co+n1, atol, rtol) | ||
|
|
||
| # update the sampling matrix based on the extracted generators | ||
| Ωcol2 = V2' * Ωcol2 | ||
| Ωcol1 = V1' * Ωcol1 | ||
|
|
@@ -416,7 +424,7 @@ function _randcompress!(hssA::HssMatrix, A::AbstractMatOrLinOp, Scol::Matrix, Sr | |
| Scol = Scol[Jcolloc, :] | ||
| Jcol = Jcol[Jcolloc] | ||
| U = [hssA.R1; hssA.R2] | ||
|
|
||
| # take care of the rows | ||
| Xrow, Jrowloc = _interpolate(Srow', atol, rtol) | ||
| hssA.W1 = Xrow[:, 1:size(Srow1, 1)]' | ||
|
|
@@ -447,7 +455,7 @@ end | |
| function hss_blkdiag(A::AbstractMatOrLinOp{T}, rcl::ClusterTree, ccl::ClusterTree; rootnode=true) where T | ||
| m = length(rcl.data); n = length(ccl.data) | ||
| if isleaf(rcl) # only check row cluster as we have already checked cluster equality | ||
|
|
||
| D = convert(Matrix{T}, A[rcl.data, ccl.data]) | ||
| if rootnode | ||
| hssA = HssMatrix(D) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.