Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding initialstates function to RNNs #2541

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@

@layer RNNCell

initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

Check warning on line 72 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L72

Added line #L72 was not covered by tests
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

function RNNCell(
(in, out)::Pair,
σ = tanh;
Expand All @@ -82,7 +84,10 @@
return RNNCell(σ, Wi, Wh, b)
end

(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
rnn(x, state)

Check warning on line 89 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L87-L89

Added lines #L87 - L89 were not covered by tests
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -261,6 +266,10 @@

@layer LSTMCell

function initialstates(lstm:: LSTMCell)
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))

Check warning on line 270 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L269-L270

Added lines #L269 - L270 were not covered by tests
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -274,10 +283,9 @@
return cell
end

function (m::LSTMCell)(x::AbstractVecOrMat)
h = zeros_like(x, size(m.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))

Check warning on line 288 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L286-L288

Added lines #L286 - L288 were not covered by tests
end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
Expand Down Expand Up @@ -447,6 +455,8 @@

@layer GRUCell

initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

Check warning on line 458 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L458

Added line #L458 was not covered by tests

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -459,7 +469,10 @@
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUCell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)

Check warning on line 474 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L472-L474

Added lines #L472 - L474 were not covered by tests
end

function (m::GRUCell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -603,6 +616,8 @@

@layer GRUv3Cell

initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))

Check warning on line 619 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L619

Added line #L619 was not covered by tests

function GRUv3Cell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -616,7 +631,10 @@
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)

Check warning on line 636 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L634-L636

Added lines #L634 - L636 were not covered by tests
end

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down
Loading