Skip to content

Commit 0b90bf4

Browse files
Merge pull request #133 from dingraha/master
Make sparse non-square Jacobians work
2 parents 9945e7d + c106d2a commit 0b90bf4

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDiff"
22
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
3-
version = "2.11.0"
3+
version = "2.11.1"
44

55
[deps]
66
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/jacobians.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mutable struct JacobianCache{CacheType1,CacheType2,CacheType3,ColorType,SparsityType,fdtype,returntype}
22
x1 :: CacheType1
3+
x2 :: CacheType1
34
fx :: CacheType2
45
fx1 :: CacheType3
56
colorvec :: ColorType
@@ -96,7 +97,8 @@ function JacobianCache(
9697
@assert eltype(fx1) == T2
9798
_fx = fx
9899
end
99-
JacobianCache{typeof(_x1),typeof(_fx),typeof(fx1),typeof(colorvec),typeof(sparsity),fdtype,returntype}(_x1,_fx,fx1,colorvec,sparsity)
100+
_x2 = similar(_x1)
101+
JacobianCache{typeof(_x1),typeof(_fx),typeof(fx1),typeof(colorvec),typeof(sparsity),fdtype,returntype}(_x1,_x2,_fx,fx1,colorvec,sparsity)
100102
end
101103

102104
function _make_Ji(::SparseMatrixCSC, rows_index,cols_index,dx,colorvec,color_i,nrows,ncols)
@@ -334,7 +336,7 @@ function finite_difference_jacobian!(
334336
m, n = size(J)
335337
_color = reshape(colorvec, axes(x)...)
336338

337-
x1, fx, fx1 = cache.x1, cache.fx, cache.fx1
339+
x1, x2, fx, fx1 = cache.x1, cache.x2, cache.fx, cache.fx1
338340
copyto!(x1, x)
339341
vfx = _vec(fx)
340342

@@ -377,8 +379,8 @@ function finite_difference_jacobian!(
377379
# Now return x1 back to its original value
378380
ArrayInterface.allowed_setindex!(x1, x1_save, color_i)
379381
else # Perturb along the colorvec vector
380-
@. fx1 = x1 * (_color == color_i)
381-
tmp = norm(fx1)
382+
@. x2 = x1 * (_color == color_i)
383+
tmp = norm(x2)
382384
epsilon = compute_epsilon(Val(:forward), sqrt(tmp), relstep, absstep, dir)
383385
@. x1 = x1 + epsilon * (_color == color_i)
384386
f(fx1, x1)
@@ -420,8 +422,8 @@ function finite_difference_jacobian!(
420422
@. J[:,color_i] = (vfx1 - vfx) / 2epsilon
421423
ArrayInterface.allowed_setindex!(x1, x_save, color_i)
422424
else # Perturb along the colorvec vector
423-
@. fx1 = x1 * (_color == color_i)
424-
tmp = norm(fx1)
425+
@. x2 = x1 * (_color == color_i)
426+
tmp = norm(x2)
425427
epsilon = compute_epsilon(Val(:central), sqrt(tmp), relstep, absstep, dir)
426428
@. x1 = x1 + epsilon * (_color == color_i)
427429
@. x = x - epsilon * (_color == color_i)

test/coloring_tests.jl

+40
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,43 @@ Jbb = BlockBandedMatrix(similar(Jsparse),fill(100, 100), fill(100, 100),(1,1));
117117
colorsbb = ArrayInterface.matrix_colors(Jbb)
118118
FiniteDiff.finite_difference_jacobian!(Jbb, f, x, colorvec=colorsbb)
119119
@test Jbb Jsparse
120+
121+
122+
# Non-square Jacobian test.
123+
# The Jacobian of f_nonsquare! has size (n, 2*n).
124+
function f_nonsquare!(y, x)
125+
global fcalls += 1
126+
@assert length(x) == 2*length(y)
127+
n = length(x) ÷ 2
128+
x1 = @view x[1:n]
129+
x2 = @view x[n+1:end]
130+
131+
@. y = (x1 .- 3).^2 .+ x1.*x2 .+ (x2 .+ 4).^2 .- 3
132+
return nothing
133+
end
134+
135+
n = 4
136+
x0 = vcat(ones(n).*(1:n) .+ 0.5, ones(n).*(1:n) .+ 1.5)
137+
y0 = zeros(n)
138+
rows = vcat([i for i in 1:n], [i for i in 1:n])
139+
cols = vcat([i for i in 1:n], [i+n for i in 1:n])
140+
sparsity = sparse(rows, cols, ones(length(rows)))
141+
colorvec = vcat(fill(1, n), fill(2, n))
142+
143+
J_nonsquare1 = zeros(size(sparsity))
144+
FiniteDiff.finite_difference_jacobian!(J_nonsquare1, f_nonsquare!, x0)
145+
146+
J_nonsquare2 = similar(sparsity)
147+
for method in [Val(:forward), Val(:central), Val(:complex)]
148+
cache = FiniteDiff.JacobianCache(copy(x0), copy(y0), copy(y0), method; sparsity, colorvec)
149+
global fcalls = 0
150+
FiniteDiff.finite_difference_jacobian!(J_nonsquare2, f_nonsquare!, x0, cache)
151+
if method == Val(:central)
152+
@test fcalls == 2*maximum(colorvec)
153+
elseif method == Val(:complex)
154+
@test fcalls == maximum(colorvec)
155+
else
156+
@test fcalls == maximum(colorvec) + 1
157+
end
158+
@test isapprox(J_nonsquare2, J_nonsquare1; rtol=1e-6)
159+
end

0 commit comments

Comments
 (0)