File tree 2 files changed +4
-5
lines changed
2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change 471
471
472
472
@functor Embedding
473
473
474
- function Embedding (in:: Integer , out:: Integer ;
475
- init = (i... ) -> randn (Float32, i... ))
476
- return Embedding (init (out, in))
477
- end
478
-
474
+ Embedding (in:: Integer , out:: Integer ; init = randn32) = Embedding (init (out, in))
475
+
479
476
(m:: Embedding )(x:: Union{OneHotVector, OneHotMatrix} ) = m. weight * x # equivalent to m.weight[:,onecold(x)]
480
477
(m:: Embedding )(x:: Integer ) = m ([x])
481
478
(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
Original file line number Diff line number Diff line change @@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit
368
368
369
369
ones32 (dims... ) = Base. ones (Float32, dims... )
370
370
zeros32 (dims... ) = Base. zeros (Float32, dims... )
371
+ rand32 (dims... ) = Base. rand (Float32, dims... )
372
+ randn32 (dims... ) = Base. randn (Float32, dims... )
371
373
372
374
"""
373
375
create_bias(weights, bias, length)
You can’t perform that action at this time.
0 commit comments