Skip to content

Commit 28e4921

Browse files
authored
added zoneout to rec.py (#86)
1 parent 0a83cdd commit 28e4921

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

nn/rec.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,37 @@ def make_layer_dict(
3737
# TODO specify per-step, how? this should also work without rec loop, when there is no time dim.
3838
# https://github.com/rwth-i6/returnn/issues/847
3939
return super().make_layer_dict(source, state=state, axis=None)
40+
41+
42+
class ZoneoutLSTM(_Rec):
43+
"""
44+
LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
45+
"""
46+
def __init__(self, n_out: int, zoneout_factor_cell: int = 0., zoneout_factor_output: int = 0., **kwargs):
47+
super().__init__(
48+
n_out=n_out, unit="zoneoutlstm",
49+
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output}, **kwargs)
50+
51+
# noinspection PyMethodOverriding
52+
def make_layer_dict(
53+
self, source: nn.LayerRef, *, initial_state: Optional[nn.LayerState] = None) -> nn.LayerDictRaw:
54+
"""make layer"""
55+
return super().make_layer_dict(source, initial_state=initial_state)
56+
57+
58+
class ZoneoutLSTMStep(_Rec):
59+
"""
60+
LSTM with zoneout operating one step. returns (output, state) tuple, where state is (h,c).
61+
"""
62+
default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
63+
64+
def __init__(self, n_out: int, zoneout_factor_cell: int = 0., zoneout_factor_output: int = 0., **kwargs):
65+
super().__init__(
66+
n_out=n_out, unit="zoneoutlstm",
67+
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output}, **kwargs)
68+
69+
# noinspection PyMethodOverriding
70+
def make_layer_dict(
71+
self, source: nn.LayerRef, *, state: nn.LayerState) -> nn.LayerDictRaw:
72+
"""make layer"""
73+
return super().make_layer_dict(source, state=state)

0 commit comments

Comments
 (0)