@@ -41,8 +41,10 @@ class ZoneoutLSTM(_Rec):
41
41
"""
42
42
LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
43
43
"""
44
- def __init__ (self , n_out : int , ** kwargs ):
45
- super ().__init__ (n_out = n_out , unit = "zoneoutlstm" , ** kwargs )
44
+ def __init__ (self , n_out : int , zoneout_factor_cell : int = 0. , zoneout_factor_output : int = 0. , ** kwargs ):
45
+ super ().__init__ (
46
+ unit_opts = {'zoneout_factor_cell' : zoneout_factor_cell , 'zoneout_factor_output' : zoneout_factor_output },
47
+ n_out = n_out , unit = "zoneoutlstm" , ** kwargs )
46
48
47
49
# noinspection PyMethodOverriding
48
50
def make_layer_dict (
@@ -57,8 +59,10 @@ class ZoneoutLSTMStep(_Rec):
57
59
"""
58
60
default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
59
61
60
- def __init__ (self , n_out : int , ** kwargs ):
61
- super ().__init__ (n_out = n_out , unit = "zoneoutlstm" , ** kwargs )
62
+ def __init__ (self , n_out : int , zoneout_factor_cell : int = 0. , zoneout_factor_output : int = 0. , ** kwargs ):
63
+ super ().__init__ (
64
+ unit_opts = {'zoneout_factor_cell' : zoneout_factor_cell , 'zoneout_factor_output' : zoneout_factor_output },
65
+ n_out = n_out , unit = "zoneoutlstm" , ** kwargs )
62
66
63
67
# noinspection PyMethodOverriding
64
68
def make_layer_dict (
0 commit comments