Skip to content

Commit b7c588a

Browse files
committed
update Embedding constructor
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
1 parent 1c61da8 commit b7c588a

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

src/layers/basic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ function Base.show(io::IO, m::Parallel)
431431
end
432432

433433
"""
434-
Embedding(in, out; init=randn)
434+
Embedding(in => out; init=randn)
435435
436436
A lookup table that stores embeddings of dimension `out`
437437
for a vocabulary of size `in`.
@@ -466,7 +466,7 @@ end
466466

467467
@functor Embedding
468468

469-
Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
469+
Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims)))
470470

471471
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
472472
(m::Embedding)(x::Integer) = m([x])

src/onehot.jl

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
4646
Union{OneHotArray{T, L, N, var"N+1", I},
4747
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
4848

49+
const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
50+
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}
51+
4952
_isonehot(x::OneHotArray) = true
5053
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
5154

test/cuda/layers.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ end
262262

263263
@testset "Embedding" begin
264264
vocab_size, embed_size = 5, 2
265-
m = Flux.Embedding(vocab_size, embed_size)
266-
265+
m = Flux.Embedding(vocab_size => embed_size)
266+
267267
x = [1, 3, 5]
268268
y = m(x)
269269
m_g = m |> gpu

test/layers/basic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ import Flux: activations
194194

195195
@testset "Embedding" begin
196196
vocab_size, embed_size = 10, 4
197-
m = Flux.Embedding(vocab_size, embed_size)
197+
m = Flux.Embedding(vocab_size => embed_size)
198198
@test size(m.weight) == (embed_size, vocab_size)
199199

200200
x = rand(1:vocab_size, 3)

0 commit comments

Comments
 (0)