Skip to content

Commit 10685d7

Browse files
committed
updates
1 parent e85d515 commit 10685d7

File tree

2 files changed

+116
-57
lines changed

2 files changed

+116
-57
lines changed

common/setups/rasr/hybrid_system.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,21 @@
2121
add_tf_flow_to_base_flow,
2222
)
2323
from i6_core.util import MultiPath, MultiOutputPath
24+
from i6_core.mm import CreateDummyMixturesJob
25+
from i6_core.returnn import ReturnnComputePriorJobV2
2426

25-
from .nn_system import NnSystem
27+
from .hybrid_decoder import HybridDecoder
28+
from .nn_system import NnSystem, returnn_training
2629

2730
from .util import (
2831
RasrInitArgs,
2932
ReturnnRasrDataInput,
30-
OggZipHdfDataInput,
3133
HybridArgs,
3234
NnRecogArgs,
3335
RasrSteps,
3436
NnForcedAlignArgs,
37+
ReturnnTrainingJobArgs,
38+
AllowedReturnnTrainingDataInput,
3539
)
3640

3741
# -------------------- Init --------------------
@@ -90,11 +94,13 @@ def __init__(
9094
self.cv_corpora = []
9195
self.devtrain_corpora = []
9296

93-
self.train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
94-
self.cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
95-
self.devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
96-
self.dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97-
self.test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97+
self.train_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None
98+
self.cv_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None
99+
self.devtrain_input_data: Optional[
100+
Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]
101+
] = None
102+
self.dev_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None
103+
self.test_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None
98104

99105
self.train_cv_pairing = None
100106

@@ -128,9 +134,9 @@ def _add_output_alias_for_train_job(
128134
def init_system(
129135
self,
130136
rasr_init_args: RasrInitArgs,
131-
train_data: Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]],
132-
cv_data: Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]],
133-
devtrain_data: Optional[Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]]] = None,
137+
train_data: Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]],
138+
cv_data: Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]],
139+
devtrain_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None,
134140
dev_data: Optional[Dict[str, ReturnnRasrDataInput]] = None,
135141
test_data: Optional[Dict[str, ReturnnRasrDataInput]] = None,
136142
train_cv_pairing: Optional[List[Tuple[str, ...]]] = None, # List[Tuple[trn_c, cv_c, name, dvtr_c]]
@@ -211,27 +217,29 @@ def generate_lattices(self):
211217

212218
def returnn_training(
213219
self,
214-
name,
215-
returnn_config,
216-
nn_train_args,
220+
name: str,
221+
returnn_config: returnn.ReturnnConfig,
222+
nn_train_args: Union[Dict, ReturnnTrainingJobArgs],
217223
train_corpus_key,
218224
cv_corpus_key,
219225
devtrain_corpus_key=None,
220-
):
221-
assert isinstance(returnn_config, returnn.ReturnnConfig)
222-
223-
returnn_config.config["train"] = self.train_input_data[train_corpus_key].get_data_dict()
224-
returnn_config.config["dev"] = self.cv_input_data[cv_corpus_key].get_data_dict()
225-
if devtrain_corpus_key is not None:
226-
returnn_config.config["eval_datasets"] = {
227-
"devtrain": self.devtrain_input_data[devtrain_corpus_key].get_data_dict()
228-
}
229-
230-
train_job = returnn.ReturnnTrainingJob(
226+
) -> returnn.ReturnnTrainingJob:
227+
if isinstance(nn_train_args, ReturnnTrainingJobArgs):
228+
if nn_train_args.returnn_root is None:
229+
nn_train_args.returnn_root = self.returnn_root
230+
if nn_train_args.returnn_python_exe is None:
231+
nn_train_args.returnn_python_exe = self.returnn_python_exe
232+
233+
train_job = returnn_training(
234+
name=name,
231235
returnn_config=returnn_config,
232-
returnn_root=self.returnn_root,
233-
returnn_python_exe=self.returnn_python_exe,
234-
**nn_train_args,
236+
training_args=nn_train_args,
237+
train_data=self.train_input_data[train_corpus_key],
238+
cv_data=self.cv_input_data[cv_corpus_key],
239+
additional_data={"devtrain": self.devtrain_input_data[devtrain_corpus_key]}
240+
if devtrain_corpus_key is not None
241+
else None,
242+
register_output=False,
235243
)
236244
self._add_output_alias_for_train_job(
237245
train_job=train_job,
@@ -346,7 +354,9 @@ def nn_recognition(
346354
name: str,
347355
returnn_config: returnn.ReturnnConfig,
348356
checkpoints: Dict[int, returnn.Checkpoint],
349-
acoustic_mixture_path: tk.Path, # TODO maybe Optional if prior file provided -> automatically construct dummy file
357+
acoustic_mixture_path: Optional[
358+
tk.Path
359+
], # TODO maybe Optional if prior file provided -> automatically construct dummy file
350360
prior_scales: List[float],
351361
pronunciation_scales: List[float],
352362
lm_scales: List[float],
@@ -362,6 +372,7 @@ def nn_recognition(
362372
use_epoch_for_compile=False,
363373
forward_output_layer="output",
364374
native_ops: Optional[List[str]] = None,
375+
train_job: Optional[Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob]] = None,
365376
**kwargs,
366377
):
367378
with tk.block(f"{name}_recognition"):
@@ -383,17 +394,37 @@ def nn_recognition(
383394
epochs = epochs if epochs is not None else list(checkpoints.keys())
384395

385396
for pron, lm, prior, epoch in itertools.product(pronunciation_scales, lm_scales, prior_scales, epochs):
386-
assert epoch in checkpoints.keys()
387-
assert acoustic_mixture_path is not None
388-
389-
if use_epoch_for_compile:
390-
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)
391397

398+
assert epoch in checkpoints.keys()
399+
prior_file = None
400+
lmgc_scorer = None
401+
if acoustic_mixture_path is None:
402+
assert train_job is not None, "Need ReturnnTrainingJob for computation of priors"
403+
tmp_acoustic_mixture_path = CreateDummyMixturesJob(
404+
num_mixtures=returnn_config.config["extern_data"]["classes"]["dim"],
405+
num_features=returnn_config.config["extern_data"]["data"]["dim"],
406+
).out_mixtures
407+
lmgc_scorer = rasr.GMMFeatureScorer(tmp_acoustic_mixture_path)
408+
prior_job = ReturnnComputePriorJobV2(
409+
model_checkpoint=checkpoints[epoch],
410+
returnn_config=train_job.returnn_config,
411+
returnn_python_exe=train_job.returnn_python_exe,
412+
returnn_root=train_job.returnn_root,
413+
log_verbosity=train_job.returnn_config.post_config["log_verbosity"],
414+
)
415+
prior_job.add_alias("extract_nn_prior/" + name)
416+
prior_file = prior_job.out_prior_xml_file
417+
else:
418+
tmp_acoustic_mixture_path = acoustic_mixture_path
392419
scorer = rasr.PrecomputedHybridFeatureScorer(
393-
prior_mixtures=acoustic_mixture_path,
420+
prior_mixtures=tmp_acoustic_mixture_path, # This needs to be a new variable otherwise nesting causes undesired behavior
394421
priori_scale=prior,
422+
prior_file=prior_file,
395423
)
396424

425+
if use_epoch_for_compile:
426+
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)
427+
397428
tf_flow = make_precomputed_hybrid_tf_feature_flow(
398429
tf_checkpoint=checkpoints[epoch],
399430
tf_graph=tf_graph,
@@ -419,6 +450,8 @@ def nn_recognition(
419450
parallelize_conversion=parallelize_conversion,
420451
rtf=rtf,
421452
mem=mem,
453+
lmgc_alias=f"lmgc/{name}/{recognition_corpus_key}-{recog_name}",
454+
lmgc_scorer=lmgc_scorer,
422455
**kwargs,
423456
)
424457

@@ -429,14 +462,21 @@ def nn_recog(
429462
returnn_config: Path,
430463
checkpoints: Dict[int, returnn.Checkpoint],
431464
step_args: HybridArgs,
465+
train_job: Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob],
432466
):
433467
for recog_name, recog_args in step_args.recognition_args.items():
468+
recog_args = copy.deepcopy(recog_args)
469+
whitelist = recog_args.pop("training_whitelist", None)
470+
if whitelist:
471+
if train_name not in whitelist:
472+
continue
434473
for dev_c in self.dev_corpora:
435474
self.nn_recognition(
436475
name=f"{train_corpus_key}-{train_name}-{recog_name}",
437476
returnn_config=returnn_config,
438477
checkpoints=checkpoints,
439478
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
479+
train_job=train_job,
440480
recognition_corpus_key=dev_c,
441481
**recog_args,
442482
)
@@ -452,6 +492,7 @@ def nn_recog(
452492
returnn_config=returnn_config,
453493
checkpoints=checkpoints,
454494
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
495+
train_job=train_job,
455496
recognition_corpus_key=tst_c,
456497
**r_args,
457498
)
@@ -472,7 +513,7 @@ def nn_compile_graph(
472513
:return: the TF graph
473514
"""
474515
graph_compile_job = returnn.CompileTFGraphJob(
475-
returnn_config,
516+
returnn_config=returnn_config,
476517
epoch=epoch,
477518
returnn_root=self.returnn_root,
478519
returnn_python_exe=self.returnn_python_exe,
@@ -509,7 +550,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
509550
train_corpus_key=trn_c,
510551
cv_corpus_key=cv_c,
511552
)
512-
else:
553+
elif isinstance(self.train_input_data[trn_c], AllowedReturnnTrainingDataInput):
513554
returnn_train_job = self.returnn_training(
514555
name=name,
515556
returnn_config=step_args.returnn_training_configs[name],
@@ -518,6 +559,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
518559
cv_corpus_key=cv_c,
519560
devtrain_corpus_key=dvtr_c,
520561
)
562+
else:
563+
raise NotImplementedError
521564

522565
returnn_recog_config = step_args.returnn_recognition_configs.get(
523566
name, step_args.returnn_training_configs[name]
@@ -529,6 +572,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
529572
returnn_config=returnn_recog_config,
530573
checkpoints=returnn_train_job.out_checkpoints,
531574
step_args=step_args,
575+
train_job=returnn_train_job,
532576
)
533577

534578
def run_nn_recog_step(self, step_args: NnRecogArgs):

common/setups/rasr/nn_system.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,19 @@
1-
__all__ = ["NnSystem"]
1+
__all__ = ["NnSystem", "returnn_training"]
22

33
import copy
4-
import itertools
5-
import sys
64
from dataclasses import asdict
7-
from typing import Dict, List, Optional, Tuple, Union
5+
from typing import Dict, List, Optional, Union
86

97
# -------------------- Sisyphus --------------------
108

11-
import sisyphus.toolkit as tk
12-
import sisyphus.global_settings as gs
13-
14-
from sisyphus.delayed_ops import DelayedFormat
9+
from sisyphus import tk, gs
1510

1611
# -------------------- Recipes --------------------
1712

18-
import i6_core.features as features
19-
import i6_core.rasr as rasr
2013
import i6_core.returnn as returnn
2114

22-
from i6_core.util import MultiPath, MultiOutputPath
23-
2415
from .rasr_system import RasrSystem
25-
26-
from .util import (
27-
RasrInitArgs,
28-
ReturnnRasrDataInput,
29-
OggZipHdfDataInput,
30-
HybridArgs,
31-
NnRecogArgs,
32-
RasrSteps,
33-
)
16+
from .util import ReturnnTrainingJobArgs, AllowedReturnnTrainingDataInput
3417

3518
# -------------------- Init --------------------
3619

@@ -95,3 +78,35 @@ def get_native_ops(self, op_names: Optional[List[str]]) -> Optional[List[tk.Path
9578
if op_name not in self.native_ops.keys():
9679
self.compile_native_op(op_name)
9780
return [self.native_ops[op_name] for op_name in op_names]
81+
82+
83+
def returnn_training(
84+
name: str,
85+
returnn_config: returnn.ReturnnConfig,
86+
training_args: Union[Dict, ReturnnTrainingJobArgs],
87+
train_data: AllowedReturnnTrainingDataInput,
88+
*,
89+
cv_data: Optional[AllowedReturnnTrainingDataInput] = None,
90+
additional_data: Optional[Dict[str, AllowedReturnnTrainingDataInput]] = None,
91+
register_output: bool = True,
92+
) -> returnn.ReturnnTrainingJob:
93+
assert isinstance(returnn_config, returnn.ReturnnConfig)
94+
95+
config = copy.deepcopy(returnn_config)
96+
97+
config.config["train"] = train_data if isinstance(train_data, Dict) else train_data.get_data_dict()
98+
if cv_data is not None:
99+
config.config["dev"] = cv_data if isinstance(cv_data, Dict) else cv_data.get_data_dict()
100+
if additional_data is not None:
101+
config.config["eval_datasets"] = {}
102+
for name, data in additional_data.items():
103+
config.config["eval_datasets"][name] = data if isinstance(data, Dict) else data.get_data_dict()
104+
returnn_training_job = returnn.ReturnnTrainingJob(
105+
returnn_config=config,
106+
**asdict(training_args) if isinstance(training_args, ReturnnTrainingJobArgs) else training_args,
107+
)
108+
if register_output:
109+
returnn_training_job.add_alias(f"nn_train/{name}")
110+
tk.register_output(f"nn_train/{name}_learning_rates.png", returnn_training_job.out_plot_lr)
111+
112+
return returnn_training_job

0 commit comments

Comments
 (0)