@@ -200,10 +200,11 @@ end
200
200
RNNCell ((in, out):: Pair , σ= tanh; init= Flux. glorot_uniform, initb= zeros32, init_state= zeros32) =
201
201
RNNCell (σ, init (out, in), init (out, out), initb (out), init_state (out,1 ))
202
202
203
- function (m:: RNNCell{F,I,H,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T },OneHotArray} ) where {F,I,H,V,T}
203
+ function (m:: RNNCell{F,I,H,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{<:AbstractFloat },OneHotArray} ) where {F,I,H,V,T}
204
204
Wi, Wh, b = m. Wi, m. Wh, m. b
205
205
σ = NNlib. fast_act (m. σ, x)
206
- h = σ .(Wi* x .+ Wh* h .+ b)
206
+ xT = _match_eltype (m, T, x)
207
+ h = σ .(Wi* xT .+ Wh* h .+ b)
207
208
return h, reshape_cell_output (h, x)
208
209
end
209
210
@@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair;
305
306
return cell
306
307
end
307
308
308
- function (m:: LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T },OneHotArray} ) where {I,H,V,T}
309
+ function (m:: LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{<:AbstractFloat },OneHotArray} ) where {I,H,V,T}
309
310
b, o = m. b, size (h, 1 )
310
- g = muladd (m. Wi, x, muladd (m. Wh, h, b))
311
+ xT = _match_eltype (m, T, x)
312
+ g = muladd (m. Wi, xT, muladd (m. Wh, h, b))
311
313
input, forget, cell, output = multigate (g, o, Val (4 ))
312
314
c′ = @. sigmoid_fast (forget) * c + sigmoid_fast (input) * tanh_fast (cell)
313
315
h′ = @. sigmoid_fast (output) * tanh_fast (c′)
376
378
GRUCell ((in, out):: Pair ; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
377
379
GRUCell (init (out * 3 , in), init (out * 3 , out), initb (out * 3 ), init_state (out,1 ))
378
380
379
- function (m:: GRUCell{I,H,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T },OneHotArray} ) where {I,H,V,T}
381
+ function (m:: GRUCell{I,H,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{<:AbstractFloat },OneHotArray} ) where {I,H,V,T}
380
382
Wi, Wh, b, o = m. Wi, m. Wh, m. b, size (h, 1 )
381
- gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (3 )), multigate (b, o, Val (3 ))
383
+ xT = _match_eltype (m, T, x)
384
+ gxs, ghs, bs = multigate (Wi* xT, o, Val (3 )), multigate (Wh* h, o, Val (3 )), multigate (b, o, Val (3 ))
382
385
r, z = _gru_output (gxs, ghs, bs)
383
386
h̃ = @. tanh_fast (gxs[3 ] + r * ghs[3 ] + bs[3 ])
384
387
h′ = @. (1 - z) * h̃ + z * h
@@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
444
447
GRUv3Cell (init (out * 3 , in), init (out * 2 , out), initb (out * 3 ),
445
448
init (out, out), init_state (out,1 ))
446
449
447
- function (m:: GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T },OneHotArray} ) where {I,H,V,HH,T}
450
+ function (m:: GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{<:AbstractFloat },OneHotArray} ) where {I,H,V,HH,T}
448
451
Wi, Wh, b, Wh_h̃, o = m. Wi, m. Wh, m. b, m. Wh_h̃, size (h, 1 )
449
- gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (2 )), multigate (b, o, Val (3 ))
452
+ xT = _match_eltype (m, T, x)
453
+ gxs, ghs, bs = multigate (Wi* xT, o, Val (3 )), multigate (Wh* h, o, Val (2 )), multigate (b, o, Val (3 ))
450
454
r, z = _gru_output (gxs, ghs, bs)
451
455
h̃ = tanh_fast .(gxs[3 ] .+ (Wh_h̃ * (r .* h)) .+ bs[3 ])
452
456
h′ = @. (1 - z) * h̃ + z * h
0 commit comments