You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1660: Printing & docstrings for `onehot` / `onehotbatch` r=mcabbott a=mcabbott
Right now, the printing of OneHotArray lies about its type parameters. That's pretty confusing. The standard way to hide messy details the way `view` and `adjoint` do is via `Base.showarg`, so I did that. Then I also re-used the dots which LinearAlgebra's sparse matrix printing uses:
```
julia> Flux.onehotbatch(collect("foo"), 'a':'z') # before
26×3 Flux.OneHotArray{26,2,Vector{UInt32}}:
0 0 0
0 0 0
0 0 0
0 0 0
0 0 0
1 0 0
0 0 0
...
julia> typeof(ans)
Flux.OneHotArray{UInt32, 26, 1, 2, Vector{UInt32}}
julia> Flux.onehotbatch(collect("foo"), 'a':'z') # after
26×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
⋅ ⋅ ⋅
⋅ ⋅ ⋅
⋅ ⋅ ⋅
⋅ ⋅ ⋅
⋅ ⋅ ⋅
1 ⋅ ⋅
⋅ ⋅ ⋅
```
I've also tried to tidy up things I thought were unclear in the docstrings. E.g. it looked `unk...` was indicating that you could specify multiple defaults, but in fact the splat is just an implementation trick. And following `Base.get` perhaps it's called `default`.
Should have no functional changes at all.
Co-authored-by: Michael Abbott <[email protected]>
Copy file name to clipboardexpand all lines: docs/src/data/onehot.md
+16-16
Original file line number
Diff line number
Diff line change
@@ -6,15 +6,15 @@ It's common to encode categorical variables (like `true`, `false` or `cat`, `dog
6
6
julia> using Flux: onehot, onecold
7
7
8
8
julia> onehot(:b, [:a, :b, :c])
9
-
3-element Flux.OneHotVector{3,UInt32}:
10
-
0
9
+
3-element OneHotVector(::UInt32) with eltype Bool:
10
+
⋅
11
11
1
12
-
0
12
+
⋅
13
13
14
14
julia> onehot(:c, [:a, :b, :c])
15
-
3-element Flux.OneHotVector{3,UInt32}:
16
-
0
17
-
0
15
+
3-element OneHotVector(::UInt32) with eltype Bool:
16
+
⋅
17
+
⋅
18
18
1
19
19
```
20
20
@@ -44,16 +44,16 @@ Flux.onecold
44
44
julia> using Flux: onehotbatch
45
45
46
46
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
47
-
3×3 Flux.OneHotArray{3,2,Vector{UInt32}}:
48
-
0 1 0
49
-
1 0 1
50
-
0 0 0
51
-
52
-
julia> onecold(ans, [:a, :b, :c])
53
-
3-element Vector{Symbol}:
54
-
:b
55
-
:a
56
-
:b
47
+
3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
48
+
⋅ 1 ⋅
49
+
1 ⋅ 1
50
+
⋅ ⋅ ⋅
51
+
52
+
julia> onecold(ans, [:a, :b, :c])
53
+
3-element Vector{Symbol}:
54
+
:b
55
+
:a
56
+
:b
57
57
```
58
58
59
59
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
0 commit comments