Skip to content

Commit 9410677

Browse files
bors[bot]mcabbott
andauthored
Merge #1635
1635: Fixup `Dataloader`'s docstring r=DhairyaLGandhi a=mcabbott This makes the Dataloader example into a `jldoctest`, and adds a friendly error if you mess up the sizes. DataLoader was exported from Data, but accessible only as (deep breath) `Flux.Data.DataLoader` because the module was not imported. I take it that's a mistake, because of the export statement. So I fixed it. Note that `julia> Flux.DataL<tab>` still doesn't find it. Perhaps we can murder some sub-modules once the datasets currently residing in Data are removed. (Test failure on nightly ought to be fixed by FluxML/NNlibCUDA.jl#15, but CI gets the old version?) Co-authored-by: Michael Abbott <[email protected]>
2 parents b454024 + 7ecc114 commit 9410677

File tree

3 files changed

+51
-42
lines changed

3 files changed

+51
-42
lines changed

src/Flux.jl

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ include("layers/upsample.jl")
4747
include("outputsize.jl")
4848

4949
include("data/Data.jl")
50+
using .Data
5051

5152
include("losses/Losses.jl")
5253
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12

src/data/Data.jl

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Base: @propagate_inbounds
66
include("dataloader.jl")
77
export DataLoader
88

9-
109
## TODO for v0.13: remove everything below ##############
1110
## Also remove the following deps:
1211
## AbstractTrees, ZipFiles, CodecZLib

src/data/dataloader.jl

+50-41
Original file line numberDiff line numberDiff line change
@@ -13,59 +13,66 @@ struct DataLoader{D,R<:AbstractRNG}
1313
end
1414

1515
"""
16-
DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
16+
Flux.DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
1717
18-
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
18+
An object that iterates over mini-batches of `data`,
19+
each mini-batch containing `batchsize` observations
1920
(except possibly the last one).
2021
2122
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
22-
The last dimension in each tensor is considered to be the observation dimension.
23+
The last dimension in each tensor is the observation dimension, i.e. the one
24+
divided into mini-batches.
2325
24-
If `shuffle=true`, shuffles the observations each time iterations are re-started.
25-
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
26+
If `shuffle=true`, it shuffles the observations each time iterations are re-started.
27+
If `partial=false` and the number of observations is not divisible by the batchsize,
28+
then the last mini-batch is dropped.
2629
2730
The original data is preserved in the `data` field of the DataLoader.
2831
29-
Usage example:
32+
# Examples
33+
```jldoctest
34+
julia> Xtrain = rand(10, 100);
3035
31-
Xtrain = rand(10, 100)
32-
train_loader = DataLoader(Xtrain, batchsize=2)
33-
# iterate over 50 mini-batches of size 2
34-
for x in train_loader
35-
@assert size(x) == (10, 2)
36-
...
37-
end
36+
julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
3837
39-
train_loader.data # original dataset
38+
julia> for x in array_loader
39+
@assert size(x) == (10, 2)
40+
# do something with x, 50 times
41+
end
4042
41-
# similar, but yielding tuples
42-
train_loader = DataLoader((Xtrain,), batchsize=2)
43-
for (x,) in train_loader
44-
@assert size(x) == (10, 2)
45-
...
46-
end
43+
julia> array_loader.data === Xtrain
44+
true
4745
48-
Xtrain = rand(10, 100)
49-
Ytrain = rand(100)
50-
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
51-
for epoch in 1:100
52-
for (x, y) in train_loader
53-
@assert size(x) == (10, 2)
54-
@assert size(y) == (2,)
55-
...
56-
end
57-
end
46+
julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
5847
59-
# train for 10 epochs
60-
using IterTools: ncycle
61-
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
48+
julia> for x in tuple_loader
49+
@assert x isa Tuple{Matrix}
50+
@assert size(x[1]) == (10, 2)
51+
end
6252
63-
# can use NamedTuple to name tensors
64-
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
65-
for datum in train_loader
66-
@assert size(datum.images) == (10, 2)
67-
@assert size(datum.labels) == (2,)
68-
end
53+
julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples
54+
55+
julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
56+
57+
julia> for epoch in 1:100
58+
for (x, y) in train_loader # access via tuple destructuring
59+
@assert size(x) == (10, 5)
60+
@assert size(y) == (5,)
61+
# loss += f(x, y) # etc, runs 100 * 20 times
62+
end
63+
end
64+
65+
julia> first(train_loader).label isa Vector{Char} # access via property name
66+
true
67+
68+
julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
69+
false
70+
71+
julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
72+
10×30 Matrix{Int8}
73+
10×30 Matrix{Int8}
74+
10×4 Matrix{Int8}
75+
```
6976
"""
7077
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
7178
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
@@ -100,8 +107,10 @@ _nobs(data::AbstractArray) = size(data)[end]
100107
function _nobs(data::Union{Tuple, NamedTuple})
101108
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
102109
n = _nobs(data[1])
103-
if !all(x -> _nobs(x) == n, Base.tail(data))
104-
throw(DimensionMismatch("All data should contain same number of observations"))
110+
for i in keys(data)
111+
ni = _nobs(data[i])
112+
n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. " *
113+
"But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni."))
105114
end
106115
return n
107116
end

0 commit comments

Comments
 (0)