@@ -11,14 +11,14 @@ class LSTM(_Rec):
11
11
"""
12
12
LSTM operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
13
13
"""
14
- def __init__ (self , n_out : int , ** kwargs ):
15
- super ().__init__ (n_out = n_out , unit = "nativelstm2" , ** kwargs )
14
+ def __init__ (self , out_dim : nn . Dim , ** kwargs ):
15
+ super ().__init__ (out_dim = out_dim , unit = "nativelstm2" , ** kwargs )
16
16
17
17
# noinspection PyMethodOverriding
18
18
def make_layer_dict (
19
- self , source : nn .LayerRef , * , initial_state : Optional [nn .LayerState ] = None ) -> nn .LayerDictRaw :
19
+ self , source : nn .LayerRef , * , axis : nn . Dim , initial_state : Optional [nn .LayerState ] = None ) -> nn .LayerDictRaw :
20
20
"""make layer"""
21
- return super ().make_layer_dict (source , initial_state = initial_state )
21
+ return super ().make_layer_dict (source , axis = axis , initial_state = initial_state )
22
22
23
23
24
24
class LSTMStep (_Rec ):
@@ -27,11 +27,12 @@ class LSTMStep(_Rec):
27
27
"""
28
28
default_name = "lstm" # make consistent to LSTM
29
29
30
- def __init__ (self , n_out : int , ** kwargs ):
31
- super ().__init__ (n_out = n_out , unit = "nativelstm2" , ** kwargs )
30
+ def __init__ (self , out_dim : nn . Dim , ** kwargs ):
31
+ super ().__init__ (out_dim = out_dim , unit = "nativelstm2" , ** kwargs )
32
32
33
33
# noinspection PyMethodOverriding
34
34
def make_layer_dict (
35
35
self , source : nn .LayerRef , * , state : nn .LayerState ) -> nn .LayerDictRaw :
36
36
"""make layer"""
37
- return super ().make_layer_dict (source , state = state )
37
+ # TODO specify per-step, how? this should also work without rec loop, when there is no time dim.
38
+ return super ().make_layer_dict (source , state = state , axis = None )
0 commit comments