Skip to content

Commit f1dd15f

Browse files
committed
added zoneout to rec
1 parent dc1f302 commit f1dd15f

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

nn/rec.py

+30
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,33 @@ def make_layer_dict(
3535
self, source: nn.LayerRef, *, state: nn.LayerState) -> nn.LayerDictRaw:
3636
"""make layer"""
3737
return super().make_layer_dict(source, state=state)
38+
39+
40+
class ZoneoutLSTM(_Rec):
41+
"""
42+
LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
43+
"""
44+
def __init__(self, n_out: int, **kwargs):
45+
super().__init__(n_out=n_out, unit="zoneoutlstm", **kwargs)
46+
47+
# noinspection PyMethodOverriding
48+
def make_layer_dict(
49+
self, source: nn.LayerRef, *, initial_state: Optional[nn.LayerState] = None) -> nn.LayerDictRaw:
50+
"""make layer"""
51+
return super().make_layer_dict(source, initial_state=initial_state)
52+
53+
54+
class ZoneoutLSTMStep(_Rec):
55+
"""
56+
LSTM with zoneout operating one step. returns (output, state) tuple, where state is (h,c).
57+
"""
58+
default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
59+
60+
def __init__(self, n_out: int, **kwargs):
61+
super().__init__(n_out=n_out, unit="zoneoutlstm", **kwargs)
62+
63+
# noinspection PyMethodOverriding
64+
def make_layer_dict(
65+
self, source: nn.LayerRef, *, state: nn.LayerState) -> nn.LayerDictRaw:
66+
"""make layer"""
67+
return super().make_layer_dict(source, state=state)

0 commit comments

Comments
 (0)