Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 80 additions & 36 deletions common/setups/rasr/hybrid_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@
add_tf_flow_to_base_flow,
)
from i6_core.util import MultiPath, MultiOutputPath
from i6_core.mm import CreateDummyMixturesJob
from i6_core.returnn import ReturnnComputePriorJobV2

from .nn_system import NnSystem
from .hybrid_decoder import HybridDecoder
from .nn_system import NnSystem, returnn_training

from .util import (
RasrInitArgs,
ReturnnRasrDataInput,
OggZipHdfDataInput,
HybridArgs,
NnRecogArgs,
RasrSteps,
NnForcedAlignArgs,
ReturnnTrainingJobArgs,
AllowedReturnnTrainingDataInput,
)

# -------------------- Init --------------------
Expand Down Expand Up @@ -90,11 +94,13 @@ def __init__(
self.cv_corpora = []
self.devtrain_corpora = []

self.train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
self.cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
self.devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
self.dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
self.test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
self.train_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None
self.cv_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None
self.devtrain_input_data: Optional[
Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]
] = None
self.dev_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None
self.test_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None

self.train_cv_pairing = None

Expand Down Expand Up @@ -128,9 +134,9 @@ def _add_output_alias_for_train_job(
def init_system(
self,
rasr_init_args: RasrInitArgs,
train_data: Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]],
cv_data: Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]],
devtrain_data: Optional[Dict[str, Union[ReturnnRasrDataInput, OggZipHdfDataInput]]] = None,
train_data: Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]],
cv_data: Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]],
devtrain_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None,
dev_data: Optional[Dict[str, ReturnnRasrDataInput]] = None,
test_data: Optional[Dict[str, ReturnnRasrDataInput]] = None,
train_cv_pairing: Optional[List[Tuple[str, ...]]] = None, # List[Tuple[trn_c, cv_c, name, dvtr_c]]
Expand Down Expand Up @@ -211,27 +217,29 @@ def generate_lattices(self):

def returnn_training(
self,
name,
returnn_config,
nn_train_args,
name: str,
returnn_config: returnn.ReturnnConfig,
nn_train_args: Union[Dict, ReturnnTrainingJobArgs],
train_corpus_key,
cv_corpus_key,
devtrain_corpus_key=None,
):
assert isinstance(returnn_config, returnn.ReturnnConfig)

returnn_config.config["train"] = self.train_input_data[train_corpus_key].get_data_dict()
returnn_config.config["dev"] = self.cv_input_data[cv_corpus_key].get_data_dict()
if devtrain_corpus_key is not None:
returnn_config.config["eval_datasets"] = {
"devtrain": self.devtrain_input_data[devtrain_corpus_key].get_data_dict()
}

train_job = returnn.ReturnnTrainingJob(
) -> returnn.ReturnnTrainingJob:
if isinstance(nn_train_args, ReturnnTrainingJobArgs):
if nn_train_args.returnn_root is None:
nn_train_args.returnn_root = self.returnn_root
if nn_train_args.returnn_python_exe is None:
nn_train_args.returnn_python_exe = self.returnn_python_exe

train_job = returnn_training(
name=name,
returnn_config=returnn_config,
returnn_root=self.returnn_root,
returnn_python_exe=self.returnn_python_exe,
**nn_train_args,
training_args=nn_train_args,
train_data=self.train_input_data[train_corpus_key],
cv_data=self.cv_input_data[cv_corpus_key],
additional_data={"devtrain": self.devtrain_input_data[devtrain_corpus_key]}
if devtrain_corpus_key is not None
else None,
register_output=False,
)
self._add_output_alias_for_train_job(
train_job=train_job,
Expand Down Expand Up @@ -346,7 +354,9 @@ def nn_recognition(
name: str,
returnn_config: returnn.ReturnnConfig,
checkpoints: Dict[int, returnn.Checkpoint],
acoustic_mixture_path: tk.Path, # TODO maybe Optional if prior file provided -> automatically construct dummy file
acoustic_mixture_path: Optional[
tk.Path
], # TODO maybe Optional if prior file provided -> automatically construct dummy file
prior_scales: List[float],
pronunciation_scales: List[float],
lm_scales: List[float],
Expand All @@ -362,6 +372,7 @@ def nn_recognition(
use_epoch_for_compile=False,
forward_output_layer="output",
native_ops: Optional[List[str]] = None,
train_job: Optional[Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob]] = None,
**kwargs,
):
with tk.block(f"{name}_recognition"):
Expand All @@ -383,17 +394,37 @@ def nn_recognition(
epochs = epochs if epochs is not None else list(checkpoints.keys())

for pron, lm, prior, epoch in itertools.product(pronunciation_scales, lm_scales, prior_scales, epochs):
assert epoch in checkpoints.keys()
assert acoustic_mixture_path is not None

if use_epoch_for_compile:
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)

assert epoch in checkpoints.keys()
prior_file = None
lmgc_scorer = None
if acoustic_mixture_path is None:
assert train_job is not None, "Need ReturnnTrainingJob for computation of priors"
tmp_acoustic_mixture_path = CreateDummyMixturesJob(
num_mixtures=returnn_config.config["extern_data"]["classes"]["dim"],
num_features=returnn_config.config["extern_data"]["data"]["dim"],
).out_mixtures
lmgc_scorer = rasr.GMMFeatureScorer(tmp_acoustic_mixture_path)
prior_job = ReturnnComputePriorJobV2(
model_checkpoint=checkpoints[epoch],
returnn_config=train_job.returnn_config,
returnn_python_exe=train_job.returnn_python_exe,
returnn_root=train_job.returnn_root,
log_verbosity=train_job.returnn_config.post_config["log_verbosity"],
)
prior_job.add_alias("extract_nn_prior/" + name)
prior_file = prior_job.out_prior_xml_file
else:
tmp_acoustic_mixture_path = acoustic_mixture_path
scorer = rasr.PrecomputedHybridFeatureScorer(
prior_mixtures=acoustic_mixture_path,
prior_mixtures=tmp_acoustic_mixture_path, # This needs to be a new variable otherwise nesting causes undesired behavior
priori_scale=prior,
prior_file=prior_file,
)

if use_epoch_for_compile:
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)

tf_flow = make_precomputed_hybrid_tf_feature_flow(
tf_checkpoint=checkpoints[epoch],
tf_graph=tf_graph,
Expand All @@ -419,6 +450,8 @@ def nn_recognition(
parallelize_conversion=parallelize_conversion,
rtf=rtf,
mem=mem,
lmgc_alias=f"lmgc/{name}/{recognition_corpus_key}-{recog_name}",
lmgc_scorer=lmgc_scorer,
**kwargs,
)

Expand All @@ -429,14 +462,21 @@ def nn_recog(
returnn_config: Path,
checkpoints: Dict[int, returnn.Checkpoint],
step_args: HybridArgs,
train_job: Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob],
):
for recog_name, recog_args in step_args.recognition_args.items():
recog_args = copy.deepcopy(recog_args)
whitelist = recog_args.pop("training_whitelist", None)
if whitelist:
if train_name not in whitelist:
continue
for dev_c in self.dev_corpora:
self.nn_recognition(
name=f"{train_corpus_key}-{train_name}-{recog_name}",
returnn_config=returnn_config,
checkpoints=checkpoints,
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
train_job=train_job,
recognition_corpus_key=dev_c,
**recog_args,
)
Expand All @@ -452,6 +492,7 @@ def nn_recog(
returnn_config=returnn_config,
checkpoints=checkpoints,
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
train_job=train_job,
recognition_corpus_key=tst_c,
**r_args,
)
Expand All @@ -472,7 +513,7 @@ def nn_compile_graph(
:return: the TF graph
"""
graph_compile_job = returnn.CompileTFGraphJob(
returnn_config,
returnn_config=returnn_config,
epoch=epoch,
returnn_root=self.returnn_root,
returnn_python_exe=self.returnn_python_exe,
Expand Down Expand Up @@ -509,7 +550,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
train_corpus_key=trn_c,
cv_corpus_key=cv_c,
)
else:
elif isinstance(self.train_input_data[trn_c], AllowedReturnnTrainingDataInput):
returnn_train_job = self.returnn_training(
name=name,
returnn_config=step_args.returnn_training_configs[name],
Expand All @@ -518,6 +559,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
cv_corpus_key=cv_c,
devtrain_corpus_key=dvtr_c,
)
else:
raise NotImplementedError

returnn_recog_config = step_args.returnn_recognition_configs.get(
name, step_args.returnn_training_configs[name]
Expand All @@ -529,6 +572,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
returnn_config=returnn_recog_config,
checkpoints=returnn_train_job.out_checkpoints,
step_args=step_args,
train_job=returnn_train_job,
)

def run_nn_recog_step(self, step_args: NnRecogArgs):
Expand Down
57 changes: 36 additions & 21 deletions common/setups/rasr/nn_system.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,19 @@
__all__ = ["NnSystem"]
__all__ = ["NnSystem", "returnn_training"]

import copy
import itertools
import sys
from dataclasses import asdict
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

# -------------------- Sisyphus --------------------

import sisyphus.toolkit as tk
import sisyphus.global_settings as gs

from sisyphus.delayed_ops import DelayedFormat
from sisyphus import tk, gs

# -------------------- Recipes --------------------

import i6_core.features as features
import i6_core.rasr as rasr
import i6_core.returnn as returnn

from i6_core.util import MultiPath, MultiOutputPath

from .rasr_system import RasrSystem

from .util import (
RasrInitArgs,
ReturnnRasrDataInput,
OggZipHdfDataInput,
HybridArgs,
NnRecogArgs,
RasrSteps,
)
from .util import ReturnnTrainingJobArgs, AllowedReturnnTrainingDataInput

# -------------------- Init --------------------

Expand Down Expand Up @@ -95,3 +78,35 @@ def get_native_ops(self, op_names: Optional[List[str]]) -> Optional[List[tk.Path
if op_name not in self.native_ops.keys():
self.compile_native_op(op_name)
return [self.native_ops[op_name] for op_name in op_names]


def returnn_training(
name: str,
returnn_config: returnn.ReturnnConfig,
training_args: Union[Dict, ReturnnTrainingJobArgs],
train_data: AllowedReturnnTrainingDataInput,
*,
cv_data: Optional[AllowedReturnnTrainingDataInput] = None,
additional_data: Optional[Dict[str, AllowedReturnnTrainingDataInput]] = None,
register_output: bool = True,
) -> returnn.ReturnnTrainingJob:
assert isinstance(returnn_config, returnn.ReturnnConfig)

config = copy.deepcopy(returnn_config)

config.config["train"] = train_data if isinstance(train_data, Dict) else train_data.get_data_dict()
if cv_data is not None:
config.config["dev"] = cv_data if isinstance(cv_data, Dict) else cv_data.get_data_dict()
if additional_data is not None:
config.config["eval_datasets"] = {}
for name, data in additional_data.items():
config.config["eval_datasets"][name] = data if isinstance(data, Dict) else data.get_data_dict()
returnn_training_job = returnn.ReturnnTrainingJob(
returnn_config=config,
**asdict(training_args) if isinstance(training_args, ReturnnTrainingJobArgs) else training_args,
)
if register_output:
returnn_training_job.add_alias(f"nn_train/{name}")
tk.register_output(f"nn_train/{name}_learning_rates.png", returnn_training_job.out_plot_lr)

return returnn_training_job
Loading