-
Notifications
You must be signed in to change notification settings - Fork 42
BlockMap enhancements #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5f31819
227fd96
284aeef
1cf0818
8ca39c1
44e0b52
06f2d87
7028d5d
68dd8db
f4194ac
60d4da7
d20cbbc
4625cb5
46a4606
7f4a162
f8c2916
c08d6e3
3fc2c5f
e3c6d95
70442bc
0857943
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}}} <: LinearMap{T} | ||
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}},Rranges<:Tuple{Vararg{UnitRange{Int}}},Cranges<:Tuple{Vararg{UnitRange{Int}}}} <: LinearMap{T} | ||
maps::As | ||
rows::Rs | ||
rowranges::Vector{UnitRange{Int}} | ||
colranges::Vector{UnitRange{Int}} | ||
rowranges::Rranges | ||
colranges::Cranges | ||
function BlockMap{T,R,S}(maps::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap}}, S<:Tuple{Vararg{Int}}} | ||
for A in maps | ||
promote_type(T, eltype(A)) == T || throw(InexactError()) | ||
end | ||
rowranges, colranges = rowcolranges(maps, rows) | ||
return new{T,R,S}(maps, rows, rowranges, colranges) | ||
return new{T,R,S,typeof(rowranges),typeof(colranges)}(maps, rows, rowranges, colranges) | ||
end | ||
end | ||
|
||
|
@@ -28,28 +28,28 @@ Determines the range of rows for each block row and the range of columns for eac | |
map in `maps`, according to its position in a virtual matrix representation of the | ||
block linear map obtained from `hvcat(rows, maps...)`. | ||
""" | ||
function rowcolranges(maps, rows)::Tuple{Vector{UnitRange{Int}},Vector{UnitRange{Int}}} | ||
rowranges = Vector{UnitRange{Int}}(undef, length(rows)) | ||
colranges = Vector{UnitRange{Int}}(undef, length(maps)) | ||
function rowcolranges(maps, rows) | ||
rowranges = () | ||
colranges = () | ||
mapind = 0 | ||
rowstart = 1 | ||
for rowind in 1:length(rows) | ||
xinds = vcat(1, map(a -> size(a, 2), maps[mapind+1:mapind+rows[rowind]])...) | ||
for row in rows | ||
xinds = vcat(1, map(a -> size(a, 2), maps[mapind+1:mapind+row])...) | ||
cumsum!(xinds, xinds) | ||
mapind += 1 | ||
rowend = rowstart + size(maps[mapind], 1) - 1 | ||
rowranges[rowind] = rowstart:rowend | ||
colranges[mapind] = xinds[1]:xinds[2]-1 | ||
for colind in 2:rows[rowind] | ||
rowranges = (rowranges..., rowstart:rowend) | ||
colranges = (colranges..., xinds[1]:xinds[2]-1) | ||
for colind in 2:row | ||
mapind +=1 | ||
colranges[mapind] = xinds[colind]:xinds[colind+1]-1 | ||
colranges = (colranges..., xinds[colind]:xinds[colind+1]-1) | ||
end | ||
rowstart = rowend + 1 | ||
end | ||
return rowranges, colranges | ||
return rowranges::NTuple{length(rows), UnitRange{Int}}, colranges::NTuple{length(maps), UnitRange{Int}} | ||
end | ||
|
||
Base.size(A::BlockMap) = (last(A.rowranges[end]), last(A.colranges[end])) | ||
Base.size(A::BlockMap) = (last(last(A.rowranges)), last(last(A.colranges))) | ||
|
||
############ | ||
# concatenation | ||
|
@@ -299,75 +299,82 @@ LinearAlgebra.transpose(A::BlockMap) = TransposeMap(A) | |
LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A) | ||
|
||
############ | ||
# multiplication with vectors | ||
# multiplication helper functions | ||
############ | ||
|
||
function A_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) | ||
require_one_based_indexing(y, x) | ||
m, n = size(A) | ||
@boundscheck (m == length(y) && n == length(x)) || throw(DimensionMismatch("A_mul_B!")) | ||
@inline function _blockmul!(y, A::BlockMap, x, α, β) | ||
maps, rows, yinds, xinds = A.maps, A.rows, A.rowranges, A.colranges | ||
mapind = 0 | ||
@views @inbounds for rowind in 1:length(rows) | ||
yrow = y[yinds[rowind]] | ||
@views @inbounds for (row, yi) in zip(rows, yinds) | ||
yrow = selectdim(y, 1, yi) | ||
mapind += 1 | ||
A_mul_B!(yrow, maps[mapind], x[xinds[mapind]]) | ||
for colind in 2:rows[rowind] | ||
mul!(yrow, maps[mapind], selectdim(x, 1, xinds[mapind]), α, β) | ||
for _ in 2:row | ||
mapind +=1 | ||
mul!(yrow, maps[mapind], x[xinds[mapind]], true, true) | ||
mul!(yrow, maps[mapind], selectdim(x, 1, xinds[mapind]), α, true) | ||
end | ||
end | ||
return y | ||
end | ||
|
||
function At_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) | ||
require_one_based_indexing(y, x) | ||
m, n = size(A) | ||
@boundscheck (n == length(y) && m == length(x)) || throw(DimensionMismatch("At_mul_B!")) | ||
@inline function _transblockmul!(y, A::BlockMap, x, α, β, transform) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is corresponding multiplication code we used to have twice, once for |
||
maps, rows, xinds, yinds = A.maps, A.rows, A.rowranges, A.colranges | ||
mapind = 0 | ||
# first block row (rowind = 1) of A, meaning first block column of A', fill all of y | ||
@views @inbounds begin | ||
xcol = x[xinds[1]] | ||
for colind in 1:rows[1] | ||
mapind +=1 | ||
A_mul_B!(y[yinds[mapind]], transpose(maps[mapind]), xcol) | ||
# first block row (rowind = 1) of A, meaning first block column of A', fill all of y | ||
xcol = selectdim(x, 1, first(xinds)) | ||
for rowind in 1:first(rows) | ||
mul!(selectdim(y, 1, yinds[rowind]), transform(maps[rowind]), xcol, α, β) | ||
end | ||
# subsequent block rows of A, add results to corresponding parts of y | ||
for rowind in 2:length(rows) | ||
xcol = x[xinds[rowind]] | ||
for colind in 1:rows[rowind] | ||
mapind = first(rows) | ||
# subsequent block rows of A (block columns of A'), | ||
# add results to corresponding parts of y | ||
# TODO: think about multithreading | ||
for (row, xi) in zip(Base.tail(rows), Base.tail(xinds)) | ||
xcol = selectdim(x, 1, xi) | ||
for _ in 1:row | ||
mapind +=1 | ||
mul!(y[yinds[mapind]], transpose(maps[mapind]), xcol, true, true) | ||
mul!(selectdim(y, 1, yinds[mapind]), transform(maps[mapind]), xcol, α, true) | ||
end | ||
end | ||
end | ||
return y | ||
end | ||
|
||
function Ac_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) | ||
require_one_based_indexing(y, x) | ||
m, n = size(A) | ||
@boundscheck (n == length(y) && m == length(x)) || throw(DimensionMismatch("At_mul_B!")) | ||
maps, rows, xinds, yinds = A.maps, A.rows, A.rowranges, A.colranges | ||
mapind = 0 | ||
# first block row (rowind = 1) of A, fill all of y | ||
@views @inbounds begin | ||
xcol = x[xinds[1]] | ||
for colind in 1:rows[1] | ||
mapind +=1 | ||
A_mul_B!(y[yinds[mapind]], adjoint(maps[mapind]), xcol) | ||
end | ||
# subsequent block rows of A, add results to corresponding parts of y | ||
for rowind in 2:length(rows) | ||
xcol = x[xinds[rowind]] | ||
for colind in 1:rows[rowind] | ||
mapind +=1 | ||
mul!(y[yinds[mapind]], adjoint(maps[mapind]), xcol, true, true) | ||
end | ||
############ | ||
# multiplication with vectors & matrices | ||
############ | ||
|
||
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) = | ||
mul!(y, A, x) | ||
|
||
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::TransposeMap{<:Any,<:BlockMap}, x::AbstractVector) = | ||
mul!(y, A, x) | ||
|
||
Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) = | ||
mul!(y, transpose(A), x) | ||
|
||
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::AdjointMap{<:Any,<:BlockMap}, x::AbstractVector) = | ||
mul!(y, A, x) | ||
|
||
Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) = | ||
mul!(y, adjoint(A), x) | ||
Comment on lines
+347
to
+360
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have multiplication handled by |
||
|
||
for Atype in (AbstractVector, AbstractMatrix) | ||
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::BlockMap, x::$Atype, | ||
α::Number=true, β::Number=false) | ||
require_one_based_indexing(y, x) | ||
@boundscheck check_dim_mul(y, A, x) | ||
return _blockmul!(y, A, x, α, β) | ||
end | ||
|
||
for (maptype, transform) in ((:(TransposeMap{<:Any,<:BlockMap}), :transpose), (:(AdjointMap{<:Any,<:BlockMap}), :adjoint)) | ||
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, wrapA::$maptype, x::$Atype, | ||
α::Number=true, β::Number=false) | ||
require_one_based_indexing(y, x) | ||
@boundscheck check_dim_mul(y, wrapA, x) | ||
return _transblockmul!(y, wrapA.lmap, x, α, β, $transform) | ||
end | ||
end | ||
return y | ||
end | ||
|
||
############ | ||
|
@@ -388,3 +395,91 @@ end | |
# show(io, T) | ||
# print(io, '}') | ||
# end | ||
|
||
############ | ||
# BlockDiagonalMap | ||
############ | ||
|
||
struct BlockDiagonalMap{T,As<:Tuple{Vararg{LinearMap}},Ranges<:Tuple{Vararg{UnitRange{Int}}}} <: LinearMap{T} | ||
maps::As | ||
rowranges::Ranges | ||
colranges::Ranges | ||
function BlockDiagonalMap{T,As}(maps::As) where {T, As<:Tuple{Vararg{LinearMap}}} | ||
for A in maps | ||
promote_type(T, eltype(A)) == T || throw(InexactError()) | ||
end | ||
# row ranges | ||
inds = vcat(1, size.(maps, 1)...) | ||
cumsum!(inds, inds) | ||
rowranges = ntuple(i -> inds[i]:inds[i+1]-1, Val(length(maps))) | ||
# column ranges | ||
inds[2:end] .= size.(maps, 2) | ||
cumsum!(inds, inds) | ||
colranges = ntuple(i -> inds[i]:inds[i+1]-1, Val(length(maps))) | ||
return new{T,As,typeof(rowranges)}(maps, rowranges, colranges) | ||
end | ||
end | ||
|
||
BlockDiagonalMap{T}(maps::As) where {T,As<:Tuple{Vararg{LinearMap}}} = | ||
BlockDiagonalMap{T,As}(maps) | ||
BlockDiagonalMap(maps::LinearMap...) = | ||
BlockDiagonalMap{promote_type(map(eltype, maps)...)}(maps) | ||
|
||
for k in 1:8 # is 8 sufficient? | ||
Is = ntuple(n->:($(Symbol(:A,n))::AbstractMatrix), Val(k-1)) | ||
# yields (:A1, :A2, :A3, ..., :A(k-1)) | ||
L = :($(Symbol(:A,k))::LinearMap) | ||
# yields :Ak | ||
mapargs = ntuple(n -> :(LinearMap($(Symbol(:A,n)))), Val(k-1)) | ||
# yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1))) | ||
|
||
@eval begin | ||
SparseArrays.blockdiag($(Is...), $L, As::Union{LinearMap,AbstractMatrix}...) = | ||
BlockDiagonalMap($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...) | ||
function Base.cat($(Is...), $L, As::Union{LinearMap,AbstractMatrix}...; dims::Dims{2}) | ||
if dims == (1,2) | ||
return BlockDiagonalMap($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...) | ||
else | ||
throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) | ||
end | ||
end | ||
end | ||
end | ||
|
||
Base.size(A::BlockDiagonalMap) = (last(A.rowranges[end]), last(A.colranges[end])) | ||
|
||
LinearAlgebra.issymmetric(A::BlockDiagonalMap) = all(issymmetric, A.maps) | ||
LinearAlgebra.ishermitian(A::BlockDiagonalMap{<:Real}) = all(issymmetric, A.maps) | ||
LinearAlgebra.ishermitian(A::BlockDiagonalMap) = all(ishermitian, A.maps) | ||
|
||
LinearAlgebra.adjoint(A::BlockDiagonalMap{T}) where {T} = BlockDiagonalMap{T}(map(adjoint, A.maps)) | ||
LinearAlgebra.transpose(A::BlockDiagonalMap{T}) where {T} = BlockDiagonalMap{T}(map(transpose, A.maps)) | ||
|
||
Base.:(==)(A::BlockDiagonalMap, B::BlockDiagonalMap) = (eltype(A) == eltype(B) && A.maps == B.maps) | ||
|
||
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) = | ||
mul!(y, A, x, true, false) | ||
|
||
Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) = | ||
mul!(y, transpose(A), x, true, false) | ||
|
||
Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) = | ||
mul!(y, adjoint(A), x, true, false) | ||
|
||
for Atype in (AbstractVector, AbstractMatrix) | ||
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::BlockDiagonalMap, x::$Atype, | ||
α::Number=true, β::Number=false) | ||
require_one_based_indexing(y, x) | ||
@boundscheck check_dim_mul(y, A, x) | ||
return _blockscaling!(y, A, x, α, β) | ||
end | ||
end | ||
|
||
@inline function _blockscaling!(y, A::BlockDiagonalMap, x, α, β) | ||
maps, yinds, xinds = A.maps, A.rowranges, A.colranges | ||
# TODO: think about multi-threading here | ||
@views @inbounds for i in eachindex(yinds, maps, xinds) | ||
mul!(selectdim(y, 1, yinds[i]), maps[i], selectdim(x, 1, xinds[i]), α, β) | ||
end | ||
return y | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This used to be the
A_mul_B!
code, which can be easily generalized to work withα
,β
, and matrices instead of vectors, so I factored it out. The generic version of indexing is thenselectdim(y, 1, ...)
.