Skip to content

Commit 69cd23c

Browse files
committed
encoder interfaces
1 parent 83286fd commit 69cd23c

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

nn/encoder/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
"""
22
Code to create encoders (for hybrid, the encoder of encoder-decoder-attention, or also transducer).
33
"""
4+
5+
from .base import *

nn/encoder/base.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,55 @@
1313
you only care about some encoded vector of type :class:`Tensor`.
1414
"""
1515

16+
from typing import Tuple
17+
from abc import ABC
1618
from ... import nn
1719

1820

19-
class IEncoder(nn.Module):
21+
class IEncoder(nn.Module, ABC):
2022
"""
2123
Generic encoder interface
24+
25+
The encoder is a function x -> y.
26+
The input can potentially be sparse or dense.
27+
The output is dense with feature dim `out_dim`.
2228
"""
2329

30+
out_dim: nn.Dim
31+
2432
@nn.scoped
2533
def __call__(self, source: nn.Tensor) -> nn.Tensor:
2634
"""
2735
Encode the input
2836
"""
2937
raise NotImplementedError
38+
39+
40+
class ISeqFramewiseEncoder(nn.Module, ABC):
41+
"""
42+
This specializes IEncoder that it operates on a sequence.
43+
The output sequence length here is the same as the input.
44+
"""
45+
46+
out_dim: nn.Dim
47+
48+
@nn.scoped
49+
def __call__(self, source: nn.Tensor, *, spatial_dim: nn.Dim) -> nn.Tensor:
50+
raise NotImplementedError
51+
52+
53+
class ISeqDownsamplingEncoder(nn.Module, ABC):
54+
"""
55+
This is more specific than IEncoder in that it operates on a sequence.
56+
The output sequence length here is shorter than the input.
57+
58+
This is a common scenario for speech recognition
59+
where the input might be on 10ms/frame
60+
and the output might cover 30ms/frame or 60ms/frame or so.
61+
"""
62+
63+
out_dim: nn.Dim
64+
65+
@nn.scoped
66+
def __call__(self, source: nn.Tensor, *, in_spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
67+
raise NotImplementedError

0 commit comments

Comments
 (0)