Skip to content

Commit 7b04b15

Browse files
committed
more tests
1 parent 791531a commit 7b04b15

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

Diff for: test/bias_act.jl

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib, Zygote, Test
1+
using NNlib, Zygote, ChainRulesCore, Test
22
using Zygote: ForwardDiff
33

44
ACTIVATION_FUNCTIONS =
@@ -14,14 +14,21 @@ ACTIVATION_FUNCTIONS =
1414
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)
1515

1616
# Check that it does overwrite:
17-
x32 = rand(Float32, 3, 4)
18-
x32copy = copy(x32)
17+
x32 = rand(Float32, 3, 4); x32copy = copy(x32)
1918
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
20-
@test x32 cbrt.(x32copy .+ b)
21-
x32 = rand(Float32, 3, 4)
22-
x32copy = copy(x32)
19+
@test x32 cbrt.(x32copy .+ b)
20+
21+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
2322
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
24-
@test x32 tanh.(x32copy)
23+
@test x32 tanh.(x32copy)
24+
25+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule
26+
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)
27+
@test y x32 relu.(x32copy .+ b)
28+
29+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
30+
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)
31+
@test y x32 relu.(x32copy)
2532

2633
# Check that it doesn't try to overwrite non-float arrays:
2734
xint = rand(-3:3, 3, 4)
@@ -78,7 +85,7 @@ ACTIVATION_FUNCTIONS =
7885
g2 = ForwardDiff.gradient(x) do x
7986
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
8087
end
81-
@test_broken gx Zygote.gradient(x) do x
88+
@test_skip gx Zygote.gradient(x) do x # Here global variable b causes an error
8289
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
8390
end
8491
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).

Diff for: test/runtests.jl

-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ end
101101
@info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them"
102102
end
103103

104-
<<<<<<< HEAD
105104
if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true"
106105
import Pkg
107106
test_info = Pkg.project()

0 commit comments

Comments
 (0)