Skip to content

Commit db4bb28

Browse files
committed
also, don't allow bias to be wider type
1 parent 438db81 commit db4bb28

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,14 @@ to the constructor's keyword `bias=bias`.
514514
* `bias == true` creates a trainable array of the given size, of the same type as `weights`, initialised to zero.
515515
* `bias == false` returns `false`, which is understood by AD to be non-differentiable.
516516
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
517-
It does not at present correct the `eltype` to match that of `weights`.
517+
It will also correct the `eltype` to match that of `weights`.
518518
"""
519519
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
520520
bias ? fill!(similar(weights, dims...), 0) : false
521521
end
522522
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
523523
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
524-
bias
524+
convert(AbstractArray{eltype(weights)}, bias)
525525
end
526526

527527

test/layers/basic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ import Flux: activations
5858
@test Dense(rand(100,10), false, tanh).σ == tanh
5959
@test Dense(rand(100,10), rand(100)).σ == identity
6060
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
61-
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
61+
@test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
6262

6363
@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
64-
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
64+
@test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
6565

6666
@test_throws MethodError Dense(10, 10.5)
6767
@test_throws MethodError Dense(10, 10.5, tanh)

0 commit comments

Comments
 (0)