Skip to content

Commit 4532ec6

Browse files
committed
update Embedding constructor
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
1 parent 0ae379e commit 4532ec6

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

src/layers/basic.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ function Base.show(io::IO, m::Parallel)
374374
end
375375

376376
"""
377-
Embedding(in, out; init=randn)
377+
Embedding(in => out; init=randn)
378378
379379
A lookup table that stores embeddings of dimension `out`
380380
for a vocabulary of size `in`.
@@ -409,9 +409,9 @@ end
409409

410410
@functor Embedding
411411

412-
function Embedding(in::Integer, out::Integer;
412+
function Embedding(dims::Pair{<:Integer, <:Integer};
413413
init = (i...) -> randn(Float32, i...))
414-
return Embedding(init(out, in))
414+
return Embedding(init(last(dims), first(dims)))
415415
end
416416

417417
(m::Embedding)(x::Union{OneHotLikeVector, OneHotLikeMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]

src/onehot.jl

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
2424
Union{OneHotArray{T, L, N, var"N+1", I},
2525
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
2626

27+
const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
28+
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}
29+
2730
_isonehot(x::OneHotArray) = true
2831
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
2932

test/cuda/layers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ end
221221

222222
@testset "Embedding" begin
223223
vocab_size, embed_size = 10, 4
224-
m = Embedding(vocab_size, embed_size)
224+
m = Embedding(vocab_size => embed_size)
225225
x = rand(1:vocab_size, 3)
226226
y = m(x)
227227
m_g = m |> gpu

test/layers/basic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ import Flux: activations
156156

157157
@testset "Embedding" begin
158158
vocab_size, embed_size = 10, 4
159-
m = Embedding(vocab_size, embed_size)
159+
m = Embedding(vocab_size => embed_size)
160160
@test size(m.weight) == (embed_size, vocab_size)
161161

162162
x = rand(1:vocab_size, 3)

0 commit comments

Comments
 (0)