Skip to content

Commit 9553267

Browse files
use gather; fix outdated docs
Co-authored-by: Manikya <[email protected]>
1 parent 4d3944c commit 9553267

File tree

10 files changed

+72
-64
lines changed

10 files changed

+72
-64
lines changed

Manifest.toml

+28-28
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ version = "3.2.1"
5151

5252
[[ChainRules]]
5353
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
54-
git-tree-sha1 = "720fa9a9ce61ff18842a40f501d6a1f8ba771c64"
54+
git-tree-sha1 = "85c579fa131b5545eef874a5b413bb3b783e21c6"
5555
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
56-
version = "0.8.6"
56+
version = "0.8.21"
5757

5858
[[ChainRulesCore]]
5959
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
60-
git-tree-sha1 = "8b31cc69cbc38c5c826aaa1c890c694be3622d99"
60+
git-tree-sha1 = "dcc25ff085cf548bc8befad5ce048391a7c07d40"
6161
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
62-
version = "0.10.3"
62+
version = "0.10.11"
6363

6464
[[CodecZlib]]
6565
deps = ["TranscodingStreams", "Zlib_jll"]
@@ -87,18 +87,18 @@ version = "0.3.0"
8787

8888
[[Compat]]
8989
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
90-
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
90+
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
9191
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
92-
version = "3.30.0"
92+
version = "3.31.0"
9393

9494
[[CompilerSupportLibraries_jll]]
9595
deps = ["Artifacts", "Libdl"]
9696
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
9797

9898
[[DataAPI]]
99-
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
99+
git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
100100
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
101-
version = "1.6.0"
101+
version = "1.7.0"
102102

103103
[[DataStructures]]
104104
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -141,15 +141,15 @@ deps = ["ArgTools", "LibCURL", "NetworkOptions"]
141141
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
142142

143143
[[ExprTools]]
144-
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
144+
git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92"
145145
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
146-
version = "0.1.3"
146+
version = "0.1.6"
147147

148148
[[FillArrays]]
149149
deps = ["LinearAlgebra", "Random", "SparseArrays"]
150-
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
150+
git-tree-sha1 = "693210145367e7685d8604aee33d9bfb85db8b31"
151151
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
152-
version = "0.11.7"
152+
version = "0.11.9"
153153

154154
[[FixedPointNumbers]]
155155
deps = ["Statistics"]
@@ -183,9 +183,9 @@ version = "0.11.5"
183183

184184
[[IRTools]]
185185
deps = ["InteractiveUtils", "MacroTools", "Test"]
186-
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
186+
git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
187187
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
188-
version = "0.4.2"
188+
version = "0.4.3"
189189

190190
[[IfElse]]
191191
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
@@ -210,9 +210,9 @@ version = "0.8.4"
210210

211211
[[LLVM]]
212212
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
213-
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
213+
git-tree-sha1 = "f57ac3fd2045b50d3db081663837ac5b4096947e"
214214
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
215-
version = "3.7.1"
215+
version = "3.9.0"
216216

217217
[[LazyArtifacts]]
218218
deps = ["Artifacts", "Pkg"]
@@ -243,9 +243,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
243243

244244
[[LogExpFunctions]]
245245
deps = ["DocStringExtensions", "LinearAlgebra"]
246-
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
246+
git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070"
247247
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
248-
version = "0.2.4"
248+
version = "0.2.5"
249249

250250
[[Logging]]
251251
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -290,9 +290,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
290290

291291
[[NNlib]]
292292
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
293-
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
293+
git-tree-sha1 = "7e6f31cfa39b1ff1c541cc8580b14b0ff4ba22d0"
294294
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
295-
version = "0.7.21"
295+
version = "0.7.23"
296296

297297
[[NNlibCUDA]]
298298
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
@@ -347,9 +347,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
347347

348348
[[Random123]]
349349
deps = ["Libdl", "Random", "RandomNumbers"]
350-
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
350+
git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
351351
uuid = "74087812-796a-5b5d-8853-05524746bad3"
352-
version = "1.3.1"
352+
version = "1.4.2"
353353

354354
[[RandomNumbers]]
355355
deps = ["Random", "Requires"]
@@ -389,9 +389,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
389389

390390
[[SortingAlgorithms]]
391391
deps = ["DataStructures"]
392-
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
392+
git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508"
393393
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
394-
version = "1.0.0"
394+
version = "1.0.1"
395395

396396
[[SparseArrays]]
397397
deps = ["LinearAlgebra", "Random"]
@@ -411,9 +411,9 @@ version = "0.2.5"
411411

412412
[[StaticArrays]]
413413
deps = ["LinearAlgebra", "Random", "Statistics"]
414-
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
414+
git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e"
415415
uuid = "90137ffa-7385-5640-81b9-e52037218182"
416-
version = "1.2.2"
416+
version = "1.2.6"
417417

418418
[[Statistics]]
419419
deps = ["LinearAlgebra", "SparseArrays"]
@@ -444,9 +444,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
444444

445445
[[TimerOutputs]]
446446
deps = ["ExprTools", "Printf"]
447-
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
447+
git-tree-sha1 = "209a8326c4f955e2442c07b56029e88bb48299c7"
448448
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
449-
version = "0.5.9"
449+
version = "0.5.12"
450450

451451
[[TranscodingStreams]]
452452
deps = ["Random", "Test"]

docs/src/gpu.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
3030
```julia
3131
d = Dense(10, 5, σ)
3232
d = fmap(cu, d)
33-
d.W # CuArray
33+
d.weight # CuArray
3434
d(cu(rand(10))) # CuArray output
3535

3636
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)

docs/src/models/advanced.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ by simply deleting it from `ps`:
6868

6969
```julia
7070
ps = params(m)
71-
delete!(ps, m[2].b)
71+
delete!(ps, m[2].bias)
7272
```
7373

7474
## Custom multiple input or output layer

docs/src/models/nnlib.md

+7
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,10 @@ NNlib.batched_mul!
6767
NNlib.batched_adjoint
6868
NNlib.batched_transpose
6969
```
70+
71+
## Gather and Scatter
72+
73+
```@docs
74+
NNlib.gather
75+
NNlib.scatter
76+
```

docs/src/models/overview.md

+19-19
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s
1515

1616
This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate:
1717

18-
```
18+
```julia
1919
julia> using Flux
2020

2121
julia> actual(x) = 4x + 2
@@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function.
2828

2929
Use the `actual` function to build sets of data for training and verification:
3030

31-
```
31+
```julia
3232
julia> x_train, x_test = hcat(0:5...), hcat(6:10...)
3333
([0 1 4 5], [6 7 9 10])
3434

@@ -42,38 +42,38 @@ Normally, your training and test data come from real world observations, but thi
4242

4343
Now, build a model to make predictions with `1` input and `1` output:
4444

45-
```
45+
```julia
4646
julia> model = Dense(1, 1)
4747
Dense(1, 1)
4848

49-
julia> model.W
50-
1-element Array{Float64,1}:
51-
-0.99009055
49+
julia> model.weight
50+
1×1 Matrix{Float32}:
51+
-1.4925033
5252

53-
julia> model.b
54-
1-element Array{Float64,1}:
53+
julia> model.bias
54+
1-element Vector{Float32}:
5555
0.0
5656
```
5757

58-
Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
58+
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
5959

60-
```
60+
```julia
6161
julia> predict = Dense(1, 1)
6262
```
6363

6464
`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.
6565

6666
This model will already make predictions, though not accurate ones yet:
6767

68-
```
68+
```julia
6969
julia> predict(x_train)
70-
1×6 Array{Float32,2}:
71-
-1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782
70+
1×6 Matrix{Float32}:
71+
0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252
7272
```
7373

7474
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.
7575

76-
```
76+
```julia
7777
julia> loss(x, y) = Flux.Losses.mse(predict(x), y)
7878
loss (generic function with 1 method)
7979

@@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f
8787

8888
Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md):
8989

90-
```
90+
```julia
9191
julia> using Flux: train!
9292

9393
julia> opt = Descent()
@@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)]
100100

101101
Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs:
102102

103-
```
104-
julia> predict.W
103+
```julia
104+
julia> predict.weight
105105
1-element Array{Float64,1}:
106106
-0.99009055
107107

108-
julia> predict.b
108+
julia> predict.bias
109109
1-element Array{Float64,1}:
110110
0.0
111111
```
@@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]])
120120
These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model:
121121

122122
```
123-
julia> predict.W in parameters, predict.b in parameters
123+
julia> predict.weight in parameters, predict.bias in parameters
124124
(true, true)
125125
126126
```

docs/src/models/regularisation.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ m = Dense(10, 5)
1313
loss(x, y) = logitcrossentropy(m(x), y)
1414
```
1515

16-
We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`.
16+
We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`.
1717

1818
```julia
19-
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
19+
penalty() = sum(abs2, m.weight) + sum(abs2, m.bias)
2020
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
2121
```
2222

src/layers/basic.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ function Embedding(in::Integer, out::Integer;
469469
end
470470

471471
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
472-
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
472+
(m::Embedding)(x::Integer) = m([x])
473+
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
473474
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
474475

475476
function Base.show(io::IO, m::Embedding)

src/utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r
1616
julia> layer = Dense(10, 20)
1717
Dense(10, 20)
1818
19-
julia> Flux.nfan(size(layer.W))
19+
julia> Flux.nfan(size(layer.weight))
2020
(10, 20)
2121
2222
julia> layer = Conv((3, 3), 2=>10)

test/cuda/layers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ end
140140

141141
@test sum(l(ip)) 0.f0
142142
gs = gradient(() -> sum(l(ip)), Flux.params(l))
143-
@test l.b gs.params
143+
@test l.bias gs.params
144144
end
145145

146146
@testset "Extended BatchNorm" begin

test/utils.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -226,19 +226,19 @@ end
226226
m = Chain(Dense(10, 5, relu), Dense(5, 2))
227227
x64 = rand(Float64, 10)
228228
x32 = rand(Float32, 10)
229-
@test eltype(m[1].W) == Float32
229+
@test eltype(m[1].weight) == Float32
230230
@test eltype(m(x32)) == Float32
231231
@test eltype(m(x64)) == Float64
232232
@test eltype(f64(m)(x32)) == Float64
233233
@test eltype(f64(m)(x64)) == Float64
234-
@test eltype(f64(m)[1].W) == Float64
235-
@test eltype(f32(f64(m))[1].W) == Float32
234+
@test eltype(f64(m)[1].weight) == Float64
235+
@test eltype(f32(f64(m))[1].weight) == Float32
236236
end
237237

238238
@testset "Zeros" begin
239239
m = Dense(3,2; bias=false)
240-
@test f64(m).b === m.b === Zeros()
241-
@test f32(m).b === m.b === Zeros()
240+
@test f64(m).bias === m.bias === Zeros()
241+
@test f32(m).bias === m.bias === Zeros()
242242

243243
@testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3))
244244
o = ones(s)
@@ -340,19 +340,19 @@ end
340340

341341
nobias(n) = Zeros()
342342
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
343-
@test l1.W == l2.W
344-
@test l1.b == l2.b
345-
@test_skip typeof(l1.b) === typeof(l2.b)
343+
@test l1.weight == l2.weight
344+
@test l1.bias == l2.bias
345+
@test_skip typeof(l1.bias) === typeof(l2.bias)
346346
end
347347

348348
@testset "loadparams!" begin
349349
import Flux: loadparams!
350350
pars(w, b) = [w, b]
351351
import Flux: loadparams!, Zeros
352352
pars(w, b::Zeros) = [w, Flux.zeros(size(w,1))]
353-
pars(l) = pars(l.W, l.b)
353+
pars(l) = pars(l.weight, l.bias)
354354
pararray(m) = mapreduce(pars, vcat, m)
355-
weights(m) = mapreduce(l -> [l.W], vcat, m)
355+
weights(m) = mapreduce(l -> [l.weight], vcat, m)
356356
@testset "Bias type $bt" for bt in (Flux.zeros, nobias)
357357
m = dm(bt)
358358
loadparams!(m, params(m))

0 commit comments

Comments
 (0)