|
13 | 13 | you only care about some encoded vector of type :class:`Tensor`.
|
14 | 14 | """
|
15 | 15 |
|
| 16 | +from typing import Tuple |
| 17 | +from abc import ABC |
16 | 18 | from ... import nn
|
17 | 19 |
|
18 | 20 |
|
19 |
| -class IEncoder(nn.Module): |
| 21 | +class IEncoder(nn.Module, ABC): |
20 | 22 | """
|
21 | 23 | Generic encoder interface
|
| 24 | +
|
| 25 | + The encoder is a function x -> y. |
| 26 | + The input can potentially be sparse or dense. |
| 27 | + The output is dense with feature dim `out_dim`. |
22 | 28 | """
|
23 | 29 |
|
| 30 | + out_dim: nn.Dim |
| 31 | + |
24 | 32 | @nn.scoped
|
25 | 33 | def __call__(self, source: nn.Tensor) -> nn.Tensor:
|
26 | 34 | """
|
27 | 35 | Encode the input
|
28 | 36 | """
|
29 | 37 | raise NotImplementedError
|
| 38 | + |
| 39 | + |
| 40 | +class ISeqFramewiseEncoder(nn.Module, ABC): |
| 41 | + """ |
| 42 | + This specializes IEncoder that it operates on a sequence. |
| 43 | + The output sequence length here is the same as the input. |
| 44 | + """ |
| 45 | + |
| 46 | + out_dim: nn.Dim |
| 47 | + |
| 48 | + @nn.scoped |
| 49 | + def __call__(self, source: nn.Tensor, *, spatial_dim: nn.Dim) -> nn.Tensor: |
| 50 | + raise NotImplementedError |
| 51 | + |
| 52 | + |
| 53 | +class ISeqDownsamplingEncoder(nn.Module, ABC): |
| 54 | + """ |
| 55 | + This is more specific than IEncoder in that it operates on a sequence. |
| 56 | + The output sequence length here is shorter than the input. |
| 57 | +
|
| 58 | + This is a common scenario for speech recognition |
| 59 | + where the input might be on 10ms/frame |
| 60 | + and the output might cover 30ms/frame or 60ms/frame or so. |
| 61 | + """ |
| 62 | + |
| 63 | + out_dim: nn.Dim |
| 64 | + |
| 65 | + @nn.scoped |
| 66 | + def __call__(self, source: nn.Tensor, *, in_spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]: |
| 67 | + raise NotImplementedError |
0 commit comments