Skip to content

Commit a9618af

Browse files
use gather with onehot input
1 parent e54440b commit a9618af

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

src/layers/basic.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,16 @@ end
473473

474474
Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
475475

476-
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
477-
(m::Embedding)(x::Integer) = m([x])
476+
477+
(m::Embedding)(x::Integer) = m.weight[:, x]
478478
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
479479
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
480480

481+
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
482+
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
483+
return m(onecold(x))
484+
end
485+
481486
function Base.show(io::IO, m::Embedding)
482487
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
483488
end

test/cuda/layers.jl

+15-7
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,21 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
5151
# test
5252
if test_cpu
5353
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
54-
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
54+
if isnothing(xg_cpu)
55+
@test isnothing(xg_gpu)
56+
else
57+
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
58+
end
5559
end
5660
@test gs_gpu isa Flux.Zygote.Grads
5761
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
58-
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
59-
if test_cpu
60-
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
62+
if isnothing(xg_cpu)
63+
@test isnothing(xg_gpu)
64+
else
65+
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
66+
if test_cpu
67+
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
68+
end
6169
end
6270
end
6371
end
@@ -114,14 +122,14 @@ pixelshuffle = [PixelShuffle]
114122
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
115123
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
116124

117-
embedding = [Embedding]
125+
embedding = [Flux.Embedding]
118126
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
119127
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
120128
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
121129
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
122130
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
123131
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
124-
gpu_gradtest("Embedding OneHotMatrix repeated indices", OneHotMatrix([1,2,2], 5), 5, 2)
132+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
125133

126134
@testset "function layers" begin
127135
x = rand(Float32, 3,3)
@@ -144,7 +152,7 @@ end
144152
end
145153

146154
@testset "Dense with Zeros bias" begin
147-
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
155+
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
148156
ip = zeros(Float32, 3, 7) |> gpu
149157

150158
@test sum(l(ip)) 0.f0

0 commit comments

Comments
 (0)