1
- __all__ = ["HybridArgs" , " HybridSystem" ]
1
+ __all__ = ["HybridSystem" ]
2
2
3
3
import copy
4
4
import itertools
21
21
add_tf_flow_to_base_flow ,
22
22
)
23
23
from i6_core .util import MultiPath , MultiOutputPath
24
+ from i6_core .mm import CreateDummyMixturesJob
25
+ from i6_core .returnn import ReturnnComputePriorJobV2
24
26
25
27
from .nn_system import NnSystem
28
+ from .hybrid_decoder import HybridDecoder
26
29
27
30
from .util import (
28
31
RasrInitArgs ,
29
32
ReturnnRasrDataInput ,
30
- OggZipHdfDataInput ,
31
33
HybridArgs ,
32
34
NnRecogArgs ,
33
35
RasrSteps ,
34
36
NnForcedAlignArgs ,
37
+ ReturnnTrainingJobArgs ,
38
+ AllowedReturnnTrainingDataInput ,
35
39
)
36
40
37
41
# -------------------- Init --------------------
@@ -90,9 +94,15 @@ def __init__(
90
94
self .cv_corpora = []
91
95
self .devtrain_corpora = []
92
96
93
- self .train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
94
- self .cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
95
- self .devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97
+ self .train_input_data = (
98
+ None
99
+ ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
100
+ self .cv_input_data = (
101
+ None
102
+ ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
103
+ self .devtrain_input_data = (
104
+ None
105
+ ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
96
106
self .dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97
107
self .test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
98
108
@@ -128,9 +138,9 @@ def _add_output_alias_for_train_job(
128
138
def init_system (
129
139
self ,
130
140
rasr_init_args : RasrInitArgs ,
131
- train_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
132
- cv_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
133
- devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]]] = None ,
141
+ train_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
142
+ cv_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
143
+ devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None ,
134
144
dev_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
135
145
test_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
136
146
train_cv_pairing : Optional [List [Tuple [str , ...]]] = None , # List[Tuple[trn_c, cv_c, name, dvtr_c]]
@@ -211,21 +221,17 @@ def generate_lattices(self):
211
221
212
222
def returnn_training (
213
223
self ,
214
- name ,
215
- returnn_config ,
216
- nn_train_args ,
224
+ name : str ,
225
+ returnn_config : returnn . ReturnnConfig ,
226
+ nn_train_args : Union [ Dict , ReturnnTrainingJobArgs ] ,
217
227
train_corpus_key ,
218
228
cv_corpus_key ,
219
229
devtrain_corpus_key = None ,
220
- ):
221
- assert isinstance (returnn_config , returnn .ReturnnConfig )
222
-
223
- returnn_config .config ["train" ] = self .train_input_data [train_corpus_key ].get_data_dict ()
224
- returnn_config .config ["dev" ] = self .cv_input_data [cv_corpus_key ].get_data_dict ()
225
- if devtrain_corpus_key is not None :
226
- returnn_config .config ["eval_datasets" ] = {
227
- "devtrain" : self .devtrain_input_data [devtrain_corpus_key ].get_data_dict ()
228
- }
230
+ ) -> returnn .ReturnnTrainingJob :
231
+ if nn_train_args .returnn_root is None :
232
+ nn_train_args .returnn_root = self .returnn_root
233
+ if nn_train_args .returnn_python_exe is None :
234
+ nn_train_args .returnn_python_exe = self .returnn_python_exe
229
235
230
236
train_job = returnn .ReturnnTrainingJob (
231
237
returnn_config = returnn_config ,
@@ -346,7 +352,7 @@ def nn_recognition(
346
352
name : str ,
347
353
returnn_config : returnn .ReturnnConfig ,
348
354
checkpoints : Dict [int , returnn .Checkpoint ],
349
- acoustic_mixture_path : tk . Path , # TODO maybe Optional if prior file provided -> automatically construct dummy file
355
+ train_job : Union [ returnn . ReturnnTrainingJob , returnn . ReturnnRasrTrainingJob ],
350
356
prior_scales : List [float ],
351
357
pronunciation_scales : List [float ],
352
358
lm_scales : List [float ],
@@ -384,15 +390,31 @@ def nn_recognition(
384
390
385
391
for pron , lm , prior , epoch in itertools .product (pronunciation_scales , lm_scales , prior_scales , epochs ):
386
392
assert epoch in checkpoints .keys ()
387
- assert acoustic_mixture_path is not None
388
-
389
- if use_epoch_for_compile :
390
- tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
393
+ acoustic_mixture_path = CreateDummyMixturesJob (
394
+ num_mixtures = returnn_config .config ["extern_data" ]["classes" ]["dim" ],
395
+ num_features = returnn_config .config ["extern_data" ]["data" ]["dim" ],
396
+ ).out_mixtures
397
+ lmgc_scorer = rasr .GMMFeatureScorer (acoustic_mixture_path )
398
+ prior_job = ReturnnComputePriorJobV2 (
399
+ model_checkpoint = checkpoints [epoch ],
400
+ returnn_config = train_job .returnn_config ,
401
+ returnn_python_exe = train_job .returnn_python_exe ,
402
+ returnn_root = train_job .returnn_root ,
403
+ log_verbosity = train_job .returnn_config .post_config ["log_verbosity" ],
404
+ )
391
405
406
+ prior_job .add_alias ("extract_nn_prior/" + name )
407
+ prior_file = prior_job .out_prior_xml_file
408
+ assert prior_file is not None
392
409
scorer = rasr .PrecomputedHybridFeatureScorer (
393
410
prior_mixtures = acoustic_mixture_path ,
394
411
priori_scale = prior ,
412
+ prior_file = prior_file ,
395
413
)
414
+ assert acoustic_mixture_path is not None
415
+
416
+ if use_epoch_for_compile :
417
+ tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
396
418
397
419
tf_flow = make_precomputed_hybrid_tf_feature_flow (
398
420
tf_checkpoint = checkpoints [epoch ],
@@ -419,6 +441,8 @@ def nn_recognition(
419
441
parallelize_conversion = parallelize_conversion ,
420
442
rtf = rtf ,
421
443
mem = mem ,
444
+ lmgc_alias = f"lmgc/{ name } /{ recognition_corpus_key } -{ recog_name } " ,
445
+ lmgc_scorer = lmgc_scorer ,
422
446
** kwargs ,
423
447
)
424
448
@@ -429,15 +453,22 @@ def nn_recog(
429
453
returnn_config : Path ,
430
454
checkpoints : Dict [int , returnn .Checkpoint ],
431
455
step_args : HybridArgs ,
456
+ train_job : Union [returnn .ReturnnTrainingJob , returnn .ReturnnRasrTrainingJob ],
432
457
):
433
458
for recog_name , recog_args in step_args .recognition_args .items ():
459
+ recog_args = copy .deepcopy (recog_args )
460
+ whitelist = recog_args .pop ("training_whitelist" , None )
461
+ if whitelist :
462
+ if train_name not in whitelist :
463
+ continue
434
464
for dev_c in self .dev_corpora :
435
465
self .nn_recognition (
436
466
name = f"{ train_corpus_key } -{ train_name } -{ recog_name } " ,
437
467
returnn_config = returnn_config ,
438
468
checkpoints = checkpoints ,
439
- acoustic_mixture_path = self . train_input_data [ train_corpus_key ]. acoustic_mixtures ,
469
+ train_job = train_job ,
440
470
recognition_corpus_key = dev_c ,
471
+ acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
441
472
** recog_args ,
442
473
)
443
474
@@ -451,8 +482,9 @@ def nn_recog(
451
482
name = f"{ train_name } -{ recog_name } " ,
452
483
returnn_config = returnn_config ,
453
484
checkpoints = checkpoints ,
454
- acoustic_mixture_path = self . train_input_data [ train_corpus_key ]. acoustic_mixtures ,
485
+ train_job = train_job ,
455
486
recognition_corpus_key = tst_c ,
487
+ acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
456
488
** r_args ,
457
489
)
458
490
@@ -509,7 +541,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
509
541
train_corpus_key = trn_c ,
510
542
cv_corpus_key = cv_c ,
511
543
)
512
- else :
544
+ elif isinstance ( self . train_input_data [ trn_c ], AllowedReturnnTrainingDataInput ) :
513
545
returnn_train_job = self .returnn_training (
514
546
name = name ,
515
547
returnn_config = step_args .returnn_training_configs [name ],
@@ -518,6 +550,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
518
550
cv_corpus_key = cv_c ,
519
551
devtrain_corpus_key = dvtr_c ,
520
552
)
553
+ else :
554
+ raise NotImplementedError
521
555
522
556
returnn_recog_config = step_args .returnn_recognition_configs .get (
523
557
name , step_args .returnn_training_configs [name ]
@@ -529,6 +563,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
529
563
returnn_config = returnn_recog_config ,
530
564
checkpoints = returnn_train_job .out_checkpoints ,
531
565
step_args = step_args ,
566
+ train_job = returnn_train_job ,
532
567
)
533
568
534
569
def run_nn_recog_step (self , step_args : NnRecogArgs ):
0 commit comments