Skip to content

Commit 3914543

Browse files
committed
also, don't allow bias to be wider type
1 parent 438db81 commit 3914543

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,14 @@ to the constructor's keyword `bias=bias`.
514514
* `bias == true` creates a trainable array of the given size, of the same type as `weights`, initialised to zero.
515515
* `bias == false` returns `false`, which is understood by AD to be non-differentiable.
516516
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
517-
It does not at present correct the `eltype` to match that of `weights`.
517+
It will also correct the `eltype` to match that of `weights`.
518518
"""
519519
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
520520
bias ? fill!(similar(weights, dims...), 0) : false
521521
end
522522
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
523523
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
524-
bias
524+
convert(AbstractArray{eltype(weights)}, bias)
525525
end
526526

527527

test/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ import Flux: activations
5858
@test Dense(rand(100,10), false, tanh).σ == tanh
5959
@test Dense(rand(100,10), rand(100)).σ == identity
6060
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
61-
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
61+
@test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
6262

6363
@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
64-
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
64+
@test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
6565

6666
@test_throws MethodError Dense(10, 10.5)
6767
@test_throws MethodError Dense(10, 10.5, tanh)

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ end
299299

300300
c3 = ConvTranspose((3,), 2=>4, relu)
301301
@test c3(x) isa Array{Float32, 3}
302-
if VERSION >= "v1.8"
302+
if VERSION >= v"1.8"
303303
@test (@inferred c3(x); true) # fails on 1.6
304304
end
305305
end

0 commit comments

Comments
 (0)