@@ -37,3 +37,37 @@ def make_layer_dict(
3737 # TODO specify per-step, how? this should also work without rec loop, when there is no time dim.
3838 # https://github.com/rwth-i6/returnn/issues/847
3939 return super ().make_layer_dict (source , state = state , axis = None )
40+
41+
42+ class ZoneoutLSTM (_Rec ):
43+ """
44+ LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
45+ """
46+ def __init__ (self , n_out : int , zoneout_factor_cell : int = 0. , zoneout_factor_output : int = 0. , ** kwargs ):
47+ super ().__init__ (
48+ n_out = n_out , unit = "zoneoutlstm" ,
49+ unit_opts = {'zoneout_factor_cell' : zoneout_factor_cell , 'zoneout_factor_output' : zoneout_factor_output }, ** kwargs )
50+
51+ # noinspection PyMethodOverriding
52+ def make_layer_dict (
53+ self , source : nn .LayerRef , * , initial_state : Optional [nn .LayerState ] = None ) -> nn .LayerDictRaw :
54+ """make layer"""
55+ return super ().make_layer_dict (source , initial_state = initial_state )
56+
57+
58+ class ZoneoutLSTMStep (_Rec ):
59+ """
60+ LSTM with zoneout operating one step. returns (output, state) tuple, where state is (h,c).
61+ """
62+ default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
63+
64+ def __init__ (self , n_out : int , zoneout_factor_cell : int = 0. , zoneout_factor_output : int = 0. , ** kwargs ):
65+ super ().__init__ (
66+ n_out = n_out , unit = "zoneoutlstm" ,
67+ unit_opts = {'zoneout_factor_cell' : zoneout_factor_cell , 'zoneout_factor_output' : zoneout_factor_output }, ** kwargs )
68+
69+ # noinspection PyMethodOverriding
70+ def make_layer_dict (
71+ self , source : nn .LayerRef , * , state : nn .LayerState ) -> nn .LayerDictRaw :
72+ """make layer"""
73+ return super ().make_layer_dict (source , state = state )
0 commit comments