Skip to content

Commit 6d26350

Browse files
committed
fix, todo
1 parent 0c2504a commit 6d26350

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

nn/rec.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ class LSTM(_Rec):
1111
"""
1212
LSTM operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
1313
"""
14-
def __init__(self, n_out: int, **kwargs):
15-
super().__init__(n_out=n_out, unit="nativelstm2", **kwargs)
14+
def __init__(self, out_dim: nn.Dim, **kwargs):
15+
super().__init__(out_dim=out_dim, unit="nativelstm2", **kwargs)
1616

1717
# noinspection PyMethodOverriding
1818
def make_layer_dict(
19-
self, source: nn.LayerRef, *, initial_state: Optional[nn.LayerState] = None) -> nn.LayerDictRaw:
19+
self, source: nn.LayerRef, *, axis: nn.Dim, initial_state: Optional[nn.LayerState] = None) -> nn.LayerDictRaw:
2020
"""make layer"""
21-
return super().make_layer_dict(source, initial_state=initial_state)
21+
return super().make_layer_dict(source, axis=axis, initial_state=initial_state)
2222

2323

2424
class LSTMStep(_Rec):
@@ -27,11 +27,12 @@ class LSTMStep(_Rec):
2727
"""
2828
default_name = "lstm" # make consistent to LSTM
2929

30-
def __init__(self, n_out: int, **kwargs):
31-
super().__init__(n_out=n_out, unit="nativelstm2", **kwargs)
30+
def __init__(self, out_dim: nn.Dim, **kwargs):
31+
super().__init__(out_dim=out_dim, unit="nativelstm2", **kwargs)
3232

3333
# noinspection PyMethodOverriding
3434
def make_layer_dict(
3535
self, source: nn.LayerRef, *, state: nn.LayerState) -> nn.LayerDictRaw:
3636
"""make layer"""
37-
return super().make_layer_dict(source, state=state)
37+
# TODO specify per-step, how? this should also work without rec loop, when there is no time dim.
38+
return super().make_layer_dict(source, state=state, axis=None)

0 commit comments

Comments
 (0)