@@ -7,9 +7,33 @@ ACTIVATION_FUNCTIONS =
77@testset " bias_act!" begin
88 x = randn (3 ,4 )
99 b = randn (3 )
10- @test bias_act! (identity, copy (x), b) ≈ (x .+ b)
11- @test bias_act! (relu, copy (x), b) ≈ relu .(x .+ b)
12- @test bias_act! (tanh, copy (x), b) ≈ tanh .(x .+ b)
10+ @test @inferred (bias_act! (identity, x, false )) === x # pass-through
11+ @test @inferred (bias_act! (identity, copy (x), b)) ≈ (x .+ b)
12+ @test @inferred (bias_act! (relu, copy (x), b)) ≈ relu .(x .+ b)
13+ @test @inferred (bias_act! (tanh, copy (x), b)) ≈ tanh .(x .+ b)
14+ @test @inferred (bias_act! (tanh, copy (x), false )) ≈ tanh .(x)
15+
16+ # Check that it does overwrite:
17+ x32 = rand (Float32, 3 , 4 )
18+ x32copy = copy (x32)
19+ @test @inferred (bias_act! (cbrt, x32, b)) ≈ cbrt .(x32copy .+ b)
20+ @test x32 ≈ cbrt .(x32copy .+ b)
21+ x32 = rand (Float32, 3 , 4 )
22+ x32copy = copy (x32)
23+ @test @inferred (bias_act! (tanh, x32, false )) ≈ tanh .(x32copy)
24+ @test x32 ≈ tanh .(x32copy)
25+
26+ # Check that it doesn't try to overwrite non-float arrays:
27+ xint = rand (- 3 : 3 , 3 , 4 )
28+ bint = rand (- 2 : 2 , 3 )
29+ @test bias_act! (identity, copy (xint), bint) ≈ xint .+ bint
30+ @test bias_act! (tanh, copy (xint), bint) ≈ tanh .(xint .+ bint)
31+ @test bias_act! (tanh, copy (xint), false ) ≈ tanh .(xint)
32+
33+ # Reject bias===true so that Bool means one thing:
34+ @test_throws Exception bias_act! (identity, rand (3 ), true )
35+ @test_throws Exception bias_act! (cbrt, rand (3 ), true )
36+ @test_throws Exception bias_act! (cbrt, rand (1 : 3 , 3 ), true )
1337
1438 @testset " gradient with $fun " for fun in vcat ([identity, tanh, cbrt],
1539 ACTIVATION_FUNCTIONS,
@@ -21,9 +45,21 @@ ACTIVATION_FUNCTIONS =
2145 @test bias_act! (fun, copy (x), false ) ≈ fun .(x)
2246
2347 gx = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), b)), x)
48+ gxplus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), b)), x .+ eps ())
49+ gxminus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), b)), x .- eps ())
50+ if ! (gx ≈ gxplus ≈ gxminus)
51+ @warn " skipping gradient tests due to discontinuity" fun x b
52+ continue
53+ end
2454 @test gx ≈ Zygote. gradient (x -> sum (bias_act! (fun, copy (x), b)), x)[1 ]
2555
2656 gx2 = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), false )), x)
57+ gx2plus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), false )), x .- eps ())
58+ gx2minus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), false )), x .- eps ())
59+ if ! (gx2 ≈ gx2plus ≈ gx2minus)
60+ @warn " skipping gradient tests due to discontinuity" fun x
61+ continue
62+ end
2763 @test gx2 ≈ Zygote. gradient (x -> sum (bias_act! (fun, copy (x), false )), x)[1 ]
2864
2965 gb = ForwardDiff. gradient (b -> sum (bias_act! (fun, copy (x), b)), b)
0 commit comments