Skip to content

Commit a2649ce

Browse files
implement bandwidths for OneElement (#447)
* implement bandwidths for OneElement * make improvements * fix sparse(::SparseMatrixCSC) * fix bandwidths for SparseMatrixCSC, add for SparseVector * add bandwidths(::Zeros) behaviour for empty sparse structures * add unit tests * cleanup bandwidths * Update interfaceimpl.jl --------- Co-authored-by: Sheehan Olver <[email protected]>
1 parent ea616cc commit a2649ce

File tree

5 files changed

+73
-15
lines changed

5 files changed

+73
-15
lines changed

ext/BandedMatricesSparseArraysExt.jl

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,53 @@ module BandedMatricesSparseArraysExt
22

33
using BandedMatrices
44
using BandedMatrices: _banded_rowval, _banded_colval, _banded_nzval
5-
using SparseArrays
5+
using SparseArrays, FillArrays
66
import SparseArrays: sparse
77

88
function sparse(B::BandedMatrix)
99
sparse(_banded_rowval(B), _banded_colval(B), _banded_nzval(B), size(B)...)
1010
end
1111

1212
function BandedMatrices.bandwidths(A::SparseMatrixCSC)
13-
l,u = -size(A,1),-size(A,2)
14-
15-
m,n = size(A)
13+
l = u = -max(size(A,1),size(A,2))
14+
n = size(A)[2]
1615
rows = rowvals(A)
1716
vals = nonzeros(A)
17+
18+
if isempty(vals)
19+
return bandwidths(Zeros(1))
20+
end
21+
1822
for j = 1:n
1923
for ind in nzrange(A, j)
2024
i = rows[ind]
2125
# We skip non-structural zeros when computing the
2226
# bandwidths.
2327
iszero(vals[ind]) && continue
24-
ij = abs(i-j)
25-
if i j
26-
l = max(l, ij)
27-
u = max(u, -ij)
28-
elseif i < j
29-
l = max(l, -ij)
30-
u = max(u, ij)
31-
end
28+
u = max(u, j-i)
29+
l = max(l, i-j)
3230
end
3331
end
3432

3533
l,u
3634
end
3735

36+
#Treat as n x 1 matrix
37+
function BandedMatrices.bandwidths(A::SparseVector)
38+
l = u = -size(A,1)
39+
rows = rowvals(A)
40+
41+
if isempty(rows)
42+
return bandwidths(Zeros(1))
43+
end
44+
45+
for i in rows
46+
iszero(i) && continue
47+
u = max(u, 1-i)
48+
l = max(l, i-1)
49+
end
50+
51+
l,u
52+
end
53+
3854
end

src/BandedMatrices.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import ArrayLayouts: AbstractTridiagonalLayout, BidiagonalLayout, BlasMatLdivVec
3434
symmetricuplo, transposelayout, triangulardata, triangularlayout, zero!,
3535
QRPackedQLayout, AdjQRPackedQLayout
3636

37-
import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal
37+
import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector
3838

3939
const libblas = LinearAlgebra.BLAS.libblas
4040
const liblapack = LinearAlgebra.BLAS.liblapack

src/interfaceimpl.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ bandwidths(::Tridiagonal) = (1,1)
5656
sublayout(::AbstractTridiagonalLayout, ::Type{<:Tuple{AbstractUnitRange{Int},AbstractUnitRange{Int}}}) =
5757
BandedLayout()
5858

59+
#Implement bandwidths for OneElement structure
60+
function bandwidths(o::OneElementVector)
61+
k = FillArrays.nzind(o)[1] # index of non-zero
62+
n = length(o)
63+
if k > n || k < 1
64+
bandwidths(Zeros(o))
65+
else
66+
(k-1, 1-k)
67+
end
68+
end
69+
70+
function bandwidths(o::OneElementMatrix)
71+
n,m = size(o)
72+
k,j = Tuple(FillArrays.nzind(o)) # indices of non-zero entries
73+
if k > n || j > m || k < 1 || j < 1
74+
bandwidths(Zeros(o))
75+
else
76+
(k-j,j-k)
77+
end
78+
end
79+
5980
###
6081
# rot180
6182
###

test/test_interface.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module TestInterface
22

3-
using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test
3+
using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test, Random
44
import BandedMatrices: isbanded, AbstractBandedLayout, BandedStyle,
55
BandedColumns, bandeddata
66
import ArrayLayouts: OnesLayout, UnknownLayout
7-
using InfiniteArrays
7+
using InfiniteArrays, SparseArrays
88

99
struct PseudoBandedMatrix{T} <: AbstractMatrix{T}
1010
data::Array{T}
@@ -310,6 +310,18 @@ end
310310
@test layout_getindex(T,1:10,1:10) isa BandedMatrix
311311
end
312312

313+
@testset "OneElement" begin
314+
o = OneElement(1, 3, 5)
315+
@test bandwidths(o) == (2,-2)
316+
n,m = rand(1:10,2)
317+
o = OneElement(1, (rand(1:n),rand(1:m)), (n, m))
318+
@test bandwidths(o) == bandwidths(sparse(o))
319+
o = OneElement(1, (n+1,m+1), (n, m))
320+
@test bandwidths(o) == bandwidths(Zeros(o))
321+
o = OneElement(1, 6, 5)
322+
@test bandwidths(o) == bandwidths(Zeros(o))
323+
end
324+
313325
@testset "rot180" begin
314326
A = brand(5,5,1,2)
315327
R = rot180(A)

test/test_miscs.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,17 @@ import BandedMatrices: _BandedMatrix, DefaultBandedMatrix
5050
@test bA isa BandedMatrix
5151
@test bA == A
5252
@test bandwidths(bA) == min.((l,u),9)
53+
v = sparsevec(brand(10, 1, l, u))
54+
@test bandwidths(v) == (l, min(0, u))
5355
end
5456

57+
l, u = -1, 0
58+
A = brand(10, 10, l, u)
59+
sA = sparse(A)
60+
@test bandwidths(sA) == bandwidths(Zeros(1))
61+
v = sparsevec(brand(10, 1, l, u))
62+
@test bandwidths(v) == bandwidths(Zeros(1))
63+
5564
for diags = [(-1 => ones(Int, 5),),
5665
(-2 => ones(Int, 5),),
5766
(2 => ones(Int, 5),),

0 commit comments

Comments
 (0)