Skip to content

Commit 0a75474

Browse files
Dhairya GandhiCarloLucibello
Dhairya Gandhi
authored andcommitted
test with regular layers
1 parent acbf243 commit 0a75474

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

src/layers/basic.jl

-10
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,13 @@ end
143143

144144
@functor Dense
145145

146-
<<<<<<< HEAD
147146
function (a::Dense)(x::AbstractVecOrMat)
148147
W, b, σ = a.weight, a.bias, a.σ
149148
return σ.(W*x .+ b)
150149
end
151150

152151
(a::Dense)(x::AbstractArray) =
153152
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
154-
=======
155-
function (a::Dense)(x::Union{AbstractVector, AbstractMatrix})
156-
W, b, σ = a.W, a.b, a.σ
157-
return σ.(W*x .+ b)
158-
end
159-
160-
(a::Dense)(x::AbstractArray) = reshape(a(mat(x)), :, size(x)[2:end]...)
161-
>>>>>>> 017acdf9 (extend to generic arrays; add cuda tests)
162153

163154
function Base.show(io::IO, l::Dense)
164155
print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1))
@@ -484,4 +475,3 @@ end
484475
function Base.show(io::IO, m::Embedding)
485476
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
486477
end
487-
>>>>>>> b22cd2dc (cl/embed)

test/cuda/layers.jl

+17-13
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ pixelshuffle = [PixelShuffle]
114114
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
115115
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
116116

117+
embedding = [Embedding]
118+
gpu_gradtest("Embedding", embedding, rand(1:10, 3), 10, 4)
119+
117120
@testset "function layers" begin
118121
x = rand(Float32, 3,3)
119122
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
@@ -258,18 +261,19 @@ end
258261
@test gs_cpu[pcpu] gs_gpu[pgpu]
259262
end
260263
end
264+
end
261265

262-
@testset "Embedding" begin
263-
vocab_size, embed_size = 10, 4
264-
m = Embedding(vocab_size, embed_size)
265-
x = rand(1:vocab_size, 3)
266-
y = m(x)
267-
m_g = m |> gpu
268-
x_g = x |> gpu
269-
y_g = m_g(x_g)
270-
@test collect(y_g) == y
271-
gs = gradient(() -> sum(tanh.(m(x))), params(m))
272-
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
273-
@test collect(gs_g[m_g.weight]) gs[m.weight]
274-
end
266+
@testset "Embedding" begin
267+
vocab_size, embed_size = 10, 4
268+
m = Embedding(vocab_size, embed_size)
269+
x = rand(1:vocab_size, 3)
270+
y = m(x)
271+
m_g = m |> gpu
272+
x_g = x |> gpu
273+
y_g = m_g(x_g)
274+
@test collect(y_g) == y
275+
gs = gradient(() -> sum(tanh.(m(x))), params(m))
276+
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
277+
@test collect(gs_g[m_g.weight]) gs[m.weight]
275278
end
279+

0 commit comments

Comments
 (0)