Skip to content

Commit 4d3944c

Browse files
add embedding layer
Co-authored-by: Kyle Daruwalla <[email protected]>
1 parent 9410677 commit 4d3944c

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed

docs/src/models/layers.md

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ SkipConnection
5858
Parallel
5959
Flux.Bilinear
6060
Flux.Diagonal
61+
Flux.Embedding
6162
```
6263

6364
## Normalisation & Regularisation

src/Flux.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
1212

1313
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
14-
RNN, LSTM, GRU,
14+
RNN, LSTM, GRU, Embedding,
1515
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1616
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
1717
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,

src/layers/basic.jl

+54-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on a given input.
88
`m[1:3](x)` will calculate the output of the first three layers.
99
1010
# Examples
11+
1112
```jldoctest
1213
julia> m = Chain(x -> x^2, x -> x+1);
1314
@@ -421,4 +422,56 @@ function Base.show(io::IO, m::Parallel)
421422
print(io, "Parallel(", m.connection, ", ")
422423
join(io, m.layers, ", ")
423424
print(io, ")")
424-
end
425+
end
426+
427+
"""
428+
Embedding(in, out; init=randn)
429+
430+
A lookup table that stores embeddings of dimension `out`
431+
for a vocabulary of size `in`.
432+
433+
This layers is often used to store word embeddings and retrieve them using indices.
434+
The input to the layer can be either a vector of indexes
435+
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
436+
437+
# Examples
438+
439+
```julia-repl
440+
julia> vocab_size, embed_size = 1000, 4;
441+
442+
julia> model = Embedding(vocab_size, embed_size)
443+
Embedding(1000, 4)
444+
445+
julia> vocab_idxs = [1, 722, 53, 220, 3]
446+
447+
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
448+
449+
julia> model(x)
450+
4×5 Matrix{Float32}:
451+
0.91139 0.670462 0.463217 0.670462 0.110932
452+
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
453+
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
454+
-0.497621 0.87595 -0.870251 0.87595 -0.772696
455+
```
456+
457+
julia> model(vocab_idxs) == model(x)
458+
true
459+
"""
460+
struct Embedding{W}
461+
weight::W
462+
end
463+
464+
@functor Embedding
465+
466+
function Embedding(in::Integer, out::Integer;
467+
init = (i...) -> randn(Float32, i...))
468+
return Embedding(init(out, in))
469+
end
470+
471+
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
472+
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
473+
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
474+
475+
function Base.show(io::IO, m::Embedding)
476+
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
477+
end

test/cuda/layers.jl

+15
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,18 @@ end
259259
end
260260
end
261261
end
262+
263+
@testset "Embedding" begin
264+
vocab_size, embed_size = 10, 4
265+
m = Embedding(vocab_size, embed_size)
266+
x = rand(1:vocab_size, 3)
267+
y = m(x)
268+
m_g = m |> gpu
269+
x_g = x |> gpu
270+
y_g = m_g(x_g)
271+
@test collect(y_g) == y
272+
gs = gradient(() -> sum(tanh.(m(x))), params(m))
273+
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
274+
@test collect(gs_g[m_g.weight]) gs[m.weight]
275+
end
276+

test/layers/basic.jl

+25
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,29 @@ import Flux: activations
191191
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
192192
end
193193
end
194+
195+
@testset "Embedding" begin
196+
vocab_size, embed_size = 10, 4
197+
m = Embedding(vocab_size, embed_size)
198+
@test size(m.weight) == (embed_size, vocab_size)
199+
200+
x = rand(1:vocab_size, 3)
201+
y = m(x)
202+
@test y isa Matrix{Float32}
203+
@test y m.weight[:,x]
204+
x2 = OneHotMatrix(x, vocab_size)
205+
y2 = m(x2)
206+
@test y2 isa Matrix{Float32}
207+
@test y2 y
208+
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
209+
210+
x = rand(1:vocab_size, 3, 4)
211+
y = m(x)
212+
@test y isa Array{Float32, 3}
213+
@test size(y) == (embed_size, 3, 4)
214+
215+
@test m(2) m.weight[:,2]
216+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
217+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
218+
end
194219
end

0 commit comments

Comments
 (0)