Skip to content

Commit 9931730

Browse files
Merge #1516
1516: add Embedding layer r=DhairyaLGandhi a=CarloLucibello Basic implementation. Maybe could be improved when FluxML/NNlib.jl#255 lands ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Carlo Lucibello <[email protected]> Co-authored-by: Carlo Lucibello <[email protected]>
2 parents 1a0b519 + 397cabd commit 9931730

File tree

13 files changed

+162
-42
lines changed

13 files changed

+162
-42
lines changed

NEWS.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Flux Release Notes
22

3+
## v0.12.4
4+
* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516)
5+
based on `NNlib.gather` and `NNlib.scatter`.
6+
7+
## v0.12.1 - v0.12.3
8+
9+
* CUDA.jl 3.0 support
10+
* Bug fixes and optimizations.
11+
312
## v0.12.0
413

514
* Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524).

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.12.4"
3+
version = "0.12.5"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -37,7 +37,7 @@ Colors = "0.12"
3737
Functors = "0.2.1"
3838
Juno = "0.8"
3939
MacroTools = "0.5"
40-
NNlib = "0.7.14"
40+
NNlib = "0.7.24"
4141
NNlibCUDA = "0.1"
4242
Reexport = "0.2, 1.0"
4343
StatsBase = "0.33"

docs/src/gpu.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
3030
```julia
3131
d = Dense(10, 5, σ)
3232
d = fmap(cu, d)
33-
d.W # CuArray
33+
d.weight # CuArray
3434
d(cu(rand(10))) # CuArray output
3535

3636
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)

docs/src/models/advanced.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ by simply deleting it from `ps`:
6868

6969
```julia
7070
ps = params(m)
71-
delete!(ps, m[2].b)
71+
delete!(ps, m[2].bias)
7272
```
7373

7474
## Custom multiple input or output layer

docs/src/models/layers.md

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ SkipConnection
5858
Parallel
5959
Flux.Bilinear
6060
Flux.Diagonal
61+
Flux.Embedding
6162
```
6263

6364
## Normalisation & Regularisation

docs/src/models/nnlib.md

+7
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,10 @@ NNlib.batched_mul!
6767
NNlib.batched_adjoint
6868
NNlib.batched_transpose
6969
```
70+
71+
## Gather and Scatter
72+
73+
```@docs
74+
NNlib.gather
75+
NNlib.scatter
76+
```

docs/src/models/overview.md

+19-19
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s
1515

1616
This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate:
1717

18-
```
18+
```julia
1919
julia> using Flux
2020

2121
julia> actual(x) = 4x + 2
@@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function.
2828

2929
Use the `actual` function to build sets of data for training and verification:
3030

31-
```
31+
```julia
3232
julia> x_train, x_test = hcat(0:5...), hcat(6:10...)
3333
([0 1 4 5], [6 7 9 10])
3434

@@ -42,38 +42,38 @@ Normally, your training and test data come from real world observations, but thi
4242

4343
Now, build a model to make predictions with `1` input and `1` output:
4444

45-
```
45+
```julia
4646
julia> model = Dense(1, 1)
4747
Dense(1, 1)
4848

49-
julia> model.W
50-
1-element Array{Float64,1}:
51-
-0.99009055
49+
julia> model.weight
50+
1×1 Matrix{Float32}:
51+
-1.4925033
5252

53-
julia> model.b
54-
1-element Array{Float64,1}:
53+
julia> model.bias
54+
1-element Vector{Float32}:
5555
0.0
5656
```
5757

58-
Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
58+
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
5959

60-
```
60+
```julia
6161
julia> predict = Dense(1, 1)
6262
```
6363

6464
`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.
6565

6666
This model will already make predictions, though not accurate ones yet:
6767

68-
```
68+
```julia
6969
julia> predict(x_train)
70-
1×6 Array{Float32,2}:
71-
-1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782
70+
1×6 Matrix{Float32}:
71+
0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252
7272
```
7373

7474
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.
7575

76-
```
76+
```julia
7777
julia> loss(x, y) = Flux.Losses.mse(predict(x), y)
7878
loss (generic function with 1 method)
7979

@@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f
8787

8888
Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md):
8989

90-
```
90+
```julia
9191
julia> using Flux: train!
9292

9393
julia> opt = Descent()
@@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)]
100100

101101
Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs:
102102

103-
```
104-
julia> predict.W
103+
```julia
104+
julia> predict.weight
105105
1-element Array{Float64,1}:
106106
-0.99009055
107107

108-
julia> predict.b
108+
julia> predict.bias
109109
1-element Array{Float64,1}:
110110
0.0
111111
```
@@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]])
120120
These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model:
121121

122122
```
123-
julia> predict.W in parameters, predict.b in parameters
123+
julia> predict.weight in parameters, predict.bias in parameters
124124
(true, true)
125125
126126
```

docs/src/models/regularisation.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ m = Dense(10, 5)
1313
loss(x, y) = logitcrossentropy(m(x), y)
1414
```
1515

16-
We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`.
16+
We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`.
1717

1818
```julia
19-
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
19+
penalty() = sum(abs2, m.weight) + sum(abs2, m.bias)
2020
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
2121
```
2222

src/layers/basic.jl

+58
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on a given input.
88
`m[1:3](x)` will calculate the output of the first three layers.
99
1010
# Examples
11+
1112
```jldoctest
1213
julia> m = Chain(x -> x^2, x -> x+1);
1314
@@ -428,3 +429,60 @@ function Base.show(io::IO, m::Parallel)
428429
join(io, m.layers, ", ")
429430
print(io, ")")
430431
end
432+
433+
"""
434+
Embedding(in, out; init=randn)
435+
436+
A lookup table that stores embeddings of dimension `out`
437+
for a vocabulary of size `in`.
438+
439+
This layers is often used to store word embeddings and retrieve them using indices.
440+
The input to the layer can be either a vector of indexes
441+
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
442+
443+
# Examples
444+
445+
```julia-repl
446+
julia> using Flux: Embedding
447+
448+
julia> vocab_size, embed_size = 1000, 4;
449+
450+
julia> model = Embedding(vocab_size, embed_size)
451+
Embedding(1000, 4)
452+
453+
julia> vocab_idxs = [1, 722, 53, 220, 3]
454+
455+
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
456+
457+
julia> model(x)
458+
4×5 Matrix{Float32}:
459+
0.91139 0.670462 0.463217 0.670462 0.110932
460+
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
461+
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
462+
-0.497621 0.87595 -0.870251 0.87595 -0.772696
463+
```
464+
465+
julia> model(vocab_idxs) == model(x)
466+
true
467+
"""
468+
struct Embedding{W}
469+
weight::W
470+
end
471+
472+
@functor Embedding
473+
474+
Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
475+
476+
477+
(m::Embedding)(x::Integer) = m.weight[:, x]
478+
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
479+
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
480+
481+
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
482+
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
483+
return m(onecold(x))
484+
end
485+
486+
function Base.show(io::IO, m::Embedding)
487+
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
488+
end

src/utils.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r
1515
```jldoctest
1616
julia> layer = Dense(10, 20);
1717
18-
julia> Flux.nfan(size(layer.W))
18+
julia> Flux.nfan(size(layer.weight))
1919
(10, 20)
2020
2121
julia> layer = Conv((3, 3), 2=>10);
@@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit
368368

369369
ones32(dims...) = Base.ones(Float32, dims...)
370370
zeros32(dims...) = Base.zeros(Float32, dims...)
371+
rand32(dims...) = Base.rand(Float32, dims...)
372+
randn32(dims...) = Base.randn(Float32, dims...)
371373

372374
"""
373375
create_bias(weights, bias, length)

test/cuda/layers.jl

+23-6
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,21 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
5151
# test
5252
if test_cpu
5353
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
54-
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
54+
if isnothing(xg_cpu)
55+
@test isnothing(xg_gpu)
56+
else
57+
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
58+
end
5559
end
5660
@test gs_gpu isa Flux.Zygote.Grads
5761
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
58-
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
59-
if test_cpu
60-
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
62+
if isnothing(gs_cpu[p_cpu])
63+
@test isnothing(gs_gpu[p_gpu])
64+
else
65+
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
66+
if test_cpu
67+
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
68+
end
6169
end
6270
end
6371
end
@@ -114,6 +122,15 @@ pixelshuffle = [PixelShuffle]
114122
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
115123
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
116124

125+
embedding = [Flux.Embedding]
126+
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
127+
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
128+
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
129+
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
130+
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
131+
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
132+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
133+
117134
@testset "function layers" begin
118135
x = rand(Float32, 3,3)
119136
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
@@ -135,12 +152,12 @@ end
135152
end
136153

137154
@testset "Dense with Zeros bias" begin
138-
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
155+
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
139156
ip = zeros(Float32, 3, 7) |> gpu
140157

141158
@test sum(l(ip)) 0.f0
142159
gs = gradient(() -> sum(l(ip)), Flux.params(l))
143-
@test l.b gs.params
160+
@test l.bias gs.params
144161
end
145162

146163
@testset "Extended BatchNorm" begin

test/layers/basic.jl

+25
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,29 @@ import Flux: activations
191191
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
192192
end
193193
end
194+
195+
@testset "Embedding" begin
196+
vocab_size, embed_size = 10, 4
197+
m = Flux.Embedding(vocab_size, embed_size)
198+
@test size(m.weight) == (embed_size, vocab_size)
199+
200+
x = rand(1:vocab_size, 3)
201+
y = m(x)
202+
@test y isa Matrix{Float32}
203+
@test y m.weight[:,x]
204+
x2 = OneHotMatrix(x, vocab_size)
205+
y2 = m(x2)
206+
@test y2 isa Matrix{Float32}
207+
@test y2 y
208+
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
209+
210+
x = rand(1:vocab_size, 3, 4)
211+
y = m(x)
212+
@test y isa Array{Float32, 3}
213+
@test size(y) == (embed_size, 3, 4)
214+
215+
@test m(2) m.weight[:,2]
216+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
217+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
218+
end
194219
end

0 commit comments

Comments
 (0)