Skip to content

Commit bfb948f

Browse files
committed
_Rec template
1 parent 7b510c8 commit bfb948f

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

nn/rec.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,40 @@
22
Basic RNNs.
33
"""
44

5+
from typing import Optional, Union, Dict, Tuple, Any
56
from .. import nn
6-
from ._generated_layers import _Rec
7+
8+
9+
class _Rec(nn.Module):
10+
def __init__(self, *, out_dim: nn.Dim, unit: str, unit_opts: Optional[Dict[str, Any]] = None):
11+
super().__init__()
12+
self.out_dim = out_dim
13+
self.unit = unit
14+
self.unit_opts = unit_opts
15+
16+
def __call__(self, source: nn.LayerRef, *,
17+
in_dim: Optional[nn.Dim] = None,
18+
axis: nn.Dim,
19+
state: Optional[Union[nn.LayerRef, Dict[str, nn.LayerRef], nn.NotSpecified]] = nn.NotSpecified,
20+
initial_state: Optional[Union[nn.LayerRef, Dict[str, nn.LayerRef], nn.NotSpecified]] = nn.NotSpecified,
21+
) -> Tuple[nn.Layer, nn.LayerState]:
22+
pass # TODO
723

824

925
class LSTM(_Rec):
1026
"""
1127
LSTM operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
1228
"""
13-
def __init__(self, out_dim: nn.Dim, **kwargs):
14-
super().__init__(out_dim=out_dim, unit="nativelstm2", **kwargs)
29+
def __init__(self, out_dim: nn.Dim):
30+
super().__init__(out_dim=out_dim, unit="nativelstm2")
1531

1632

1733
class ZoneoutLSTM(_Rec):
1834
"""
1935
LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
2036
"""
21-
def __init__(self, n_out: int, zoneout_factor_cell: float = 0., zoneout_factor_output: float = 0., **kwargs):
37+
def __init__(self, out_dim: nn.Dim, zoneout_factor_cell: float = 0., zoneout_factor_output: float = 0.):
2238
super().__init__(
23-
n_out=n_out, unit="zoneoutlstm",
39+
out_dim=out_dim, unit="zoneoutlstm",
2440
unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output},
25-
**kwargs)
41+
)

0 commit comments

Comments
 (0)