1
- struct BlockMap{T,As<: Tuple{Vararg{LinearMap}} ,Rs<: Tuple{Vararg{Int}} } <: LinearMap{T}
1
+ struct BlockMap{T,As<: Tuple{Vararg{LinearMap}} ,Rs<: Tuple{Vararg{Int}} ,Rranges <: Tuple{Vararg{UnitRange{Int}}} ,Cranges <: Tuple{Vararg{UnitRange{Int}}} } <: LinearMap{T}
2
2
maps:: As
3
3
rows:: Rs
4
- rowranges:: Vector{UnitRange{Int}}
5
- colranges:: Vector{UnitRange{Int}}
4
+ rowranges:: Rranges
5
+ colranges:: Cranges
6
6
function BlockMap {T,R,S} (maps:: R , rows:: S ) where {T, R<: Tuple{Vararg{LinearMap}} , S<: Tuple{Vararg{Int}} }
7
7
for A in maps
8
8
promote_type (T, eltype (A)) == T || throw (InexactError ())
9
9
end
10
10
rowranges, colranges = rowcolranges (maps, rows)
11
- return new {T,R,S} (maps, rows, rowranges, colranges)
11
+ return new {T,R,S,typeof(rowranges),typeof(colranges) } (maps, rows, rowranges, colranges)
12
12
end
13
13
end
14
14
@@ -28,28 +28,28 @@ Determines the range of rows for each block row and the range of columns for eac
28
28
map in `maps`, according to its position in a virtual matrix representation of the
29
29
block linear map obtained from `hvcat(rows, maps...)`.
30
30
"""
31
- function rowcolranges (maps, rows):: Tuple{Vector{UnitRange{Int}},Vector{UnitRange{Int}}}
32
- rowranges = Vector {UnitRange{Int}} (undef, length (rows) )
33
- colranges = Vector {UnitRange{Int}} (undef, length (maps) )
31
+ function rowcolranges (maps, rows)
32
+ rowranges = ( )
33
+ colranges = ( )
34
34
mapind = 0
35
35
rowstart = 1
36
- for rowind in 1 : length ( rows)
37
- xinds = vcat (1 , map (a -> size (a, 2 ), maps[mapind+ 1 : mapind+ rows[rowind] ])... )
36
+ for row in rows
37
+ xinds = vcat (1 , map (a -> size (a, 2 ), maps[mapind+ 1 : mapind+ row ])... )
38
38
cumsum! (xinds, xinds)
39
39
mapind += 1
40
40
rowend = rowstart + size (maps[mapind], 1 ) - 1
41
- rowranges[rowind] = rowstart: rowend
42
- colranges[mapind] = xinds[1 ]: xinds[2 ]- 1
43
- for colind in 2 : rows[rowind]
41
+ rowranges = (rowranges ... , rowstart: rowend)
42
+ colranges = (colranges ... , xinds[1 ]: xinds[2 ]- 1 )
43
+ for colind in 2 : row
44
44
mapind += 1
45
- colranges[mapind] = xinds[colind]: xinds[colind+ 1 ]- 1
45
+ colranges = (colranges ... , xinds[colind]: xinds[colind+ 1 ]- 1 )
46
46
end
47
47
rowstart = rowend + 1
48
48
end
49
- return rowranges, colranges
49
+ return rowranges:: NTuple{length(rows), UnitRange{Int}} , colranges:: NTuple{length(maps), UnitRange{Int}}
50
50
end
51
51
52
- Base. size (A:: BlockMap ) = (last (A. rowranges[ end ]) , last (A. colranges[ end ] ))
52
+ Base. size (A:: BlockMap ) = (last (last ( A. rowranges)) , last (last ( A. colranges) ))
53
53
54
54
# ###########
55
55
# concatenation
@@ -299,75 +299,82 @@ LinearAlgebra.transpose(A::BlockMap) = TransposeMap(A)
299
299
LinearAlgebra. adjoint (A:: BlockMap ) = AdjointMap (A)
300
300
301
301
# ###########
302
- # multiplication with vectors
302
+ # multiplication helper functions
303
303
# ###########
304
304
305
- function A_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector )
306
- require_one_based_indexing (y, x)
307
- m, n = size (A)
308
- @boundscheck (m == length (y) && n == length (x)) || throw (DimensionMismatch (" A_mul_B!" ))
305
+ @inline function _blockmul! (y, A:: BlockMap , x, α, β)
309
306
maps, rows, yinds, xinds = A. maps, A. rows, A. rowranges, A. colranges
310
307
mapind = 0
311
- @views @inbounds for rowind in 1 : length (rows)
312
- yrow = y[yinds[rowind]]
308
+ @views @inbounds for (row, yi) in zip (rows, yinds )
309
+ yrow = selectdim (y, 1 , yi)
313
310
mapind += 1
314
- A_mul_B ! (yrow, maps[mapind], x[ xinds[mapind]] )
315
- for colind in 2 : rows[rowind]
311
+ mul ! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, β )
312
+ for _ in 2 : row
316
313
mapind += 1
317
- mul! (yrow, maps[mapind], x[ xinds[mapind]], true , true )
314
+ mul! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α , true )
318
315
end
319
316
end
320
317
return y
321
318
end
322
319
323
- function At_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector )
324
- require_one_based_indexing (y, x)
325
- m, n = size (A)
326
- @boundscheck (n == length (y) && m == length (x)) || throw (DimensionMismatch (" At_mul_B!" ))
320
+ @inline function _transblockmul! (y, A:: BlockMap , x, α, β, transform)
327
321
maps, rows, xinds, yinds = A. maps, A. rows, A. rowranges, A. colranges
328
- mapind = 0
329
- # first block row (rowind = 1) of A, meaning first block column of A', fill all of y
330
322
@views @inbounds begin
331
- xcol = x[xinds[ 1 ]]
332
- for colind in 1 : rows[ 1 ]
333
- mapind += 1
334
- A_mul_B! (y[ yinds[mapind]], transpose (maps[mapind ]), xcol)
323
+ # first block row (rowind = 1) of A, meaning first block column of A', fill all of y
324
+ xcol = selectdim (x, 1 , first (xinds))
325
+ for rowind in 1 : first (rows)
326
+ mul! ( selectdim (y, 1 , yinds[rowind]), transform (maps[rowind ]), xcol, α, β )
335
327
end
336
- # subsequent block rows of A, add results to corresponding parts of y
337
- for rowind in 2 : length (rows)
338
- xcol = x[xinds[rowind]]
339
- for colind in 1 : rows[rowind]
328
+ mapind = first (rows)
329
+ # subsequent block rows of A (block columns of A'),
330
+ # add results to corresponding parts of y
331
+ # TODO : think about multithreading
332
+ for (row, xi) in zip (Base. tail (rows), Base. tail (xinds))
333
+ xcol = selectdim (x, 1 , xi)
334
+ for _ in 1 : row
340
335
mapind += 1
341
- mul! (y[ yinds[mapind]], transpose (maps[mapind]), xcol, true , true )
336
+ mul! (selectdim (y, 1 , yinds[mapind]), transform (maps[mapind]), xcol, α , true )
342
337
end
343
338
end
344
339
end
345
340
return y
346
341
end
347
342
348
- function Ac_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector )
349
- require_one_based_indexing (y, x)
350
- m, n = size (A)
351
- @boundscheck (n == length (y) && m == length (x)) || throw (DimensionMismatch (" At_mul_B!" ))
352
- maps, rows, xinds, yinds = A. maps, A. rows, A. rowranges, A. colranges
353
- mapind = 0
354
- # first block row (rowind = 1) of A, fill all of y
355
- @views @inbounds begin
356
- xcol = x[xinds[1 ]]
357
- for colind in 1 : rows[1 ]
358
- mapind += 1
359
- A_mul_B! (y[yinds[mapind]], adjoint (maps[mapind]), xcol)
360
- end
361
- # subsequent block rows of A, add results to corresponding parts of y
362
- for rowind in 2 : length (rows)
363
- xcol = x[xinds[rowind]]
364
- for colind in 1 : rows[rowind]
365
- mapind += 1
366
- mul! (y[yinds[mapind]], adjoint (maps[mapind]), xcol, true , true )
367
- end
343
+ # ###########
344
+ # multiplication with vectors & matrices
345
+ # ###########
346
+
347
+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
348
+ mul! (y, A, x)
349
+
350
+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: TransposeMap{<:Any,<:BlockMap} , x:: AbstractVector ) =
351
+ mul! (y, A, x)
352
+
353
+ Base. @propagate_inbounds At_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
354
+ mul! (y, transpose (A), x)
355
+
356
+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: AdjointMap{<:Any,<:BlockMap} , x:: AbstractVector ) =
357
+ mul! (y, A, x)
358
+
359
+ Base. @propagate_inbounds Ac_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
360
+ mul! (y, adjoint (A), x)
361
+
362
+ for Atype in (AbstractVector, AbstractMatrix)
363
+ @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , A:: BlockMap , x:: $Atype ,
364
+ α:: Number = true , β:: Number = false )
365
+ require_one_based_indexing (y, x)
366
+ @boundscheck check_dim_mul (y, A, x)
367
+ return _blockmul! (y, A, x, α, β)
368
+ end
369
+
370
+ for (maptype, transform) in ((:(TransposeMap{<: Any ,<: BlockMap }), :transpose ), (:(AdjointMap{<: Any ,<: BlockMap }), :adjoint ))
371
+ @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , wrapA:: $maptype , x:: $Atype ,
372
+ α:: Number = true , β:: Number = false )
373
+ require_one_based_indexing (y, x)
374
+ @boundscheck check_dim_mul (y, wrapA, x)
375
+ return _transblockmul! (y, wrapA. lmap, x, α, β, $ transform)
368
376
end
369
377
end
370
- return y
371
378
end
372
379
373
380
# ###########
388
395
# show(io, T)
389
396
# print(io, '}')
390
397
# end
398
+
399
+ # ###########
400
+ # BlockDiagonalMap
401
+ # ###########
402
+
403
+ struct BlockDiagonalMap{T,As<: Tuple{Vararg{LinearMap}} ,Ranges<: Tuple{Vararg{UnitRange{Int}}} } <: LinearMap{T}
404
+ maps:: As
405
+ rowranges:: Ranges
406
+ colranges:: Ranges
407
+ function BlockDiagonalMap {T,As} (maps:: As ) where {T, As<: Tuple{Vararg{LinearMap}} }
408
+ for A in maps
409
+ promote_type (T, eltype (A)) == T || throw (InexactError ())
410
+ end
411
+ # row ranges
412
+ inds = vcat (1 , size .(maps, 1 )... )
413
+ cumsum! (inds, inds)
414
+ rowranges = ntuple (i -> inds[i]: inds[i+ 1 ]- 1 , Val (length (maps)))
415
+ # column ranges
416
+ inds[2 : end ] .= size .(maps, 2 )
417
+ cumsum! (inds, inds)
418
+ colranges = ntuple (i -> inds[i]: inds[i+ 1 ]- 1 , Val (length (maps)))
419
+ return new {T,As,typeof(rowranges)} (maps, rowranges, colranges)
420
+ end
421
+ end
422
+
423
+ BlockDiagonalMap {T} (maps:: As ) where {T,As<: Tuple{Vararg{LinearMap}} } =
424
+ BlockDiagonalMap {T,As} (maps)
425
+ BlockDiagonalMap (maps:: LinearMap... ) =
426
+ BlockDiagonalMap {promote_type(map(eltype, maps)...)} (maps)
427
+
428
+ for k in 1 : 8 # is 8 sufficient?
429
+ Is = ntuple (n-> :($ (Symbol (:A ,n)):: AbstractMatrix ), Val (k- 1 ))
430
+ # yields (:A1, :A2, :A3, ..., :A(k-1))
431
+ L = :($ (Symbol (:A ,k)):: LinearMap )
432
+ # yields :Ak
433
+ mapargs = ntuple (n -> :(LinearMap ($ (Symbol (:A ,n)))), Val (k- 1 ))
434
+ # yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))
435
+
436
+ @eval begin
437
+ SparseArrays. blockdiag ($ (Is... ), $ L, As:: Union{LinearMap,AbstractMatrix} ...) =
438
+ BlockDiagonalMap ($ (mapargs... ), $ (Symbol (:A ,k)), convert_to_lmaps (As... )... )
439
+ function Base. cat ($ (Is... ), $ L, As:: Union{LinearMap,AbstractMatrix} ...; dims:: Dims{2} )
440
+ if dims == (1 ,2 )
441
+ return BlockDiagonalMap ($ (mapargs... ), $ (Symbol (:A ,k)), convert_to_lmaps (As... )... )
442
+ else
443
+ throw (ArgumentError (" dims keyword in cat of LinearMaps must be (1,2)" ))
444
+ end
445
+ end
446
+ end
447
+ end
448
+
449
+ Base. size (A:: BlockDiagonalMap ) = (last (A. rowranges[end ]), last (A. colranges[end ]))
450
+
451
+ LinearAlgebra. issymmetric (A:: BlockDiagonalMap ) = all (issymmetric, A. maps)
452
+ LinearAlgebra. ishermitian (A:: BlockDiagonalMap{<:Real} ) = all (issymmetric, A. maps)
453
+ LinearAlgebra. ishermitian (A:: BlockDiagonalMap ) = all (ishermitian, A. maps)
454
+
455
+ LinearAlgebra. adjoint (A:: BlockDiagonalMap{T} ) where {T} = BlockDiagonalMap {T} (map (adjoint, A. maps))
456
+ LinearAlgebra. transpose (A:: BlockDiagonalMap{T} ) where {T} = BlockDiagonalMap {T} (map (transpose, A. maps))
457
+
458
+ Base.:(== )(A:: BlockDiagonalMap , B:: BlockDiagonalMap ) = (eltype (A) == eltype (B) && A. maps == B. maps)
459
+
460
+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: BlockDiagonalMap , x:: AbstractVector ) =
461
+ mul! (y, A, x, true , false )
462
+
463
+ Base. @propagate_inbounds At_mul_B! (y:: AbstractVector , A:: BlockDiagonalMap , x:: AbstractVector ) =
464
+ mul! (y, transpose (A), x, true , false )
465
+
466
+ Base. @propagate_inbounds Ac_mul_B! (y:: AbstractVector , A:: BlockDiagonalMap , x:: AbstractVector ) =
467
+ mul! (y, adjoint (A), x, true , false )
468
+
469
+ for Atype in (AbstractVector, AbstractMatrix)
470
+ @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , A:: BlockDiagonalMap , x:: $Atype ,
471
+ α:: Number = true , β:: Number = false )
472
+ require_one_based_indexing (y, x)
473
+ @boundscheck check_dim_mul (y, A, x)
474
+ return _blockscaling! (y, A, x, α, β)
475
+ end
476
+ end
477
+
478
+ @inline function _blockscaling! (y, A:: BlockDiagonalMap , x, α, β)
479
+ maps, yinds, xinds = A. maps, A. rowranges, A. colranges
480
+ # TODO : think about multi-threading here
481
+ @views @inbounds for i in eachindex (yinds, maps, xinds)
482
+ mul! (selectdim (y, 1 , yinds[i]), maps[i], selectdim (x, 1 , xinds[i]), α, β)
483
+ end
484
+ return y
485
+ end
0 commit comments