From 865439bf7c73ae46663e906399765759b0dc7a4a Mon Sep 17 00:00:00 2001 From: Patrick Jaap Date: Wed, 6 Aug 2025 12:31:39 +0200 Subject: [PATCH 1/4] Implement first demo of Lagrange Restrictions --- Project.toml | 7 +- examples/Example212_PeriodicBoundary2D.jl | 13 ++- examples/Example312_PeriodicBoundary3D.jl | 15 ++- src/ExtendableFEM.jl | 9 +- .../coupled_dofs_restriction.jl | 26 +++++ src/problemdescription.jl | 13 ++- src/restrictions.jl | 22 ++++ src/solvers.jl | 109 ++++++++++++++---- 8 files changed, 182 insertions(+), 32 deletions(-) create mode 100644 src/common_restrictions/coupled_dofs_restriction.jl create mode 100644 src/restrictions.jl diff --git a/Project.toml b/Project.toml index e932f510..6b6b1453 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,18 @@ name = "ExtendableFEM" uuid = "a722555e-65e0-4074-a036-ca7ce79a4aed" -version = "1.4" +version = "1.4.0" authors = ["Christian Merdon ", "Patrick Jaap "] [deps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ExtendableFEMBase = "12fb9182-3d4c-4424-8fd1-727a0899810c" ExtendableGrids = "cfc395e8-590f-11e8-1f13-43a2532b2fa8" ExtendableSparse = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GridVisualize = "5eed8a63-0fb0-45eb-886d-8d5a387d12b8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -24,6 +27,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] Aqua = "0.8" +BlockArrays = "1.7.0" CommonSolve = "0.2" DiffResults = "1" DocStringExtensions = "0.8,0.9" @@ -32,6 +36,7 @@ ExplicitImports = "1" ExtendableFEMBase = "1.3.0" ExtendableGrids = "1.10.3" ExtendableSparse = "1.5.3" +FillArrays = "1.13.0" ForwardDiff = "0.10.35,1" GridVisualize = "1.8.1" IncompleteLU = "0.2.1" diff --git a/examples/Example212_PeriodicBoundary2D.jl b/examples/Example212_PeriodicBoundary2D.jl index 24bd7bd9..7af5d718 100644 --- a/examples/Example212_PeriodicBoundary2D.jl +++ b/examples/Example212_PeriodicBoundary2D.jl @@ -116,6 +116,7 @@ end function main(; order = 1, periodic = true, + use_LM_restrictions = true, Plotter = nothing, force = 10.0, h = 1.0e-2, @@ -154,8 +155,12 @@ function main(; end coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!) - display(coupling_matrix) - assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...)) + # display(coupling_matrix) + if use_LM_restrictions + assign_restriction!(PD, CoupledDofsRestriction(coupling_matrix)) + else + assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...)) + end end sol = solve(PD, FES) @@ -172,8 +177,10 @@ end generateplots = ExtendableFEM.default_generateplots(Example212_PeriodicBoundary2D, "example212.png") #hide function runtests() #hide - sol, plt = main() #hide + sol, plt = main(use_LM_restrictions = true) #hide @test abs(maximum(view(sol[1])) - 1.3447465095618172) < 1.0e-3 #hide + sol2, plt = main(use_LM_restrictions = false) #hide + @test sol.entries ≈ sol2.entries #hide return nothing #hide end #hide diff --git a/examples/Example312_PeriodicBoundary3D.jl b/examples/Example312_PeriodicBoundary3D.jl index 6915e633..637b0d26 100644 --- a/examples/Example312_PeriodicBoundary3D.jl +++ b/examples/Example312_PeriodicBoundary3D.jl @@ -130,6 +130,7 @@ end function main(; order = 1, periodic = true, + use_LM_restrictions = true, Plotter = nothing, force = 1.0, h = 1.0e-4, @@ -169,8 +170,12 @@ function main(; return nothing end coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!) - display(coupling_matrix) - assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...)) + # display(coupling_matrix) + if use_LM_restrictions + assign_restriction!(PD, CoupledDofsRestriction(coupling_matrix)) + else + assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...)) + end end ## solve @@ -185,8 +190,10 @@ end generateplots = ExtendableFEM.default_generateplots(Example312_PeriodicBoundary3D, "example312.png") #hide function runtests() #hide - sol, plt = main() #hide - @test abs(maximum(view(sol[1])) - 1.8004602502175202) < 2.0e-3 #hide + sol, plt = main(use_LM_restrictions = true) #hide + @test abs(maximum(view(sol[1])) - 1.8004602502175202) < 2.0e-3 #hide + sol2, plt = main(use_LM_restrictions = false) #hide + @test sol.entries ≈ sol2.entries #hide return nothing #hide end #hide diff --git a/src/ExtendableFEM.jl b/src/ExtendableFEM.jl index f7091e1a..ef79014c 100644 --- a/src/ExtendableFEM.jl +++ b/src/ExtendableFEM.jl @@ -5,6 +5,7 @@ $(read(joinpath(@__DIR__, "..", "README.md"), String)) """ module ExtendableFEM +using BlockArrays: BlockMatrix, BlockVector, Block, BlockedMatrix, BlockedVector using CommonSolve: CommonSolve using DiffResults: DiffResults using DocStringExtensions: DocStringExtensions, TYPEDEF, TYPEDSIGNATURES @@ -59,6 +60,7 @@ using ExtendableGrids: ExtendableGrids, AT_NODES, AbstractElementGeometry, using ExtendableSparse: ExtendableSparse, ExtendableSparseMatrix, flush!, MTExtendableSparseMatrixCSC, findindex, rawupdateindex! +using FillArrays: Zeros using ForwardDiff: ForwardDiff using GridVisualize: GridVisualize, GridVisualizer, gridplot!, reveal, save, scalarplot!, vectorplot! @@ -67,7 +69,7 @@ using LinearSolve: LinearSolve, LinearProblem, UMFPACKFactorization, deleteat!, init, solve using Printf: Printf, @printf, @sprintf using SparseArrays: SparseArrays, AbstractSparseArray, SparseMatrixCSC, findnz, nnz, - nzrange, rowvals, sparse, SparseVector + nzrange, rowvals, sparse, SparseVector, spzeros using SparseDiffTools: SparseDiffTools, ForwardColorJacCache, forwarddiff_color_jacobian!, matrix_colors using Symbolics: Symbolics @@ -121,11 +123,16 @@ include("common_operators/reduction_operator.jl") #export AbstractReductionOperator #export FixbyInterpolation +include("restrictions.jl") +include("common_restrictions/coupled_dofs_restriction.jl") +export CoupledDofsRestriction + include("problemdescription.jl") export ProblemDescription export assign_unknown! export assign_operator! export replace_operator! +export assign_restriction! include("helper_functions.jl") export get_periodic_coupling_info diff --git a/src/common_restrictions/coupled_dofs_restriction.jl b/src/common_restrictions/coupled_dofs_restriction.jl new file mode 100644 index 00000000..751e051d --- /dev/null +++ b/src/common_restrictions/coupled_dofs_restriction.jl @@ -0,0 +1,26 @@ +struct CoupledDofsRestriction{TM} <: AbstractRestriction + coupling_matrix::TM + parameters::Dict{Symbol, Any} +end + +function CoupledDofsRestriction(matrix::AbstractMatrix) + return CoupledDofsRestriction(matrix, Dict{Symbol, Any}(:name => "CoupledDofsRestriction")) +end + + +function assemble!(R::CoupledDofsRestriction, A, b, sol, SC; kwargs...) + + # extract all col indices + _, J, _ = findnz(R.coupling_matrix) + + # remove duplicates + unique_cols = unique(J) + + # subtract diagonal and shrink matrix to non-empty cols + B = (R.coupling_matrix - LinearAlgebra.I)[:, unique_cols] + + R.parameters[:matrix] = B + return R.parameters[:rhs] = Zeros(length(unique_cols)) + + return nothing +end diff --git a/src/problemdescription.jl b/src/problemdescription.jl index 12aaf291..36c91819 100644 --- a/src/problemdescription.jl +++ b/src/problemdescription.jl @@ -26,6 +26,11 @@ mutable struct ProblemDescription """ operators::Array{AbstractOperator, 1} #reduction_operators::Array{AbstractReductionOperator,1} + + """ + A vector of Lagrange restrictions that are involved in the problem. + """ + restrictions::Vector{AbstractRestriction} end """ @@ -41,7 +46,7 @@ Create an empty `ProblemDescription` with the given name. """ function ProblemDescription(name = "My problem") - return ProblemDescription(name, Array{Unknown, 1}(undef, 0), Array{AbstractOperator, 1}(undef, 0)) + return ProblemDescription(name, [], [], []) end @@ -92,6 +97,12 @@ function assign_operator!(PD::ProblemDescription, o::AbstractOperator) end +function assign_restriction!(PD::ProblemDescription, r::AbstractRestriction) + push!(PD.restrictions, r) + return length(PD.restrictions) +end + + """ $(TYPEDSIGNATURES) diff --git a/src/restrictions.jl b/src/restrictions.jl new file mode 100644 index 00000000..72a88a46 --- /dev/null +++ b/src/restrictions.jl @@ -0,0 +1,22 @@ +""" + AbstractRestriction + +Root type for all restrictions +""" +abstract type AbstractRestriction end + + +function Base.show(io::IO, R::AbstractRestriction) + print(io, "AbstractRestriction") + return nothing +end + +# informs solver when operator needs reassembly in a time dependent setting +function is_timedependent(R::AbstractRestriction) + return false +end + +function assemble!(R::AbstractRestriction, A, b, sol, SC; kwargs...) + ## assembles internal restriction matrix in R + return nothing +end diff --git a/src/solvers.jl b/src/solvers.jl index 6def518e..b88b5397 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -221,27 +221,92 @@ Solves the linear system and updates the solution vector. This includes: - Computing the residual - Updating the solution with optional damping """ -function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer) - # Update system matrix if needed - if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized] +function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs...) + + @timeit timer "Lagrange restrictions" begin + ## assemble restrctions + if !SC.parameters[:initialized] + for restriction in PD.restrictions + @timeit timer "$(restriction.parameters[:name])" assemble!(restriction, A, b, sol, SC; kwargs...) + end + end + end + + @timeit timer "linear solver" begin + ## start with the assembled matrix containing all assembled operators + if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized] + if length(freedofs) > 0 + A_unrestricted = A.entries.cscmatrix[freedofs, freedofs] + else + A_unrestricted = A.entries.cscmatrix + end + end + + # we solve for A Δx = r + # and update x = sol - Δx if length(freedofs) > 0 - linsolve.A = A.entries.cscmatrix[freedofs, freedofs] + b_unrestricted = residual.entries[freedofs] else - linsolve.A = A.entries.cscmatrix + b_unrestricted = residual.entries end - end - # Set right-hand side - if length(freedofs) > 0 - linsolve.b = residual.entries[freedofs] - else - linsolve.b = residual.entries - end - SC.parameters[:initialized] = true + @timeit timer "LM restrictions" begin + ## add possible Lagrange restrictions + @timeit timer "prepare" begin + restriction_matrices = [length(freedofs) > 0 ? re.parameters[:matrix][freedofs, :] : re.parameters[:matrix] for re in PD.restrictions ] + restriction_rhs = [length(freedofs) > 0 ? re.parameters[:rhs][freedofs] : re.parameters[:rhs] for re in PD.restrictions ] + + ## we need to add the (initial) solution to the rhs, since we work with the residual equation + for (B, rhs) in zip(restriction_matrices, restriction_rhs) + rhs .+= B'sol.entries + end + end + + + @timeit timer "compute blocks" begin + # block sizes for the block matrix + block_sizes = [size(A_unrestricted, 2), [ size(B, 2) for B in restriction_matrices ]...] - # Solve linear system - push!(stats[:matrix_nnz], nnz(linsolve.A)) - @timeit timer "solve! call" Δx = LinearSolve.solve!(linsolve) + total_size = sum(block_sizes) + Tv = eltype(A_unrestricted) + + ## create block matrix + A_block = BlockMatrix(spzeros(Tv, total_size, total_size), block_sizes, block_sizes) + A_block[Block(1, 1)] = A_unrestricted + + b_block = BlockVector(zeros(Tv, total_size), block_sizes) + b_block[Block(1)] = b_unrestricted + + u_unrestricted = linsolve.u + u_block = BlockVector(zeros(Tv, total_size), block_sizes) + u_block[Block(1)] = u_unrestricted + + for i in eachindex(PD.restrictions) + A_block[Block(1, i + 1)] = restriction_matrices[i] + A_block[Block(i + 1, 1)] = transpose(restriction_matrices[i]) + b_block[Block(i + 1)] = restriction_rhs[i] + + end + end + + @timeit timer "convert" begin + + linsolve.A = sparse(A_block) # convert to CSC Matrix + linsolve.b = Vector(b_block) # convert to dense vector + linsolve.u = Vector(u_block) # convert to dense vector + + end + end + + SC.parameters[:initialized] = true + + # Solve linear system + push!(stats[:matrix_nnz], nnz(linsolve.A)) + @timeit timer "solve! call" blocked_Δx = LinearSolve.solve!(linsolve) + + # extract the solution / dismiss the lagrange multipliers + @views Δx = blocked_Δx[1:length(b_unrestricted)] + end # Compute solution update @timeit timer "update solution" begin @@ -570,12 +635,12 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{ linsolve = SC.linsolver - @timeit timer "linear solver" begin - linres = solve_linear_system!( - A, b, sol, soltemp, residual, linsolve, unknowns, - freedofs, damping, PD, SC, stats, is_linear, timer - ) - end + + linres = solve_linear_system!( + A, b, sol, soltemp, residual, linsolve, unknowns, + freedofs, damping, PD, SC, stats, is_linear, timer + ) + if SC.parameters[:verbosity] > -1 if is_linear @printf "%.3e\t" linres From ca22e4026e1b628317ffd5511ef97481ce03d962 Mon Sep 17 00:00:00 2001 From: Patrick Jaap Date: Wed, 6 Aug 2025 13:34:09 +0200 Subject: [PATCH 2/4] Remove unused packages --- Project.toml | 3 --- src/ExtendableFEM.jl | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 6b6b1453..853978ae 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,12 @@ authors = ["Christian Merdon ", "Patrick Jaap Date: Wed, 6 Aug 2025 13:41:46 +0200 Subject: [PATCH 3/4] Adapt signatures and add docs --- .../coupled_dofs_restriction.jl | 14 +++++++++++++- src/problemdescription.jl | 11 +++++++++++ src/restrictions.jl | 2 +- src/solvers.jl | 2 +- 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/common_restrictions/coupled_dofs_restriction.jl b/src/common_restrictions/coupled_dofs_restriction.jl index 751e051d..f8e1f44b 100644 --- a/src/common_restrictions/coupled_dofs_restriction.jl +++ b/src/common_restrictions/coupled_dofs_restriction.jl @@ -3,12 +3,24 @@ struct CoupledDofsRestriction{TM} <: AbstractRestriction parameters::Dict{Symbol, Any} end + +""" + CoupledDofsRestriction(matrix::AbstractMatrix) + + Creates an restriction that couples dofs together. + + The coupling is stored in a CSC Matrix `matrix`, s.t., + + dofᵢ = Σⱼ Aⱼᵢ dofⱼ (col-wise) + + The matrix can be obtained from, e.g., `get_periodic_coupling_matrix`. +""" function CoupledDofsRestriction(matrix::AbstractMatrix) return CoupledDofsRestriction(matrix, Dict{Symbol, Any}(:name => "CoupledDofsRestriction")) end -function assemble!(R::CoupledDofsRestriction, A, b, sol, SC; kwargs...) +function assemble!(R::CoupledDofsRestriction, SC; kwargs...) # extract all col indices _, J, _ = findnz(R.coupling_matrix) diff --git a/src/problemdescription.jl b/src/problemdescription.jl index 36c91819..8c78cae9 100644 --- a/src/problemdescription.jl +++ b/src/problemdescription.jl @@ -96,7 +96,18 @@ function assign_operator!(PD::ProblemDescription, o::AbstractOperator) return length(PD.operators) end +""" +$(TYPEDSIGNATURES) + +Adds a restrction to a `ProblemDescription`. + +# Arguments +- `PD::ProblemDescription`: The problem description to which the operator should be added. +- `r::AbstractRestriction`: The restriction to add. +# Returns +- The index (position) of the restriction in the `restrictions` array of `PD`. +""" function assign_restriction!(PD::ProblemDescription, r::AbstractRestriction) push!(PD.restrictions, r) return length(PD.restrictions) diff --git a/src/restrictions.jl b/src/restrictions.jl index 72a88a46..52c164ff 100644 --- a/src/restrictions.jl +++ b/src/restrictions.jl @@ -16,7 +16,7 @@ function is_timedependent(R::AbstractRestriction) return false end -function assemble!(R::AbstractRestriction, A, b, sol, SC; kwargs...) +function assemble!(R::AbstractRestriction, SC; kwargs...) ## assembles internal restriction matrix in R return nothing end diff --git a/src/solvers.jl b/src/solvers.jl index b88b5397..42a82ca2 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -227,7 +227,7 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, ## assemble restrctions if !SC.parameters[:initialized] for restriction in PD.restrictions - @timeit timer "$(restriction.parameters[:name])" assemble!(restriction, A, b, sol, SC; kwargs...) + @timeit timer "$(restriction.parameters[:name])" assemble!(restriction, SC; kwargs...) end end end From 92b3a0050438a97305f9518b4866909c12332af2 Mon Sep 17 00:00:00 2001 From: Patrick Jaap Date: Wed, 6 Aug 2025 16:15:03 +0200 Subject: [PATCH 4/4] Use manual construction of sparse matrix --- Project.toml | 2 + src/ExtendableFEM.jl | 3 +- .../coupled_dofs_restriction.jl | 2 +- src/solvers.jl | 87 ++++++++++++------- 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/Project.toml b/Project.toml index 853978ae..2ad4d3c3 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ExtendableFEMBase = "12fb9182-3d4c-4424-8fd1-727a0899810c" ExtendableGrids = "cfc395e8-590f-11e8-1f13-43a2532b2fa8" ExtendableSparse = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GridVisualize = "5eed8a63-0fb0-45eb-886d-8d5a387d12b8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -34,6 +35,7 @@ ExplicitImports = "1" ExtendableFEMBase = "1.3.0" ExtendableGrids = "1.10.3" ExtendableSparse = "1.5.3" +FillArrays = "1.13.0" ForwardDiff = "0.10.35,1" GridVisualize = "1.8.1" IncompleteLU = "0.2.1" diff --git a/src/ExtendableFEM.jl b/src/ExtendableFEM.jl index e047305b..05c34806 100644 --- a/src/ExtendableFEM.jl +++ b/src/ExtendableFEM.jl @@ -5,7 +5,7 @@ $(read(joinpath(@__DIR__, "..", "README.md"), String)) """ module ExtendableFEM -using BlockArrays: BlockMatrix, BlockVector, Block +using BlockArrays: BlockMatrix, BlockVector, Block, BlockedMatrix, blocks, axes using CommonSolve: CommonSolve using DiffResults: DiffResults using DocStringExtensions: DocStringExtensions, TYPEDEF, TYPEDSIGNATURES @@ -60,6 +60,7 @@ using ExtendableGrids: ExtendableGrids, AT_NODES, AbstractElementGeometry, using ExtendableSparse: ExtendableSparse, ExtendableSparseMatrix, flush!, MTExtendableSparseMatrixCSC, findindex, rawupdateindex! +using FillArrays: Zeros using ForwardDiff: ForwardDiff using GridVisualize: GridVisualize, GridVisualizer, gridplot!, reveal, save, scalarplot!, vectorplot! diff --git a/src/common_restrictions/coupled_dofs_restriction.jl b/src/common_restrictions/coupled_dofs_restriction.jl index f8e1f44b..33b96690 100644 --- a/src/common_restrictions/coupled_dofs_restriction.jl +++ b/src/common_restrictions/coupled_dofs_restriction.jl @@ -32,7 +32,7 @@ function assemble!(R::CoupledDofsRestriction, SC; kwargs...) B = (R.coupling_matrix - LinearAlgebra.I)[:, unique_cols] R.parameters[:matrix] = B - return R.parameters[:rhs] = Zeros(length(unique_cols)) + R.parameters[:rhs] = Zeros(length(unique_cols)) return nothing end diff --git a/src/solvers.jl b/src/solvers.jl index 42a82ca2..c5588430 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -250,51 +250,72 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, b_unrestricted = residual.entries end - @timeit timer "LM restrictions" begin - ## add possible Lagrange restrictions - @timeit timer "prepare" begin - restriction_matrices = [length(freedofs) > 0 ? re.parameters[:matrix][freedofs, :] : re.parameters[:matrix] for re in PD.restrictions ] - restriction_rhs = [length(freedofs) > 0 ? re.parameters[:rhs][freedofs] : re.parameters[:rhs] for re in PD.restrictions ] - - ## we need to add the (initial) solution to the rhs, since we work with the residual equation - for (B, rhs) in zip(restriction_matrices, restriction_rhs) - rhs .+= B'sol.entries - end - end + if length(PD.restrictions) == 0 + linsolve.A = A_unrestricted + linsolve.b = b_unrestricted + else + @timeit timer "LM restrictions" begin + ## add possible Lagrange restrictions + @timeit timer "prepare" begin + restriction_matrices = [length(freedofs) > 0 ? re.parameters[:matrix][freedofs, :] : re.parameters[:matrix] for re in PD.restrictions ] + restriction_rhs = [length(freedofs) > 0 ? re.parameters[:rhs][freedofs] : re.parameters[:rhs] for re in PD.restrictions ] - @timeit timer "compute blocks" begin - # block sizes for the block matrix - block_sizes = [size(A_unrestricted, 2), [ size(B, 2) for B in restriction_matrices ]...] + ## we need to add the (initial) solution to the rhs, since we work with the residual equation + for (B, rhs) in zip(restriction_matrices, restriction_rhs) + rhs .+= B'sol.entries + end + end + + @timeit timer "set blocks" begin + # block sizes for the block matrix + block_sizes = [size(A_unrestricted, 2), [ size(B, 2) for B in restriction_matrices ]...] - total_size = sum(block_sizes) - Tv = eltype(A_unrestricted) + total_size = sum(block_sizes) + Tv = eltype(A_unrestricted) - ## create block matrix - A_block = BlockMatrix(spzeros(Tv, total_size, total_size), block_sizes, block_sizes) - A_block[Block(1, 1)] = A_unrestricted + ## create block matrix + A_block = BlockMatrix(spzeros(Tv, total_size, total_size), block_sizes, block_sizes) + A_block[Block(1, 1)] = A_unrestricted - b_block = BlockVector(zeros(Tv, total_size), block_sizes) - b_block[Block(1)] = b_unrestricted + b_block = BlockVector(zeros(Tv, total_size), block_sizes) + b_block[Block(1)] = b_unrestricted - u_unrestricted = linsolve.u - u_block = BlockVector(zeros(Tv, total_size), block_sizes) - u_block[Block(1)] = u_unrestricted + u_unrestricted = linsolve.u + u_block = BlockVector(zeros(Tv, total_size), block_sizes) + u_block[Block(1)] = u_unrestricted - for i in eachindex(PD.restrictions) - A_block[Block(1, i + 1)] = restriction_matrices[i] - A_block[Block(i + 1, 1)] = transpose(restriction_matrices[i]) - b_block[Block(i + 1)] = restriction_rhs[i] + for i in eachindex(PD.restrictions) + A_block[Block(1, i + 1)] = restriction_matrices[i] + A_block[Block(i + 1, 1)] = transpose(restriction_matrices[i]) + b_block[Block(i + 1)] = restriction_rhs[i] + end end - end - @timeit timer "convert" begin + @timeit timer "convert" begin + + linsolve.b = Vector(b_block) # convert to dense vector + linsolve.u = Vector(u_block) # convert to dense vector + + # linsolve.A = sparse(A_block) # convert to CSC Matrix is very slow https://github.com/JuliaArrays/BlockArrays.jl/issues/78 + # do it manually: + A_flat = spzeros(size(A_block)) + + lasts_i = [0, axes(A_block)[1].lasts...] # add zero + lasts_j = [0, axes(A_block)[2].lasts...] # add zero - linsolve.A = sparse(A_block) # convert to CSC Matrix - linsolve.b = Vector(b_block) # convert to dense vector - linsolve.u = Vector(u_block) # convert to dense vector + (ni, nj) = size(blocks(A_block)) + for i in 1:ni, j in 1:nj + range_i = (lasts_i[i] + 1):lasts_i[i + 1] + range_j = (lasts_j[j] + 1):lasts_j[j + 1] + + # write each block directly in the resulting matrix + A_flat[range_i, range_j] = A_block[Block(i, j)] + end + linsolve.A = A_flat + end end end