Skip to content

Commit d265d2b

Browse files
committed
added explicit parameters
1 parent f1dd15f commit d265d2b

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

nn/rec.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ class ZoneoutLSTM(_Rec):
4141
"""
4242
LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
4343
"""
44-
def __init__(self, n_out: int, **kwargs):
45-
super().__init__(n_out=n_out, unit="zoneoutlstm", **kwargs)
44+
def __init__(self, n_out: int, zoneout_factor_cell: int = 0., zoneout_factor_output: int = 0., **kwargs):
45+
super().__init__(
46+
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output},
47+
n_out=n_out, unit="zoneoutlstm", **kwargs)
4648

4749
# noinspection PyMethodOverriding
4850
def make_layer_dict(
@@ -57,8 +59,10 @@ class ZoneoutLSTMStep(_Rec):
5759
"""
5860
default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
5961

60-
def __init__(self, n_out: int, **kwargs):
61-
super().__init__(n_out=n_out, unit="zoneoutlstm", **kwargs)
62+
def __init__(self, n_out: int, zoneout_factor_cell: int = 0., zoneout_factor_output: int = 0., **kwargs):
63+
super().__init__(
64+
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output},
65+
n_out=n_out, unit="zoneoutlstm", **kwargs)
6266

6367
# noinspection PyMethodOverriding
6468
def make_layer_dict(

0 commit comments

Comments
 (0)