From 363a043e154fd0dbf07ed1d1ebdfd58c4be75792 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 07:36:52 +0100 Subject: [PATCH] cleanup --- test/ext_enzyme/enzyme.jl | 14 +++++++------- test/runtests.jl | 12 +++++++----- test/test_utils.jl | 3 +++ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index edff72f3fe..b3a1cf8f8c 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -17,8 +17,8 @@ using Enzyme: Enzyme, Duplicated, Const, Active (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - # (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), - # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), + (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), + (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), ] @@ -36,11 +36,11 @@ end end models_xs = [ - # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs diff --git a/test/runtests.jl b/test/runtests.jl index 3a111d0f04..8c375a3f22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,11 +35,6 @@ if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT using Enzyme: Enzyme end -if FLUX_TEST_REACTANT - Pkg.add("Reactant") - using Reactant: Reactant -end - include("test_utils.jl") # for test_gradients Random.seed!(0) @@ -185,6 +180,13 @@ end end if FLUX_TEST_REACTANT + ## This Pg.add has to be done after Pkg.add("CUDA") otherwise CUDA.jl + ## will not be functional and complain with: + # ┌ Error: CUDA.jl could not find an appropriate CUDA runtime to use. + # │ + # │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present. + Pkg.add("Reactant") + using Reactant: Reactant @testset "Reactant" begin include("ext_reactant/test_utils_reactant.jl") include("ext_reactant/reactant.jl") diff --git a/test/test_utils.jl b/test/test_utils.jl index b8f4fec717..d8e10f3bf6 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -37,6 +37,7 @@ end function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) fmapstructure_with_path(a, b) do kp, x, y + # @show kp if x isa AbstractArray @test x ≈ y rtol=rtol atol=atol elseif x isa Number @@ -66,6 +67,8 @@ function test_gradients( error("You should either compare numerical gradients methods or CPU vs GPU.") end + Flux.trainmode!(f) # for layers like BatchNorm + ## Let's make sure first that the forward pass works. l = loss(f, xs...) @test l isa Number