Skip to content

Commit 55c64c3

Browse files
committed
also give a non-random example where the vectors can be printed. Do it without 5 named variables, and show that the point of onehot is variables which aren't 1:n already. Also show result of higher-rank input.
1 parent 2cbf391 commit 55c64c3

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

src/layers/basic.jl

+22-15
Original file line numberDiff line numberDiff line change
@@ -654,25 +654,32 @@ The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654654
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655655
656656
For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657-
For one-hot `x`, the result is of size `(out, size(x)[2:end]...)`.
657+
For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.
658658
659659
# Examples
660660
```jldoctest
661-
julia> vocab_size, embed_size = 1000, 4;
662-
663-
julia> model = Flux.Embedding(vocab_size => embed_size)
664-
Embedding(1000 => 4) # 4_000 parameters
665-
666-
julia> vocab_idxs = [1, 722, 53, 220, 3];
667-
668-
julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x)
669-
"1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool"
670-
671-
julia> model(x) |> summary
672-
"4×5 Matrix{Float32}"
673-
674-
julia> model(vocab_idxs) == model(x)
661+
julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))
662+
Embedding(26 => 4) # 104 parameters
663+
664+
julia> emb(2) # one column of e.weight (here not random!)
665+
4-element Vector{Float32}:
666+
0.0
667+
22.0
668+
0.0
669+
0.0
670+
671+
julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26
672+
4×7 Matrix{Float32}:
673+
0.0 22.0 0.0 0.0 0.0 0.0 0.0
674+
0.0 0.0 0.0 0.0 0.0 0.0 0.0
675+
22.0 0.0 0.0 0.0 0.0 0.0 0.0
676+
0.0 0.0 0.0 0.0 22.0 0.0 0.0
677+
678+
julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
675679
true
680+
681+
julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
682+
(4, 10, 1, 12)
676683
```
677684
"""
678685
struct Embedding{W}

0 commit comments

Comments
 (0)