Skip to content

Commit 6748bf8

Browse files
Merge pull request #191 from JuliaDiff/os/finite_difference_jvp`
add `finite_difference_jvp`
2 parents 74c16b0 + 59b1ebb commit 6748bf8

File tree

3 files changed

+237
-10
lines changed

3 files changed

+237
-10
lines changed

src/FiniteDiff.jl

+1
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,6 @@ include("derivatives.jl")
4040
include("gradients.jl")
4141
include("jacobians.jl")
4242
include("hessians.jl")
43+
include("jvp.jl")
4344

4445
end # module

src/jvp.jl

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
mutable struct JVPCache{X1, FX1, FDType}
2+
x1 :: X1
3+
fx1 :: FX1
4+
end
5+
6+
"""
7+
FiniteDiff.JVPCache(
8+
x,
9+
fdtype :: Type{T1} = Val{:forward})
10+
11+
Allocating Cache Constructor.
12+
"""
13+
function JVPCache(
14+
x,
15+
fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD}
16+
fdtype isa Type && (fdtype = fdtype())
17+
JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x))
18+
end
19+
20+
"""
21+
FiniteDiff.JVPCache(
22+
x,
23+
fx1,
24+
fdtype :: Type{T1} = Val{:forward},
25+
26+
Non-Allocating Cache Constructor.
27+
"""
28+
function JVPCache(
29+
x,
30+
fx,
31+
fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD}
32+
fdtype isa Type && (fdtype = fdtype())
33+
JVPCache{typeof(x), typeof(fx), fdtype}(x,fx)
34+
end
35+
36+
"""
37+
FiniteDiff.finite_difference_jvp(
38+
f,
39+
x :: AbstractArray{<:Number},
40+
v :: AbstractArray{<:Number},
41+
fdtype :: Type{T1}=Val{:central},
42+
relstep=default_relstep(fdtype, eltype(x)),
43+
absstep=relstep)
44+
45+
Cache-less.
46+
"""
47+
function finite_difference_jvp(f, x, v,
48+
fdtype = Val(:forward),
49+
f_in = nothing;
50+
relstep=default_relstep(fdtype, eltype(x)),
51+
absstep=relstep,
52+
dir=true)
53+
54+
if f_in isa Nothing
55+
fx = f(x)
56+
else
57+
fx = f_in
58+
end
59+
cache = JVPCache(x, fx, fdtype)
60+
finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir)
61+
end
62+
63+
"""
64+
FiniteDiff.finite_difference_jvp(
65+
f,
66+
x,
67+
v,
68+
cache::JVPCache;
69+
relstep=default_relstep(fdtype, eltype(x)),
70+
absstep=relstep,
71+
72+
Cached.
73+
"""
74+
function finite_difference_jvp(
75+
f,
76+
x,
77+
v,
78+
cache::JVPCache{X1, FX1, fdtype},
79+
f_in=nothing;
80+
relstep=default_relstep(fdtype, eltype(x)),
81+
absstep=relstep,
82+
dir=true) where {X1, FX1, fdtype}
83+
84+
if fdtype == Val(:complex)
85+
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
86+
end
87+
88+
tmp = sqrt(abs(dot(_vec(x), _vec(v))))
89+
epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir)
90+
if fdtype == Val(:forward)
91+
fx = f_in isa Nothing ? f(x) : f_in
92+
x1 = @. x + epsilon * v
93+
fx1 = f(x1)
94+
fx1 = @. (fx1-fx)/epsilon
95+
elseif fdtype == Val(:central)
96+
x1 = @. x + epsilon * v
97+
fx1 = f(x1)
98+
x1 = @. x - epsilon * v
99+
fx = f(x1)
100+
fx1 = @. (fx1-fx)/(2epsilon)
101+
else
102+
fdtype_error(eltype(x))
103+
end
104+
fx1
105+
end
106+
107+
"""
108+
finite_difference_jvp!(
109+
jvp::AbstractArray{<:Number},
110+
f,
111+
x::AbstractArray{<:Number},
112+
v::AbstractArray{<:Number},
113+
fdtype :: Type{T1}=Val{:forward},
114+
returntype :: Type{T2}=eltype(x),
115+
f_in :: Union{T2,Nothing}=nothing;
116+
relstep=default_relstep(fdtype, eltype(x)),
117+
absstep=relstep)
118+
119+
Cache-less.
120+
"""
121+
function finite_difference_jvp!(jvp,
122+
f,
123+
x,
124+
v,
125+
fdtype = Val(:forward),
126+
f_in = nothing;
127+
relstep=default_relstep(fdtype, eltype(x)),
128+
absstep=relstep)
129+
if !isnothing(f_in)
130+
cache = JVPCache(x, f_in, fdtype)
131+
elseif fdtype == Val(:forward)
132+
fx = zero(x)
133+
f(fx,x)
134+
cache = JVPCache(x, fx, fdtype)
135+
else
136+
cache = JVPCache(x, fdtype)
137+
end
138+
finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep)
139+
end
140+
141+
"""
142+
FiniteDiff.finite_difference_jvp!(
143+
jvp::AbstractArray{<:Number},
144+
f,
145+
x::AbstractArray{<:Number},
146+
v::AbstractArray{<:Number},
147+
cache::JVPCache;
148+
relstep=default_relstep(fdtype, eltype(x)),
149+
absstep=relstep,)
150+
151+
Cached.
152+
"""
153+
function finite_difference_jvp!(
154+
jvp,
155+
f,
156+
x,
157+
v,
158+
cache::JVPCache{X1, FX1, fdtype},
159+
f_in = nothing;
160+
relstep = default_relstep(fdtype, eltype(x)),
161+
absstep = relstep,
162+
dir = true) where {X1, FX1, fdtype}
163+
164+
if fdtype == Val(:complex)
165+
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
166+
end
167+
168+
(;x1, fx1) = cache
169+
tmp = sqrt(abs(dot(_vec(x), _vec(v))))
170+
epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir)
171+
if fdtype == Val(:forward)
172+
if f_in isa Nothing
173+
f(fx1, x)
174+
else
175+
fx1 = f_in
176+
end
177+
@. x1 = x + epsilon * v
178+
f(jvp, x1)
179+
@. jvp = (jvp-fx1)/epsilon
180+
elseif fdtype == Val(:central)
181+
@. x1 = x - epsilon * v
182+
f(fx1, x1)
183+
@. x1 = x + epsilon * v
184+
f(jvp, x1)
185+
@. jvp = (jvp-fx1)/(2epsilon)
186+
else
187+
fdtype_error(eltype(x))
188+
end
189+
nothing
190+
end
191+
192+
function resize!(cache::JVPCache, i::Int)
193+
resize!(cache.x1, i)
194+
cache.fx1 !== nothing && resize!(cache.fx1, i)
195+
nothing
196+
end

test/finitedifftests.jl

+40-10
Original file line numberDiff line numberDiff line change
@@ -382,38 +382,68 @@ df = zero(x)
382382
df_ref = diag(J_ref)
383383
epsilon = zero(x)
384384
forward_cache = FiniteDiff.JacobianCache(x, Val{:forward}, eltype(x))
385+
forward_jvp_cache = FiniteDiff.JVPCache(x, Val{:forward})
385386
@test forward_cache.colorvec == 1:length(x)
386387
central_cache = FiniteDiff.JacobianCache(x, Val{:central}, eltype(x))
388+
central_jvp_cache = FiniteDiff.JVPCache(x, Val{:central})
387389
complex_cache = FiniteDiff.JacobianCache(x, Val{:complex}, eltype(x))
388390
f_in = copy(y)
391+
vdir = rand(2)
392+
jvp_ref = J_ref*vdir
389393

390394
@time @testset "Out-of-Place Jacobian StridedArray real-valued tests" begin
391-
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-4
392-
@test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-4
393-
@test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-4
394-
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4
395-
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-4
395+
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-6
396+
@test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-6
397+
@test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-6
398+
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6
399+
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-6
396400
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, central_cache), J_ref) < 1e-8
397401
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, Val{:central}), J_ref) < 1e-8
398402
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, complex_cache), J_ref) < 1e-14
399403
end
400404

405+
@time @testset "Out-of-Place JVP StridedArray real-valued tests" begin
406+
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
407+
@test err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6
408+
@test_throws Any err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
409+
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6
410+
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6
411+
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8
412+
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, Val{:central}), jvp_ref) < 1e-8
413+
end
414+
401415
function test_iipJac(J_ref, args...; kwargs...)
402416
_J = zero(J_ref)
403417
FiniteDiff.finite_difference_jacobian!(_J, args...; kwargs...)
404418
_J
405419
end
406420
@time @testset "inPlace Jacobian StridedArray real-valued tests" begin
407-
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-4
408-
@test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-4
409-
@test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-4
410-
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4
411-
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-4
421+
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-6
422+
@test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-6
423+
@test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-6
424+
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6
425+
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-6
412426
@test err_func(test_iipJac(J_ref, iipf, x, central_cache), J_ref) < 1e-8
413427
@test err_func(test_iipJac(J_ref, iipf, x, Val{:central}), J_ref) < 1e-8
414428
@test err_func(test_iipJac(J_ref, iipf, x, complex_cache), J_ref) < 1e-14
415429
end
416430

431+
function test_iipJVP(jvp_ref, args...; kwargs...)
432+
_jvp = zero(jvp_ref)
433+
FiniteDiff.finite_difference_jvp!(_jvp, args...; kwargs...)
434+
_jvp
435+
end
436+
437+
@time @testset "inPlace JVP StridedArray real-valued tests" begin
438+
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
439+
@test err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6
440+
@test_throws Any err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
441+
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6
442+
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6
443+
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8
444+
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, Val{:central}), jvp_ref) < 1e-8
445+
end
446+
417447
function iipf(fvec, x)
418448
fvec[1] = (im * x[1] + 3) * (x[2]^3 - 7) + 18
419449
fvec[2] = sin(x[2] * exp(x[1]) - 1)

0 commit comments

Comments
 (0)