Skip to content

Commit ca493ba

Browse files
committed
decoder interface more
#49
1 parent 3a498ee commit ca493ba

File tree

1 file changed

+169
-16
lines changed

1 file changed

+169
-16
lines changed

nn/decoder/base.py

+169-16
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
from __future__ import annotations
19-
from typing import Union, Tuple
19+
from typing import Union, Optional, Tuple
2020
from enum import Enum
2121
import dataclasses
2222
from ... import nn
@@ -36,14 +36,14 @@ class Decoder(nn.Module):
3636
"""
3737
Generic decoder, for attention-based encoder-decoder or transducer.
3838
Can use label-sync label topology, or time-sync (RNA/CTC), or with vertical transitions (RNN-T).
39-
The label emitted in the current step is referred to as alignment-label (or step-label),
39+
The label emitted in the current (align) step is referred to as alignment-label (or step-label),
4040
and can include blank in case this is not label-sync.
4141
4242
None of this is really enforced here, and what mainly defines the interfaces
4343
is the dependency graph.
4444
The returned shapes and time axes could be anything,
4545
as long as it fits together.
46-
The step_sync_rnn could also return a 4D tensor with both time-axis and label-axis.
46+
The predictor could also return a 4D tensor with both time-axis and label-axis.
4747
4848
Dependency graph:
4949
@@ -59,19 +59,129 @@ class Decoder(nn.Module):
5959

6060
def __init__(self, *,
6161
label_topology: LabelTopology,
62-
label_sync_rnn: TDecoderLabelSync,
63-
joint_net_log_prob: TDecoderJointNetLogProb,
62+
label_predict_enc: Optional[TDecoderLabelSync],
63+
predictor: TDecoderJointNetLogProb,
64+
target_dim: nn.Dim,
65+
target_bos_symbol: int = 0,
66+
target_eos_symbol: int = 0,
6467
):
6568
super().__init__()
6669
self.label_topology = label_topology
67-
self.label_sync_rnn = label_sync_rnn # earlier: slow_rnn
68-
self.joint_net_log_prob = joint_net_log_prob # earlier: fast_rnn + readout
70+
self.label_predict_enc = label_predict_enc # earlier: slow_rnn. label-sync. incl (nb) label embedding
71+
self.predictor = predictor # earlier: fast_rnn + readout. align-sync or matrix time * label. predicts align label
72+
self.target_dim = target_dim # includes blank if not label-sync
73+
self.target_bos_symbol = target_bos_symbol
74+
self.target_eos_symbol = target_eos_symbol
6975

70-
def __call__(self, encoder: nn.Tensor) -> nn.Tensor:
76+
def __call__(self, *,
77+
encoder: nn.Tensor,
78+
encoder_spatial_axis: nn.Dim,
79+
target: Optional[Union[nn.Tensor, nn.SearchFuncInterface]] = None,
80+
axis: Optional[nn.Dim] = None,
81+
state: Optional[nn.LayerState] = None,
82+
) -> Tuple[nn.Tensor, nn.LayerState]:
7183
"""
7284
Make one decoder step (train and/or recognition).
7385
"""
7486
# TODO ...
87+
search = None
88+
if isinstance(target, nn.SearchFuncInterface):
89+
search = target
90+
target = None
91+
if target is not None:
92+
assert axis, f"{self}: Target spatial axis must be specified when target is given"
93+
loop = nn.Loop(axis=axis)
94+
loop.state = state if state else self.default_initial_state()
95+
with loop:
96+
97+
if self.label_predict_enc is None:
98+
label_predict_enc = None
99+
elif isinstance(self.label_predict_enc, IDecoderLabelSyncRnn):
100+
label_predict_enc, loop.state.label_predict_enc = self.label_predict_enc(
101+
prev_label=loop.state.label_nb,
102+
encoder_seq=encoder,
103+
state=loop.state.label_predict_enc)
104+
elif isinstance(self.label_predict_enc, IDecoderLabelSyncLabelsOnlyRnn):
105+
label_predict_enc, loop.state.label_predict_enc = self.label_predict_enc(
106+
prev_label=loop.state.label_nb,
107+
state=loop.state.label_predict_enc)
108+
elif isinstance(self.label_predict_enc, IDecoderLabelSyncAlignDepRnn):
109+
encoder_frame = ... # TODO align dep. or unstack if time-sync
110+
label_predict_enc, loop.state.label_predict_enc = self.label_predict_enc(
111+
prev_label=loop.state.label_nb,
112+
encoder_seq=encoder,
113+
encoder_frame=encoder_frame,
114+
state=loop.state.label_predict_enc)
115+
else:
116+
raise TypeError(f"{self}: Unsupported label_predict_enc type {type(self.label_predict_enc)}")
117+
118+
if isinstance(self.predictor, IDecoderLabelSyncLogits):
119+
assert self.label_topology == LabelTopology.LABEL_SYNC, f"{self}: Label topology must be label-sync"
120+
assert label_predict_enc is not None, f"{self}: Label predict encoder must be specified"
121+
probs = self.predictor(label_sync_in=label_predict_enc)
122+
probs_type = "logits"
123+
elif isinstance(self.predictor, IDecoderJointNoStateLogProb):
124+
assert self.label_topology != LabelTopology.LABEL_SYNC, f"{self}: Label topology must not be label-sync"
125+
assert label_predict_enc is not None, f"{self}: Label predict encoder must be specified"
126+
encoder_frame = ... # TODO share with above
127+
predictor_out = self.predictor(time_sync_in=encoder_frame, label_sync_in=label_predict_enc)
128+
probs = predictor_out.prob_like_wb
129+
probs_type = predictor_out.prob_like_type
130+
elif isinstance(self.predictor, IDecoderJointAlignStateLogProb):
131+
assert self.label_topology != LabelTopology.LABEL_SYNC, f"{self}: Label topology must not be label-sync"
132+
assert label_predict_enc is not None, f"{self}: Label predict encoder must be specified"
133+
encoder_frame = ... # TODO share with above
134+
predictor_out, loop.state.predictor = self.predictor(
135+
time_sync_in=encoder_frame,
136+
label_sync_in=label_predict_enc,
137+
prev_align_label=loop.state.label_wb,
138+
state=loop.state.predictor)
139+
probs = predictor_out.prob_like_wb
140+
probs_type = predictor_out.prob_like_type
141+
elif isinstance(self.predictor, IDecoderJointNoCtxLogProb):
142+
assert self.label_topology != LabelTopology.LABEL_SYNC, f"{self}: Label topology must not be label-sync"
143+
assert label_predict_enc is None, f"{self}: Label predict encoder not used"
144+
encoder_frame = ... # TODO share with above
145+
predictor_out = self.predictor(time_sync_in=encoder_frame)
146+
probs = predictor_out.prob_like_wb
147+
probs_type = predictor_out.prob_like_type
148+
elif isinstance(self.predictor, IDecoderAlignStateLogProb):
149+
assert self.label_topology != LabelTopology.LABEL_SYNC, f"{self}: Label topology must not be label-sync"
150+
assert label_predict_enc is None, f"{self}: Label predict encoder not used"
151+
encoder_frame = ... # TODO share with above
152+
predictor_out, loop.state.predictor = self.predictor(
153+
time_sync_in=encoder_frame,
154+
prev_align_label=loop.state.label_wb,
155+
state=loop.state.predictor)
156+
probs = predictor_out.prob_like_wb
157+
probs_type = predictor_out.prob_like_type
158+
else:
159+
raise TypeError(f"{self}: Unsupported predictor type {type(self.predictor)}")
160+
161+
# TODO loss handling here? in that case, cleverly do the most efficient?
162+
# TODO logits instead of log probs?
163+
# TODO see below, related is whether and we output
164+
165+
target = loop.unstack(target) if target is not None else None
166+
if search:
167+
search.apply_loop(loop)
168+
align_label = search.choice(probs=probs, probs_type=probs_type)
169+
else:
170+
assert target is not None
171+
align_label = target
172+
if self.label_topology == LabelTopology.LABEL_SYNC:
173+
loop.state.label_nb = align_label
174+
loop.end(loop.state.label_nb == self.target_eos_symbol, include_eos=False)
175+
else:
176+
loop.state.label_wb = align_label
177+
178+
out_labels = loop.stack(align_label) if target is None else None
179+
# TODO? out_logits = loop.stack(logits) # TODO not necessarily logits...
180+
181+
return out_labels, loop.state
182+
183+
def default_initial_state(self) -> Optional[nn.LayerState]:
184+
"""default init state"""
75185

76186

77187
# TODO enc ctx module
@@ -90,6 +200,20 @@ def blank_idx(self) -> int:
90200
"""
91201
raise NotImplementedError
92202

203+
@property
204+
def prob_like_wb(self) -> nn.Tensor:
205+
"""
206+
:return: logits if possible, else log probs. see prob_like_type
207+
"""
208+
return self.log_prob_wb
209+
210+
@property
211+
def prob_like_type(self) -> str:
212+
"""
213+
:return: type of prob_like_wb. "logits" or "log_prob"
214+
"""
215+
return "log_prob"
216+
93217
@property
94218
def log_prob_wb(self) -> nn.Tensor:
95219
"""
@@ -152,12 +276,12 @@ class DecoderJointLogProbSeparatedOutput(IDecoderJointLogProbOutput):
152276
log_prob_not_blank: nn.Tensor # log(-expm1(log_prob_blank)) but you maybe could calc it more directly
153277

154278

155-
class IDecoderLabelSyncLogProb(nn.Module):
279+
class IDecoderLabelSyncLogits(nn.Module):
156280
"""
157281
For simple (maybe attention-based) encoder-decoder models,
158282
getting input from some label-sync encoding (TDecoderLabelSync).
159283
160-
This will produce log probs for non-blank labels.
284+
This will produce logits (non-normalized log probs) for non-blank labels.
161285
There is no blank in this concept.
162286
"""
163287

@@ -167,7 +291,7 @@ def __call__(self, *, label_sync_in: nn.Tensor) -> nn.Tensor:
167291

168292
class IDecoderJointNoStateLogProb(nn.Module):
169293
"""
170-
Joint network for transducer-like models:
294+
Joint network for transducer-like models (e.g. the original RNN-T):
171295
172296
Getting in time-sync inputs, label-sync inputs,
173297
producing probabilities for labels + blank.
@@ -186,7 +310,7 @@ def __call__(self, *, time_sync_in: nn.Tensor, label_sync_in: nn.Tensor) -> IDec
186310

187311
class IDecoderJointAlignStateLogProb(nn.Module):
188312
"""
189-
Joint network for transducer-like models:
313+
Joint network for transducer-like models (specifically the extended transducer model):
190314
191315
Getting in time-sync inputs, label-sync inputs,
192316
producing probabilities for labels + blank.
@@ -196,12 +320,42 @@ def __call__(self, *,
196320
time_sync_in: nn.Tensor,
197321
label_sync_in: nn.Tensor,
198322
prev_align_label: nn.Tensor,
199-
state: nn.LayerState,
323+
state: nn.LayerState, # align-sync
324+
) -> Tuple[IDecoderJointLogProbOutput, nn.LayerState]:
325+
raise NotImplementedError
326+
327+
328+
class IDecoderJointNoCtxLogProb(nn.Module):
329+
"""
330+
Joint network for CTC-like models, having no dependence on the label context:
331+
332+
Getting in time-sync inputs,
333+
producing probabilities for labels + blank.
334+
"""
335+
336+
def __call__(self, *, time_sync_in: nn.Tensor) -> IDecoderJointLogProbOutput:
337+
raise NotImplementedError
338+
339+
340+
class IDecoderAlignStateLogProb(nn.Module):
341+
"""
342+
Joint network for transducer-like models, no explicit nb label dep, only align-label (like RNA):
343+
344+
Getting in time-sync inputs,
345+
producing probabilities for labels + blank.
346+
"""
347+
348+
def __call__(self, *,
349+
time_sync_in: nn.Tensor,
350+
prev_align_label: nn.Tensor,
351+
state: nn.LayerState, # align-sync
200352
) -> Tuple[IDecoderJointLogProbOutput, nn.LayerState]:
201353
raise NotImplementedError
202354

203355

204-
TDecoderJointNetLogProb = Union[IDecoderLabelSyncLogProb, IDecoderJointNoStateLogProb, IDecoderJointAlignStateLogProb]
356+
TDecoderJointNetLogProb = Union[
357+
IDecoderLabelSyncLogits, IDecoderJointNoStateLogProb, IDecoderJointAlignStateLogProb,
358+
IDecoderJointNoCtxLogProb, IDecoderAlignStateLogProb]
205359

206360

207361
class IDecoderLabelSyncRnn(nn.Module):
@@ -243,7 +397,6 @@ class IDecoderLabelSyncRnn(nn.Module):
243397
def __call__(self, *,
244398
prev_label: nn.Tensor,
245399
encoder_seq: nn.Tensor,
246-
encoder_frame: nn.Tensor,
247400
state: nn.LayerState,
248401
) -> Tuple[nn.Tensor, nn.LayerState]:
249402
raise NotImplementedError
@@ -301,7 +454,7 @@ def __call__(self, *,
301454
class IDecoderStepSyncRnn(nn.Module):
302455
"""
303456
Represents FastRNN in Transducer.
304-
Otherwise in general this runs step-synchronous,
457+
Otherwise, in general this runs step-synchronous,
305458
which is alignment-synchronous or time-synchronous for RNN-T/RNA/CTC,
306459
or label-synchronous for att-enc-dec.
307460
"""

0 commit comments

Comments
 (0)