Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 6, 2025
1 parent 5c3d72d commit 363a043
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
14 changes: 7 additions & 7 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ using Enzyme: Enzyme, Duplicated, Const, Active
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
# (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
# (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
(first LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
(first MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"),
]

Expand All @@ -36,11 +36,11 @@ end
end

models_xs = [
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
# (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
Expand Down
12 changes: 7 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT
using Enzyme: Enzyme
end

if FLUX_TEST_REACTANT
Pkg.add("Reactant")
using Reactant: Reactant
end

include("test_utils.jl") # for test_gradients

Random.seed!(0)
Expand Down Expand Up @@ -185,6 +180,13 @@ end
end

if FLUX_TEST_REACTANT
## This Pg.add has to be done after Pkg.add("CUDA") otherwise CUDA.jl
## will not be functional and complain with:
# ┌ Error: CUDA.jl could not find an appropriate CUDA runtime to use.
#
# │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present.
Pkg.add("Reactant")
using Reactant: Reactant
@testset "Reactant" begin
include("ext_reactant/test_utils_reactant.jl")
include("ext_reactant/reactant.jl")
Expand Down
3 changes: 3 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
# @show kp
if x isa AbstractArray
@test x y rtol=rtol atol=atol
elseif x isa Number
Expand Down Expand Up @@ -66,6 +67,8 @@ function test_gradients(
error("You should either compare numerical gradients methods or CPU vs GPU.")
end

Flux.trainmode!(f) # for layers like BatchNorm

## Let's make sure first that the forward pass works.
l = loss(f, xs...)
@test l isa Number
Expand Down

0 comments on commit 363a043

Please sign in to comment.