@@ -442,30 +442,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
442
442
443
443
# Examples
444
444
445
- ```julia-repl
446
- julia> using Flux: Embedding
447
-
448
- julia> vocab_size, embed_size = 1000, 4;
449
-
450
- julia> model = Embedding(vocab_size, embed_size)
451
- Embedding(1000, 4)
452
-
453
- julia> vocab_idxs = [1, 722, 53, 220, 3]
445
+ ```jldoctest
446
+ julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
447
+ Embedding(26 => 2)
454
448
455
- julia> x = OneHotMatrix(vocab_idxs, vocab_size);
449
+ julia> m(5) # embedding vector for 5th element
450
+ 2-element Vector{Float32}:
451
+ 2.01
452
+ 3.01
456
453
457
- julia> model(x)
458
- 4×5 Matrix{Float32}:
459
- 0.91139 0.670462 0.463217 0.670462 0.110932
460
- 0.247225 -0.0823874 0.698694 -0.0823874 0.945958
461
- -0.393626 -0.590136 -0.545422 -0.590136 0.77743
462
- -0.497621 0.87595 -0.870251 0.87595 -0.772696
463
- ```
454
+ julia> m([6, 15, 15]) # applied to a batch
455
+ 2×3 Matrix{Float32}:
456
+ 4.01 22.01 22.01
457
+ 5.01 23.01 23.01
464
458
465
- julia> model(vocab_idxs) == model(x )
459
+ julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26) )
466
460
true
461
+ ```
467
462
"""
468
- struct Embedding{W}
463
+ struct Embedding{W <: AbstractMatrix }
469
464
weight:: W
470
465
end
471
466
@@ -479,5 +474,5 @@ Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
479
474
(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
480
475
481
476
function Base. show (io:: IO , m:: Embedding )
482
- print (io, " Embedding($(size (m. weight, 2 )) , $(size (m. weight, 1 )) )" )
477
+ print (io, " Embedding($(size (m. weight, 2 )) => $(size (m. weight, 1 )) )" )
483
478
end
0 commit comments