diff --git a/src/dual.jl b/src/dual.jl index 7e8ec110..b29fe475 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -238,13 +238,30 @@ function unary_dual_definition(M, f) val = $Mf(x) deriv = $(DiffRules.diffrule(M, f, :x)) end) - return quote + real_diff_exp = quote @inline function $M.$f(d::$FD.Dual{T}) where T x = $FD.value(d) $work return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d)) end end + if (M, f) in ((:Base, :abs), (:Base, :abs2), (:Base, :inv)) + return real_diff_exp + else + complex_diff_expr = quote + @inline function $M.$f(d::Complex{<:$FD.Dual{T}}) where{T} + x = complex($FD.value(real(d)), $FD.value(imag(d))) + $work + re_deriv, im_deriv = reim(deriv) + re_partials = $FD.partials(real(d)) + im_partials = $FD.partials(imag(d)) + re_retval = $FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = $FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) + end + end + return Expr(:block, real_diff_exp, complex_diff_expr) + end end function binary_dual_definition(M, f) @@ -709,6 +726,22 @@ end Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body ) +# inv(Complex) # +#--------# + +function Base.inv(d::Complex{<:Dual{T}}) where{T} + FD = ForwardDiff + x = complex(FD.value(real(d)), FD.value(imag(d))) + val = inv(x) + deriv = - val * val + re_deriv, im_deriv = reim(deriv) + re_partials = FD.partials(real(d)) + im_partials = FD.partials(imag(d)) + re_retval = FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) +end + # sin/cos # #--------# @@ -727,6 +760,32 @@ end return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d))) end +function Base.sin(d::Complex{<:Dual{T}}) where{T} + FD = ForwardDiff + x = complex(FD.value(real(d)), FD.value(imag(d))) + val = sin(x) + deriv = cos(x) + re_deriv, im_deriv = reim(deriv) + re_partials = FD.partials(real(d)) + im_partials = FD.partials(imag(d)) + re_retval = FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) +end + +function Base.cos(d::Complex{<:Dual{T}}) where{T} + FD = ForwardDiff + x = complex(FD.value(real(d)), FD.value(imag(d))) + val = cos(x) + deriv = -sin(x) + re_deriv, im_deriv = reim(deriv) + re_partials = FD.partials(real(d)) + im_partials = FD.partials(imag(d)) + re_retval = FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) +end + # sincospi # #----------# diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index 4de1a6de..a105f084 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -113,4 +113,11 @@ end @test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im) end +@testset "analytic functions" begin + dexp(x) = ForwardDiff.derivative(y -> exp(complex(0, y)), x) + @test ForwardDiff.derivative(dexp, 0.0) ≈ -1 + @test ForwardDiff.derivative(x -> exp(1im*x), 0.7) ≈ im * cis(0.7) + @test ForwardDiff.derivative(x -> sqrt(im + (1+im) * x), 1.23) ≈ (1+im) / (2 * sqrt(im + (1+im)*1.23)) +end + end # module