Skip to content

Commit adb3ae6

Browse files
committed
Add ForwardDiff-specific methods of second-derivative functions
1 parent 8f95f02 commit adb3ae6

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

ext/AbstractDifferentiationForwardDiffExt.jl

+18
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,38 @@ function AD.hessian(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6161
return (ForwardDiff.hessian(f, x, cfg),)
6262
end
6363

64+
function AD.value_and_derivative(::AD.ForwardDiffBackend, f, x::Real)
65+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
66+
ydual = f(ForwardDiff.Dual{T}(x, one(x)))
67+
return ForwardDiff.value(T, ydual), (ForwardDiff.extract_derivative(T, ydual),)
68+
end
69+
6470
function AD.value_and_gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6571
result = DiffResults.GradientResult(x)
6672
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
6773
ForwardDiff.gradient!(result, f, x, cfg)
6874
return DiffResults.value(result), (DiffResults.derivative(result),)
6975
end
7076

77+
function AD.value_and_second_derivative(ba::AD.ForwardDiffBackend, f, x::Real)
78+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
79+
ydual, ddual = AD.value_and_derivative(ba, f, ForwardDiff.Dual{T}(x, one(x)))
80+
return value(T, ydual), (extract_derivative(T, ddual[1]),)
81+
end
82+
7183
function AD.value_and_hessian(ba::AD.ForwardDiffBackend, f, x)
7284
result = DiffResults.HessianResult(x)
7385
cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x))
7486
ForwardDiff.hessian!(result, f, x, cfg)
7587
return DiffResults.value(result), (DiffResults.hessian(result),)
7688
end
7789

90+
function AD.value_and_derivatives(ba::AD.ForwardDiffBackend, f, x::Real)
91+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
92+
ydual, ddual = AD.value_and_derivative(ba, f, ForwardDiff.Dual{T}(x, one(x)))
93+
return ForwardDiff.value(T, ydual), (ForwardDiff.value(T, ddual[1]),), (ForwardDiff.extract_derivative(T, ddual[1]),)
94+
end
95+
7896
@inline step_toward(x::Number, v::Number, h) = x + h * v
7997
# support arrays and tuples
8098
@noinline step_toward(x, v, h) = x .+ h .* v

0 commit comments

Comments
 (0)