21
21
add_tf_flow_to_base_flow ,
22
22
)
23
23
from i6_core .util import MultiPath , MultiOutputPath
24
+ from i6_core .mm import CreateDummyMixturesJob
25
+ from i6_core .returnn import ReturnnComputePriorJobV2
24
26
25
- from .nn_system import NnSystem
27
+ from .hybrid_decoder import HybridDecoder
28
+ from .nn_system import NnSystem , returnn_training
26
29
27
30
from .util import (
28
31
RasrInitArgs ,
29
32
ReturnnRasrDataInput ,
30
- OggZipHdfDataInput ,
31
33
HybridArgs ,
32
34
NnRecogArgs ,
33
35
RasrSteps ,
34
36
NnForcedAlignArgs ,
37
+ ReturnnTrainingJobArgs ,
38
+ AllowedReturnnTrainingDataInput ,
35
39
)
36
40
37
41
# -------------------- Init --------------------
@@ -90,11 +94,13 @@ def __init__(
90
94
self .cv_corpora = []
91
95
self .devtrain_corpora = []
92
96
93
- self .train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
94
- self .cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
95
- self .devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
96
- self .dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97
- self .test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97
+ self .train_input_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None
98
+ self .cv_input_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None
99
+ self .devtrain_input_data : Optional [
100
+ Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]
101
+ ] = None
102
+ self .dev_input_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None
103
+ self .test_input_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None
98
104
99
105
self .train_cv_pairing = None
100
106
@@ -128,9 +134,9 @@ def _add_output_alias_for_train_job(
128
134
def init_system (
129
135
self ,
130
136
rasr_init_args : RasrInitArgs ,
131
- train_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
132
- cv_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
133
- devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]]] = None ,
137
+ train_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
138
+ cv_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
139
+ devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None ,
134
140
dev_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
135
141
test_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
136
142
train_cv_pairing : Optional [List [Tuple [str , ...]]] = None , # List[Tuple[trn_c, cv_c, name, dvtr_c]]
@@ -211,27 +217,29 @@ def generate_lattices(self):
211
217
212
218
def returnn_training (
213
219
self ,
214
- name ,
215
- returnn_config ,
216
- nn_train_args ,
220
+ name : str ,
221
+ returnn_config : returnn . ReturnnConfig ,
222
+ nn_train_args : Union [ Dict , ReturnnTrainingJobArgs ] ,
217
223
train_corpus_key ,
218
224
cv_corpus_key ,
219
225
devtrain_corpus_key = None ,
220
- ):
221
- assert isinstance (returnn_config , returnn .ReturnnConfig )
222
-
223
- returnn_config .config ["train" ] = self .train_input_data [train_corpus_key ].get_data_dict ()
224
- returnn_config .config ["dev" ] = self .cv_input_data [cv_corpus_key ].get_data_dict ()
225
- if devtrain_corpus_key is not None :
226
- returnn_config .config ["eval_datasets" ] = {
227
- "devtrain" : self .devtrain_input_data [devtrain_corpus_key ].get_data_dict ()
228
- }
229
-
230
- train_job = returnn .ReturnnTrainingJob (
226
+ ) -> returnn .ReturnnTrainingJob :
227
+ if isinstance (nn_train_args , ReturnnTrainingJobArgs ):
228
+ if nn_train_args .returnn_root is None :
229
+ nn_train_args .returnn_root = self .returnn_root
230
+ if nn_train_args .returnn_python_exe is None :
231
+ nn_train_args .returnn_python_exe = self .returnn_python_exe
232
+
233
+ train_job = returnn_training (
234
+ name = name ,
231
235
returnn_config = returnn_config ,
232
- returnn_root = self .returnn_root ,
233
- returnn_python_exe = self .returnn_python_exe ,
234
- ** nn_train_args ,
236
+ training_args = nn_train_args ,
237
+ train_data = self .train_input_data [train_corpus_key ],
238
+ cv_data = self .cv_input_data [cv_corpus_key ],
239
+ additional_data = {"devtrain" : self .devtrain_input_data [devtrain_corpus_key ]}
240
+ if devtrain_corpus_key is not None
241
+ else None ,
242
+ register_output = False ,
235
243
)
236
244
self ._add_output_alias_for_train_job (
237
245
train_job = train_job ,
@@ -346,7 +354,9 @@ def nn_recognition(
346
354
name : str ,
347
355
returnn_config : returnn .ReturnnConfig ,
348
356
checkpoints : Dict [int , returnn .Checkpoint ],
349
- acoustic_mixture_path : tk .Path , # TODO maybe Optional if prior file provided -> automatically construct dummy file
357
+ acoustic_mixture_path : Optional [
358
+ tk .Path
359
+ ], # TODO maybe Optional if prior file provided -> automatically construct dummy file
350
360
prior_scales : List [float ],
351
361
pronunciation_scales : List [float ],
352
362
lm_scales : List [float ],
@@ -362,6 +372,7 @@ def nn_recognition(
362
372
use_epoch_for_compile = False ,
363
373
forward_output_layer = "output" ,
364
374
native_ops : Optional [List [str ]] = None ,
375
+ train_job : Optional [Union [returnn .ReturnnTrainingJob , returnn .ReturnnRasrTrainingJob ]] = None ,
365
376
** kwargs ,
366
377
):
367
378
with tk .block (f"{ name } _recognition" ):
@@ -383,17 +394,37 @@ def nn_recognition(
383
394
epochs = epochs if epochs is not None else list (checkpoints .keys ())
384
395
385
396
for pron , lm , prior , epoch in itertools .product (pronunciation_scales , lm_scales , prior_scales , epochs ):
386
- assert epoch in checkpoints .keys ()
387
- assert acoustic_mixture_path is not None
388
-
389
- if use_epoch_for_compile :
390
- tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
391
397
398
+ assert epoch in checkpoints .keys ()
399
+ prior_file = None
400
+ lmgc_scorer = None
401
+ if acoustic_mixture_path is None :
402
+ assert train_job is not None , "Need ReturnnTrainingJob for computation of priors"
403
+ tmp_acoustic_mixture_path = CreateDummyMixturesJob (
404
+ num_mixtures = returnn_config .config ["extern_data" ]["classes" ]["dim" ],
405
+ num_features = returnn_config .config ["extern_data" ]["data" ]["dim" ],
406
+ ).out_mixtures
407
+ lmgc_scorer = rasr .GMMFeatureScorer (tmp_acoustic_mixture_path )
408
+ prior_job = ReturnnComputePriorJobV2 (
409
+ model_checkpoint = checkpoints [epoch ],
410
+ returnn_config = train_job .returnn_config ,
411
+ returnn_python_exe = train_job .returnn_python_exe ,
412
+ returnn_root = train_job .returnn_root ,
413
+ log_verbosity = train_job .returnn_config .post_config ["log_verbosity" ],
414
+ )
415
+ prior_job .add_alias ("extract_nn_prior/" + name )
416
+ prior_file = prior_job .out_prior_xml_file
417
+ else :
418
+ tmp_acoustic_mixture_path = acoustic_mixture_path
392
419
scorer = rasr .PrecomputedHybridFeatureScorer (
393
- prior_mixtures = acoustic_mixture_path ,
420
+ prior_mixtures = tmp_acoustic_mixture_path , # This needs to be a new variable otherwise nesting causes undesired behavior
394
421
priori_scale = prior ,
422
+ prior_file = prior_file ,
395
423
)
396
424
425
+ if use_epoch_for_compile :
426
+ tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
427
+
397
428
tf_flow = make_precomputed_hybrid_tf_feature_flow (
398
429
tf_checkpoint = checkpoints [epoch ],
399
430
tf_graph = tf_graph ,
@@ -419,6 +450,8 @@ def nn_recognition(
419
450
parallelize_conversion = parallelize_conversion ,
420
451
rtf = rtf ,
421
452
mem = mem ,
453
+ lmgc_alias = f"lmgc/{ name } /{ recognition_corpus_key } -{ recog_name } " ,
454
+ lmgc_scorer = lmgc_scorer ,
422
455
** kwargs ,
423
456
)
424
457
@@ -429,14 +462,21 @@ def nn_recog(
429
462
returnn_config : Path ,
430
463
checkpoints : Dict [int , returnn .Checkpoint ],
431
464
step_args : HybridArgs ,
465
+ train_job : Union [returnn .ReturnnTrainingJob , returnn .ReturnnRasrTrainingJob ],
432
466
):
433
467
for recog_name , recog_args in step_args .recognition_args .items ():
468
+ recog_args = copy .deepcopy (recog_args )
469
+ whitelist = recog_args .pop ("training_whitelist" , None )
470
+ if whitelist :
471
+ if train_name not in whitelist :
472
+ continue
434
473
for dev_c in self .dev_corpora :
435
474
self .nn_recognition (
436
475
name = f"{ train_corpus_key } -{ train_name } -{ recog_name } " ,
437
476
returnn_config = returnn_config ,
438
477
checkpoints = checkpoints ,
439
478
acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
479
+ train_job = train_job ,
440
480
recognition_corpus_key = dev_c ,
441
481
** recog_args ,
442
482
)
@@ -452,6 +492,7 @@ def nn_recog(
452
492
returnn_config = returnn_config ,
453
493
checkpoints = checkpoints ,
454
494
acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
495
+ train_job = train_job ,
455
496
recognition_corpus_key = tst_c ,
456
497
** r_args ,
457
498
)
@@ -472,7 +513,7 @@ def nn_compile_graph(
472
513
:return: the TF graph
473
514
"""
474
515
graph_compile_job = returnn .CompileTFGraphJob (
475
- returnn_config ,
516
+ returnn_config = returnn_config ,
476
517
epoch = epoch ,
477
518
returnn_root = self .returnn_root ,
478
519
returnn_python_exe = self .returnn_python_exe ,
@@ -509,7 +550,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
509
550
train_corpus_key = trn_c ,
510
551
cv_corpus_key = cv_c ,
511
552
)
512
- else :
553
+ elif isinstance ( self . train_input_data [ trn_c ], AllowedReturnnTrainingDataInput ) :
513
554
returnn_train_job = self .returnn_training (
514
555
name = name ,
515
556
returnn_config = step_args .returnn_training_configs [name ],
@@ -518,6 +559,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
518
559
cv_corpus_key = cv_c ,
519
560
devtrain_corpus_key = dvtr_c ,
520
561
)
562
+ else :
563
+ raise NotImplementedError
521
564
522
565
returnn_recog_config = step_args .returnn_recognition_configs .get (
523
566
name , step_args .returnn_training_configs [name ]
@@ -529,6 +572,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
529
572
returnn_config = returnn_recog_config ,
530
573
checkpoints = returnn_train_job .out_checkpoints ,
531
574
step_args = step_args ,
575
+ train_job = returnn_train_job ,
532
576
)
533
577
534
578
def run_nn_recog_step (self , step_args : NnRecogArgs ):
0 commit comments