1- using NNlib, Zygote, Test
1+ using NNlib, Zygote, ChainRulesCore, Test
22using Zygote: ForwardDiff
33
44ACTIVATION_FUNCTIONS =
@@ -14,14 +14,21 @@ ACTIVATION_FUNCTIONS =
1414 @test @inferred (bias_act! (tanh, copy (x), false )) ≈ tanh .(x)
1515
1616 # Check that it does overwrite:
17- x32 = rand (Float32, 3 , 4 )
18- x32copy = copy (x32)
17+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32)
1918 @test @inferred (bias_act! (cbrt, x32, b)) ≈ cbrt .(x32copy .+ b)
20- @test x32 ≈ cbrt .(x32copy .+ b)
21- x32 = rand (Float32, 3 , 4 )
22- x32copy = copy (x32)
19+ @test x32 ≈ cbrt .(x32copy .+ b)
20+
21+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32) # without bias
2322 @test @inferred (bias_act! (tanh, x32, false )) ≈ tanh .(x32copy)
24- @test x32 ≈ tanh .(x32copy)
23+ @test x32 ≈ tanh .(x32copy)
24+
25+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32) # now check gradient rule
26+ y, back = rrule (Zygote. ZygoteRuleConfig (), bias_act!, relu, x32, b)
27+ @test y ≈ x32 ≈ relu .(x32copy .+ b)
28+
29+ x32 = rand (Float32, 3 , 4 ); x32copy = copy (x32) # without bias
30+ y, back = rrule (Zygote. ZygoteRuleConfig (), bias_act!, relu, x32, false )
31+ @test y ≈ x32 ≈ relu .(x32copy)
2532
2633 # Check that it doesn't try to overwrite non-float arrays:
2734 xint = rand (- 3 : 3 , 3 , 4 )
@@ -78,7 +85,7 @@ ACTIVATION_FUNCTIONS =
7885 g2 = ForwardDiff. gradient (x) do x
7986 sum (abs2, Zygote. gradient (x -> sum (abs2, bias_act! (relu, copy (x), b)), x)[1 ])
8087 end
81- @test_broken gx ≈ Zygote. gradient (x) do x
88+ @test_skip gx ≈ Zygote. gradient (x) do x # Here global variable b causes an error
8289 sum (abs2,Zygote. gradient (x -> sum (abs2, bias_act! (relu, copy (x), b)), x)[1 ])
8390 end
8491 # Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
0 commit comments