@@ -8,6 +8,7 @@ on a given input.
8
8
`m[1:3](x)` will calculate the output of the first three layers.
9
9
10
10
# Examples
11
+
11
12
```jldoctest
12
13
julia> m = Chain(x -> x^2, x -> x+1);
13
14
@@ -421,4 +422,56 @@ function Base.show(io::IO, m::Parallel)
421
422
print (io, " Parallel(" , m. connection, " , " )
422
423
join (io, m. layers, " , " )
423
424
print (io, " )" )
424
- end
425
+ end
426
+
427
+ """
428
+ Embedding(in, out; init=randn)
429
+
430
+ A lookup table that stores embeddings of dimension `out`
431
+ for a vocabulary of size `in`.
432
+
433
+ This layers is often used to store word embeddings and retrieve them using indices.
434
+ The input to the layer can be either a vector of indexes
435
+ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
436
+
437
+ # Examples
438
+
439
+ ```julia-repl
440
+ julia> vocab_size, embed_size = 1000, 4;
441
+
442
+ julia> model = Embedding(vocab_size, embed_size)
443
+ Embedding(1000, 4)
444
+
445
+ julia> vocab_idxs = [1, 722, 53, 220, 3]
446
+
447
+ julia> x = OneHotMatrix(vocab_idxs, vocab_size);
448
+
449
+ julia> model(x)
450
+ 4×5 Matrix{Float32}:
451
+ 0.91139 0.670462 0.463217 0.670462 0.110932
452
+ 0.247225 -0.0823874 0.698694 -0.0823874 0.945958
453
+ -0.393626 -0.590136 -0.545422 -0.590136 0.77743
454
+ -0.497621 0.87595 -0.870251 0.87595 -0.772696
455
+ ```
456
+
457
+ julia> model(vocab_idxs) == model(x)
458
+ true
459
+ """
460
+ struct Embedding{W}
461
+ weight:: W
462
+ end
463
+
464
+ @functor Embedding
465
+
466
+ function Embedding (in:: Integer , out:: Integer ;
467
+ init = (i... ) -> randn (Float32, i... ))
468
+ return Embedding (init (out, in))
469
+ end
470
+
471
+ (m:: Embedding )(x:: Union{OneHotVector, OneHotMatrix} ) = m. weight * x # equivalent to m.weight[:,onecold(x)]
472
+ (m:: Embedding )(x:: Union{Int,AbstractVector} ) = m. weight[:, x]
473
+ (m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
474
+
475
+ function Base. show (io:: IO , m:: Embedding )
476
+ print (io, " Embedding($(size (m. weight, 2 )) , $(size (m. weight, 1 )) )" )
477
+ end
0 commit comments