Skip to content

Commit 1a0b519

Browse files
bors[bot]mcabbott
andauthored
Merge #1660
1660: Printing & docstrings for `onehot` / `onehotbatch` r=mcabbott a=mcabbott Right now, the printing of OneHotArray lies about its type parameters. That's pretty confusing. The standard way to hide messy details the way `view` and `adjoint` do is via `Base.showarg`, so I did that. Then I also re-used the dots which LinearAlgebra's sparse matrix printing uses: ``` julia> Flux.onehotbatch(collect("foo"), 'a':'z') # before 26×3 Flux.OneHotArray{26,2,Vector{UInt32}}: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 ... julia> typeof(ans) Flux.OneHotArray{UInt32, 26, 1, 2, Vector{UInt32}} julia> Flux.onehotbatch(collect("foo"), 'a':'z') # after 26×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ``` I've also tried to tidy up things I thought were unclear in the docstrings. E.g. it looked `unk...` was indicating that you could specify multiple defaults, but in fact the splat is just an implementation trick. And following `Base.get` perhaps it's called `default`. Should have no functional changes at all. Co-authored-by: Michael Abbott <[email protected]>
2 parents 1a14301 + 499cf27 commit 1a0b519

File tree

3 files changed

+123
-86
lines changed

3 files changed

+123
-86
lines changed

docs/src/data/onehot.md

+16-16
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ It's common to encode categorical variables (like `true`, `false` or `cat`, `dog
66
julia> using Flux: onehot, onecold
77
88
julia> onehot(:b, [:a, :b, :c])
9-
3-element Flux.OneHotVector{3,UInt32}:
10-
0
9+
3-element OneHotVector(::UInt32) with eltype Bool:
10+
1111
1
12-
0
12+
1313
1414
julia> onehot(:c, [:a, :b, :c])
15-
3-element Flux.OneHotVector{3,UInt32}:
16-
0
17-
0
15+
3-element OneHotVector(::UInt32) with eltype Bool:
16+
17+
1818
1
1919
```
2020

@@ -44,16 +44,16 @@ Flux.onecold
4444
julia> using Flux: onehotbatch
4545
4646
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
47-
3×3 Flux.OneHotArray{3,2,Vector{UInt32}}:
48-
0 1 0
49-
1 0 1
50-
0 0 0
51-
52-
julia> onecold(ans, [:a, :b, :c])
53-
3-element Vector{Symbol}:
54-
:b
55-
:a
56-
:b
47+
3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
48+
1
49+
1 1
50+
⋅ ⋅ ⋅
51+
52+
julia> onecold(ans, [:a, :b, :c])
53+
3-element Vector{Symbol}:
54+
:b
55+
:a
56+
:b
5757
```
5858

5959
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.

src/losses/functions.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ of label smoothing to binary distributions encoded in a single number.
115115
# Example
116116
```jldoctest
117117
julia> y = Flux.onehotbatch([1, 1, 1, 0, 1, 0], 0:1)
118-
2×6 Flux.OneHotArray{2,2,Vector{UInt32}}:
119-
0 0 0 1 0 1
120-
1 1 1 0 1 0
118+
2×6 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
119+
⋅ ⋅ ⋅ 1 1
120+
1 1 1 1
121121
122122
julia> y_smoothed = Flux.label_smoothing(y, 0.2f0)
123123
2×6 Matrix{Float32}:
@@ -180,10 +180,10 @@ See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbina
180180
# Example
181181
```jldoctest
182182
julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2)
183-
3×5 Flux.OneHotArray{3,2,Vector{UInt32}}:
184-
1 0 0 0 1
185-
0 1 0 1 0
186-
0 0 1 0 0
183+
3×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
184+
1 ⋅ ⋅ ⋅ 1
185+
1 1
186+
⋅ ⋅ 1 ⋅ ⋅
187187
188188
julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0)
189189
3×5 Matrix{Float32}:
@@ -232,10 +232,10 @@ See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`lab
232232
# Example
233233
```jldoctest
234234
julia> y_label = Flux.onehotbatch(collect("abcabaa"), 'a':'c')
235-
3×7 Flux.OneHotArray{3,2,Vector{UInt32}}:
236-
1 0 0 1 0 1 1
237-
0 1 0 0 1 0 0
238-
0 0 1 0 0 0 0
235+
3×7 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
236+
1 ⋅ ⋅ 1 1 1
237+
1 ⋅ ⋅ 1 ⋅ ⋅
238+
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
239239
240240
julia> y_model = reshape(vcat(-9:0, 0:9, 7.5f0), 3, 7)
241241
3×7 Matrix{Float32}:
@@ -291,9 +291,9 @@ julia> all(p -> 0 < p < 1, y_prob[2,:]) # else DomainError
291291
true
292292
293293
julia> y_hot = Flux.onehotbatch(y_bin, 0:1)
294-
2×3 Flux.OneHotArray{2,2,Vector{UInt32}}:
295-
0 1 0
296-
1 0 1
294+
2×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
295+
1
296+
1 1
297297
298298
julia> Flux.crossentropy(y_prob, y_hot)
299299
0.43989f0

src/onehot.jl

+93-56
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import Adapt
22
import .CUDA
33

4+
"""
5+
OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
6+
7+
These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
8+
Parameter `I` is the type of the underlying storage, and `T` its eltype.
9+
"""
410
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
511
indices::I
612
end
@@ -15,31 +21,11 @@ _indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
1521
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
1622
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
1723

24+
@doc @doc(OneHotArray)
1825
OneHotVector(idx, L) = OneHotArray(idx, L)
26+
@doc @doc(OneHotArray)
1927
OneHotMatrix(indices, L) = OneHotArray(indices, L)
2028

21-
function _show_elements(x::OneHotArray)
22-
xbool = convert(Array{Bool}, cpu(x))
23-
xrepr = join(split(repr(MIME("text/plain"), xbool; context= :limit => true), "\n")[2:end], "\n")
24-
25-
return xrepr
26-
end
27-
28-
function Base.show(io::IO, ::MIME"text/plain", x::OneHotArray{<:Any, L, <:Any, N, I}) where {L, N, I}
29-
join(io, string.(size(x)), "×")
30-
print(io, " Flux.OneHotArray{")
31-
join(io, string.([L, N, I]), ",")
32-
println(io, "}:")
33-
print(io, _show_elements(x))
34-
end
35-
function Base.show(io::IO, ::MIME"text/plain", x::OneHotVector{T, L}) where {T, L}
36-
print(io, string.(length(x)))
37-
print(io, "-element Flux.OneHotVector{")
38-
join(io, string.([L, T]), ",")
39-
println(io, "}:")
40-
print(io, _show_elements(x))
41-
end
42-
4329
# use this type so reshaped arrays hit fast paths
4430
# e.g. argmax
4531
const OneHotLike{T, L, N, var"N+1", I} =
@@ -61,6 +47,25 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i
6147
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
6248
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
6349

50+
function Base.showarg(io::IO, x::OneHotArray, toplevel)
51+
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
52+
Base.showarg(io, x.indices, false)
53+
print(io, ')')
54+
toplevel && print(io, " with eltype Bool")
55+
return nothing
56+
end
57+
58+
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
59+
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
60+
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
61+
end
62+
63+
# copy CuArray versions back before trying to print them:
64+
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
65+
Base.print_array(io, cpu(X))
66+
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
67+
Base.print_array(io, cpu(X))
68+
6469
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
6570
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
6671

@@ -95,72 +100,104 @@ Base.argmax(x::OneHotLike; dims = Colon()) =
95100
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
96101

97102
"""
98-
onehot(l, labels[, unk])
103+
onehot(x, labels, [default])
104+
105+
Return a `OneHotVector` which is roughly a sparse representation of `x .== labels`.
99106
100-
Return a `OneHotVector` where only first occourence of `l` in `labels` is `1` and
101-
all other elements are `0`.
107+
Instead of storing say `Vector{Bool}`, it stores the index of the first occurrence
108+
of `x` in `labels`. If `x` is not found in labels, then it either returns `onehot(default, labels)`,
109+
or gives an error if no default is given.
102110
103-
If `l` is not found in labels and `unk` is present, the function returns
104-
`onehot(unk, labels)`; otherwise the function raises an error.
111+
See also [`onehotbatch`](@ref) to apply this to many `x`s,
112+
and [`onecold`](@ref) to reverse either of these, as well as to generalise `argmax`.
105113
106114
# Examples
107115
```jldoctest
108-
julia> Flux.onehot(:b, [:a, :b, :c])
109-
3-element Flux.OneHotVector{3,UInt32}:
110-
0
116+
julia> β = Flux.onehot(:b, [:a, :b, :c])
117+
3-element OneHotVector(::UInt32) with eltype Bool:
118+
111119
1
112-
0
120+
113121
114-
julia> Flux.onehot(:c, [:a, :b, :c])
115-
3-element Flux.OneHotVector{3,UInt32}:
116-
0
117-
0
118-
1
122+
julia> αβγ = (Flux.onehot(0, 0:2), β, Flux.onehot(:z, [:a, :b, :c], :c)) # uses default
123+
(Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1])
124+
125+
julia> hcat(αβγ...) # preserves sparsity
126+
3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
127+
1 ⋅ ⋅
128+
⋅ 1 ⋅
129+
⋅ ⋅ 1
119130
```
120131
"""
121-
function onehot(l, labels)
122-
i = something(findfirst(isequal(l), labels), 0)
123-
i > 0 || error("Value $l is not in labels")
132+
function onehot(x, labels)
133+
i = something(findfirst(isequal(x), labels), 0)
134+
i > 0 || error("Value $x is not in labels")
124135
OneHotVector{UInt32, length(labels)}(i)
125136
end
126137

127-
function onehot(l, labels, unk)
128-
i = something(findfirst(isequal(l), labels), 0)
129-
i > 0 || return onehot(unk, labels)
138+
function onehot(x, labels, default)
139+
i = something(findfirst(isequal(x), labels), 0)
140+
i > 0 || return onehot(default, labels)
130141
OneHotVector{UInt32, length(labels)}(i)
131142
end
132143

133144
"""
134-
onehotbatch(ls, labels[, unk...])
145+
onehotbatch(xs, labels, [default])
146+
147+
Returns a `OneHotMatrix` where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot).
148+
This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the
149+
nonzero elements.
135150
136-
Return a `OneHotMatrix` where `k`th column of the matrix is `onehot(ls[k], labels)`.
151+
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
152+
if `default` is given, else an error.
137153
138-
If one of the input labels `ls` is not found in `labels` and `unk` is given,
139-
return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an error.
154+
If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
155+
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
156+
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
140157
141158
# Examples
142159
```jldoctest
143-
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
144-
3×3 Flux.OneHotArray{3,2,Vector{UInt32}}:
145-
0 1 0
146-
1 0 1
147-
0 0 0
160+
julia> oh = Flux.onehotbatch(collect("abracadabra"), 'a':'e', 'e')
161+
5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
162+
1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1
163+
⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
164+
⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
165+
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
166+
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅
167+
168+
julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
169+
3×11 Matrix{Int64}:
170+
1 4 13 1 7 1 10 1 4 13 1
171+
2 5 14 2 8 2 11 2 5 14 2
172+
3 6 15 3 9 3 12 3 6 15 3
148173
```
149174
"""
150-
onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls])
175+
onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
151176

152177
"""
153-
onecold(y[, labels = 1:length(y)])
178+
onecold(y::AbstractArray, labels = 1:size(y,1))
154179
155-
Inverse operations of [`onehot`](@ref).
180+
Roughly the inverse operation of [`onehot`](@ref) or [`onehotbatch`](@ref):
181+
This finds the index of the largest element of `y`, or each column of `y`,
182+
and looks them up in `labels`.
183+
184+
If `labels` are not specified, the default is integers `1:size(y,1)` --
185+
the same operation as `argmax(y, dims=1)` but sometimes a different return type.
156186
157187
# Examples
158188
```jldoctest
159-
julia> Flux.onecold([true, false, false], [:a, :b, :c])
160-
:a
189+
julia> Flux.onecold([false, true, false])
190+
2
161191
162192
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
163193
:c
194+
195+
julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1
196+
0 1 0 0 0 0 0 0 1 0 0
197+
0 0 0 0 1 0 0 0 0 0 0
198+
0 0 0 0 0 0 1 0 0 0 0
199+
0 0 1 0 0 0 0 0 0 1 0 ], 'a':'e') |> String
200+
"abeacadabea"
164201
```
165202
"""
166203
onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)]

0 commit comments

Comments
 (0)