1
- using NNlib, Zygote, Test
1
+ using NNlib, Zygote, ChainRulesCore, Test
2
2
using Zygote: ForwardDiff
3
3
4
4
ACTIVATION_FUNCTIONS =
@@ -14,14 +14,21 @@ ACTIVATION_FUNCTIONS =
14
14
@test @inferred (bias_act! (tanh, copy (x), false )) ≈ tanh .(x)
15
15
16
16
# Check that it does overwrite:
17
- x32 = rand (Float32, 3 , 4 )
18
- x32copy = copy (x32)
17
+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32)
19
18
@test @inferred (bias_act! (cbrt, x32, b)) ≈ cbrt .(x32copy .+ b)
20
- @test x32 ≈ cbrt .(x32copy .+ b)
21
- x32 = rand (Float32, 3 , 4 )
22
- x32copy = copy (x32)
19
+ @test x32 ≈ cbrt .(x32copy .+ b)
20
+
21
+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32) # without bias
23
22
@test @inferred (bias_act! (tanh, x32, false )) ≈ tanh .(x32copy)
24
- @test x32 ≈ tanh .(x32copy)
23
+ @test x32 ≈ tanh .(x32copy)
24
+
25
+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32) # now check gradient rule
26
+ y, back = rrule (Zygote. ZygoteRuleConfig (), bias_act!, relu, x32, b)
27
+ @test y ≈ x32 ≈ relu .(x32copy .+ b)
28
+
29
+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32) # without bias
30
+ y, back = rrule (Zygote. ZygoteRuleConfig (), bias_act!, relu, x32, false )
31
+ @test y ≈ x32 ≈ relu .(x32copy)
25
32
26
33
# Check that it doesn't try to overwrite non-float arrays:
27
34
xint = rand (- 3 : 3 , 3 , 4 )
@@ -78,7 +85,7 @@ ACTIVATION_FUNCTIONS =
78
85
g2 = ForwardDiff. gradient (x) do x
79
86
sum (abs2, Zygote. gradient (x -> sum (abs2, bias_act! (relu, copy (x), b)), x)[1 ])
80
87
end
81
- @test_broken gx ≈ Zygote. gradient (x) do x
88
+ @test_skip gx ≈ Zygote. gradient (x) do x # Here global variable b causes an error
82
89
sum (abs2,Zygote. gradient (x -> sum (abs2, bias_act! (relu, copy (x), b)), x)[1 ])
83
90
end
84
91
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
0 commit comments