@@ -35,3 +35,33 @@ def make_layer_dict(
35
35
self , source : nn .LayerRef , * , state : nn .LayerState ) -> nn .LayerDictRaw :
36
36
"""make layer"""
37
37
return super ().make_layer_dict (source , state = state )
38
+
39
+
40
+ class ZoneoutLSTM (_Rec ):
41
+ """
42
+ LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
43
+ """
44
+ def __init__ (self , n_out : int , ** kwargs ):
45
+ super ().__init__ (n_out = n_out , unit = "zoneoutlstm" , ** kwargs )
46
+
47
+ # noinspection PyMethodOverriding
48
+ def make_layer_dict (
49
+ self , source : nn .LayerRef , * , initial_state : Optional [nn .LayerState ] = None ) -> nn .LayerDictRaw :
50
+ """make layer"""
51
+ return super ().make_layer_dict (source , initial_state = initial_state )
52
+
53
+
54
+ class ZoneoutLSTMStep (_Rec ):
55
+ """
56
+ LSTM with zoneout operating one step. returns (output, state) tuple, where state is (h,c).
57
+ """
58
+ default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
59
+
60
+ def __init__ (self , n_out : int , ** kwargs ):
61
+ super ().__init__ (n_out = n_out , unit = "zoneoutlstm" , ** kwargs )
62
+
63
+ # noinspection PyMethodOverriding
64
+ def make_layer_dict (
65
+ self , source : nn .LayerRef , * , state : nn .LayerState ) -> nn .LayerDictRaw :
66
+ """make layer"""
67
+ return super ().make_layer_dict (source , state = state )
0 commit comments