Skip to content

Commit 70488bf

Browse files
Merge pull request #151 from JuliaDiff/setinde
Piracy-free setindex
2 parents 6ff1755 + b303e46 commit 70488bf

File tree

3 files changed

+43
-23
lines changed

3 files changed

+43
-23
lines changed

src/FiniteDiff.jl

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,26 @@ _mat(x::AbstractMatrix) = x
1212
_mat(x::StaticVector) = reshape(x, (axes(x, 1), SOneTo(1)))
1313
_mat(x::AbstractVector) = reshape(x, (axes(x, 1), Base.OneTo(1)))
1414

15+
# Setindex overloads without piracy
16+
setindex(x...) = Base.setindex(x...)
17+
setindex(x::StaticArray, v, i::Int...) = StaticArrays.setindex(x, v, i...)
18+
19+
function setindex(x::AbstractArray, v, i...)
20+
_x = Base.copymutable(x)
21+
_x[i...] = v
22+
return _x
23+
end
24+
25+
function setindex(x::AbstractVector, v, i::Int)
26+
n = length(x)
27+
x .* (i .!== 1:n) .+ v .* (i .== 1:n)
28+
end
29+
30+
function setindex(x::AbstractMatrix, v, i::Int, j::Int)
31+
n, m = Base.size(x)
32+
x .* (i .!== 1:n) .* (j .!== i:m)' .+ v .* (i .== 1:n) .* (j .== i:m)'
33+
end
34+
1535
include("iteration_utils.jl")
1636
include("epsilons.jl")
1737
include("derivatives.jl")

src/hessians.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ function finite_difference_hessian!(H,f,x,
7575
ArrayInterfaceCore.allowed_setindex!(xpp,xi + epsilon,i)
7676
ArrayInterfaceCore.allowed_setindex!(xmm,xi - epsilon,i)
7777
else
78-
_xpp = Base.setindex(xpp,xi + epsilon, i)
79-
_xmm = Base.setindex(xmm,xi - epsilon, i)
78+
_xpp = setindex(xpp,xi + epsilon, i)
79+
_xmm = setindex(xmm,xi - epsilon, i)
8080
end
8181

8282
ArrayInterfaceCore.allowed_setindex!(H,(f(_xpp) - 2*fx + f(_xmm)) / epsilon^2,i,i)
@@ -90,10 +90,10 @@ function finite_difference_hessian!(H,f,x,
9090
ArrayInterfaceCore.allowed_setindex!(xmp,xm,i)
9191
ArrayInterfaceCore.allowed_setindex!(xmm,xm,i)
9292
else
93-
_xpp = Base.setindex(xpp,xp,i)
94-
_xpm = Base.setindex(xpm,xp,i)
95-
_xmp = Base.setindex(xmp,xm,i)
96-
_xmm = Base.setindex(xmm,xm,i)
93+
_xpp = setindex(xpp,xp,i)
94+
_xpm = setindex(xpm,xp,i)
95+
_xmp = setindex(xmp,xm,i)
96+
_xmm = setindex(xmm,xm,i)
9797
end
9898

9999
for j = i+1:n
@@ -108,10 +108,10 @@ function finite_difference_hessian!(H,f,x,
108108
ArrayInterfaceCore.allowed_setindex!(xmp,xp,j)
109109
ArrayInterfaceCore.allowed_setindex!(xmm,xm,j)
110110
else
111-
_xpp = Base.setindex(_xpp,xp,j)
112-
_xpm = Base.setindex(_xpm,xm,j)
113-
_xmp = Base.setindex(_xmp,xp,j)
114-
_xmm = Base.setindex(_xmm,xm,j)
111+
_xpp = setindex(_xpp,xp,j)
112+
_xpm = setindex(_xpm,xm,j)
113+
_xmp = setindex(_xmp,xp,j)
114+
_xmm = setindex(_xmm,xm,j)
115115
end
116116

117117
ArrayInterfaceCore.allowed_setindex!(H,(f(_xpp) - f(_xpm) - f(_xmp) + f(_xmm))/(4*epsiloni*epsilonj),i,j)
@@ -122,10 +122,10 @@ function finite_difference_hessian!(H,f,x,
122122
ArrayInterfaceCore.allowed_setindex!(xmp,xj,j)
123123
ArrayInterfaceCore.allowed_setindex!(xmm,xj,j)
124124
else
125-
_xpp = Base.setindex(_xpp,xj,j)
126-
_xpm = Base.setindex(_xpm,xj,j)
127-
_xmp = Base.setindex(_xmp,xj,j)
128-
_xmm = Base.setindex(_xmm,xj,j)
125+
_xpp = setindex(_xpp,xj,j)
126+
_xpm = setindex(_xpm,xj,j)
127+
_xmp = setindex(_xmp,xj,j)
128+
_xmm = setindex(_xmm,xj,j)
129129
end
130130
end
131131

src/jacobians.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function finite_difference_jacobian(
193193
function calculate_Ji_forward(i)
194194
x_save = ArrayInterfaceCore.allowed_getindex(vecx, i)
195195
epsilon = compute_epsilon(Val(:forward), x_save, relstep, absstep, dir)
196-
_vecx1 = Base.setindex(vecx, x_save+epsilon, i)
196+
_vecx1 = setindex(vecx, x_save+epsilon, i)
197197
_x1 = reshape(_vecx1, axes(x))
198198
vecfx1 = _vec(f(_x1))
199199
dx = (vecfx1-vecfx) / epsilon
@@ -226,8 +226,8 @@ function finite_difference_jacobian(
226226
x1_save = ArrayInterfaceCore.allowed_getindex(vecx1,i)
227227
x_save = ArrayInterfaceCore.allowed_getindex(vecx,i)
228228
epsilon = compute_epsilon(Val(:forward), x1_save, relstep, absstep, dir)
229-
_vecx1 = Base.setindex(vecx1,x1_save+epsilon,i)
230-
_vecx = Base.setindex(vecx,x_save-epsilon,i)
229+
_vecx1 = setindex(vecx1,x1_save+epsilon,i)
230+
_vecx = setindex(vecx,x_save-epsilon,i)
231231
_x1 = reshape(_vecx1, axes(x))
232232
_x = reshape(_vecx, axes(x))
233233
vecfx1 = _vec(f(_x1))
@@ -264,7 +264,7 @@ function finite_difference_jacobian(
264264

265265
function calculate_Ji_complex(i)
266266
x_save = ArrayInterfaceCore.allowed_getindex(vecx,i)
267-
_vecx = Base.setindex(complex.(vecx),x_save+im*epsilon,i)
267+
_vecx = setindex(complex.(vecx),x_save+im*epsilon,i)
268268
_x = reshape(_vecx, axes(x))
269269
vecfx = _vec(f(_x))
270270
dx = imag(vecfx)/epsilon
@@ -325,18 +325,18 @@ function _findstructralnz(A::DenseMatrix)
325325
numnz = count(A .≠ 0)
326326
I = Vector{Int64}(undef, numnz)
327327
J = Vector{Int64}(undef, numnz)
328-
idx = 1
328+
idx = 1
329329
for j in axes(A, 2)
330330
for i in axes(A, 1)
331331
if A[i, j] 0
332-
I[idx] = i
332+
I[idx] = i
333333
J[idx] = j
334-
idx += 1
334+
idx += 1
335335
end
336336
end
337337
end
338338
I, J
339-
end
339+
end
340340

341341
function finite_difference_jacobian!(
342342
J,
@@ -361,7 +361,7 @@ function finite_difference_jacobian!(
361361
cols_index = nothing
362362
if _use_findstructralnz(sparsity)
363363
rows_index, cols_index = ArrayInterfaceCore.findstructralnz(sparsity)
364-
elseif sparsity isa DenseMatrix
364+
elseif sparsity isa DenseMatrix
365365
rows_index, cols_index = FiniteDiff._findstructralnz(sparsity)
366366
end
367367

0 commit comments

Comments
 (0)