Skip to content

Commit 438d0e4

Browse files
committed
cleanup
No rec ...Step anymore. Fix #81.
1 parent de03751 commit 438d0e4

File tree

2 files changed

+14
-52
lines changed

2 files changed

+14
-52
lines changed

nn/rec.py

+3-43
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Basic RNNs.
33
"""
44

5-
from typing import Optional
65
from .. import nn
76
from ._generated_layers import _Rec
87

@@ -14,52 +13,13 @@ class LSTM(_Rec):
1413
def __init__(self, out_dim: nn.Dim, **kwargs):
1514
super().__init__(out_dim=out_dim, unit="nativelstm2", **kwargs)
1615

17-
def __call__(
18-
self, source: nn.LayerRef, *, axis: nn.Dim, initial_state: Optional[nn.LayerState] = None) -> nn.Layer:
19-
"""make layer"""
20-
return super()(source, axis=axis, initial_state=initial_state)
21-
22-
23-
class LSTMStep(_Rec):
24-
"""
25-
LSTM operating one step. returns (output, state) tuple, where state is (h,c).
26-
"""
27-
default_name = "lstm" # make consistent to LSTM
28-
29-
def __init__(self, out_dim: nn.Dim, **kwargs):
30-
super().__init__(out_dim=out_dim, unit="nativelstm2", **kwargs)
31-
32-
def __call__(self, source: nn.LayerRef, *, state: nn.LayerState) -> nn.Layer:
33-
"""make layer"""
34-
return super()(source, state=state, axis=nn.single_step_dim)
35-
3616

3717
class ZoneoutLSTM(_Rec):
3818
"""
3919
LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
4020
"""
41-
def __init__(self, n_out: int, zoneout_factor_cell: int = 0., zoneout_factor_output: int = 0., **kwargs):
21+
def __init__(self, n_out: int, zoneout_factor_cell: float = 0., zoneout_factor_output: float = 0., **kwargs):
4222
super().__init__(
4323
n_out=n_out, unit="zoneoutlstm",
44-
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output}, **kwargs)
45-
46-
def __call__(
47-
self, source: nn.LayerRef, *, axis: nn.Dim, initial_state: Optional[nn.LayerState] = None) -> nn.Layer:
48-
"""make layer"""
49-
return super()(source, axis=axis, initial_state=initial_state)
50-
51-
52-
class ZoneoutLSTMStep(_Rec):
53-
"""
54-
LSTM with zoneout operating one step. returns (output, state) tuple, where state is (h,c).
55-
"""
56-
default_name = "zoneoutlstm" # make consistent to ZoneoutLSTM
57-
58-
def __init__(self, n_out: int, zoneout_factor_cell: int = 0., zoneout_factor_output: int = 0., **kwargs):
59-
super().__init__(
60-
n_out=n_out, unit="zoneoutlstm",
61-
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output}, **kwargs)
62-
63-
def __call__(self, source: nn.LayerRef, *, state: nn.LayerState) -> nn.Layer:
64-
"""make layer"""
65-
return super()(source, state=state, axis=nn.single_step_dim)
24+
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output},
25+
**kwargs)

tests/test_models_rec.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
3030
"""
3131
# https://github.com/rwth-i6/returnn_common/issues/16
3232
with nn.Loop() as loop:
33-
x_ = loop.unstack(x, axis="T", declare_rec_time=True)
33+
x_ = loop.unstack(x, axis="T", declare_rec_time=True) # TODO how to get axis?
3434
loop.state.h = nn.State(initial=0) # TODO proper initial...
35-
loop.state.h = self.rec_linear(nn.concat((x_, "F"), (loop.state.h, "F")))
35+
loop.state.h = self.rec_linear(nn.concat((x_, "F"), (loop.state.h, self.rec_linear.out_dim))) # TODO dim?
3636
y = loop.stack(loop.state.h)
3737
return y
3838

@@ -59,7 +59,9 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
5959

6060
def test_lstm_default_name():
6161
assert_equal(nn.LSTM(nn.FeatureDim("out", 3)).get_default_name(), "lstm")
62-
assert_equal(nn.LSTMStep(nn.FeatureDim("out", 3)).get_default_name(), "lstm")
62+
# no LSTMStep anymore, so nothing really to test here.
63+
# https://github.com/rwth-i6/returnn_common/issues/81
64+
# assert_equal(nn.LSTMStep(nn.FeatureDim("out", 3)).get_default_name(), "lstm")
6365

6466

6567
def test_rec_inner_lstm():
@@ -74,7 +76,7 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
7476
Forward
7577
"""
7678
with nn.Loop() as loop:
77-
x_ = loop.unstack(x, axis="T", declare_rec_time=True)
79+
x_ = loop.unstack(x, axis="T", declare_rec_time=True) # TODO how to get axis?
7880
loop.state.lstm = nn.State(initial=self.lstm.default_initial_state())
7981
y_, loop.state.lstm = self.lstm(x_, state=loop.state.lstm)
8082
y = loop.stack(y_)
@@ -98,7 +100,7 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
98100
loop.state.i = nn.State(initial=0.)
99101
loop.state.i = loop.state.i + 1.
100102
loop.end(loop.state.i >= 5., include_eos=True)
101-
y = loop.stack(loop.state.i * nn.reduce(x, mode="mean", axis="T"))
103+
y = loop.stack(loop.state.i * nn.reduce(x, mode="mean", axis="T")) # TODO axis
102104
return y
103105

104106
net = _Net()
@@ -118,9 +120,9 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
118120
"""
119121
Forward
120122
"""
121-
y, state = self.lstm(x)
122-
y_ = nn.reduce(y, mode="mean", axis="T") # TODO just because concat allow_broadcast=True does not work yet...
123-
res = nn.concat((y_, "F"), (state.h, "F"), (state.c, "F"))
123+
y, state = self.lstm(x) # TODO axis
124+
res = nn.concat(
125+
(y, self.lstm.out_dim), (state.h, self.lstm.out_dim), (state.c, self.lstm.out_dim), allow_broadcast=True)
124126
return res
125127

126128
net = _Net()
@@ -144,7 +146,7 @@ def __call__(self, x: nn.LayerRef) -> nn.LayerRef:
144146
y = self.linear(x)
145147
state = None
146148
for _ in range_(3):
147-
y, state = self.lstm(y, initial_state=state)
149+
y, state = self.lstm(y, initial_state=state) # TODO axis?
148150
return y
149151

150152
net = _Net()

0 commit comments

Comments
 (0)