16
16
"""
17
17
18
18
from __future__ import annotations
19
- from typing import Union , Tuple
19
+ from typing import Union , Optional , Tuple
20
20
from enum import Enum
21
21
import dataclasses
22
22
from ... import nn
@@ -36,14 +36,14 @@ class Decoder(nn.Module):
36
36
"""
37
37
Generic decoder, for attention-based encoder-decoder or transducer.
38
38
Can use label-sync label topology, or time-sync (RNA/CTC), or with vertical transitions (RNN-T).
39
- The label emitted in the current step is referred to as alignment-label (or step-label),
39
+ The label emitted in the current (align) step is referred to as alignment-label (or step-label),
40
40
and can include blank in case this is not label-sync.
41
41
42
42
None of this is really enforced here, and what mainly defines the interfaces
43
43
is the dependency graph.
44
44
The returned shapes and time axes could be anything,
45
45
as long as it fits together.
46
- The step_sync_rnn could also return a 4D tensor with both time-axis and label-axis.
46
+ The predictor could also return a 4D tensor with both time-axis and label-axis.
47
47
48
48
Dependency graph:
49
49
@@ -59,19 +59,129 @@ class Decoder(nn.Module):
59
59
60
60
def __init__ (self , * ,
61
61
label_topology : LabelTopology ,
62
- label_sync_rnn : TDecoderLabelSync ,
63
- joint_net_log_prob : TDecoderJointNetLogProb ,
62
+ label_predict_enc : Optional [TDecoderLabelSync ],
63
+ predictor : TDecoderJointNetLogProb ,
64
+ target_dim : nn .Dim ,
65
+ target_bos_symbol : int = 0 ,
66
+ target_eos_symbol : int = 0 ,
64
67
):
65
68
super ().__init__ ()
66
69
self .label_topology = label_topology
67
- self .label_sync_rnn = label_sync_rnn # earlier: slow_rnn
68
- self .joint_net_log_prob = joint_net_log_prob # earlier: fast_rnn + readout
70
+ self .label_predict_enc = label_predict_enc # earlier: slow_rnn. label-sync. incl (nb) label embedding
71
+ self .predictor = predictor # earlier: fast_rnn + readout. align-sync or matrix time * label. predicts align label
72
+ self .target_dim = target_dim # includes blank if not label-sync
73
+ self .target_bos_symbol = target_bos_symbol
74
+ self .target_eos_symbol = target_eos_symbol
69
75
70
- def __call__ (self , encoder : nn .Tensor ) -> nn .Tensor :
76
+ def __call__ (self , * ,
77
+ encoder : nn .Tensor ,
78
+ encoder_spatial_axis : nn .Dim ,
79
+ target : Optional [Union [nn .Tensor , nn .SearchFuncInterface ]] = None ,
80
+ axis : Optional [nn .Dim ] = None ,
81
+ state : Optional [nn .LayerState ] = None ,
82
+ ) -> Tuple [nn .Tensor , nn .LayerState ]:
71
83
"""
72
84
Make one decoder step (train and/or recognition).
73
85
"""
74
86
# TODO ...
87
+ search = None
88
+ if isinstance (target , nn .SearchFuncInterface ):
89
+ search = target
90
+ target = None
91
+ if target is not None :
92
+ assert axis , f"{ self } : Target spatial axis must be specified when target is given"
93
+ loop = nn .Loop (axis = axis )
94
+ loop .state = state if state else self .default_initial_state ()
95
+ with loop :
96
+
97
+ if self .label_predict_enc is None :
98
+ label_predict_enc = None
99
+ elif isinstance (self .label_predict_enc , IDecoderLabelSyncRnn ):
100
+ label_predict_enc , loop .state .label_predict_enc = self .label_predict_enc (
101
+ prev_label = loop .state .label_nb ,
102
+ encoder_seq = encoder ,
103
+ state = loop .state .label_predict_enc )
104
+ elif isinstance (self .label_predict_enc , IDecoderLabelSyncLabelsOnlyRnn ):
105
+ label_predict_enc , loop .state .label_predict_enc = self .label_predict_enc (
106
+ prev_label = loop .state .label_nb ,
107
+ state = loop .state .label_predict_enc )
108
+ elif isinstance (self .label_predict_enc , IDecoderLabelSyncAlignDepRnn ):
109
+ encoder_frame = ... # TODO align dep. or unstack if time-sync
110
+ label_predict_enc , loop .state .label_predict_enc = self .label_predict_enc (
111
+ prev_label = loop .state .label_nb ,
112
+ encoder_seq = encoder ,
113
+ encoder_frame = encoder_frame ,
114
+ state = loop .state .label_predict_enc )
115
+ else :
116
+ raise TypeError (f"{ self } : Unsupported label_predict_enc type { type (self .label_predict_enc )} " )
117
+
118
+ if isinstance (self .predictor , IDecoderLabelSyncLogits ):
119
+ assert self .label_topology == LabelTopology .LABEL_SYNC , f"{ self } : Label topology must be label-sync"
120
+ assert label_predict_enc is not None , f"{ self } : Label predict encoder must be specified"
121
+ probs = self .predictor (label_sync_in = label_predict_enc )
122
+ probs_type = "logits"
123
+ elif isinstance (self .predictor , IDecoderJointNoStateLogProb ):
124
+ assert self .label_topology != LabelTopology .LABEL_SYNC , f"{ self } : Label topology must not be label-sync"
125
+ assert label_predict_enc is not None , f"{ self } : Label predict encoder must be specified"
126
+ encoder_frame = ... # TODO share with above
127
+ predictor_out = self .predictor (time_sync_in = encoder_frame , label_sync_in = label_predict_enc )
128
+ probs = predictor_out .prob_like_wb
129
+ probs_type = predictor_out .prob_like_type
130
+ elif isinstance (self .predictor , IDecoderJointAlignStateLogProb ):
131
+ assert self .label_topology != LabelTopology .LABEL_SYNC , f"{ self } : Label topology must not be label-sync"
132
+ assert label_predict_enc is not None , f"{ self } : Label predict encoder must be specified"
133
+ encoder_frame = ... # TODO share with above
134
+ predictor_out , loop .state .predictor = self .predictor (
135
+ time_sync_in = encoder_frame ,
136
+ label_sync_in = label_predict_enc ,
137
+ prev_align_label = loop .state .label_wb ,
138
+ state = loop .state .predictor )
139
+ probs = predictor_out .prob_like_wb
140
+ probs_type = predictor_out .prob_like_type
141
+ elif isinstance (self .predictor , IDecoderJointNoCtxLogProb ):
142
+ assert self .label_topology != LabelTopology .LABEL_SYNC , f"{ self } : Label topology must not be label-sync"
143
+ assert label_predict_enc is None , f"{ self } : Label predict encoder not used"
144
+ encoder_frame = ... # TODO share with above
145
+ predictor_out = self .predictor (time_sync_in = encoder_frame )
146
+ probs = predictor_out .prob_like_wb
147
+ probs_type = predictor_out .prob_like_type
148
+ elif isinstance (self .predictor , IDecoderAlignStateLogProb ):
149
+ assert self .label_topology != LabelTopology .LABEL_SYNC , f"{ self } : Label topology must not be label-sync"
150
+ assert label_predict_enc is None , f"{ self } : Label predict encoder not used"
151
+ encoder_frame = ... # TODO share with above
152
+ predictor_out , loop .state .predictor = self .predictor (
153
+ time_sync_in = encoder_frame ,
154
+ prev_align_label = loop .state .label_wb ,
155
+ state = loop .state .predictor )
156
+ probs = predictor_out .prob_like_wb
157
+ probs_type = predictor_out .prob_like_type
158
+ else :
159
+ raise TypeError (f"{ self } : Unsupported predictor type { type (self .predictor )} " )
160
+
161
+ # TODO loss handling here? in that case, cleverly do the most efficient?
162
+ # TODO logits instead of log probs?
163
+ # TODO see below, related is whether and we output
164
+
165
+ target = loop .unstack (target ) if target is not None else None
166
+ if search :
167
+ search .apply_loop (loop )
168
+ align_label = search .choice (probs = probs , probs_type = probs_type )
169
+ else :
170
+ assert target is not None
171
+ align_label = target
172
+ if self .label_topology == LabelTopology .LABEL_SYNC :
173
+ loop .state .label_nb = align_label
174
+ loop .end (loop .state .label_nb == self .target_eos_symbol , include_eos = False )
175
+ else :
176
+ loop .state .label_wb = align_label
177
+
178
+ out_labels = loop .stack (align_label ) if target is None else None
179
+ # TODO? out_logits = loop.stack(logits) # TODO not necessarily logits...
180
+
181
+ return out_labels , loop .state
182
+
183
+ def default_initial_state (self ) -> Optional [nn .LayerState ]:
184
+ """default init state"""
75
185
76
186
77
187
# TODO enc ctx module
@@ -90,6 +200,20 @@ def blank_idx(self) -> int:
90
200
"""
91
201
raise NotImplementedError
92
202
203
+ @property
204
+ def prob_like_wb (self ) -> nn .Tensor :
205
+ """
206
+ :return: logits if possible, else log probs. see prob_like_type
207
+ """
208
+ return self .log_prob_wb
209
+
210
+ @property
211
+ def prob_like_type (self ) -> str :
212
+ """
213
+ :return: type of prob_like_wb. "logits" or "log_prob"
214
+ """
215
+ return "log_prob"
216
+
93
217
@property
94
218
def log_prob_wb (self ) -> nn .Tensor :
95
219
"""
@@ -152,12 +276,12 @@ class DecoderJointLogProbSeparatedOutput(IDecoderJointLogProbOutput):
152
276
log_prob_not_blank : nn .Tensor # log(-expm1(log_prob_blank)) but you maybe could calc it more directly
153
277
154
278
155
- class IDecoderLabelSyncLogProb (nn .Module ):
279
+ class IDecoderLabelSyncLogits (nn .Module ):
156
280
"""
157
281
For simple (maybe attention-based) encoder-decoder models,
158
282
getting input from some label-sync encoding (TDecoderLabelSync).
159
283
160
- This will produce log probs for non-blank labels.
284
+ This will produce logits (non-normalized log probs) for non-blank labels.
161
285
There is no blank in this concept.
162
286
"""
163
287
@@ -167,7 +291,7 @@ def __call__(self, *, label_sync_in: nn.Tensor) -> nn.Tensor:
167
291
168
292
class IDecoderJointNoStateLogProb (nn .Module ):
169
293
"""
170
- Joint network for transducer-like models:
294
+ Joint network for transducer-like models (e.g. the original RNN-T) :
171
295
172
296
Getting in time-sync inputs, label-sync inputs,
173
297
producing probabilities for labels + blank.
@@ -186,7 +310,7 @@ def __call__(self, *, time_sync_in: nn.Tensor, label_sync_in: nn.Tensor) -> IDec
186
310
187
311
class IDecoderJointAlignStateLogProb (nn .Module ):
188
312
"""
189
- Joint network for transducer-like models:
313
+ Joint network for transducer-like models (specifically the extended transducer model) :
190
314
191
315
Getting in time-sync inputs, label-sync inputs,
192
316
producing probabilities for labels + blank.
@@ -196,12 +320,42 @@ def __call__(self, *,
196
320
time_sync_in : nn .Tensor ,
197
321
label_sync_in : nn .Tensor ,
198
322
prev_align_label : nn .Tensor ,
199
- state : nn .LayerState ,
323
+ state : nn .LayerState , # align-sync
324
+ ) -> Tuple [IDecoderJointLogProbOutput , nn .LayerState ]:
325
+ raise NotImplementedError
326
+
327
+
328
+ class IDecoderJointNoCtxLogProb (nn .Module ):
329
+ """
330
+ Joint network for CTC-like models, having no dependence on the label context:
331
+
332
+ Getting in time-sync inputs,
333
+ producing probabilities for labels + blank.
334
+ """
335
+
336
+ def __call__ (self , * , time_sync_in : nn .Tensor ) -> IDecoderJointLogProbOutput :
337
+ raise NotImplementedError
338
+
339
+
340
+ class IDecoderAlignStateLogProb (nn .Module ):
341
+ """
342
+ Joint network for transducer-like models, no explicit nb label dep, only align-label (like RNA):
343
+
344
+ Getting in time-sync inputs,
345
+ producing probabilities for labels + blank.
346
+ """
347
+
348
+ def __call__ (self , * ,
349
+ time_sync_in : nn .Tensor ,
350
+ prev_align_label : nn .Tensor ,
351
+ state : nn .LayerState , # align-sync
200
352
) -> Tuple [IDecoderJointLogProbOutput , nn .LayerState ]:
201
353
raise NotImplementedError
202
354
203
355
204
- TDecoderJointNetLogProb = Union [IDecoderLabelSyncLogProb , IDecoderJointNoStateLogProb , IDecoderJointAlignStateLogProb ]
356
+ TDecoderJointNetLogProb = Union [
357
+ IDecoderLabelSyncLogits , IDecoderJointNoStateLogProb , IDecoderJointAlignStateLogProb ,
358
+ IDecoderJointNoCtxLogProb , IDecoderAlignStateLogProb ]
205
359
206
360
207
361
class IDecoderLabelSyncRnn (nn .Module ):
@@ -243,7 +397,6 @@ class IDecoderLabelSyncRnn(nn.Module):
243
397
def __call__ (self , * ,
244
398
prev_label : nn .Tensor ,
245
399
encoder_seq : nn .Tensor ,
246
- encoder_frame : nn .Tensor ,
247
400
state : nn .LayerState ,
248
401
) -> Tuple [nn .Tensor , nn .LayerState ]:
249
402
raise NotImplementedError
@@ -301,7 +454,7 @@ def __call__(self, *,
301
454
class IDecoderStepSyncRnn (nn .Module ):
302
455
"""
303
456
Represents FastRNN in Transducer.
304
- Otherwise in general this runs step-synchronous,
457
+ Otherwise, in general this runs step-synchronous,
305
458
which is alignment-synchronous or time-synchronous for RNN-T/RNA/CTC,
306
459
or label-synchronous for att-enc-dec.
307
460
"""
0 commit comments