Skip to content

Commit 039e35f

Browse files
authored
BlockMap enhancements (#72)
1 parent a0f5016 commit 039e35f

File tree

6 files changed

+433
-136
lines changed

6 files changed

+433
-136
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "2.5.2"
3+
version = "2.6"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 186 additions & 43 deletions
Large diffs are not rendered by default.

src/LinearMaps.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,17 @@ Base.length(A::LinearMap) = size(A)[1] * size(A)[2]
4848

4949
# check dimension consistency for y = A*x and Y = A*X
5050
function check_dim_mul(y::AbstractVector, A::LinearMap, x::AbstractVector)
51-
m, n = size(A)
52-
(m == length(y) && n == length(x)) || throw(DimensionMismatch("mul!"))
53-
return nothing
54-
end
55-
function check_dim_mul(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix)
56-
m, n = size(A)
57-
(m == size(Y, 1) && n == size(X, 1) && size(Y, 2) == size(X, 2)) || throw(DimensionMismatch("mul!"))
58-
return nothing
59-
end
51+
# @info "checked vector dimensions" # uncomment for testing
52+
m, n = size(A)
53+
(m == length(y) && n == length(x)) || throw(DimensionMismatch("mul!"))
54+
return nothing
55+
end
56+
function check_dim_mul(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix)
57+
# @info "checked matrix dimensions" # uncomment for testing
58+
m, n = size(A)
59+
(m == size(Y, 1) && n == size(X, 1) && size(Y, 2) == size(X, 2)) || throw(DimensionMismatch("mul!"))
60+
return nothing
61+
end
6062

6163
# conversion of AbstractMatrix to LinearMap
6264
convert_to_lmaps_(A::AbstractMatrix) = LinearMap(A)
@@ -66,9 +68,12 @@ convert_to_lmaps(A) = (convert_to_lmaps_(A),)
6668
@inline convert_to_lmaps(A, B, Cs...) =
6769
(convert_to_lmaps_(A), convert_to_lmaps_(B), convert_to_lmaps(Cs...)...)
6870

69-
Base.:(*)(A::LinearMap, x::AbstractVector) = mul!(similar(x, promote_type(eltype(A), eltype(x)), size(A, 1)), A, x)
71+
function Base.:(*)(A::LinearMap, x::AbstractVector)
72+
size(A, 2) == length(x) || throw(DimensionMismatch("mul!"))
73+
return @inbounds mul!(similar(x, promote_type(eltype(A), eltype(x)), size(A, 1)), A, x)
74+
end
7075
function LinearAlgebra.mul!(y::AbstractVector, A::LinearMap, x::AbstractVector, α::Number=true, β::Number=false)
71-
length(y) == size(A, 1) || throw(DimensionMismatch("mul!"))
76+
@boundscheck check_dim_mul(y, A, x)
7277
if isone(α)
7378
iszero(β) && (A_mul_B!(y, A, x); return y)
7479
isone(β) && (y .+= A * x; return y)
@@ -92,8 +97,8 @@ function LinearAlgebra.mul!(y::AbstractVector, A::LinearMap, x::AbstractVector,
9297
end
9398
end
9499
# the following is of interest in, e.g., subspace-iteration methods
95-
function LinearAlgebra.mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix, α::Number=true, β::Number=false)
96-
(size(Y, 1) == size(A, 1) && size(X, 1) == size(A, 2) && size(Y, 2) == size(X, 2)) || throw(DimensionMismatch("mul!"))
100+
Base.@propagate_inbounds function LinearAlgebra.mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix, α::Number=true, β::Number=false)
101+
@boundscheck check_dim_mul(Y, A, X)
97102
@inbounds @views for i = 1:size(X, 2)
98103
mul!(Y[:, i], A, X[:, i], α, β)
99104
end

src/blockmap.jl

Lines changed: 156 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}}} <: LinearMap{T}
1+
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}},Rranges<:Tuple{Vararg{UnitRange{Int}}},Cranges<:Tuple{Vararg{UnitRange{Int}}}} <: LinearMap{T}
22
maps::As
33
rows::Rs
4-
rowranges::Vector{UnitRange{Int}}
5-
colranges::Vector{UnitRange{Int}}
4+
rowranges::Rranges
5+
colranges::Cranges
66
function BlockMap{T,R,S}(maps::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap}}, S<:Tuple{Vararg{Int}}}
77
for A in maps
88
promote_type(T, eltype(A)) == T || throw(InexactError())
99
end
1010
rowranges, colranges = rowcolranges(maps, rows)
11-
return new{T,R,S}(maps, rows, rowranges, colranges)
11+
return new{T,R,S,typeof(rowranges),typeof(colranges)}(maps, rows, rowranges, colranges)
1212
end
1313
end
1414

@@ -28,28 +28,28 @@ Determines the range of rows for each block row and the range of columns for eac
2828
map in `maps`, according to its position in a virtual matrix representation of the
2929
block linear map obtained from `hvcat(rows, maps...)`.
3030
"""
31-
function rowcolranges(maps, rows)::Tuple{Vector{UnitRange{Int}},Vector{UnitRange{Int}}}
32-
rowranges = Vector{UnitRange{Int}}(undef, length(rows))
33-
colranges = Vector{UnitRange{Int}}(undef, length(maps))
31+
function rowcolranges(maps, rows)
32+
rowranges = ()
33+
colranges = ()
3434
mapind = 0
3535
rowstart = 1
36-
for rowind in 1:length(rows)
37-
xinds = vcat(1, map(a -> size(a, 2), maps[mapind+1:mapind+rows[rowind]])...)
36+
for row in rows
37+
xinds = vcat(1, map(a -> size(a, 2), maps[mapind+1:mapind+row])...)
3838
cumsum!(xinds, xinds)
3939
mapind += 1
4040
rowend = rowstart + size(maps[mapind], 1) - 1
41-
rowranges[rowind] = rowstart:rowend
42-
colranges[mapind] = xinds[1]:xinds[2]-1
43-
for colind in 2:rows[rowind]
41+
rowranges = (rowranges..., rowstart:rowend)
42+
colranges = (colranges..., xinds[1]:xinds[2]-1)
43+
for colind in 2:row
4444
mapind +=1
45-
colranges[mapind] = xinds[colind]:xinds[colind+1]-1
45+
colranges = (colranges..., xinds[colind]:xinds[colind+1]-1)
4646
end
4747
rowstart = rowend + 1
4848
end
49-
return rowranges, colranges
49+
return rowranges::NTuple{length(rows), UnitRange{Int}}, colranges::NTuple{length(maps), UnitRange{Int}}
5050
end
5151

52-
Base.size(A::BlockMap) = (last(A.rowranges[end]), last(A.colranges[end]))
52+
Base.size(A::BlockMap) = (last(last(A.rowranges)), last(last(A.colranges)))
5353

5454
############
5555
# concatenation
@@ -299,75 +299,82 @@ LinearAlgebra.transpose(A::BlockMap) = TransposeMap(A)
299299
LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A)
300300

301301
############
302-
# multiplication with vectors
302+
# multiplication helper functions
303303
############
304304

305-
function A_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector)
306-
require_one_based_indexing(y, x)
307-
m, n = size(A)
308-
@boundscheck (m == length(y) && n == length(x)) || throw(DimensionMismatch("A_mul_B!"))
305+
@inline function _blockmul!(y, A::BlockMap, x, α, β)
309306
maps, rows, yinds, xinds = A.maps, A.rows, A.rowranges, A.colranges
310307
mapind = 0
311-
@views @inbounds for rowind in 1:length(rows)
312-
yrow = y[yinds[rowind]]
308+
@views @inbounds for (row, yi) in zip(rows, yinds)
309+
yrow = selectdim(y, 1, yi)
313310
mapind += 1
314-
A_mul_B!(yrow, maps[mapind], x[xinds[mapind]])
315-
for colind in 2:rows[rowind]
311+
mul!(yrow, maps[mapind], selectdim(x, 1, xinds[mapind]), α, β)
312+
for _ in 2:row
316313
mapind +=1
317-
mul!(yrow, maps[mapind], x[xinds[mapind]], true, true)
314+
mul!(yrow, maps[mapind], selectdim(x, 1, xinds[mapind]), α, true)
318315
end
319316
end
320317
return y
321318
end
322319

323-
function At_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector)
324-
require_one_based_indexing(y, x)
325-
m, n = size(A)
326-
@boundscheck (n == length(y) && m == length(x)) || throw(DimensionMismatch("At_mul_B!"))
320+
@inline function _transblockmul!(y, A::BlockMap, x, α, β, transform)
327321
maps, rows, xinds, yinds = A.maps, A.rows, A.rowranges, A.colranges
328-
mapind = 0
329-
# first block row (rowind = 1) of A, meaning first block column of A', fill all of y
330322
@views @inbounds begin
331-
xcol = x[xinds[1]]
332-
for colind in 1:rows[1]
333-
mapind +=1
334-
A_mul_B!(y[yinds[mapind]], transpose(maps[mapind]), xcol)
323+
# first block row (rowind = 1) of A, meaning first block column of A', fill all of y
324+
xcol = selectdim(x, 1, first(xinds))
325+
for rowind in 1:first(rows)
326+
mul!(selectdim(y, 1, yinds[rowind]), transform(maps[rowind]), xcol, α, β)
335327
end
336-
# subsequent block rows of A, add results to corresponding parts of y
337-
for rowind in 2:length(rows)
338-
xcol = x[xinds[rowind]]
339-
for colind in 1:rows[rowind]
328+
mapind = first(rows)
329+
# subsequent block rows of A (block columns of A'),
330+
# add results to corresponding parts of y
331+
# TODO: think about multithreading
332+
for (row, xi) in zip(Base.tail(rows), Base.tail(xinds))
333+
xcol = selectdim(x, 1, xi)
334+
for _ in 1:row
340335
mapind +=1
341-
mul!(y[yinds[mapind]], transpose(maps[mapind]), xcol, true, true)
336+
mul!(selectdim(y, 1, yinds[mapind]), transform(maps[mapind]), xcol, α, true)
342337
end
343338
end
344339
end
345340
return y
346341
end
347342

348-
function Ac_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector)
349-
require_one_based_indexing(y, x)
350-
m, n = size(A)
351-
@boundscheck (n == length(y) && m == length(x)) || throw(DimensionMismatch("At_mul_B!"))
352-
maps, rows, xinds, yinds = A.maps, A.rows, A.rowranges, A.colranges
353-
mapind = 0
354-
# first block row (rowind = 1) of A, fill all of y
355-
@views @inbounds begin
356-
xcol = x[xinds[1]]
357-
for colind in 1:rows[1]
358-
mapind +=1
359-
A_mul_B!(y[yinds[mapind]], adjoint(maps[mapind]), xcol)
360-
end
361-
# subsequent block rows of A, add results to corresponding parts of y
362-
for rowind in 2:length(rows)
363-
xcol = x[xinds[rowind]]
364-
for colind in 1:rows[rowind]
365-
mapind +=1
366-
mul!(y[yinds[mapind]], adjoint(maps[mapind]), xcol, true, true)
367-
end
343+
############
344+
# multiplication with vectors & matrices
345+
############
346+
347+
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) =
348+
mul!(y, A, x)
349+
350+
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::TransposeMap{<:Any,<:BlockMap}, x::AbstractVector) =
351+
mul!(y, A, x)
352+
353+
Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) =
354+
mul!(y, transpose(A), x)
355+
356+
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::AdjointMap{<:Any,<:BlockMap}, x::AbstractVector) =
357+
mul!(y, A, x)
358+
359+
Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) =
360+
mul!(y, adjoint(A), x)
361+
362+
for Atype in (AbstractVector, AbstractMatrix)
363+
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::BlockMap, x::$Atype,
364+
α::Number=true, β::Number=false)
365+
require_one_based_indexing(y, x)
366+
@boundscheck check_dim_mul(y, A, x)
367+
return _blockmul!(y, A, x, α, β)
368+
end
369+
370+
for (maptype, transform) in ((:(TransposeMap{<:Any,<:BlockMap}), :transpose), (:(AdjointMap{<:Any,<:BlockMap}), :adjoint))
371+
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, wrapA::$maptype, x::$Atype,
372+
α::Number=true, β::Number=false)
373+
require_one_based_indexing(y, x)
374+
@boundscheck check_dim_mul(y, wrapA, x)
375+
return _transblockmul!(y, wrapA.lmap, x, α, β, $transform)
368376
end
369377
end
370-
return y
371378
end
372379

373380
############
@@ -388,3 +395,91 @@ end
388395
# show(io, T)
389396
# print(io, '}')
390397
# end
398+
399+
############
400+
# BlockDiagonalMap
401+
############
402+
403+
struct BlockDiagonalMap{T,As<:Tuple{Vararg{LinearMap}},Ranges<:Tuple{Vararg{UnitRange{Int}}}} <: LinearMap{T}
404+
maps::As
405+
rowranges::Ranges
406+
colranges::Ranges
407+
function BlockDiagonalMap{T,As}(maps::As) where {T, As<:Tuple{Vararg{LinearMap}}}
408+
for A in maps
409+
promote_type(T, eltype(A)) == T || throw(InexactError())
410+
end
411+
# row ranges
412+
inds = vcat(1, size.(maps, 1)...)
413+
cumsum!(inds, inds)
414+
rowranges = ntuple(i -> inds[i]:inds[i+1]-1, Val(length(maps)))
415+
# column ranges
416+
inds[2:end] .= size.(maps, 2)
417+
cumsum!(inds, inds)
418+
colranges = ntuple(i -> inds[i]:inds[i+1]-1, Val(length(maps)))
419+
return new{T,As,typeof(rowranges)}(maps, rowranges, colranges)
420+
end
421+
end
422+
423+
BlockDiagonalMap{T}(maps::As) where {T,As<:Tuple{Vararg{LinearMap}}} =
424+
BlockDiagonalMap{T,As}(maps)
425+
BlockDiagonalMap(maps::LinearMap...) =
426+
BlockDiagonalMap{promote_type(map(eltype, maps)...)}(maps)
427+
428+
for k in 1:8 # is 8 sufficient?
429+
Is = ntuple(n->:($(Symbol(:A,n))::AbstractMatrix), Val(k-1))
430+
# yields (:A1, :A2, :A3, ..., :A(k-1))
431+
L = :($(Symbol(:A,k))::LinearMap)
432+
# yields :Ak
433+
mapargs = ntuple(n -> :(LinearMap($(Symbol(:A,n)))), Val(k-1))
434+
# yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))
435+
436+
@eval begin
437+
SparseArrays.blockdiag($(Is...), $L, As::Union{LinearMap,AbstractMatrix}...) =
438+
BlockDiagonalMap($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...)
439+
function Base.cat($(Is...), $L, As::Union{LinearMap,AbstractMatrix}...; dims::Dims{2})
440+
if dims == (1,2)
441+
return BlockDiagonalMap($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...)
442+
else
443+
throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)"))
444+
end
445+
end
446+
end
447+
end
448+
449+
Base.size(A::BlockDiagonalMap) = (last(A.rowranges[end]), last(A.colranges[end]))
450+
451+
LinearAlgebra.issymmetric(A::BlockDiagonalMap) = all(issymmetric, A.maps)
452+
LinearAlgebra.ishermitian(A::BlockDiagonalMap{<:Real}) = all(issymmetric, A.maps)
453+
LinearAlgebra.ishermitian(A::BlockDiagonalMap) = all(ishermitian, A.maps)
454+
455+
LinearAlgebra.adjoint(A::BlockDiagonalMap{T}) where {T} = BlockDiagonalMap{T}(map(adjoint, A.maps))
456+
LinearAlgebra.transpose(A::BlockDiagonalMap{T}) where {T} = BlockDiagonalMap{T}(map(transpose, A.maps))
457+
458+
Base.:(==)(A::BlockDiagonalMap, B::BlockDiagonalMap) = (eltype(A) == eltype(B) && A.maps == B.maps)
459+
460+
Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) =
461+
mul!(y, A, x, true, false)
462+
463+
Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) =
464+
mul!(y, transpose(A), x, true, false)
465+
466+
Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) =
467+
mul!(y, adjoint(A), x, true, false)
468+
469+
for Atype in (AbstractVector, AbstractMatrix)
470+
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::BlockDiagonalMap, x::$Atype,
471+
α::Number=true, β::Number=false)
472+
require_one_based_indexing(y, x)
473+
@boundscheck check_dim_mul(y, A, x)
474+
return _blockscaling!(y, A, x, α, β)
475+
end
476+
end
477+
478+
@inline function _blockscaling!(y, A::BlockDiagonalMap, x, α, β)
479+
maps, yinds, xinds = A.maps, A.rowranges, A.colranges
480+
# TODO: think about multi-threading here
481+
@views @inbounds for i in eachindex(yinds, maps, xinds)
482+
mul!(selectdim(y, 1, yinds[i]), maps[i], selectdim(x, 1, xinds[i]), α, β)
483+
end
484+
return y
485+
end

src/wrappedmap.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,13 @@ Base.:(==)(A::MatrixMap, B::MatrixMap) =
3030
(eltype(A)==eltype(B) && A.lmap==B.lmap && A._issymmetric==B._issymmetric &&
3131
A._ishermitian==B._ishermitian && A._isposdef==B._isposdef)
3232

33-
if VERSION v"1.3.0-alpha.115"
34-
35-
LinearAlgebra.mul!(y::AbstractVector, A::WrappedMap, x::AbstractVector, α::Number=true, β::Number=false) =
36-
mul!(y, A.lmap, x, α, β)
37-
38-
LinearAlgebra.mul!(Y::AbstractMatrix, A::MatrixMap, X::AbstractMatrix, α::Number=true, β::Number=false) =
39-
mul!(Y, A.lmap, X, α, β)
40-
41-
else
42-
43-
LinearAlgebra.mul!(Y::AbstractMatrix, A::MatrixMap, X::AbstractMatrix) =
44-
mul!(Y, A.lmap, X)
45-
46-
end # VERSION
47-
4833
# properties
4934
Base.size(A::WrappedMap) = size(A.lmap)
5035
LinearAlgebra.issymmetric(A::WrappedMap) = A._issymmetric
5136
LinearAlgebra.ishermitian(A::WrappedMap) = A._ishermitian
5237
LinearAlgebra.isposdef(A::WrappedMap) = A._isposdef
5338

54-
# multiplication with vector
39+
# multiplication with vectors & matrices
5540
A_mul_B!(y::AbstractVector, A::WrappedMap, x::AbstractVector) = A_mul_B!(y, A.lmap, x)
5641
Base.:(*)(A::WrappedMap, x::AbstractVector) = *(A.lmap, x)
5742

@@ -61,6 +46,23 @@ At_mul_B!(y::AbstractVector, A::WrappedMap, x::AbstractVector) =
6146
Ac_mul_B!(y::AbstractVector, A::WrappedMap, x::AbstractVector) =
6247
ishermitian(A) ? A_mul_B!(y, A.lmap, x) : Ac_mul_B!(y, A.lmap, x)
6348

49+
if VERSION v"1.3.0-alpha.115"
50+
for Atype in (AbstractVector, AbstractMatrix)
51+
@eval Base.@propagate_inbounds LinearAlgebra.mul!(y::$Atype, A::WrappedMap, x::$Atype,
52+
α::Number=true, β::Number=false) =
53+
mul!(y, A.lmap, x, α, β)
54+
end
55+
else
56+
# This is somewhat suboptimal, because the absence of 5-arg mul! for MatrixMaps
57+
# doesn't allow to define a 5-arg mul! for WrappedMaps which do have a 5-arg mul!
58+
# I'd assume, however, that 5-arg mul! becomes standard in Julia v≥1.3 anyway
59+
# the idea is to let the fallback handle 5-arg calls
60+
for Atype in (AbstractVector, AbstractMatrix)
61+
@eval Base.@propagate_inbounds LinearAlgebra.mul!(Y::$Atype, A::WrappedMap, X::$Atype) =
62+
mul!(Y, A.lmap, X)
63+
end
64+
end # VERSION
65+
6466
# combine LinearMap and Matrix objects: linear combinations and map composition
6567
Base.:(+)(A₁::LinearMap, A₂::AbstractMatrix) = +(A₁, WrappedMap(A₂))
6668
Base.:(+)(A₁::AbstractMatrix, A₂::LinearMap) = +(WrappedMap(A₁), A₂)

0 commit comments

Comments
 (0)