Skip to content

Commit ac1d573

Browse files
gerlerodevmotiongdalle
authored
Add second-order derivative functions (#122)
* Add second-derivative functions to interface * Add ForwardDiff-specific methods of second-derivative functions * Add tests for second derivatives * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Reformat code * Test RuleConfig backend with second derivatives * Rename second_derivative -> secondderivative * Rename value_and_derivatives -> value_and_derivative_and_second_derivative * Rename secondderivative to second_derivative * Update AbstractDifferentiationForwardDiffExt.jl * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Fix errors --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: Guillaume Dalle <[email protected]>
1 parent afec712 commit ac1d573

9 files changed

+138
-3
lines changed

docs/src/implementer_guide.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ They are just listed here to help readers figure out the code structure:
2929
- `derivative` calls `jacobian`
3030
- `gradient` calls `jacobian`
3131
- `hessian` calls `jacobian` and `gradient`
32+
- `second_derivative` calls `derivative`
3233
- `value_and_jacobian` calls `jacobian`
3334
- `value_and_derivative` calls `value_and_jacobian`
3435
- `value_and_gradient` calls `value_and_jacobian`
3536
- `value_and_hessian` calls `jacobian` and `gradient`
37+
- `value_and_second_derivative` calls `second_derivative`
3638
- `value_gradient_and_hessian` calls `value_and_jacobian` and `gradient`
39+
- `value_derivative_and_second_derivative` calls `value_and_derivative` and `second_derivative`
3740
- `pushforward_function` calls `jacobian`
3841
- `value_and_pushforward_function` calls `pushforward_function`
3942
- `pullback_function` calls `value_and_pullback_function`

docs/src/user_guide.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,27 @@ AbstractDifferentiation.HigherOrderBackend
5353

5454
## Derivatives
5555

56-
The following list of functions can be used to request the derivative, gradient, Jacobian or Hessian without the function value.
56+
The following list of functions can be used to request the derivative, gradient, Jacobian, second derivative or Hessian without the function value.
5757

5858
```@docs
5959
AbstractDifferentiation.derivative
6060
AbstractDifferentiation.gradient
6161
AbstractDifferentiation.jacobian
62+
AbstractDifferentiation.second_derivative
6263
AbstractDifferentiation.hessian
6364
```
6465

6566
## Value and derivatives
6667

67-
The following list of functions can be used to request the function value along with its derivative, gradient, Jacobian or Hessian. You can also request the function value, its gradient and Hessian for single-input functions.
68+
The following list of functions can be used to request the function value along with its derivative, gradient, Jacobian, second derivative, or Hessian. You can also request the function value, its derivative (or its gradient) and its second derivative (or Hessian) for single-input functions.
6869

6970
```@docs
7071
AbstractDifferentiation.value_and_derivative
7172
AbstractDifferentiation.value_and_gradient
7273
AbstractDifferentiation.value_and_jacobian
74+
AbstractDifferentiation.value_and_second_derivative
7375
AbstractDifferentiation.value_and_hessian
76+
AbstractDifferentiation.value_derivative_and_second_derivative
7477
AbstractDifferentiation.value_gradient_and_hessian
7578
```
7679

ext/AbstractDifferentiationForwardDiffExt.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,47 @@ function AD.hessian(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6161
return (ForwardDiff.hessian(f, x, cfg),)
6262
end
6363

64+
function AD.value_and_derivative(::AD.ForwardDiffBackend, f, x::Real)
65+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
66+
ydual = f(ForwardDiff.Dual{T}(x, one(x)))
67+
return ForwardDiff.value(T, ydual), (ForwardDiff.partials(T, ydual, 1),)
68+
end
69+
6470
function AD.value_and_gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6571
result = DiffResults.GradientResult(x)
6672
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
6773
ForwardDiff.gradient!(result, f, x, cfg)
6874
return DiffResults.value(result), (DiffResults.derivative(result),)
6975
end
7076

77+
function AD.value_and_second_derivative(ba::AD.ForwardDiffBackend, f, x::Real)
78+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
79+
xdual = ForwardDiff.Dual{T}(x, one(x))
80+
T2 = typeof(ForwardDiff.Tag(f, typeof(xdual)))
81+
ydual = f(ForwardDiff.Dual{T2}(xdual, one(xdual)))
82+
v = ForwardDiff.value(T, ForwardDiff.value(T2, ydual))
83+
d2 = ForwardDiff.partials(T, ForwardDiff.partials(T2, ydual, 1), 1)
84+
return v, (d2,)
85+
end
86+
7187
function AD.value_and_hessian(ba::AD.ForwardDiffBackend, f, x)
7288
result = DiffResults.HessianResult(x)
7389
cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x))
7490
ForwardDiff.hessian!(result, f, x, cfg)
7591
return DiffResults.value(result), (DiffResults.hessian(result),)
7692
end
7793

94+
function AD.value_derivative_and_second_derivative(ba::AD.ForwardDiffBackend, f, x::Real)
95+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
96+
xdual = ForwardDiff.Dual{T}(x, one(x))
97+
T2 = typeof(ForwardDiff.Tag(f, typeof(xdual)))
98+
ydual = f(ForwardDiff.Dual{T2}(xdual, one(xdual)))
99+
v = ForwardDiff.value(T, ForwardDiff.value(T2, ydual))
100+
d = ForwardDiff.partials(T, ForwardDiff.value(T2, ydual), 1)
101+
d2 = ForwardDiff.partials(T, ForwardDiff.partials(T2, ydual, 1), 1)
102+
return v, (d,), (d2,)
103+
end
104+
78105
@inline step_toward(x::Number, v::Number, h) = x + h * v
79106
# support arrays and tuples
80107
@noinline step_toward(x, v, h) = x .+ h .* v

src/AbstractDifferentiation.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,24 @@ function jacobian(ab::HigherOrderBackend, f, xs...)
8585
return jacobian(lowest(ab), f, xs...)
8686
end
8787

88+
"""
89+
AD.second_derivative(ab::AD.AbstractBackend, f, x)
90+
91+
Compute the second derivative of `f` with respect to the input `x` using the backend `ab`.
92+
93+
The function returns a single value because `second_derivative` currently only supports a single input.
94+
"""
95+
function second_derivative(ab::AbstractBackend, f, x)
96+
if x isa Tuple
97+
# only support computation of second derivative for functions with single input argument
98+
x = only(x)
99+
end
100+
return derivative(second_lowest(ab), x -> begin
101+
d = derivative(lowest(ab), f, x)
102+
return d[1] # derivative returns a tuple
103+
end, x)
104+
end
105+
88106
"""
89107
AD.hessian(ab::AD.AbstractBackend, f, x)
90108
@@ -139,12 +157,23 @@ function value_and_jacobian(ab::AbstractBackend, f, xs...)
139157
return value, jacs
140158
end
141159

160+
"""
161+
AD.value_and_second_derivative(ab::AD.AbstractBackend, f, x)
162+
163+
Return the tuple `(v, d2)` of the function value `v = f(x)` and the second derivative `d2 = AD.second_derivative(ab, f, x)`.
164+
165+
See also [`AbstractDifferentiation.second_derivative`](@ref)
166+
"""
167+
function value_and_second_derivative(ab::AbstractBackend, f, x)
168+
return f(x), second_derivative(ab, f, x)
169+
end
170+
142171
"""
143172
AD.value_and_hessian(ab::AD.AbstractBackend, f, x)
144173
145174
Return the tuple `(v, H)` of the function value `v = f(x)` and the Hessian `H = AD.hessian(ab, f, x)`.
146175
147-
See also [`AbstractDifferentiation.hessian`](@ref).
176+
See also [`AbstractDifferentiation.hessian`](@ref).
148177
"""
149178
function value_and_hessian(ab::AbstractBackend, f, x)
150179
if x isa Tuple
@@ -161,6 +190,28 @@ function value_and_hessian(ab::AbstractBackend, f, x)
161190
return value, hess
162191
end
163192

193+
"""
194+
AD.value_derivative_and_second_derivative(ab::AD.AbstractBackend, f, x)
195+
196+
Return the tuple `(v, d, d2)` of the function value `v = f(x)`, the first derivative `d = AD.derivative(ab, f, x)`, and the second derivative `d2 = AD.second_derivative(ab, f, x)`.
197+
"""
198+
function value_derivative_and_second_derivative(ab::AbstractBackend, f, x)
199+
if x isa Tuple
200+
# only support computation of Hessian for functions with single input argument
201+
x = only(x)
202+
end
203+
204+
value = f(x)
205+
deriv, secondderiv = value_and_derivative(
206+
second_lowest(ab), _x -> begin
207+
d = derivative(lowest(ab), f, _x)
208+
return d[1] # derivative returns a tuple
209+
end, x
210+
)
211+
212+
return value, (deriv,), secondderiv
213+
end
214+
164215
"""
165216
AD.value_gradient_and_hessian(ab::AD.AbstractBackend, f, x)
166217

test/finitedifferences.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using FiniteDifferences
2121
@testset "Jacobian" begin
2222
test_jacobians(backend)
2323
end
24+
@testset "Second derivative" begin
25+
test_second_derivatives(backend)
26+
end
2427
@testset "Hessian" begin
2528
test_hessians(backend)
2629
end

test/forwarddiff.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ using ForwardDiff
1919
@testset "Jacobian" begin
2020
test_jacobians(backend)
2121
end
22+
@testset "Second derivative" begin
23+
test_second_derivatives(backend)
24+
end
2225
@testset "Hessian" begin
2326
test_hessians(backend)
2427
end

test/reversediff.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ using ReverseDiff
1414
@testset "Jacobian" begin
1515
test_jacobians(backend)
1616
end
17+
@testset "Second derivative" begin
18+
test_second_derivatives(backend)
19+
end
1720
@testset "Hessian" begin
1821
test_hessians(backend)
1922
end

test/ruleconfig.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using Zygote
2121
@testset "j′vp" begin
2222
test_j′vp(backend)
2323
end
24+
@testset "Second derivative" begin
25+
test_second_derivatives(backend)
26+
end
2427
@testset "Lazy Derivative" begin
2528
test_lazy_derivatives(backend)
2629
end

test/test_utils.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Random.seed!(1234)
66
fder(x, y) = exp(y) * x + y * log(x)
77
dfderdx(x, y) = exp(y) + y * 1 / x
88
dfderdy(x, y) = exp(y) * x + log(x)
9+
dfderdxdx(x, y) = -y / x^2
910

1011
fgrad(x, y) = prod(x) + sum(y ./ (1:length(y)))
1112
dfgraddx(x, y) = prod(x) ./ x
@@ -143,6 +144,44 @@ function test_jacobians(backend; multiple_inputs=true, test_types=true)
143144
@test yvec == yvec2
144145
end
145146

147+
function test_second_derivatives(backend; test_types=true)
148+
# explicit test that AbstractDifferentiation throws an error
149+
# don't support tuple of second derivatives
150+
@test_throws ArgumentError AD.second_derivative(
151+
backend, x -> fder(x, yscalar), (xscalar, yscalar)
152+
)
153+
@test_throws MethodError AD.second_derivative(
154+
backend, x -> fder(x, yscalar), xscalar, yscalar
155+
)
156+
157+
# test if single input (no tuple works)
158+
dder1 = AD.second_derivative(backend, x -> fder(x, yscalar), xscalar)
159+
if test_types
160+
@test only(dder1) isa Float64
161+
end
162+
@test dfderdxdx(xscalar, yscalar) only(dder1) atol = 1e-8
163+
valscalar, dder2 = AD.value_and_second_derivative(
164+
backend, x -> fder(x, yscalar), xscalar
165+
)
166+
if test_types
167+
@test valscalar isa Float64
168+
@test only(dder2) isa Float64
169+
end
170+
@test valscalar == fder(xscalar, yscalar)
171+
@test dder2 == dder1
172+
valscalar, der, dder3 = AD.value_derivative_and_second_derivative(
173+
backend, x -> fder(x, yscalar), xscalar
174+
)
175+
if test_types
176+
@test valscalar isa Float64
177+
@test only(der) isa Float64
178+
@test only(dder3) isa Float64
179+
end
180+
@test valscalar == fder(xscalar, yscalar)
181+
@test der == AD.derivative(backend, x -> fder(x, yscalar), xscalar)
182+
@test dder3 == dder1
183+
end
184+
146185
function test_hessians(backend; multiple_inputs=false, test_types=true)
147186
if multiple_inputs
148187
# ... but

0 commit comments

Comments
 (0)