diff --git a/docs/src/user/api.md b/docs/src/user/api.md index 25289ed7..fe19c434 100644 --- a/docs/src/user/api.md +++ b/docs/src/user/api.md @@ -9,6 +9,8 @@ CurrentModule = ForwardDiff ```@docs ForwardDiff.derivative ForwardDiff.derivative! +ForwardDiff.value_and_derivative +ForwardDiff.value_and_derivatives ``` ## Gradients of `f(x::AbstractArray)::Real` diff --git a/src/derivative.jl b/src/derivative.jl index b39e2a48..2042b836 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -75,6 +75,32 @@ end derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number. Perhaps you meant gradient(f, x)?")) derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input.")) +""" + ForwardDiff.value_and_derivative(f, x::Real) + +Return `f(x)` and `df/dx` evaluated at `x` in a single pass, assuming `f` is called as `f(x)`. + +This method assumes that `isa(f(x), Union{Real,AbstractArray})`. +""" +@inline function value_and_derivative(f::F, x::R) where {F,R<:Real} + T = typeof(Tag(f, R)) + ydual = f(Dual{T}(x, one(x))) + return value(T, ydual), extract_derivative(T, ydual) +end + +""" + ForwardDiff.value_and_derivatives(f, x::Real) + +Return `f(x)` and its first and second derivatives evaluated at `x` in a single pass, assuming `f` is called as `f(x)`. + +This method assumes that `isa(f(x), Union{Real,AbstractArray})`. +""" +@inline function value_and_derivatives(f::F, x::R) where {F,R<:Real} + T = typeof(Tag(f, typeof(x))) + ydual, ddual = value_and_derivative(f, Dual{T}(x, one(x))) + return value(T, ydual), value(T, ddual), extract_derivative(T, ddual) +end + ##################### # result extraction # ##################### diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index dfdd8ed2..d6697f3e 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -22,11 +22,22 @@ for f in DiffTests.NUMBER_TO_NUMBER_FUNCS v = f(x) d = ForwardDiff.derivative(f, x) @test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR) + d2 = ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x) out = DiffResults.DiffResult(zero(v), zero(v)) out = ForwardDiff.derivative!(out, f, x) @test isapprox(DiffResults.value(out), v) @test isapprox(DiffResults.derivative(out), d) + + out = ForwardDiff.value_and_derivative(f, x) + @test length(out) == 2 + @test isapprox(out[1], v) + @test isapprox(out[2], d) + + out = ForwardDiff.value_and_derivatives(f, x) + @test isapprox(out[1], v) + @test isapprox(out[2], d) + @test isapprox(out[3], d2) end for f in DiffTests.NUMBER_TO_ARRAY_FUNCS