@@ -37,3 +37,37 @@ def make_layer_dict(
37
37
# TODO specify per-step, how? this should also work without rec loop, when there is no time dim.
38
38
# https://github.com/rwth-i6/returnn/issues/847
39
39
return super ().make_layer_dict (source , state = state , axis = None )
40
+
41
+
42
+ class ZoneoutLSTM (_Rec ):
43
+ """
44
+ LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
45
+ """
46
+ def __init__ (self , n_out : int , zoneout_factor_cell : int = 0. , zoneout_factor_output : int = 0. , ** kwargs ):
47
+ super ().__init__ (
48
+ n_out = n_out , unit = "zoneoutlstm" ,
49
+ unit_opts = {'zoneout_factor_cell' : zoneout_factor_cell , 'zoneout_factor_output' : zoneout_factor_output }, ** kwargs )
50
+
51
+ # noinspection PyMethodOverriding
52
+ def make_layer_dict (
53
+ self , source : nn .LayerRef , * , initial_state : Optional [nn .LayerState ] = None ) -> nn .LayerDictRaw :
54
+ """make layer"""
55
+ return super ().make_layer_dict (source , initial_state = initial_state )
56
+
57
+
58
+ class ZoneoutLSTMStep (_Rec ):
59
+ """
60
+ LSTM with zoneout operating one step. returns (output, state) tuple, where state is (h,c).
61
+ """
62
+ default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
63
+
64
+ def __init__ (self , n_out : int , zoneout_factor_cell : int = 0. , zoneout_factor_output : int = 0. , ** kwargs ):
65
+ super ().__init__ (
66
+ n_out = n_out , unit = "zoneoutlstm" ,
67
+ unit_opts = {'zoneout_factor_cell' : zoneout_factor_cell , 'zoneout_factor_output' : zoneout_factor_output }, ** kwargs )
68
+
69
+ # noinspection PyMethodOverriding
70
+ def make_layer_dict (
71
+ self , source : nn .LayerRef , * , state : nn .LayerState ) -> nn .LayerDictRaw :
72
+ """make layer"""
73
+ return super ().make_layer_dict (source , state = state )
0 commit comments