Skip to content

Commit 0a35a80

Browse files
authored
Merge pull request #718 from JuliaDiff/dw/backport_nanmath
Backport #717 to 0.10
2 parents 228d40d + eb5ddeb commit 0a35a80

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "0.10.37"
3+
version = "0.10.38"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"

src/dual.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ end
519519
# exponentiation #
520520
#----------------#
521521

522-
for f in (:(Base.:^), :(NaNMath.pow))
522+
for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log)))
523523
@eval begin
524524
@define_binary_dual_op(
525525
$f,
@@ -532,7 +532,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
532532
elseif iszero(vx) && vy > 0
533533
logval = zero(vx)
534534
else
535-
logval = expv * log(vx)
535+
logval = expv * ($log)(vx)
536536
end
537537
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
538538
return Dual{Txy}(expv, new_partials)
@@ -550,7 +550,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
550550
begin
551551
v = value(y)
552552
expv = ($f)(x, v)
553-
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
553+
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x)
554554
return Dual{Ty}(expv, deriv * partials(y))
555555
end
556556
)

test/DerivativeTest.jl

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DerivativeTest
22

33
import Calculus
4+
import NaNMath
45

56
using Test
67
using Random
@@ -93,6 +94,17 @@ end
9394
@test (x -> ForwardDiff.derivative(y -> x^y, 1.5))(0.0) === 0.0
9495
end
9596

97+
@testset "exponentiation with NaNMath" begin
98+
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(NaN, x), 1.0))
99+
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,NaN), 1.0))
100+
@test !isnan(ForwardDiff.derivative(x -> NaNMath.pow(1.0, x),1.0))
101+
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,0.5), -1.0))
102+
103+
@test isnan(ForwardDiff.derivative(x -> x^NaN, 2.0))
104+
@test ForwardDiff.derivative(x -> x^2.0,2.0) == 4.0
105+
@test_throws DomainError ForwardDiff.derivative(x -> x^0.5, -1.0)
106+
end
107+
96108
@testset "dimension error for derivative" begin
97109
@test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3))
98110
end

test/GradientTest.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module GradientTest
22

33
import Calculus
4+
import NaNMath
45

56
using Test
67
using ForwardDiff
@@ -168,4 +169,14 @@ end
168169
@test ForwardDiff.gradient(f, [0.3, 25.0]) == [3486.0, 0.0]
169170
end
170171

172+
@testset "gradient for exponential with NaNMath" begin
173+
@test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1])
174+
@test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0]
175+
@test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1])
176+
177+
@test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1])
178+
@test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0]
179+
@test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5])
180+
end
181+
171182
end # module

0 commit comments

Comments
 (0)