@@ -13,59 +13,66 @@ struct DataLoader{D,R<:AbstractRNG}
13
13
end
14
14
15
15
"""
16
- DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
16
+ Flux. DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
17
17
18
- An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
18
+ An object that iterates over mini-batches of `data`,
19
+ each mini-batch containing `batchsize` observations
19
20
(except possibly the last one).
20
21
21
22
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
22
- The last dimension in each tensor is considered to be the observation dimension.
23
+ The last dimension in each tensor is the observation dimension, i.e. the one
24
+ divided into mini-batches.
23
25
24
- If `shuffle=true`, shuffles the observations each time iterations are re-started.
25
- If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
26
+ If `shuffle=true`, it shuffles the observations each time iterations are re-started.
27
+ If `partial=false` and the number of observations is not divisible by the batchsize,
28
+ then the last mini-batch is dropped.
26
29
27
30
The original data is preserved in the `data` field of the DataLoader.
28
31
29
- Usage example:
32
+ # Examples
33
+ ```jldoctest
34
+ julia> Xtrain = rand(10, 100);
30
35
31
- Xtrain = rand(10, 100)
32
- train_loader = DataLoader(Xtrain, batchsize=2)
33
- # iterate over 50 mini-batches of size 2
34
- for x in train_loader
35
- @assert size(x) == (10, 2)
36
- ...
37
- end
36
+ julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
38
37
39
- train_loader.data # original dataset
38
+ julia> for x in array_loader
39
+ @assert size(x) == (10, 2)
40
+ # do something with x, 50 times
41
+ end
40
42
41
- # similar, but yielding tuples
42
- train_loader = DataLoader((Xtrain,), batchsize=2)
43
- for (x,) in train_loader
44
- @assert size(x) == (10, 2)
45
- ...
46
- end
43
+ julia> array_loader.data === Xtrain
44
+ true
47
45
48
- Xtrain = rand(10, 100)
49
- Ytrain = rand(100)
50
- train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
51
- for epoch in 1:100
52
- for (x, y) in train_loader
53
- @assert size(x) == (10, 2)
54
- @assert size(y) == (2,)
55
- ...
56
- end
57
- end
46
+ julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
58
47
59
- # train for 10 epochs
60
- using IterTools: ncycle
61
- Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
48
+ julia> for x in tuple_loader
49
+ @assert x isa Tuple{Matrix}
50
+ @assert size(x[1]) == (10, 2)
51
+ end
62
52
63
- # can use NamedTuple to name tensors
64
- train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
65
- for datum in train_loader
66
- @assert size(datum.images) == (10, 2)
67
- @assert size(datum.labels) == (2,)
68
- end
53
+ julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples
54
+
55
+ julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
56
+
57
+ julia> for epoch in 1:100
58
+ for (x, y) in train_loader # access via tuple destructuring
59
+ @assert size(x) == (10, 5)
60
+ @assert size(y) == (5,)
61
+ # loss += f(x, y) # etc, runs 100 * 20 times
62
+ end
63
+ end
64
+
65
+ julia> first(train_loader).label isa Vector{Char} # access via property name
66
+ true
67
+
68
+ julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
69
+ false
70
+
71
+ julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
72
+ 10×30 Matrix{Int8}
73
+ 10×30 Matrix{Int8}
74
+ 10×4 Matrix{Int8}
75
+ ```
69
76
"""
70
77
function DataLoader (data; batchsize= 1 , shuffle= false , partial= true , rng= GLOBAL_RNG)
71
78
batchsize > 0 || throw (ArgumentError (" Need positive batchsize" ))
@@ -100,8 +107,10 @@ _nobs(data::AbstractArray) = size(data)[end]
100
107
function _nobs (data:: Union{Tuple, NamedTuple} )
101
108
length (data) > 0 || throw (ArgumentError (" Need at least one data input" ))
102
109
n = _nobs (data[1 ])
103
- if ! all (x -> _nobs (x) == n, Base. tail (data))
104
- throw (DimensionMismatch (" All data should contain same number of observations" ))
110
+ for i in keys (data)
111
+ ni = _nobs (data[i])
112
+ n == ni || throw (DimensionMismatch (" All data inputs should have the same number of observations, i.e. size in the last dimension. " *
113
+ " But data[$(repr (first (keys (data)))) ] ($(summary (data[1 ])) ) has $n , while data[$(repr (i)) ] ($(summary (data[i])) ) has $ni ." ))
105
114
end
106
115
return n
107
116
end
0 commit comments