Skip to content

Commit 1c61da8

Browse files
manikyabardmcabbottdarsnack
committed
Apply suggestions from code review
Co-authored-by: Michael Abbott <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
1 parent 7cc77a0 commit 1c61da8

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

src/layers/basic.jl

+15-20
Original file line numberDiff line numberDiff line change
@@ -442,30 +442,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
442442
443443
# Examples
444444
445-
```julia-repl
446-
julia> using Flux: Embedding
447-
448-
julia> vocab_size, embed_size = 1000, 4;
449-
450-
julia> model = Embedding(vocab_size, embed_size)
451-
Embedding(1000, 4)
452-
453-
julia> vocab_idxs = [1, 722, 53, 220, 3]
445+
```jldoctest
446+
julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
447+
Embedding(26 => 2)
454448
455-
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
449+
julia> m(5) # embedding vector for 5th element
450+
2-element Vector{Float32}:
451+
2.01
452+
3.01
456453
457-
julia> model(x)
458-
4×5 Matrix{Float32}:
459-
0.91139 0.670462 0.463217 0.670462 0.110932
460-
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
461-
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
462-
-0.497621 0.87595 -0.870251 0.87595 -0.772696
463-
```
454+
julia> m([6, 15, 15]) # applied to a batch
455+
2×3 Matrix{Float32}:
456+
4.01 22.01 22.01
457+
5.01 23.01 23.01
464458
465-
julia> model(vocab_idxs) == model(x)
459+
julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26))
466460
true
461+
```
467462
"""
468-
struct Embedding{W}
463+
struct Embedding{W <: AbstractMatrix}
469464
weight::W
470465
end
471466

@@ -479,5 +474,5 @@ Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
479474
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
480475

481476
function Base.show(io::IO, m::Embedding)
482-
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
477+
print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))")
483478
end

0 commit comments

Comments
 (0)