Skip to content

Commit c1582fd

Browse files
authored
Optimized row/col getindex (#373)
1 parent 3a37d96 commit c1582fd

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

src/banded/BandedMatrix.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# a_21 a_22 a_23
55
# a_31 a_32 a_33 a_34
66
# a_42 a_43 a_44 ]
7-
# ordering the data like (cobbndsmns first)
7+
# ordering the data like (columns first)
88
# [ * a_12 a_23 a_34
99
# a_11 a_22 a_33 a_44
1010
# a_21 a_32 a_43 *
@@ -379,13 +379,45 @@ diagzero(D::Diagonal{B}, i, j) where B<:BandedMatrix =
379379
end
380380

381381

382-
# scalar - integer - integer
383-
@inline function getindex(A::BandedMatrix, k::Integer, j::Integer)
382+
# Int - Int
383+
@inline function getindex(A::BandedMatrix, k::Int, j::Int)
384384
@boundscheck checkbounds(A, k, j)
385385
@inbounds r = banded_getindex(A.data, A.l, A.u, k, j)
386386
r
387387
end
388388

389+
# BandRange - Int
390+
@propagate_inbounds function getindex(A::BandedMatrix, ::BandRangeType, j::Int)
391+
@boundscheck checkbounds(A, colrange(A, j), j)
392+
A.data[data_colrange(A,j)]
393+
end
394+
395+
# Colon - Int
396+
@propagate_inbounds function getindex(A::BandedMatrix, ::Colon, j::Int)
397+
@boundscheck checkbounds(A, axes(A,1), j)
398+
r = similar(A, axes(A,1))
399+
r[firstindex(r):colstart(A,j)-1] .= zero(eltype(r))
400+
r[colrange(A,j)] = @view A.data[data_colrange(A,j)]
401+
r[colstop(A,j)+1:end] .= zero(eltype(r))
402+
return r
403+
end
404+
405+
# Int - BandRange
406+
@propagate_inbounds function getindex(A::BandedMatrix, k::Int, j::BandRangeType)
407+
@boundscheck checkbounds(A, k, rowrange(A, k))
408+
A.data[data_rowrange(A,k)]
409+
end
410+
411+
# Int - Colon
412+
@propagate_inbounds function getindex(A::BandedMatrix, k::Int, ::Colon)
413+
@boundscheck checkbounds(A, k, axes(A,2))
414+
r = similar(A, axes(A,2))
415+
r[firstindex(r):rowstart(A,k)-1] .= zero(eltype(r))
416+
r[rowrange(A,k)] = @view A.data[data_rowrange(A,k)]
417+
r[rowstop(A,k)+1:end] .= zero(eltype(r))
418+
return r
419+
end
420+
389421
# ~ indexing along a band
390422
# we reduce it to converting a View
391423

test/test_indexing.jl

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,21 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
228228
0 7 10 13 15 0 0;
229229
0 0 11 14 16 17 0]
230230

231-
@test a[BandRange, 1] == [1, 2, 3]
232-
@test a[BandRange, 2] == [4, 5, 6, 7]
233-
@test a[BandRange, 3] == [8, 9, 10, 11]
234-
@test a[BandRange, 4] == [12, 13, 14]
235-
@test a[BandRange, 5] == [15, 16]
236-
@test a[BandRange, 6] == [17]
231+
@test a[BandRange, 1] == @view(a[BandRange, 1]) == [1, 2, 3]
232+
@test a[BandRange, 2] == @view(a[BandRange, 2]) == [4, 5, 6, 7]
233+
@test a[BandRange, 3] == @view(a[BandRange, 3]) == [8, 9, 10, 11]
234+
@test a[BandRange, 4] == @view(a[BandRange, 4]) == [12, 13, 14]
235+
@test a[BandRange, 5] == @view(a[BandRange, 5]) == [15, 16]
236+
@test a[BandRange, 6] == @view(a[BandRange, 6]) == [17]
237+
@test a[BandRange, 7] == @view(a[BandRange, 7]) == Int[]
238+
239+
@test a[:, 1] == view(a, :, 1) == [1,2,3,0,0]
240+
@test a[:, 2] == view(a, :, 2) == [4,5,6,7,0]
241+
@test a[:, 3] == view(a, :, 3) == [0,8,9,10,11]
242+
@test a[:, 4] == view(a, :, 4) == [0,0,12,13,14]
243+
@test a[:, 5] == view(a, :, 5) == [0,0,0,15,16]
244+
@test a[:, 6] == view(a, :, 6) == [0,0,0,0,17]
245+
@test a[:, 7] == view(a, :, 7) == [0,0,0,0,0]
237246

238247
@test_throws BoundsError a[:, 0] = [1, 2, 3]
239248
@test_throws DimensionMismatch a[:, 1] = [1, 2, 3]
@@ -310,12 +319,21 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
310319
0 7 10 13 15 0 0;
311320
0 0 11 14 16 17 0]'
312321

313-
@test a[1, BandRange] == [1, 2, 3]
314-
@test a[2, BandRange] == [4, 5, 6, 7]
315-
@test a[3, BandRange] == [8, 9, 10, 11]
316-
@test a[4, BandRange] == [12, 13, 14]
317-
@test a[5, BandRange] == [15, 16]
318-
@test a[6, BandRange] == [17]
322+
@test a[1, BandRange] == @view(a[1, BandRange]) == [1, 2, 3]
323+
@test a[2, BandRange] == @view(a[2, BandRange]) == [4, 5, 6, 7]
324+
@test a[3, BandRange] == @view(a[3, BandRange]) == [8, 9, 10, 11]
325+
@test a[4, BandRange] == @view(a[4, BandRange]) == [12, 13, 14]
326+
@test a[5, BandRange] == @view(a[5, BandRange]) == [15, 16]
327+
@test a[6, BandRange] == @view(a[6, BandRange]) == [17]
328+
@test a[7, BandRange] == @view(a[7, BandRange]) == Int[]
329+
330+
@test a[1, :] == @view(a[1, :]) == [1,2,3,0,0]
331+
@test a[2, :] == @view(a[2, :]) == [4,5,6,7,0]
332+
@test a[3, :] == @view(a[3, :]) == [0,8,9,10,11]
333+
@test a[4, :] == @view(a[4, :]) == [0,0,12,13,14]
334+
@test a[5, :] == @view(a[5, :]) == [0,0,0,15,16]
335+
@test a[6, :] == @view(a[6, :]) == [0,0,0,0,17]
336+
@test a[7, :] == @view(a[7, :]) == [0,0,0,0,0]
319337

320338
@test_throws BoundsError a[0, :] = [1, 2, 3]
321339
@test_throws DimensionMismatch a[1, :] = [1, 2, 3]

0 commit comments

Comments
 (0)