Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0eb7ed9

Browse files
committedJul 14, 2021
update Embedding constructor
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
1 parent 6a7688a commit 0eb7ed9

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed
 

‎src/layers/basic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ function Base.show(io::IO, m::Parallel)
431431
end
432432

433433
"""
434-
Embedding(in, out; init=randn)
434+
Embedding(in => out; init=randn)
435435
436436
A lookup table that stores embeddings of dimension `out`
437437
for a vocabulary of size `in`.
@@ -466,7 +466,7 @@ end
466466

467467
@functor Embedding
468468

469-
Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
469+
Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims)))
470470

471471

472472
(m::Embedding)(x::Integer) = m.weight[:, x]

‎src/onehot.jl

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
3232
Union{OneHotArray{T, L, N, var"N+1", I},
3333
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
3434

35+
const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
36+
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}
37+
3538
_isonehot(x::OneHotArray) = true
3639
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
3740

‎test/cuda/layers.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,13 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
123123
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
124124

125125
embedding = [Flux.Embedding]
126-
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
127-
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
128-
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
129-
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
130-
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
131-
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
132-
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
126+
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
127+
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2)
128+
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
129+
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
130+
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
131+
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
132+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2)
133133

134134
@testset "function layers" begin
135135
x = rand(Float32, 3,3)

‎test/layers/basic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ import Flux: activations
194194

195195
@testset "Embedding" begin
196196
vocab_size, embed_size = 10, 4
197-
m = Flux.Embedding(vocab_size, embed_size)
197+
m = Flux.Embedding(vocab_size => embed_size)
198198
@test size(m.weight) == (embed_size, vocab_size)
199199

200200
x = rand(1:vocab_size, 3)

0 commit comments

Comments
 (0)
Please sign in to comment.