@@ -51,13 +51,21 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
51
51
# test
52
52
if test_cpu
53
53
@test y_gpu ≈ y_cpu rtol= 1f-3 atol= 1f-3
54
- @test Array (xg_gpu) ≈ xg_cpu rtol= 1f-3 atol= 1f-3
54
+ if isnothing (xg_cpu)
55
+ @test isnothing (xg_gpu)
56
+ else
57
+ @test Array (xg_gpu) ≈ xg_cpu rtol= 1f-3 atol= 1f-3
58
+ end
55
59
end
56
60
@test gs_gpu isa Flux. Zygote. Grads
57
61
for (p_cpu, p_gpu) in zip (ps_cpu, ps_gpu)
58
- @test gs_gpu[p_gpu] isa Flux. CUDA. CuArray
59
- if test_cpu
60
- @test Array (gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol= 1f-3 atol= 1f-3
62
+ if isnothing (xg_cpu)
63
+ @test isnothing (xg_gpu)
64
+ else
65
+ @test gs_gpu[p_gpu] isa Flux. CUDA. CuArray
66
+ if test_cpu
67
+ @test Array (gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol= 1f-3 atol= 1f-3
68
+ end
61
69
end
62
70
end
63
71
end
@@ -114,14 +122,14 @@ pixelshuffle = [PixelShuffle]
114
122
gpu_gradtest (" PixelShuffle 2d" , pixelshuffle, rand (Float32, 3 , 4 , 18 , 3 ), 3 )
115
123
gpu_gradtest (" PixelShuffle 1d" , pixelshuffle, rand (Float32, 3 , 18 , 3 ), 3 )
116
124
117
- embedding = [Embedding]
125
+ embedding = [Flux . Embedding]
118
126
gpu_gradtest (" Embedding" , embedding, [1 ,3 ,5 ], 5 , 2 )
119
127
gpu_gradtest (" Embedding repeated indices" , embedding, [1 ,3 ,5 ,3 ], 5 , 2 )
120
128
gpu_gradtest (" Embedding integer index" , embedding, 1 , 5 , 2 )
121
129
gpu_gradtest (" Embedding 2d index" , embedding, [1 2 ; 3 4 ], 5 , 2 )
122
130
gpu_gradtest (" Embedding OneHotVec index" , embedding, OneHotVector (1 , 5 ), 5 , 2 )
123
131
gpu_gradtest (" Embedding OneHotMatrix index" , embedding, OneHotMatrix ([1 ,2 ,3 ], 5 ), 5 , 2 )
124
- gpu_gradtest (" Embedding OneHotMatrix repeated indices" , OneHotMatrix ([1 ,2 ,2 ], 5 ), 5 , 2 )
132
+ gpu_gradtest (" Embedding OneHotMatrix repeated indices" , embedding, OneHotMatrix ([1 ,2 ,2 ], 5 ), 5 , 2 )
125
133
126
134
@testset " function layers" begin
127
135
x = rand (Float32, 3 ,3 )
144
152
end
145
153
146
154
@testset " Dense with Zeros bias" begin
147
- l = Dense (ones (Float32, 4 ,3 ), Flux. Zeros ()) |> gpu
155
+ l = Dense (ones (Float32, 4 , 3 ), Flux. Zeros ()) |> gpu
148
156
ip = zeros (Float32, 3 , 7 ) |> gpu
149
157
150
158
@test sum (l (ip)) ≈ 0.f0
0 commit comments