Skip to content

Commit d511d7a

Browse files
authoredFeb 8, 2023
Match layer output to weights (#2156)
* _match_eltype * fix tests * skip a test on 1.6 * also, don't allow bias to be wider type * use rand32 etc * fixup
1 parent c5a691a commit d511d7a

14 files changed

+172
-53
lines changed
 

‎docs/src/tutorials/linear_regression.md

+7-5
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ Let's start by initializing our dataset. We will be using the [`BostonHousing`](
272272
julia> dataset = BostonHousing();
273273
274274
julia> x, y = BostonHousing(as_df=false)[:];
275+
276+
julia> x, y = Float32.(x), Float32.(y);
275277
```
276278

277279
We can now split the obtained data into training and testing data -
@@ -287,7 +289,7 @@ This data contains a diverse number of features, which means that the features h
287289

288290
```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
289291
julia> std(x_train)
290-
134.06784844377117
292+
134.06786f0
291293
```
292294

293295
The data is indeed not normalised. We can use the [`Flux.normalise`](@ref) function to normalise the training data.
@@ -296,7 +298,7 @@ The data is indeed not normalised. We can use the [`Flux.normalise`](@ref) funct
296298
julia> x_train_n = Flux.normalise(x_train);
297299
298300
julia> std(x_train_n)
299-
1.0000843694328236
301+
1.0000844f0
300302
```
301303

302304
The standard deviation is now close to one! Our data is ready!
@@ -318,7 +320,7 @@ julia> function loss(model, x, y)
318320
end;
319321
320322
julia> loss(model, x_train_n, y_train)
321-
676.165591625047
323+
676.1656f0
322324
```
323325

324326
We can now proceed to the training phase!
@@ -363,7 +365,7 @@ Let's have a look at the loss -
363365

364366
```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
365367
julia> loss(model, x_train_n, y_train)
366-
27.127200028562164
368+
27.1272f0
367369
```
368370

369371
The loss went down significantly! It can be minimized further by choosing an even smaller `δ`.
@@ -376,7 +378,7 @@ The last step of this tutorial would be to test our model using the testing data
376378
julia> x_test_n = Flux.normalise(x_test);
377379
378380
julia> loss(model, x_test_n, y_test)
379-
66.91014769713368
381+
66.91015f0
380382
```
381383

382384
The loss is not as small as the loss of the training data, but it looks good! This also shows that our model is not overfitting!

‎src/layers/basic.jl

+9-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ true
1616
1717
julia> m = Chain(Dense(10 => 5, tanh), Dense(5 => 2));
1818
19-
julia> x = rand(10, 32);
19+
julia> x = rand32(10, 32);
2020
2121
julia> m(x) == m[2](m[1](x))
2222
true
@@ -132,11 +132,11 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided
132132
julia> d = Dense(5 => 2)
133133
Dense(5 => 2) # 12 parameters
134134
135-
julia> d(rand(Float32, 5, 64)) |> size
135+
julia> d(rand32(5, 64)) |> size
136136
(2, 64)
137137
138-
julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions
139-
(2, 1, 1, 64)
138+
julia> d(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions
139+
(2, 6, 4, 64)
140140
141141
julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
142142
Dense(5 => 2, tanh; bias=false) # 10 parameters
@@ -169,7 +169,8 @@ end
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
172-
return σ.(a.weight * x .+ a.bias)
172+
xT = _match_eltype(a, x) # fixes Float64 input, etc.
173+
return σ.(a.weight * xT .+ a.bias)
173174
end
174175

175176
(a::Dense)(x::AbstractArray) =
@@ -475,7 +476,7 @@ julia> model = Chain(Dense(3 => 5),
475476
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
476477
Dense(8 => 17));
477478
478-
julia> model(rand(3)) |> size
479+
julia> model(rand32(3)) |> size
479480
(17,)
480481
481482
julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2))
@@ -485,10 +486,10 @@ Parallel(
485486
β = Dense(5 => 2), # 12 parameters
486487
) # Total: 4 arrays, 34 parameters, 392 bytes.
487488
488-
julia> model2(rand(10), rand(5)) |> size
489+
julia> model2(rand32(10), rand32(5)) |> size
489490
(2,)
490491
491-
julia> model2[:α](rand(10)) |> size
492+
julia> model2[:α](rand32(10)) |> size
492493
(2,)
493494
494495
julia> model2[:β] == model2[2]

‎src/layers/conv.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ See also [`Conv`](@ref), [`MaxPool`](@ref).
2222
2323
# Examples
2424
```jldoctest
25-
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images
25+
julia> xs = rand32(100, 100, 3, 50); # a batch of images
2626
2727
julia> layer = Conv((2,2), 3 => 7, pad=SamePad())
2828
Conv((2, 2), 3 => 7, pad=(1, 0, 1, 0)) # 91 parameters
@@ -96,7 +96,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
9696
9797
# Examples
9898
```jldoctest
99-
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images
99+
julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images
100100
101101
julia> layer = Conv((5,5), 3 => 7, relu; bias = false)
102102
Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters
@@ -197,7 +197,8 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
197197
function (c::Conv)(x::AbstractArray)
198198
σ = NNlib.fast_act(c.σ, x)
199199
cdims = conv_dims(c, x)
200-
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
200+
xT = _match_eltype(c, x)
201+
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
201202
end
202203

203204
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -237,7 +238,7 @@ See also [`Conv`](@ref) for more detailed description of keywords.
237238
238239
# Examples
239240
```jldoctest
240-
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
241+
julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images
241242
242243
julia> layer = ConvTranspose((5,5), 3 => 7, relu)
243244
ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters
@@ -330,7 +331,8 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
330331
function (c::ConvTranspose)(x::AbstractArray)
331332
σ = NNlib.fast_act(c.σ, x)
332333
cdims = conv_transpose_dims(c, x)
333-
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
334+
xT = _match_eltype(c, x)
335+
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
334336
end
335337

336338
function Base.show(io::IO, l::ConvTranspose)
@@ -468,7 +470,8 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
468470
function (c::CrossCor)(x::AbstractArray)
469471
σ = NNlib.fast_act(c.σ, x)
470472
cdims = crosscor_dims(c, x)
471-
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
473+
xT = _match_eltype(c, x)
474+
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
472475
end
473476

474477
function Base.show(io::IO, l::CrossCor)

‎src/layers/normalise.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ julia> m(ones(2, 7)) # test mode, no effect
3636
2.0 2.0 2.0 2.0 2.0 2.0 2.0
3737
2.0 2.0 2.0 2.0 2.0 2.0 2.0
3838
39-
julia> Flux.trainmode!(m); # would happen within gradient
39+
julia> Flux.trainmode!(m); # equivalent to use within gradient
4040
4141
julia> m(ones(2, 7))
4242
3×7 Matrix{Float64}:
@@ -48,11 +48,11 @@ julia> y = m(ones(2, 10_000));
4848
4949
julia> using Statistics
5050
51-
julia> mean(y) # is about 2.0, as for test mode
52-
1.9892222222222182
51+
julia> mean(y) # is about 2.0, same as in test mode
52+
1.9989999999999961
5353
5454
julia> mean(iszero, y) # is about 0.4
55-
0.40323333333333333
55+
0.4003
5656
```
5757
"""
5858
mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
@@ -96,7 +96,7 @@ Does nothing to the input once [`testmode!`](@ref) is true.
9696
```jldoctest
9797
julia> using Statistics
9898
99-
julia> x = randn(1000,1);
99+
julia> x = randn32(1000,1);
100100
101101
julia> m = Chain(Dense(1000 => 1000, selu), AlphaDropout(0.2));
102102

‎src/layers/recurrent.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,11 @@ end
200200
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
201201
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
202202

203-
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T}
203+
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
204204
Wi, Wh, b = m.Wi, m.Wh, m.b
205205
σ = NNlib.fast_act(m.σ, x)
206-
h = σ.(Wi*x .+ Wh*h .+ b)
206+
xT = _match_eltype(m, T, x)
207+
h = σ.(Wi*xT .+ Wh*h .+ b)
207208
return h, reshape_cell_output(h, x)
208209
end
209210

@@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair;
305306
return cell
306307
end
307308

308-
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
309+
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
309310
b, o = m.b, size(h, 1)
310-
g = muladd(m.Wi, x, muladd(m.Wh, h, b))
311+
xT = _match_eltype(m, T, x)
312+
g = muladd(m.Wi, xT, muladd(m.Wh, h, b))
311313
input, forget, cell, output = multigate(g, o, Val(4))
312314
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
313315
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
@@ -376,9 +378,10 @@ end
376378
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
377379
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
378380

379-
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
381+
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
380382
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
381-
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
383+
xT = _match_eltype(m, T, x)
384+
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
382385
r, z = _gru_output(gxs, ghs, bs)
383386
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
384387
h′ = @. (1 - z) *+ z * h
@@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
444447
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
445448
init(out, out), init_state(out,1))
446449

447-
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T}
450+
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T}
448451
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
449-
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
452+
xT = _match_eltype(m, T, x)
453+
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
450454
r, z = _gru_output(gxs, ghs, bs)
451455
= tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
452456
h′ = @. (1 - z) *+ z * h

‎src/layers/stateless.jl

+44
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,47 @@ true
5757
σ = std(x, dims=dims, mean=μ, corrected=false)
5858
return @. (x - μ) /+ ϵ)
5959
end
60+
61+
"""
62+
_match_eltype(layer, ::Type{T}, x)
63+
_match_eltype(layer, x)
64+
65+
This internal function corrects most layer input to match the type of the weights.
66+
The second method uses `T = eltype(layer.weight)`.
67+
68+
It solves a common performance bug: Before, accidentally supplying `Float64` input,
69+
or an activation function which produces `Float64`, would silently run the
70+
entire forward pass in this precision.
71+
"""
72+
_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x
73+
74+
# A common mistake, print a friendly warning, and fix it:
75+
function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64})
76+
# This warning is the only reason this needs to take the layer.
77+
@warn "Layer with Float32 parameters got Float64 input.
78+
The input will be converted, but any earlier layers may be very slow." layer summary(x) maxlog=1
79+
convert(AbstractArray{Float32}, x)
80+
end
81+
82+
# Allow OneHot to reach specialisation of * etc:
83+
_match_eltype(layer, ::Type, x::OneHotLike) = x
84+
85+
# Other floats, and integers, silently fix.
86+
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T}
87+
convert(AbstractArray{T}, x)
88+
end
89+
90+
# Weird types like Nil, Dual, etc, we allow through:
91+
_match_eltype(layer, ::Type, x::AbstractArray) = x
92+
93+
# 2-arg method, for common layers with layer.weight
94+
_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x)
95+
96+
# Trivial rule:
97+
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T}
98+
_match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
99+
end
100+
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, x::AbstractArray)
101+
_match_eltype(layer, x), dx -> (ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
102+
end
103+

‎src/outputsize.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ for (fn, Dims) in ((:conv, DenseConvDims),)
173173
end
174174
end
175175

176+
# Recurrent layers: just convert to the type they like & convert back.
177+
178+
for Cell in [:RNNCell, :LSTMCell, :GRUCell, :GRUv3Cell]
179+
@eval function (m::Recur{<:$Cell})(x::AbstractArray{Nil})
180+
xT = fill!(similar(m.cell.Wi, size(x)), 0)
181+
_, y = m.cell(m.state, xT) # discard the new state
182+
return similar(x, size(y))
183+
end
184+
end
185+
176186

177187
"""
178188
@autosize (size...,) Chain(Layer(_ => 2), Layer(_), ...)
@@ -229,7 +239,6 @@ Limitations:
229239
* While `@autosize (5, 32) Flux.Bilinear(_ => 7)` is OK, something like `Bilinear((_, _) => 7)` will fail.
230240
* While `Scale(_)` and `LayerNorm(_)` are fine (and use the first dimension), `Scale(_,_)` and `LayerNorm(_,_)`
231241
will fail if `size(x,1) != size(x,2)`.
232-
* RNNs won't work: `@autosize (7, 11) LSTM(_ => 5)` fails, because `outputsize(RNN(3=>7), (3,))` also fails, a known issue.
233242
"""
234243
macro autosize(size, model)
235244
Meta.isexpr(size, :tuple) || error("@autosize's first argument must be a tuple, the size of the input")

‎src/train.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ It differs from `Optimisers.setup` in that it:
2727
2828
# Example
2929
```jldoctest
30-
julia> model = Dense(2=>1, leakyrelu; init=ones32);
30+
julia> model = Dense(2=>1, leakyrelu; init=ones);
3131
3232
julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
33-
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ())
33+
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ())
3434
3535
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
3636
@@ -39,11 +39,11 @@ julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
3939
end
4040
4141
julia> model.bias # was zero, mutated by Flux.train!
42-
1-element Vector{Float32}:
43-
10.190001
42+
1-element Vector{Float64}:
43+
10.19
4444
4545
julia> opt_state # mutated by Flux.train!
46-
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
46+
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ())
4747
```
4848
"""
4949
function setup(rule::Optimisers.AbstractRule, model)

‎src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,14 @@ to the constructor's keyword `bias=bias`.
514514
* `bias == true` creates a trainable array of the given size, of the same type as `weights`, initialised to zero.
515515
* `bias == false` returns `false`, which is understood by AD to be non-differentiable.
516516
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
517-
It does not at present correct the `eltype` to match that of `weights`.
517+
It will also correct the `eltype` to match that of `weights`.
518518
"""
519519
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
520520
bias ? fill!(similar(weights, dims...), 0) : false
521521
end
522522
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
523523
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
524-
bias
524+
convert(AbstractArray{eltype(weights)}, bias)
525525
end
526526

527527

‎test/layers/basic.jl

+19-2
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ import Flux: activations
5858
@test Dense(rand(100,10), false, tanh).σ == tanh
5959
@test Dense(rand(100,10), rand(100)).σ == identity
6060
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
61-
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
61+
@test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
6262

6363
@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
64-
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
64+
@test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
6565

6666
@test_throws MethodError Dense(10, 10.5)
6767
@test_throws MethodError Dense(10, 10.5, tanh)
@@ -89,6 +89,23 @@ import Flux: activations
8989
@test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
9090
@test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
9191
end
92+
@testset "type matching" begin
93+
d1 = Dense(2 => 3)
94+
d2 = Dense(d1.weight, false)
95+
x1 = randn(Float32, 2, 4)
96+
@test d1(x1) d2(x1) d1.weight * x1
97+
x2 = Float64.(x1)
98+
@test d1(x2) d2(x2) d1.weight * x2
99+
@test d1(x2) isa Array{Float32} # tests _match_eltype, will print a warning
100+
@test d2(x2) isa Array{Float32}
101+
102+
x3 = rand(-5:5, 2, 4)
103+
@test d1(x3) d2(x3) d1.weight * x3
104+
x4 = rand(Bool, 2, 4)
105+
@test d1(x4) d2(x4) d1.weight * x4
106+
x5 = Flux.onehotbatch(rand(Bool, 5), (true, false))
107+
@test d1(x5) d2(x5) d1.weight * x5
108+
end
92109
end
93110

94111
@testset "Scale" begin

‎test/layers/conv.jl

+17
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,20 @@ end
286286
end
287287
@test_throws DimensionMismatch fun(rand(2,3,4), rand(6))
288288
end
289+
290+
@testset "type matching" begin
291+
x = rand(Float64, 10,2,5)
292+
xi = rand(-3:3, 10,2,5)
293+
c1 = Conv((3,), 2=>4, relu)
294+
@test @inferred(c1(x)) isa Array{Float32, 3}
295+
@test c1(xi) isa Array{Float32, 3}
296+
297+
c2 = CrossCor((3,), 2=>1, relu)
298+
@test @inferred(c2(x)) isa Array{Float32, 3}
299+
300+
c3 = ConvTranspose((3,), 2=>4, relu)
301+
@test c3(x) isa Array{Float32, 3}
302+
if VERSION >= v"1.8"
303+
@test (@inferred c3(x); true) # fails on 1.6
304+
end
305+
end

‎test/layers/recurrent.jl

+24-9
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,6 @@ end
9090
end
9191
end
9292

93-
@testset "RNN-input-state-eltypes" begin
94-
@testset for R in [RNN, GRU, LSTM, GRUv3]
95-
m = R(3 => 5)
96-
x = rand(Float64, 3, 1)
97-
Flux.reset!(m)
98-
@test_throws MethodError m(x)
99-
end
100-
end
101-
10293
@testset "multigate" begin
10394
x = rand(6, 5)
10495
res, (dx,) = Flux.withgradient(x) do x
@@ -169,3 +160,27 @@ end
169160
@test size(m(x3)) == (5, 1, 2)
170161
end
171162
end
163+
164+
@testset "type matching" begin
165+
x = rand(Float64, 2, 4)
166+
m1 = RNN(2=>3)
167+
@test m1(x) isa Matrix{Float32} # uses _match_eltype, may print a warning
168+
@test m1.state isa Matrix{Float32}
169+
@test (@inferred m1(x); true)
170+
@test Flux.outputsize(m1, size(x)) == size(m1(x))
171+
172+
m2 = LSTM(2=>3)
173+
@test m2(x) isa Matrix{Float32}
174+
@test (@inferred m2(x); true)
175+
@test Flux.outputsize(m2, size(x)) == size(m2(x))
176+
177+
m3 = GRU(2=>3)
178+
@test m3(x) isa Matrix{Float32}
179+
@test (@inferred m3(x); true)
180+
@test Flux.outputsize(m3, size(x)) == size(m3(x))
181+
182+
m4 = GRUv3(2=>3)
183+
@test m4(x) isa Matrix{Float32}
184+
@test (@inferred m4(x); true)
185+
@test Flux.outputsize(m4, size(x)) == size(m4(x))
186+
end

‎test/outputsize.jl

+7
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,10 @@ end
257257
# Can't let |> gpu act before the arrays are materialized... so it's an error:
258258
@test_throws ErrorException @eval @autosize (1,2,3) Dense(_=>2) |> f64
259259
end
260+
261+
@testset "type matching" begin
262+
# Check that _match_eltype doesn't replace this with an array of Float32:
263+
@test Flux._match_eltype(Dense(2=>3), fill(Flux.Nil(),2,4)) isa Matrix{Flux.Nil}
264+
# For RNN etc there's a special path:
265+
@test RNN(2=>3)(fill(Flux.Nil(),2,4)) isa Matrix{Flux.Nil}
266+
end

‎test/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ end
290290
x32 = rand(Float32, 10)
291291
@test eltype(m[1].weight) == Float32
292292
@test eltype(m(x32)) == Float32
293-
@test eltype(m(x64)) == Float64
294-
@test eltype(f64(m)(x32)) == Float64
293+
@test eltype(m(x64)) == Float32 # fixed by _match_eltype
294+
@test eltype(f64(m)(x32)) == Float64 # _match_eltype promotes, Julia would too
295295
@test eltype(f64(m)(x64)) == Float64
296296
@test eltype(f64(m)[1].weight) == Float64
297297
@test eltype(f32(f64(m))[1].weight) == Float32

0 commit comments

Comments
 (0)
Please sign in to comment.