@@ -644,35 +644,45 @@ function Base.show(io::IO, m::PairwiseFusion)
644
644
end
645
645
646
646
"""
647
- Embedding(in => out; init=randn )
647
+ Embedding(in => out; init=randn32 )
648
648
649
649
A lookup table that stores embeddings of dimension `out`
650
- for a vocabulary of size `in`.
650
+ for a vocabulary of size `in`, as a trainable matrix .
651
651
652
652
This layer is often used to store word embeddings and retrieve them using indices.
653
- The input to the layer can be either a vector of indexes
654
- or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
653
+ The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654
+ or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655
+
656
+ For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657
+ For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.
655
658
656
659
# Examples
657
660
```jldoctest
658
- julia> vocab_size, embed_size = 1000, 4;
659
-
660
- julia> model = Flux.Embedding(vocab_size => embed_size)
661
- Embedding(1000 => 4) # 4_000 parameters
662
-
663
- julia> vocab_idxs = [1, 722, 53, 220, 3];
664
-
665
- julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x)
666
- "1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool"
667
-
668
- julia> model(x) |> summary
669
- "4×5 Matrix{Float32}"
670
-
671
- julia> model(vocab_idxs) == model(x)
661
+ julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))
662
+ Embedding(26 => 4) # 104 parameters
663
+
664
+ julia> emb(2) # one column of e.weight (here not random!)
665
+ 4-element Vector{Float32}:
666
+ 0.0
667
+ 22.0
668
+ 0.0
669
+ 0.0
670
+
671
+ julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26
672
+ 4×7 Matrix{Float32}:
673
+ 0.0 22.0 0.0 0.0 0.0 0.0 0.0
674
+ 0.0 0.0 0.0 0.0 0.0 0.0 0.0
675
+ 22.0 0.0 0.0 0.0 0.0 0.0 0.0
676
+ 0.0 0.0 0.0 0.0 22.0 0.0 0.0
677
+
678
+ julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
672
679
true
680
+
681
+ julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
682
+ (4, 10, 1, 12)
673
683
```
674
684
"""
675
- struct Embedding{W}
685
+ struct Embedding{W<: AbstractMatrix }
676
686
weight:: W
677
687
end
678
688
@@ -684,10 +694,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
684
694
(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
685
695
(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
686
696
687
- function (m:: Embedding )(x:: Union{OneHotVector{T,L}, OneHotMatrix{T,L}} ) where {T,L}
688
- size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
689
- return m (onecold (x))
690
- end
697
+ (m:: Embedding )(x:: AbstractVector{Bool} ) = m. weight * x # usually OneHotVector
698
+ (m:: Embedding )(x:: AbstractMatrix{Bool} ) = m. weight * x # usually OneHotMatrix
699
+ (m:: Embedding )(x:: AbstractArray{Bool} ) = reshape (m (reshape (x, size (x,1 ), :)), :, size (x)[2 : end ]. .. )
691
700
692
701
function Base. show (io:: IO , m:: Embedding )
693
702
print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
0 commit comments