Skip to content

Commit d6df0e2

Browse files
Merge pull request #124 from sjdaines/fast_sparse_path
Add fast path for _colorediteration! with sparse J
2 parents 0c79fe7 + 96670eb commit d6df0e2

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

src/iteration_utils.jl

+17
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,27 @@ end
1616
end
1717
end
1818

19+
# fast version for the case where J and sparsity have the same sparsity pattern
20+
@inline function _colorediteration!(Jsparsity::SparseMatrixCSC,vfx,colorvec,color_i,ncols)
21+
@inbounds for col_index in 1:ncols
22+
if colorvec[col_index] == color_i
23+
@inbounds for spidx in nzrange(Jsparsity, col_index)
24+
row_index = Jsparsity.rowval[spidx]
25+
Jsparsity.nzval[spidx]=vfx[row_index]
26+
end
27+
end
28+
end
29+
end
30+
1931
#override default setting of using findstructralnz
2032
_use_findstructralnz(sparsity) = ArrayInterface.has_sparsestruct(sparsity)
2133
_use_findstructralnz(::SparseMatrixCSC) = false
2234

35+
# test if J, sparsity are both SparseMatrixCSC and have the same sparsity pattern of stored values
36+
_use_sparseCSC_common_sparsity(J, sparsity) = false
37+
_use_sparseCSC_common_sparsity(J::SparseMatrixCSC, sparsity::SparseMatrixCSC) =
38+
((J.colptr == sparsity.colptr) && (J.rowval == sparsity.rowval))
39+
2340
function __init__()
2441
@require BlockBandedMatrices="ffab5731-97b5-5995-9138-79e8c1846df0" begin
2542
@require BlockArrays="8e7c35d0-a365-5155-bbbb-fb81a777f24e" begin

src/jacobians.jl

+20-5
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ function finite_difference_jacobian!(
348348
fill!(J,false)
349349
end
350350

351+
# fast path if J and sparsity are both SparseMatrixCSC and have the same sparsity pattern
352+
sparseCSC_common_sparsity = _use_sparseCSC_common_sparsity(J, sparsity)
353+
351354
if fdtype == Val(:forward)
352355
vfx1 = _vec(fx1)
353356

@@ -378,7 +381,11 @@ function finite_difference_jacobian!(
378381
# J is a sparse matrix, so decompress on the fly
379382
@. vfx1 = (vfx1 - vfx) / epsilon
380383
if ArrayInterface.fast_scalar_indexing(x1)
381-
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
384+
if sparseCSC_common_sparsity
385+
_colorediteration!(J,vfx1,colorvec,color_i,n)
386+
else
387+
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
388+
end
382389
else
383390
#=
384391
J.nzval[rows_index] .+= (colorvec[cols_index] .== color_i) .* vfx1[rows_index]
@@ -417,8 +424,12 @@ function finite_difference_jacobian!(
417424
f(fx1, x1)
418425
f(fx, x)
419426
@. vfx1 = (vfx1 - vfx) / 2epsilon
420-
if ArrayInterface.fast_scalar_indexing(x1)
421-
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
427+
if ArrayInterface.fast_scalar_indexing(x1)
428+
if sparseCSC_common_sparsity
429+
_colorediteration!(J,vfx1,colorvec,color_i,n)
430+
else
431+
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
432+
end
422433
else
423434
if J isa SparseMatrixCSC
424435
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index)
@@ -443,8 +454,12 @@ function finite_difference_jacobian!(
443454
@. x1 = x1 + im * epsilon * (_color == color_i)
444455
f(fx,x1)
445456
@. vfx = imag(vfx) / epsilon
446-
if ArrayInterface.fast_scalar_indexing(x1)
447-
_colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n)
457+
if ArrayInterface.fast_scalar_indexing(x1)
458+
if sparseCSC_common_sparsity
459+
_colorediteration!(J,vfx,colorvec,color_i,n)
460+
else
461+
_colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n)
462+
end
448463
else
449464
if J isa SparseMatrixCSC
450465
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,),rows_index), rows_index)

0 commit comments

Comments
 (0)