|
2 | 2 | Basic RNNs.
|
3 | 3 | """
|
4 | 4 |
|
| 5 | +from typing import Optional, Union, Dict, Tuple, Any |
5 | 6 | from .. import nn
|
6 |
| -from ._generated_layers import _Rec |
| 7 | + |
| 8 | + |
| 9 | +class _Rec(nn.Module): |
| 10 | + def __init__(self, *, out_dim: nn.Dim, unit: str, unit_opts: Optional[Dict[str, Any]] = None): |
| 11 | + super().__init__() |
| 12 | + self.out_dim = out_dim |
| 13 | + self.unit = unit |
| 14 | + self.unit_opts = unit_opts |
| 15 | + |
| 16 | + def __call__(self, source: nn.LayerRef, *, |
| 17 | + in_dim: Optional[nn.Dim] = None, |
| 18 | + axis: nn.Dim, |
| 19 | + state: Optional[Union[nn.LayerRef, Dict[str, nn.LayerRef], nn.NotSpecified]] = nn.NotSpecified, |
| 20 | + initial_state: Optional[Union[nn.LayerRef, Dict[str, nn.LayerRef], nn.NotSpecified]] = nn.NotSpecified, |
| 21 | + ) -> Tuple[nn.Layer, nn.LayerState]: |
| 22 | + pass # TODO |
7 | 23 |
|
8 | 24 |
|
9 | 25 | class LSTM(_Rec):
|
10 | 26 | """
|
11 | 27 | LSTM operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
|
12 | 28 | """
|
13 |
| - def __init__(self, out_dim: nn.Dim, **kwargs): |
14 |
| - super().__init__(out_dim=out_dim, unit="nativelstm2", **kwargs) |
| 29 | + def __init__(self, out_dim: nn.Dim): |
| 30 | + super().__init__(out_dim=out_dim, unit="nativelstm2") |
15 | 31 |
|
16 | 32 |
|
17 | 33 | class ZoneoutLSTM(_Rec):
|
18 | 34 | """
|
19 | 35 | LSTM with zoneout operating on a sequence. returns (output, final_state) tuple, where final_state is (h,c).
|
20 | 36 | """
|
21 |
| - def __init__(self, n_out: int, zoneout_factor_cell: float = 0., zoneout_factor_output: float = 0., **kwargs): |
| 37 | + def __init__(self, out_dim: nn.Dim, zoneout_factor_cell: float = 0., zoneout_factor_output: float = 0.): |
22 | 38 | super().__init__(
|
23 |
| - n_out=n_out, unit="zoneoutlstm", |
| 39 | + out_dim=out_dim, unit="zoneoutlstm", |
24 | 40 | unit_opts={'zoneout_factor_cell': zoneout_factor_cell, 'zoneout_factor_output': zoneout_factor_output},
|
25 |
| - **kwargs) |
| 41 | + ) |
0 commit comments