Skip to content

Commit da04cf1

Browse files
Refactor RASR system classes (#182)
* small changes for hybrid decoder * cleanup * update hybrid system * remove acoustic mixtures the reason for this is that prior estimating should be always done for performance reasons (WER and RTF) Co-authored-by: Benedikt Hilmes <[email protected]> * remove report generation * fix --------- Co-authored-by: Benedikt Hilmes <[email protected]>
1 parent 46f7f5a commit da04cf1

File tree

3 files changed

+68
-47
lines changed

3 files changed

+68
-47
lines changed

common/setups/rasr/hybrid_decoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CombineLmRasrConfig,
2424
)
2525
from .util.decode import (
26+
DevRecognitionParameters,
2627
RecognitionParameters,
2728
SearchJobArgs,
2829
Lattice2CtmArgs,
@@ -47,7 +48,7 @@ class HybridDecoder(BaseDecoder):
4748
def __init__(
4849
self,
4950
rasr_binary_path: tk.Path,
50-
rasr_arch: "str" = "linux-x86_64-standard",
51+
rasr_arch: str = "linux-x86_64-standard",
5152
compress: bool = False,
5253
append: bool = False,
5354
unbuffered: bool = False,
@@ -155,8 +156,9 @@ def recognition(
155156
tf_fwd_input_name: str = "tf-fwd-input",
156157
):
157158
"""
158-
run the recognitino, consisting of search, lattice to ctm, and scoring
159+
run the recognition, consisting of search, lattice to ctm, and scoring
159160
161+
:param name: decoding name
160162
:param returnn_config: RETURNN config for recognition
161163
:param checkpoints: epoch to model checkpoint mapping
162164
:param recognition_parameters: keys are the corpus keys so that recog params can be set for specific eval sets.

common/setups/rasr/hybrid_system.py

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["HybridArgs", "HybridSystem"]
1+
__all__ = ["HybridSystem"]
22

33
import copy
44
import itertools
@@ -21,17 +21,21 @@
2121
add_tf_flow_to_base_flow,
2222
)
2323
from i6_core.util import MultiPath, MultiOutputPath
24+
from i6_core.mm import CreateDummyMixturesJob
25+
from i6_core.returnn import ReturnnComputePriorJobV2
2426

2527
from .nn_system import NnSystem
28+
from .hybrid_decoder import HybridDecoder
2629

2730
from .util import (
2831
RasrInitArgs,
2932
ReturnnRasrDataInput,
30-
OggZipHdfDataInput,
3133
HybridArgs,
3234
NnRecogArgs,
3335
RasrSteps,
3436
NnForcedAlignArgs,
37+
ReturnnTrainingJobArgs,
38+
AllowedReturnnTrainingDataInput,
3539
)
3640

3741
# -------------------- Init --------------------
@@ -90,9 +94,15 @@ def __init__(
9094
self.cv_corpora = []
9195
self.devtrain_corpora = []
9296

93-
self.train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
94-
self.cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
95-
self.devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97+
self.train_input_data = (
98+
None
99+
) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
100+
self.cv_input_data = (
101+
None
102+
) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
103+
self.devtrain_input_data = (
104+
None
105+
) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
96106
self.dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97107
self.test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
98108

@@ -128,9 +138,9 @@ def _add_output_alias_for_train_job(
128138
def init_system(
129139
self,
130140
rasr_init_args: RasrInitArgs,
131-
train_data: Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]],
132-
cv_data: Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]],
133-
devtrain_data: Optional[Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]]] = None,
141+
train_data: Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]],
142+
cv_data: Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]],
143+
devtrain_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None,
134144
dev_data: Optional[Dict[str, ReturnnRasrDataInput]] = None,
135145
test_data: Optional[Dict[str, ReturnnRasrDataInput]] = None,
136146
train_cv_pairing: Optional[List[Tuple[str, ...]]] = None, # List[Tuple[trn_c, cv_c, name, dvtr_c]]
@@ -211,21 +221,17 @@ def generate_lattices(self):
211221

212222
def returnn_training(
213223
self,
214-
name,
215-
returnn_config,
216-
nn_train_args,
224+
name: str,
225+
returnn_config: returnn.ReturnnConfig,
226+
nn_train_args: Union[Dict, ReturnnTrainingJobArgs],
217227
train_corpus_key,
218228
cv_corpus_key,
219229
devtrain_corpus_key=None,
220-
):
221-
assert isinstance(returnn_config, returnn.ReturnnConfig)
222-
223-
returnn_config.config["train"] = self.train_input_data[train_corpus_key].get_data_dict()
224-
returnn_config.config["dev"] = self.cv_input_data[cv_corpus_key].get_data_dict()
225-
if devtrain_corpus_key is not None:
226-
returnn_config.config["eval_datasets"] = {
227-
"devtrain": self.devtrain_input_data[devtrain_corpus_key].get_data_dict()
228-
}
230+
) -> returnn.ReturnnTrainingJob:
231+
if nn_train_args.returnn_root is None:
232+
nn_train_args.returnn_root = self.returnn_root
233+
if nn_train_args.returnn_python_exe is None:
234+
nn_train_args.returnn_python_exe = self.returnn_python_exe
229235

230236
train_job = returnn.ReturnnTrainingJob(
231237
returnn_config=returnn_config,
@@ -346,7 +352,7 @@ def nn_recognition(
346352
name: str,
347353
returnn_config: returnn.ReturnnConfig,
348354
checkpoints: Dict[int, returnn.Checkpoint],
349-
acoustic_mixture_path: tk.Path, # TODO maybe Optional if prior file provided -> automatically construct dummy file
355+
train_job: Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob],
350356
prior_scales: List[float],
351357
pronunciation_scales: List[float],
352358
lm_scales: List[float],
@@ -384,15 +390,31 @@ def nn_recognition(
384390

385391
for pron, lm, prior, epoch in itertools.product(pronunciation_scales, lm_scales, prior_scales, epochs):
386392
assert epoch in checkpoints.keys()
387-
assert acoustic_mixture_path is not None
388-
389-
if use_epoch_for_compile:
390-
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)
393+
acoustic_mixture_path = CreateDummyMixturesJob(
394+
num_mixtures=returnn_config.config["extern_data"]["classes"]["dim"],
395+
num_features=returnn_config.config["extern_data"]["data"]["dim"],
396+
).out_mixtures
397+
lmgc_scorer = rasr.GMMFeatureScorer(acoustic_mixture_path)
398+
prior_job = ReturnnComputePriorJobV2(
399+
model_checkpoint=checkpoints[epoch],
400+
returnn_config=train_job.returnn_config,
401+
returnn_python_exe=train_job.returnn_python_exe,
402+
returnn_root=train_job.returnn_root,
403+
log_verbosity=train_job.returnn_config.post_config["log_verbosity"],
404+
)
391405

406+
prior_job.add_alias("extract_nn_prior/" + name)
407+
prior_file = prior_job.out_prior_xml_file
408+
assert prior_file is not None
392409
scorer = rasr.PrecomputedHybridFeatureScorer(
393410
prior_mixtures=acoustic_mixture_path,
394411
priori_scale=prior,
412+
prior_file=prior_file,
395413
)
414+
assert acoustic_mixture_path is not None
415+
416+
if use_epoch_for_compile:
417+
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)
396418

397419
tf_flow = make_precomputed_hybrid_tf_feature_flow(
398420
tf_checkpoint=checkpoints[epoch],
@@ -419,6 +441,8 @@ def nn_recognition(
419441
parallelize_conversion=parallelize_conversion,
420442
rtf=rtf,
421443
mem=mem,
444+
lmgc_alias=f"lmgc/{name}/{recognition_corpus_key}-{recog_name}",
445+
lmgc_scorer=lmgc_scorer,
422446
**kwargs,
423447
)
424448

@@ -429,15 +453,22 @@ def nn_recog(
429453
returnn_config: Path,
430454
checkpoints: Dict[int, returnn.Checkpoint],
431455
step_args: HybridArgs,
456+
train_job: Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob],
432457
):
433458
for recog_name, recog_args in step_args.recognition_args.items():
459+
recog_args = copy.deepcopy(recog_args)
460+
whitelist = recog_args.pop("training_whitelist", None)
461+
if whitelist:
462+
if train_name not in whitelist:
463+
continue
434464
for dev_c in self.dev_corpora:
435465
self.nn_recognition(
436466
name=f"{train_corpus_key}-{train_name}-{recog_name}",
437467
returnn_config=returnn_config,
438468
checkpoints=checkpoints,
439-
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
469+
train_job=train_job,
440470
recognition_corpus_key=dev_c,
471+
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
441472
**recog_args,
442473
)
443474

@@ -451,8 +482,9 @@ def nn_recog(
451482
name=f"{train_name}-{recog_name}",
452483
returnn_config=returnn_config,
453484
checkpoints=checkpoints,
454-
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
485+
train_job=train_job,
455486
recognition_corpus_key=tst_c,
487+
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
456488
**r_args,
457489
)
458490

@@ -509,7 +541,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
509541
train_corpus_key=trn_c,
510542
cv_corpus_key=cv_c,
511543
)
512-
else:
544+
elif isinstance(self.train_input_data[trn_c], AllowedReturnnTrainingDataInput):
513545
returnn_train_job = self.returnn_training(
514546
name=name,
515547
returnn_config=step_args.returnn_training_configs[name],
@@ -518,6 +550,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
518550
cv_corpus_key=cv_c,
519551
devtrain_corpus_key=dvtr_c,
520552
)
553+
else:
554+
raise NotImplementedError
521555

522556
returnn_recog_config = step_args.returnn_recognition_configs.get(
523557
name, step_args.returnn_training_configs[name]
@@ -529,6 +563,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
529563
returnn_config=returnn_recog_config,
530564
checkpoints=returnn_train_job.out_checkpoints,
531565
step_args=step_args,
566+
train_job=returnn_train_job,
532567
)
533568

534569
def run_nn_recog_step(self, step_args: NnRecogArgs):

common/setups/rasr/nn_system.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,20 @@
11
__all__ = ["NnSystem"]
22

33
import copy
4-
import itertools
5-
import sys
64
from dataclasses import asdict
7-
from typing import Dict, List, Optional, Tuple, Union
5+
from typing import Dict, List, Optional, Union
86

97
# -------------------- Sisyphus --------------------
108

119
import sisyphus.toolkit as tk
1210
import sisyphus.global_settings as gs
1311

14-
from sisyphus.delayed_ops import DelayedFormat
15-
1612
# -------------------- Recipes --------------------
1713

18-
import i6_core.features as features
19-
import i6_core.rasr as rasr
2014
import i6_core.returnn as returnn
2115

22-
from i6_core.util import MultiPath, MultiOutputPath
23-
2416
from .rasr_system import RasrSystem
2517

26-
from .util import (
27-
RasrInitArgs,
28-
ReturnnRasrDataInput,
29-
OggZipHdfDataInput,
30-
HybridArgs,
31-
NnRecogArgs,
32-
RasrSteps,
33-
)
3418

3519
# -------------------- Init --------------------
3620

0 commit comments

Comments
 (0)