Skip to content

Commit dfb390d

Browse files
use randn32
1 parent 058a4a0 commit dfb390d

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

src/layers/basic.jl

+2-5
Original file line numberDiff line numberDiff line change
@@ -471,11 +471,8 @@ end
471471

472472
@functor Embedding
473473

474-
function Embedding(in::Integer, out::Integer;
475-
init = (i...) -> randn(Float32, i...))
476-
return Embedding(init(out, in))
477-
end
478-
474+
Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
475+
479476
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
480477
(m::Embedding)(x::Integer) = m([x])
481478
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)

src/utils.jl

+2
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit
368368

369369
ones32(dims...) = Base.ones(Float32, dims...)
370370
zeros32(dims...) = Base.zeros(Float32, dims...)
371+
rand32(dims...) = Base.rand(Float32, dims...)
372+
randn32(dims...) = Base.randn(Float32, dims...)
371373

372374
"""
373375
create_bias(weights, bias, length)

0 commit comments

Comments
 (0)