@@ -30,9 +30,9 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
30
30
"""
31
31
# https://github.com/rwth-i6/returnn_common/issues/16
32
32
with nn .Loop () as loop :
33
- x_ = loop .unstack (x , axis = "T" , declare_rec_time = True )
33
+ x_ = loop .unstack (x , axis = "T" , declare_rec_time = True ) # TODO how to get axis?
34
34
loop .state .h = nn .State (initial = 0 ) # TODO proper initial...
35
- loop .state .h = self .rec_linear (nn .concat ((x_ , "F" ), (loop .state .h , "F" )))
35
+ loop .state .h = self .rec_linear (nn .concat ((x_ , "F" ), (loop .state .h , self . rec_linear . out_dim ))) # TODO dim?
36
36
y = loop .stack (loop .state .h )
37
37
return y
38
38
@@ -59,7 +59,9 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
59
59
60
60
def test_lstm_default_name ():
61
61
assert_equal (nn .LSTM (nn .FeatureDim ("out" , 3 )).get_default_name (), "lstm" )
62
- assert_equal (nn .LSTMStep (nn .FeatureDim ("out" , 3 )).get_default_name (), "lstm" )
62
+ # no LSTMStep anymore, so nothing really to test here.
63
+ # https://github.com/rwth-i6/returnn_common/issues/81
64
+ # assert_equal(nn.LSTMStep(nn.FeatureDim("out", 3)).get_default_name(), "lstm")
63
65
64
66
65
67
def test_rec_inner_lstm ():
@@ -74,7 +76,7 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
74
76
Forward
75
77
"""
76
78
with nn .Loop () as loop :
77
- x_ = loop .unstack (x , axis = "T" , declare_rec_time = True )
79
+ x_ = loop .unstack (x , axis = "T" , declare_rec_time = True ) # TODO how to get axis?
78
80
loop .state .lstm = nn .State (initial = self .lstm .default_initial_state ())
79
81
y_ , loop .state .lstm = self .lstm (x_ , state = loop .state .lstm )
80
82
y = loop .stack (y_ )
@@ -98,7 +100,7 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
98
100
loop .state .i = nn .State (initial = 0. )
99
101
loop .state .i = loop .state .i + 1.
100
102
loop .end (loop .state .i >= 5. , include_eos = True )
101
- y = loop .stack (loop .state .i * nn .reduce (x , mode = "mean" , axis = "T" ))
103
+ y = loop .stack (loop .state .i * nn .reduce (x , mode = "mean" , axis = "T" )) # TODO axis
102
104
return y
103
105
104
106
net = _Net ()
@@ -118,9 +120,9 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
118
120
"""
119
121
Forward
120
122
"""
121
- y , state = self .lstm (x )
122
- y_ = nn .reduce ( y , mode = "mean" , axis = "T" ) # TODO just because concat allow_broadcast=True does not work yet...
123
- res = nn . concat (( y_ , "F" ), (state .h , "F" ), (state .c , "F" ) )
123
+ y , state = self .lstm (x ) # TODO axis
124
+ res = nn .concat (
125
+ ( y , self . lstm . out_dim ), (state .h , self . lstm . out_dim ), (state .c , self . lstm . out_dim ), allow_broadcast = True )
124
126
return res
125
127
126
128
net = _Net ()
@@ -144,7 +146,7 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
144
146
y = self .linear (x )
145
147
state = None
146
148
for _ in range_ (3 ):
147
- y , state = self .lstm (y , initial_state = state )
149
+ y , state = self .lstm (y , initial_state = state ) # TODO axis?
148
150
return y
149
151
150
152
net = _Net ()
0 commit comments