Skip to content

Commit dfd4549

Browse files
authoredOct 17, 2022
Simplify Embedding (#2084)
1 parent 74bd04b commit dfd4549

File tree

2 files changed

+47
-28
lines changed

2 files changed

+47
-28
lines changed
 

‎src/layers/basic.jl

+32-23
Original file line numberDiff line numberDiff line change
@@ -644,35 +644,45 @@ function Base.show(io::IO, m::PairwiseFusion)
644644
end
645645

646646
"""
647-
Embedding(in => out; init=randn)
647+
Embedding(in => out; init=randn32)
648648
649649
A lookup table that stores embeddings of dimension `out`
650-
for a vocabulary of size `in`.
650+
for a vocabulary of size `in`, as a trainable matrix.
651651
652652
This layer is often used to store word embeddings and retrieve them using indices.
653-
The input to the layer can be either a vector of indexes
654-
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
653+
The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654+
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655+
656+
For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657+
For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.
655658
656659
# Examples
657660
```jldoctest
658-
julia> vocab_size, embed_size = 1000, 4;
659-
660-
julia> model = Flux.Embedding(vocab_size => embed_size)
661-
Embedding(1000 => 4) # 4_000 parameters
662-
663-
julia> vocab_idxs = [1, 722, 53, 220, 3];
664-
665-
julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x)
666-
"1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool"
667-
668-
julia> model(x) |> summary
669-
"4×5 Matrix{Float32}"
670-
671-
julia> model(vocab_idxs) == model(x)
661+
julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))
662+
Embedding(26 => 4) # 104 parameters
663+
664+
julia> emb(2) # one column of e.weight (here not random!)
665+
4-element Vector{Float32}:
666+
0.0
667+
22.0
668+
0.0
669+
0.0
670+
671+
julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26
672+
4×7 Matrix{Float32}:
673+
0.0 22.0 0.0 0.0 0.0 0.0 0.0
674+
0.0 0.0 0.0 0.0 0.0 0.0 0.0
675+
22.0 0.0 0.0 0.0 0.0 0.0 0.0
676+
0.0 0.0 0.0 0.0 22.0 0.0 0.0
677+
678+
julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
672679
true
680+
681+
julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
682+
(4, 10, 1, 12)
673683
```
674684
"""
675-
struct Embedding{W}
685+
struct Embedding{W<:AbstractMatrix}
676686
weight::W
677687
end
678688

@@ -684,10 +694,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
684694
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
685695
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
686696

687-
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
688-
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
689-
return m(onecold(x))
690-
end
697+
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
698+
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
699+
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
691700

692701
function Base.show(io::IO, m::Embedding)
693702
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")

‎test/layers/basic.jl

+15-5
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,17 @@ import Flux: activations
289289

290290
@testset "Embedding" begin
291291
vocab_size, embed_size = 10, 4
292-
m = Flux.Embedding(vocab_size, embed_size)
292+
m = Embedding(vocab_size, embed_size)
293293
@test size(m.weight) == (embed_size, vocab_size)
294+
295+
# one index
296+
@test m(1) isa Vector{Float32}
297+
@test m(2) m.weight[:,2]
298+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
299+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
300+
@test m(4) m((1:vocab_size) .== 4)
294301

302+
# a batch of indices
295303
x = rand(1:vocab_size, 3)
296304
y = m(x)
297305
@test y isa Matrix{Float32}
@@ -301,15 +309,17 @@ import Flux: activations
301309
@test y2 isa Matrix{Float32}
302310
@test y2 y
303311
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
312+
@test y m(x' .== (1:vocab_size))
304313

314+
# more dimensions via reshape
305315
x = rand(1:vocab_size, 3, 4)
306316
y = m(x)
307317
@test y isa Array{Float32, 3}
308318
@test size(y) == (embed_size, 3, 4)
309-
310-
@test m(2) m.weight[:,2]
311-
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
312-
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
319+
x3 = onehotbatch(x, 1:1:vocab_size)
320+
@test size(x3) == (vocab_size, 3, 4)
321+
y3 = m(x3)
322+
@test size(y3) == (embed_size, 3, 4)
313323
end
314324
end
315325

0 commit comments

Comments
 (0)
Please sign in to comment.