@@ -442,30 +442,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
442
442
443
443
# Examples
444
444
445
- ```julia-repl
446
- julia> using Flux: Embedding
447
-
448
- julia> vocab_size, embed_size = 1000, 4;
449
-
450
- julia> model = Embedding(vocab_size, embed_size)
451
- Embedding(1000, 4)
452
-
453
- julia> vocab_idxs = [1, 722, 53, 220, 3]
445
+ ```jldoctest
446
+ julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
447
+ Embedding(26 => 2)
454
448
455
- julia> x = OneHotMatrix(vocab_idxs, vocab_size);
449
+ julia> m(5) # embedding vector for 5th element
450
+ 2-element Vector{Float32}:
451
+ 2.01
452
+ 3.01
456
453
457
- julia> model(x)
458
- 4×5 Matrix{Float32}:
459
- 0.91139 0.670462 0.463217 0.670462 0.110932
460
- 0.247225 -0.0823874 0.698694 -0.0823874 0.945958
461
- -0.393626 -0.590136 -0.545422 -0.590136 0.77743
462
- -0.497621 0.87595 -0.870251 0.87595 -0.772696
463
- ```
454
+ julia> m([6, 15, 15]) # applied to a batch
455
+ 2×3 Matrix{Float32}:
456
+ 4.01 22.01 22.01
457
+ 5.01 23.01 23.01
464
458
465
- julia> model(vocab_idxs) == model(x )
459
+ julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26) )
466
460
true
461
+ ```
467
462
"""
468
- struct Embedding{W}
463
+ struct Embedding{W <: AbstractMatrix }
469
464
weight:: W
470
465
end
471
466
@@ -484,5 +479,5 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
484
479
end
485
480
486
481
function Base. show (io:: IO , m:: Embedding )
487
- print (io, " Embedding($(size (m. weight, 2 )) , $(size (m. weight, 1 )) )" )
482
+ print (io, " Embedding($(size (m. weight, 2 )) => $(size (m. weight, 1 )) )" )
488
483
end
0 commit comments