diff --git a/NEWS.md b/NEWS.md index db05c18067..f7f9327599 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,7 +4,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl ## v0.15.0 * Recurrent layers have undergone a complete redesign in [PR 2500](https://github.com/FluxML/Flux.jl/pull/2500). - * `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing: `rnncell(x_t, h_t) -> h_{t+1}`. + * `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing: `rnncell(h_t, x_t) -> h_{t+1}`. * `RNN`, `LSTM`, and `GRU` no longer store the hidden state internally, it has to be explicitely passed to the layer. Moreover, they now process entire sequences at once, rather than one element at a time: `rnn(x, h) -> h′`. * The `Recur` wrapper has been deprecated and removed. * The `reset!` function has also been removed; state management is now entirely up to the user. diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index 5b2e70f095..4b7ddf10b2 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -19,7 +19,7 @@ Wxh = randn(Float32, output_size, input_size) Whh = randn(Float32, output_size, output_size) b = zeros(Float32, output_size) -function rnn_cell(x, h) +function rnn_cell(h, x) h = tanh.(Wxh * x .+ Whh * h .+ b) return h end @@ -33,12 +33,12 @@ h0 = zeros(Float32, output_size) y = [] ht = h0 for xt in x - ht = rnn_cell(xt, ht) + ht = rnn_cell(ht, xt) y = [y; [ht]] # concatenate in non-mutating (AD friendly) way end ``` -Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`. +Notice how the above is essentially a `Dense` layer that acts on two inputs, `ht` and `xt`. The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model. @@ -58,7 +58,7 @@ rnn_cell = Flux.RNNCell(input_size => output_size) y = [] ht = h0 for xt in x - ht = rnn_cell(xt, ht) + ht = rnn_cell(ht, xt) y = [y; [ht]] end ``` @@ -78,7 +78,7 @@ struct RecurrentCellModel{H,C,D} end # we choose to not train the initial hidden state -Flux.@layer RecurrentCellModel trainable=(cell,dense) +Flux.@layer RecurrentCellModel trainable = (cell, dense) function RecurrentCellModel(input_size::Int, hidden_size::Int) return RecurrentCellModel( @@ -91,7 +91,7 @@ function (m::RecurrentCellModel)(x) z = [] ht = m.h0 for xt in x - ht = m.cell(xt, ht) + ht = m.cell(ht, xt) z = [z; [ht]] end z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len] @@ -151,7 +151,7 @@ function RecurrentModel(input_size::Int, hidden_size::Int) end function (m::RecurrentModel)(x) - z = m.rnn(x, m.h0) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len] + z = m.rnn(m.h0, x) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len] ŷ = m.dense(z) # [1, seq_len, batch_size] or [1, seq_len] return ŷ end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 7c69b80103..e6377892af 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -24,13 +24,13 @@ See [`RNN`](@ref) for a layer that processes entire sequences. # Forward - rnncell(x, [h]) + rnncell([h,] x) The arguments of the forward pass are: -- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros. +- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`. # Examples @@ -48,7 +48,7 @@ h = zeros(Float32, 5) ŷ = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation # is not automatic differentiation friendly yet. # Can use `y = vcat(y, [h])` as an alternative. @@ -74,9 +74,9 @@ function RNNCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) return RNNCell(σ, Wi, Wh, b) end -(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1))) +(m::RNNCell)(x::AbstractVecOrMat) = m(zeros_like(x, size(m.Wh, 1)), x) -function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) +function (m::RNNCell)(h::AbstractVecOrMat, x::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi,2)) σ = NNlib.fast_act(m.σ, x) h = σ.(m.Wi*x .+ m.Wh*h .+ m.bias) @@ -113,7 +113,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step. # Forward - rnn(x, h) + rnn(h, x) The arguments of the forward pass are: @@ -136,7 +136,7 @@ RNN( RNNCell(4 => 6, tanh), # 66 parameters ) # Total: 3 arrays, 66 parameters, 424 bytes. -julia> y = rnn(x, h); # [y] = [d_out, len, batch_size] +julia> y = rnn(h, x); # [y] = [d_out, len, batch_size] ``` Sometimes, the initial hidden state is a learnable parameter. @@ -150,7 +150,7 @@ end Flux.@layer :expand Model -(m::Model)(x) = m.rnn(x, m.h0) +(m::Model)(x) = m.rnn(m.h0, x) model = Model(RNN(32 => 64), zeros(Float32, 64)) ``` @@ -166,15 +166,15 @@ function RNN((in, out)::Pair, σ = tanh; bias = true, init = glorot_uniform) return RNN(cell) end -(m::RNN)(x) = m(x, zeros_like(x, size(m.cell.Wh, 1))) +(m::RNN)(x) = m(zeros_like(x, size(m.cell.Wh, 1)), x) -function (m::RNN)(x, h) +function (m::RNN)(h, x) @assert ndims(x) == 2 || ndims(x) == 3 # [x] = [in, L] or [in, L, B] # [h] = [out] or [out, B] y = [] for x_t in eachslice(x, dims=2) - h = m.cell(x_t, h) + h = m.cell(h, x_t) # y = [y..., h] y = vcat(y, [h]) end @@ -210,7 +210,7 @@ See also [`LSTM`](@ref) for a layer that processes entire sequences. # Forward - lstmcell(x, (h, c)) + lstmcell((h, c), x) lstmcell(x) The arguments of the forward pass are: @@ -233,7 +233,7 @@ julia> c = zeros(Float32, 5); # cell state julia> x = rand(Float32, 3, 4); # in x batch_size -julia> h′, c′ = l(x, (h, c)); +julia> h′, c′ = l((h, c), x); julia> size(h′) # out x batch_size (5, 4) @@ -258,10 +258,10 @@ end function (m::LSTMCell)(x::AbstractVecOrMat) h = zeros_like(x, size(m.Wh, 2)) c = zeros_like(h) - return m(x, (h, c)) + return m((h, c), x) end -function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) +function (m::LSTMCell)((h, c), x::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi, 2)) b = m.bias g = m.Wi * x .+ m.Wh * h .+ b @@ -304,7 +304,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step. # Forward - lstm(x, (h, c)) + lstm((h, c), x) lstm(x) The arguments of the forward pass are: @@ -327,7 +327,7 @@ end Flux.@layer :expand Model -(m::Model)(x) = m.lstm(x, (m.h0, m.c0)) +(m::Model)(x) = m.lstm((m.h0, m.c0), x) d_in, d_out, len, batch_size = 2, 3, 4, 5 x = rand(Float32, (d_in, len, batch_size)) @@ -350,15 +350,15 @@ end function (m::LSTM)(x) h = zeros_like(x, size(m.cell.Wh, 1)) c = zeros_like(h) - return m(x, (h, c)) + return m((h, c), x) end -function (m::LSTM)(x, (h, c)) +function (m::LSTM)((h, c), x) @assert ndims(x) == 2 || ndims(x) == 3 h′ = [] c′ = [] for x_t in eachslice(x, dims=2) - h, c = m.cell(x_t, (h, c)) + h, c = m.cell((h, c), x_t) h′ = vcat(h′, [h]) c′ = vcat(c′, [c]) end @@ -393,7 +393,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences. # Forward - grucell(x, h) + grucell(h, x) grucell(x) The arguments of the forward pass are: @@ -413,7 +413,7 @@ julia> h = zeros(Float32, 5); # hidden state julia> x = rand(Float32, 3, 4); # in x batch_size -julia> h′ = g(x, h); +julia> h′ = g(h, x); ``` """ struct GRUCell{I,H,V} @@ -431,9 +431,9 @@ function GRUCell((in, out)::Pair; init = glorot_uniform, bias = true) return GRUCell(Wi, Wh, b) end -(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) +(m::GRUCell)(x::AbstractVecOrMat) = m(zeros_like(x, size(m.Wh, 2)), x) -function (m::GRUCell)(x::AbstractVecOrMat, h) +function (m::GRUCell)(h, x::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi,2)) gxs = chunk(m.Wi * x, 3, dims=1) ghs = chunk(m.Wh * h, 3, dims=1) @@ -472,7 +472,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step. # Forward - gru(x, h) + gru(h, x) gru(x) The arguments of the forward pass are: @@ -506,15 +506,15 @@ end function (m::GRU)(x) h = zeros_like(x, size(m.cell.Wh, 2)) - return m(x, h) + return m(h, x) end -function (m::GRU)(x, h) +function (m::GRU)(h, x) @assert ndims(x) == 2 || ndims(x) == 3 h′ = [] # [x] = [in, L] or [in, L, B] for x_t in eachslice(x, dims=2) - h = m.cell(x_t, h) + h = m.cell(h, x_t) h′ = vcat(h′, [h]) end return stack(h′, dims=2) @@ -548,7 +548,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. # Forward - gruv3cell(x, h) + gruv3cell(h, x) gruv3cell(x) The arguments of the forward pass are: @@ -575,9 +575,9 @@ function GRUv3Cell((in, out)::Pair; init = glorot_uniform, bias = true) return GRUv3Cell(Wi, Wh, b, Wh_h̃) end -(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) +(m::GRUv3Cell)(x::AbstractVecOrMat) = m(zeros_like(x, size(m.Wh, 2)), x) -function (m::GRUv3Cell)(x::AbstractVecOrMat, h) +function (m::GRUv3Cell)(h, x::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi,2)) gxs = chunk(m.Wi * x, 3, dims=1) ghs = chunk(m.Wh * h, 3, dims=1) @@ -629,14 +629,14 @@ end function (m::GRUv3)(x) h = zeros_like(x, size(m.cell.Wh, 2)) - return m(x, h) + return m(h, x) end -function (m::GRUv3)(x, h) +function (m::GRUv3)(h, x) @assert ndims(x) == 2 || ndims(x) == 3 h′ = [] for x_t in eachslice(x, dims=2) - h = m.cell(x_t, h) + h = m.cell(h, x_t) h′ = vcat(h′, [h]) end return stack(h′, dims=2) diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index d2ef3fe34b..979f74cafe 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -1,9 +1,9 @@ @testset "RNNCell GPU AD" begin - function loss(r, x, h) + function loss(r, h, x) y = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) y = vcat(y, [h]) end # return mean(h) @@ -18,7 +18,7 @@ # Single Step @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS # Multiple Steps - @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS + @test test_gradients(r, h, x; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS end @testset "RNN GPU AD" begin @@ -29,7 +29,7 @@ end Flux.@layer :expand ModelRNN - (m::ModelRNN)(x) = m.rnn(x, m.h0) + (m::ModelRNN)(x) = m.rnn(m.h0, x) d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelRNN(RNN(d_in => d_out), zeros(Float32, d_out)) @@ -41,12 +41,12 @@ end @testset "LSTMCell" begin - function loss(r, x, hc) + function loss(r, hc, x) h, c = hc h′ = [] c′ = [] for x_t in x - h, c = r(x_t, (h, c)) + h, c = r((h, c), x_t) h′ = vcat(h′, [h]) c′ = [c′..., c] end @@ -62,9 +62,9 @@ end c = zeros(Float32, d_out) # Single Step @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false, - loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS + loss = (m, (h, c), x) -> mean(m((h, c), x)[1])) broken = :lstmcell_single ∈ BROKEN_TESTS # Multiple Steps - @test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS + @test test_gradients(cell, (h, c), x; test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS end @testset "LSTM" begin @@ -89,10 +89,10 @@ end end @testset "GRUCell" begin - function loss(r, x, h) + function loss(r, h, x) y = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) y = vcat(y, [h]) end y = stack(y, dims=2) # [D, L] or [D, L, B] @@ -104,7 +104,7 @@ end x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS - @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS + @test test_gradients(r, h, x; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS end @testset "GRU GPU AD" begin @@ -115,7 +115,7 @@ end Flux.@layer :expand ModelGRU - (m::ModelGRU)(x) = m.gru(x, m.h0) + (m::ModelGRU)(x) = m.gru(m.h0, x) d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out)) @@ -126,10 +126,10 @@ end end @testset "GRUv3Cell GPU AD" begin - function loss(r, x, h) + function loss(r, h, x) y = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) y = vcat(y, [h]) end y = stack(y, dims=2) # [D, L] or [D, L, B] @@ -141,7 +141,7 @@ end x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS - @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS + @test test_gradients(r, h, x; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS end @testset "GRUv3 GPU AD" begin @@ -152,7 +152,7 @@ end Flux.@layer :expand ModelGRUv3 - (m::ModelGRUv3)(x) = m.gru(x, m.h0) + (m::ModelGRUv3)(x) = m.gru(m.h0, x) d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out)) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 98e072cdb1..9c7184ccac 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,30 +1,30 @@ @testset "RNNCell" begin - function loss1(r, x, h) + function loss1(r, h, x) for x_t in x - h = r(x_t, h) + h = r(h, x_t) end return mean(h.^2) end - function loss2(r, x, h) - y = [r(x_t, h) for x_t in x] + function loss2(r, h, x) + y = [r(h, x_t) for x_t in x] return sum(mean, y) end - function loss3(r, x, h) + function loss3(r, h, x) y = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) y = [y..., h] end return sum(mean, y) end - function loss4(r, x, h) + function loss4(r, h, x) y = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) y = vcat(y, [h]) end y = stack(y, dims=2) # [D, L] or [D, L, B] @@ -38,28 +38,28 @@ # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h, loss=loss1) # for loop - test_gradients(r, x, h, loss=loss2) # comprehension - test_gradients(r, x, h, loss=loss3) # splat - test_gradients(r, x, h, loss=loss4) # vcat and stack + test_gradients(r, h, x, loss=loss1) # for loop + test_gradients(r, h, x, loss=loss2) # comprehension + test_gradients(r, h, x, loss=loss3) # splat + test_gradients(r, h, x, loss=loss4) # vcat and stack # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + @test r(x[1]) ≈ r(zeros(Float32, 5), x[1]) # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h, loss=loss4) + test_gradients(r, h, x, loss=loss4) # The input sequence has no batch dimension. x = [rand(Float32, 3) for _ in 1:6] h = randn(Float32, 5) - test_gradients(r, x, h, loss=loss4) + test_gradients(r, h, x, loss=loss4) # No Bias r = RNNCell(3 => 5, bias=false) @test length(Flux.trainables(r)) == 2 - test_gradients(r, x, h, loss=loss4) + test_gradients(r, h, x, loss=loss4) end @testset "RNN" begin @@ -70,7 +70,7 @@ end Flux.@layer :expand ModelRNN - (m::ModelRNN)(x) = m.rnn(x, m.h0) + (m::ModelRNN)(x) = m.rnn(m.h0, x) model = ModelRNN(RNN(2 => 4), zeros(Float32, 4)) @@ -82,7 +82,7 @@ end # no initial state same as zero initial state rnn = model.rnn - @test rnn(x) ≈ rnn(x, zeros(Float32, 4)) + @test rnn(x) ≈ rnn(zeros(Float32, 4), x) x = rand(Float32, 2, 3) y = model(x) @@ -93,12 +93,12 @@ end @testset "LSTMCell" begin - function loss(r, x, hc) + function loss(r, hc, x) h, c = hc h′ = [] c′ = [] for x_t in x - h, c = r(x_t, (h, c)) + h, c = r((h, c), x_t) h′ = vcat(h′, [h]) c′ = [c′..., c] end @@ -112,17 +112,17 @@ end x = [rand(Float32, 3, 4) for _ in 1:6] h = zeros(Float32, 5, 4) c = zeros(Float32, 5, 4) - hnew, cnew = cell(x[1], (h, c)) + hnew, cnew = cell((h, c), x[1]) @test hnew isa Matrix{Float32} @test cnew isa Matrix{Float32} @test size(hnew) == (5, 4) @test size(cnew) == (5, 4) - test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) - test_gradients(cell, x, (h, c), loss = loss) + test_gradients(cell, (h, c), x[1], loss = (m, hc, x) -> mean(m(hc, x)[1])) + test_gradients(cell, (h, c), x, loss = loss) # no initial state same as zero initial state hnew1, cnew1 = cell(x[1]) - hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) + hnew2, cnew2 = cell((zeros(Float32, 5), zeros(Float32, 5)), x[1]) @test hnew1 ≈ hnew2 @test cnew1 ≈ cnew2 @@ -140,7 +140,7 @@ end Flux.@layer :expand ModelLSTM - (m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0)) + (m::ModelLSTM)(x) = m.lstm((m.h0, m.c0), x) model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4)) @@ -162,10 +162,10 @@ end end @testset "GRUCell" begin - function loss(r, x, h) + function loss(r, h, x) y = [] for x_t in x - h = r(x_t, h) + h = r(h, x_t) y = vcat(y, [h]) end y = stack(y, dims=2) # [D, L] or [D, L, B] @@ -179,19 +179,19 @@ end # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h; loss) + test_gradients(r, h, x; loss) # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + @test r(x[1]) ≈ r(zeros(Float32, 5), x[1]) # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h; loss) + test_gradients(r, h, x; loss) # The input sequence has no batch dimension. x = [rand(Float32, 3) for _ in 1:6] h = randn(Float32, 5) - test_gradients(r, x, h; loss) + test_gradients(r, h, x; loss) # No Bias r = GRUCell(3 => 5, bias=false) @@ -206,7 +206,7 @@ end Flux.@layer :expand ModelGRU - (m::ModelGRU)(x) = m.gru(x, m.h0) + (m::ModelGRU)(x) = m.gru(m.h0, x) model = ModelGRU(GRU(2 => 4), zeros(Float32, 4)) @@ -218,7 +218,7 @@ end # no initial state same as zero initial state gru = model.gru - @test gru(x) ≈ gru(x, zeros(Float32, 4)) + @test gru(x) ≈ gru(zeros(Float32, 4), x) # No Bias gru = GRU(2 => 4, bias=false) @@ -233,19 +233,19 @@ end # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h) + test_gradients(r, h, x) # no initial state same as zero initial state - @test r(x) ≈ r(x, zeros(Float32, 5)) + @test r(x) ≈ r(zeros(Float32, 5), x) # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h) + test_gradients(r, h, x) # The input sequence has no batch dimension. x = rand(Float32, 3) h = randn(Float32, 5) - test_gradients(r, x, h) + test_gradients(r, h, x) end @testset "GRUv3" begin @@ -256,7 +256,7 @@ end Flux.@layer :expand ModelGRUv3 - (m::ModelGRUv3)(x) = m.gru(x, m.h0) + (m::ModelGRUv3)(x) = m.gru(m.h0, x) model = ModelGRUv3(GRUv3(2 => 4), zeros(Float32, 4)) @@ -268,5 +268,5 @@ end # no initial state same as zero initial state gru = model.gru - @test gru(x) ≈ gru(x, zeros(Float32, 4)) + @test gru(x) ≈ gru(zeros(Float32, 4), x) end