From 59edbeb0fad2a8562b8d76501413edc9c3d14886 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Wed, 1 Feb 2023 17:48:35 +0100 Subject: [PATCH 01/47] add SMS-WSJ RETURNN datasets --- common/datasets/sms_wsj/returnn_datasets.py | 411 ++++++++++++++++++++ 1 file changed, 411 insertions(+) create mode 100644 common/datasets/sms_wsj/returnn_datasets.py diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py new file mode 100644 index 000000000..9c50a4670 --- /dev/null +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -0,0 +1,411 @@ +""" +Dataset wrapper to allow using the SMS-WSJ dataset in RETURNN. +""" + +import functools +import json +import numpy as np +import os.path +import subprocess as sp +from typing import Dict, Tuple, Any, Optional +# noinspection PyUnresolvedReferences +from returnn.datasets.basic import DatasetSeq +# noinspection PyUnresolvedReferences +from returnn.datasets.hdf import HDFDataset +# noinspection PyUnresolvedReferences +from returnn.datasets.map import MapDatasetBase, MapDatasetWrapper +# noinspection PyUnresolvedReferences +from returnn.log import log +# noinspection PyUnresolvedReferences +from returnn.util.basic import OptionalNotImplementedError, NumbersDict +# noinspection PyUnresolvedReferences +from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn + + +class SmsWsjBase(MapDatasetBase): + """ + Base class to wrap the SMS-WSJ dataset. This is not the dataset that is used in the RETURNN config, see + ``SmsWsjWrapper`` and derived classes for that. + """ + + def __init__( + self, dataset_name, json_path, pre_batch_transform, buffer=True, zip_cache=None, scenario_map_args=None, **kwargs + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param function pre_batch_transform: function which processes raw SMS-WSJ data + :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn + """ + super(SmsWsjBase, self).__init__(**kwargs) + + if zip_cache is not None: + self._cache_zipped_audio(zip_cache, json_path, dataset_name) + + db = SmsWsj(json_path=json_path) + ds = db.get_dataset(dataset_name) + ds = ds.map(AudioReader(("original_source", "rir"))) + + scenario_map_args = { + "add_speech_image": False, + "add_speech_reverberation_early": False, + "add_speech_reverberation_tail": False, + "add_noise_image": False, + **(scenario_map_args or {})} + ds = ds.map(functools.partial(scenario_map_fn, **scenario_map_args)) + ds = ds.map(functools.partial(pre_batch_transform)) + + self._ds = ds + self._ds_iterator = iter(self._ds) + + self._use_buffer = buffer + if self._use_buffer: + self._ds = self._ds.prefetch(4, 8).copy(freeze=True) + self._buffer = {} # type Dict[int,[Dict[str,np.array]]] + self._buffer_size = 40 + + def __len__(self) -> int: + return len(self._ds) + + def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: + return self._get_seq_by_idx(seq_idx) + + def _get_seq_by_idx(self, seq_idx: int) -> Dict[str, np.array]: + """ + Returns data for sequence index. + """ + if self._use_buffer: + assert seq_idx in self._buffer, f"seq_idx {seq_idx} not in buffer. Available keys are {self._buffer.keys()}" + return self._buffer[seq_idx] + else: + return self._ds[seq_idx] + + def get_seq_tag(self, seq_idx: int) -> str: + """ + Returns tag for the sequence of the given index, default is 'seq-{seq_idx}'. + """ + if "seq_tag" in self._get_seq_by_idx(seq_idx): + return str(self._get_seq_by_idx(seq_idx)["seq_tag"]) + else: + return "seq-%i" % seq_idx + + def get_seq_len(self, seq_idx: int) -> int: + """ + Returns length of the sequence of the given index + """ + if "seq_len" in self._get_seq_by_idx(seq_idx): + return int(self._get_seq_by_idx(seq_idx)["seq_len"]) + else: + raise OptionalNotImplementedError + + def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: + """ + Returns sequence length for all data/target keys. + """ + data = self[seq_idx] + d = {k: v.size for k, v in data.items()} + for update_key in ["data", "target_signals"]: + if update_key in d.keys() and "seq_len" in data: + d[update_key] = int(data["seq_len"]) + return NumbersDict(d) + + def update_buffer(self, seq_idx: int, pop_seqs: bool = True): + """ + :param int seq_idx: + :param bool pop_seqs: if True, pop sequences from buffer that are outside buffer range + """ + if not self._use_buffer: + return + + # debugging information + keys = list(self._buffer.keys()) or [0] + if not (min(keys) <= seq_idx <= max(keys)): + print(f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}") + + # add sequences + for idx in range(seq_idx, min(seq_idx + self._buffer_size // 2, len(self))): + if idx not in self._buffer and idx < len(self): + try: + self._buffer[idx] = next(self._ds_iterator) + except StopIteration: + print(f"StopIteration for seq_idx {seq_idx}") + print(f"Dataset: {self} with SMS-WSJ {self._ds} of len {len(self)}") + print(f"Indices in buffer: {self._buffer.keys()}") + # raise + print(f"WARNING: Ignoring this, reset iterator and continue") + self._ds_iterator = iter(self._ds) + self._buffer[idx] = next(self._ds_iterator) + if idx == len(self) - 1 and 0 not in self._buffer: + print(f"Reached end of dataset, reset iterator") + try: + next(self._ds_iterator) + except StopIteration: + pass + else: + print( + "WARNING: reached final index of dataset, but iterator has more sequences. " + "Maybe the training was restarted from an epoch > 1?") + print(f"Current buffer indices: {self._buffer.keys()}") + self._ds_iterator = iter(self._ds) + for idx_ in range(self._buffer_size // 2): + if idx_ not in self._buffer and idx_ < len(self): + self._buffer[idx_] = next(self._ds_iterator) + print(f"After adding start of dataset to buffer indices: {self._buffer.keys()}") + + # remove sequences + if pop_seqs: + for idx in list(self._buffer): + if not (seq_idx - self._buffer_size // 2 <= idx <= seq_idx + self._buffer_size // 2): + if max(self._buffer.keys()) == len(self) - 1 and idx < self._buffer_size // 2: + # newly added sequences starting from 0 + continue + self._buffer.pop(idx) + + @staticmethod + def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): + """ + Caches and unzips a given archive with SMS-WSJ data which will then be used as data dir. + This is done because caching of the single files takes extremely long. + """ + print(f"Cache and unzip SMS-WSJ data from {zip_cache}") + + # cache and unzip + try: + zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") + assert zip_cache_cached != zip_cache, "cached and original file have the same path" + local_unzipped_dir = os.path.dirname(zip_cache_cached) + sp.check_call(["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir]) + except sp.CalledProcessError: + print(f"Cache manager: Error occurred when caching and unzipping {zip_cache}") + raise + + # modify json and check if all data is available + with open(json_path, "r") as f: + json_dict = json.loads(f.read()) + original_dir = next(iter(json_dict["datasets"][dataset_name].values()))["audio_path"]["original_source"][0] + while not original_dir.endswith(os.path.basename(local_unzipped_dir)) and len(original_dir) > 1: + original_dir = os.path.dirname(original_dir) + for seq in json_dict["datasets"][dataset_name]: + for audio_key in ["original_source", "rir"]: + for seq_idx in range(len(json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key])): + path = json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][seq_idx] + path = path.replace(original_dir, local_unzipped_dir) + json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][seq_idx] = path + assert path.startswith(local_unzipped_dir), ( + f"Audio file {path} was expected to start with {local_unzipped_dir}") + assert os.path.exists(path), f"Audio file {path} does not exist" + + json_path = os.path.join(local_unzipped_dir, "sms_wsj.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_dict, f, ensure_ascii=False, indent=4) + + print(f"Finished preparation of zip cache data, use json in {json_path}") + + +class SmsWsjBaseWithRasrClasses(SmsWsjBase): + """ + Base class to wrap the SMS-WSJ dataset and combine it with RASR alignments in an hdf dataset. + """ + + def __init__( + self, rasr_classes_hdf=None, rasr_corpus=None, rasr_segment_prefix="", rasr_segment_postfix="", **kwargs + ): + """ + :param Optional[str] rasr_classes_hdf: hdf file with dumped RASR class labels + :param Optional[str] rasr_corpus: RASR corpus file for reading segment start and end times for padding + :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name + :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param kwargs: + """ + super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) + + # self.data_types = {"target_signals": {"dim": 2, "shape": (None, 2)}} + self._rasr_classes_hdf = None + if rasr_classes_hdf is not None: + self._rasr_classes_hdf = HDFDataset([rasr_classes_hdf], use_cache_manager=True) + # self.data_types["target_rasr"] = {"sparse": True, "dim": 9001, "shape": (None, 2)} + self._rasr_segment_start_end = {} # type: Dict[str, Tuple[float, float]] + if rasr_corpus is not None: + from i6_core.lib.corpus import Corpus + corpus = Corpus() + corpus.load(rasr_corpus) + for seg in corpus.segments(): + self._rasr_segment_start_end[seg.fullname()] = (seg.start, seg.end) + self.rasr_segment_prefix = rasr_segment_prefix + self.rasr_segment_postfix = rasr_segment_postfix + + def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: + d = self._get_seq_by_idx(seq_idx) + if self._rasr_classes_hdf is not None: + rasr_seq_tags = [ + f"{self.rasr_segment_prefix}{d['seq_tag']}_{speaker}{self.rasr_segment_postfix}" + for speaker in range(d["target_signals"].shape[1])] + rasr_targets = [] + for idx, rasr_seq_tag in enumerate(rasr_seq_tags): + rasr_targets.append(self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, "classes")) + start_end_times = [self._rasr_segment_start_end.get(rasr_seq_tag, (0, 0)) for rasr_seq_tag in rasr_seq_tags] + padded_len_sec = max(start_end[1] for start_end in start_end_times) + padded_len_frames = max(rasr_target.shape[0] for rasr_target in rasr_targets) + for speaker_idx in range(len(rasr_targets)): + pad_start = 0 + if padded_len_sec > 0: + pad_start = round(start_end_times[speaker_idx][0] / padded_len_sec * padded_len_frames) + pad_end = padded_len_frames - rasr_targets[speaker_idx].shape[0] - pad_start + if pad_end < 0: + pad_start += pad_end + assert pad_start >= 0 + pad_end = 0 + rasr_targets[speaker_idx] = np.concatenate([ + 9000 * np.ones(pad_start), rasr_targets[speaker_idx], 9000 * np.ones(pad_end)]) + d["target_rasr"] = np.stack(rasr_targets).T + d["target_rasr_len"] = np.array(padded_len_frames) + return d + + def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: + """ + Returns sequence length for all data/target keys. + """ + d = super(SmsWsjBaseWithRasrClasses, self).get_seq_length_for_keys(seq_idx) + data = self[seq_idx] + d["target_rasr"] = int(data["target_rasr_len"]) + return NumbersDict(d) + + +class SmsWsjWrapper(MapDatasetWrapper): + """ + Base class for datasets that can be used in RETURNN config. + """ + + def __init__( + self, sms_wsj_base, **kwargs + ): + """ + :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + """ + if "seq_ordering" not in kwargs: + print("Warning: no shuffling is enabled by default", file=log.v) + super(SmsWsjWrapper, self).__init__(sms_wsj_base, **kwargs) + # self.num_outputs = ... # needs to be set in derived classes + + def _get_seq_length(seq_idx: int) -> NumbersDict: + """ + Returns sequence length for all data/target keys. + """ + corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) + return sms_wsj_base.get_seq_length_for_keys(corpus_seq_idx) + + self.get_seq_length = _get_seq_length + + @staticmethod + def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: + """ + Used to process raw SMS-WSJ data + :param inputs: input as coming from SMS-WSJ + """ + return_dict = { + "seq_tag": np.array(inputs["example_id"], dtype=object), + "source_id": np.array(inputs["source_id"], dtype=object), + "seq_len": np.array(inputs["num_samples"]["observation"]), + } + return return_dict + + def _collect_single_seq(self, seq_idx: int) -> DatasetSeq: + """ + :param seq_idx: sorted seq idx + """ + corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) + self._dataset.update_buffer(corpus_seq_idx) + data = self._dataset[corpus_seq_idx] + assert "seq_tag" in data + return DatasetSeq(seq_idx, features=data, seq_tag=data["seq_tag"]) + + def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: + """ + Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, therefore + the buffer does not contain the initial indices when continuing the training from an epoch > 0. + """ + out = super(SmsWsjWrapper, self).init_seq_order(epoch=epoch, **kwargs) + buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) + self._dataset.update_buffer(buffer_index, pop_seqs=False) + return out + + +class SmsWsjMixtureEarlyDataset(SmsWsjWrapper): + """ + Dataset with audio mixture and early signals as target. + """ + + def __init__( + self, dataset_name, json_path, num_outputs=None, zip_cache=None, sms_wsj_base=None, **kwargs + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + """ + if sms_wsj_base is None: + sms_wsj_base = SmsWsjBase( + dataset_name=dataset_name, + json_path=json_path, + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + zip_cache=zip_cache) + super(SmsWsjMixtureEarlyDataset, self).__init__(sms_wsj_base, **kwargs) + # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense + self.num_outputs = num_outputs or {"data": [1, 2], "target_signals": [2, 2]} + + @staticmethod + def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: + """ + Used to process raw SMS-WSJ data + :param inputs: input as coming from SMS-WSJ + """ + return_dict = SmsWsjWrapper._pre_batch_transform(inputs) + return_dict.update({ + "data": inputs["audio_data"]["observation"][:1, :].T, # take first of 6 channels: (T, 1) + "target_signals": inputs["audio_data"]["speech_reverberation_early"][:, 0, :].T, # first of 6 channels: (T, S) + }) + return return_dict + + +class SmsWsjMixtureEarlyAlignmentDataset(SmsWsjMixtureEarlyDataset): + """ + Dataset with audio mixture, target early signals and target RASR alignments. + """ + + def __init__( + self, dataset_name, json_path, num_outputs=None, rasr_num_outputs=None, rasr_segment_prefix="", + rasr_segment_postfix="", rasr_classes_hdf=None, rasr_corpus=None, zip_cache=None, **kwargs + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset + :param Optional[int] rasr_num_outputs: number of output labels for RASR alignment, e.g. 9001 for that CART size + :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name + :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param str rasr_classes_hdf: hdf file with dumped RASR class labels + :param str rasr_corpus: RASR corpus file for reading segment start and end times for padding + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + """ + sms_wsj_base = SmsWsjBaseWithRasrClasses( + dataset_name=dataset_name, + json_path=json_path, + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + rasr_classes_hdf=rasr_classes_hdf, + rasr_corpus=rasr_corpus, + rasr_segment_prefix=rasr_segment_prefix, + rasr_segment_postfix=rasr_segment_postfix, + zip_cache=zip_cache) + super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( + dataset_name, json_path, num_outputs=num_outputs, zip_cache=zip_cache, sms_wsj_base=sms_wsj_base, **kwargs) + if num_outputs is not None: + self.num_outputs = num_outputs + else: + assert rasr_num_outputs is not None, "either num_outputs or rasr_num_outputs has to be given" + self.num_outputs["target_rasr"] = [rasr_num_outputs, 1] # target alignments are sparse with the given dim From d95f3357b816901c362adfd7f605544dc78496d0 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 2 Feb 2023 15:37:06 +0100 Subject: [PATCH 02/47] black formatting --- common/datasets/sms_wsj/returnn_datasets.py | 843 +++++++++++--------- 1 file changed, 476 insertions(+), 367 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 9c50a4670..b6701e443 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -8,404 +8,513 @@ import os.path import subprocess as sp from typing import Dict, Tuple, Any, Optional + # noinspection PyUnresolvedReferences from returnn.datasets.basic import DatasetSeq + # noinspection PyUnresolvedReferences from returnn.datasets.hdf import HDFDataset + # noinspection PyUnresolvedReferences from returnn.datasets.map import MapDatasetBase, MapDatasetWrapper + # noinspection PyUnresolvedReferences from returnn.log import log + # noinspection PyUnresolvedReferences from returnn.util.basic import OptionalNotImplementedError, NumbersDict + # noinspection PyUnresolvedReferences from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn class SmsWsjBase(MapDatasetBase): - """ - Base class to wrap the SMS-WSJ dataset. This is not the dataset that is used in the RETURNN config, see - ``SmsWsjWrapper`` and derived classes for that. - """ - - def __init__( - self, dataset_name, json_path, pre_batch_transform, buffer=True, zip_cache=None, scenario_map_args=None, **kwargs - ): """ - :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" - :param str json_path: path to SMS-WSJ json file - :param function pre_batch_transform: function which processes raw SMS-WSJ data - :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir - :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn - """ - super(SmsWsjBase, self).__init__(**kwargs) - - if zip_cache is not None: - self._cache_zipped_audio(zip_cache, json_path, dataset_name) - - db = SmsWsj(json_path=json_path) - ds = db.get_dataset(dataset_name) - ds = ds.map(AudioReader(("original_source", "rir"))) - - scenario_map_args = { - "add_speech_image": False, - "add_speech_reverberation_early": False, - "add_speech_reverberation_tail": False, - "add_noise_image": False, - **(scenario_map_args or {})} - ds = ds.map(functools.partial(scenario_map_fn, **scenario_map_args)) - ds = ds.map(functools.partial(pre_batch_transform)) - - self._ds = ds - self._ds_iterator = iter(self._ds) - - self._use_buffer = buffer - if self._use_buffer: - self._ds = self._ds.prefetch(4, 8).copy(freeze=True) - self._buffer = {} # type Dict[int,[Dict[str,np.array]]] - self._buffer_size = 40 - - def __len__(self) -> int: - return len(self._ds) - - def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: - return self._get_seq_by_idx(seq_idx) - - def _get_seq_by_idx(self, seq_idx: int) -> Dict[str, np.array]: - """ - Returns data for sequence index. - """ - if self._use_buffer: - assert seq_idx in self._buffer, f"seq_idx {seq_idx} not in buffer. Available keys are {self._buffer.keys()}" - return self._buffer[seq_idx] - else: - return self._ds[seq_idx] - - def get_seq_tag(self, seq_idx: int) -> str: - """ - Returns tag for the sequence of the given index, default is 'seq-{seq_idx}'. - """ - if "seq_tag" in self._get_seq_by_idx(seq_idx): - return str(self._get_seq_by_idx(seq_idx)["seq_tag"]) - else: - return "seq-%i" % seq_idx - - def get_seq_len(self, seq_idx: int) -> int: - """ - Returns length of the sequence of the given index - """ - if "seq_len" in self._get_seq_by_idx(seq_idx): - return int(self._get_seq_by_idx(seq_idx)["seq_len"]) - else: - raise OptionalNotImplementedError - - def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: - """ - Returns sequence length for all data/target keys. - """ - data = self[seq_idx] - d = {k: v.size for k, v in data.items()} - for update_key in ["data", "target_signals"]: - if update_key in d.keys() and "seq_len" in data: - d[update_key] = int(data["seq_len"]) - return NumbersDict(d) - - def update_buffer(self, seq_idx: int, pop_seqs: bool = True): - """ - :param int seq_idx: - :param bool pop_seqs: if True, pop sequences from buffer that are outside buffer range - """ - if not self._use_buffer: - return - - # debugging information - keys = list(self._buffer.keys()) or [0] - if not (min(keys) <= seq_idx <= max(keys)): - print(f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}") + Base class to wrap the SMS-WSJ dataset. This is not the dataset that is used in the RETURNN config, see + ``SmsWsjWrapper`` and derived classes for that. + """ + + def __init__( + self, + dataset_name, + json_path, + pre_batch_transform, + buffer=True, + zip_cache=None, + scenario_map_args=None, + **kwargs, + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param function pre_batch_transform: function which processes raw SMS-WSJ data + :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn + """ + super(SmsWsjBase, self).__init__(**kwargs) + + if zip_cache is not None: + self._cache_zipped_audio(zip_cache, json_path, dataset_name) + + db = SmsWsj(json_path=json_path) + ds = db.get_dataset(dataset_name) + ds = ds.map(AudioReader(("original_source", "rir"))) + + scenario_map_args = { + "add_speech_image": False, + "add_speech_reverberation_early": False, + "add_speech_reverberation_tail": False, + "add_noise_image": False, + **(scenario_map_args or {}), + } + ds = ds.map(functools.partial(scenario_map_fn, **scenario_map_args)) + ds = ds.map(functools.partial(pre_batch_transform)) + + self._ds = ds + self._ds_iterator = iter(self._ds) - # add sequences - for idx in range(seq_idx, min(seq_idx + self._buffer_size // 2, len(self))): - if idx not in self._buffer and idx < len(self): - try: - self._buffer[idx] = next(self._ds_iterator) - except StopIteration: - print(f"StopIteration for seq_idx {seq_idx}") - print(f"Dataset: {self} with SMS-WSJ {self._ds} of len {len(self)}") - print(f"Indices in buffer: {self._buffer.keys()}") - # raise - print(f"WARNING: Ignoring this, reset iterator and continue") - self._ds_iterator = iter(self._ds) - self._buffer[idx] = next(self._ds_iterator) - if idx == len(self) - 1 and 0 not in self._buffer: - print(f"Reached end of dataset, reset iterator") - try: - next(self._ds_iterator) - except StopIteration: - pass + self._use_buffer = buffer + if self._use_buffer: + self._ds = self._ds.prefetch(4, 8).copy(freeze=True) + self._buffer = {} # type Dict[int,[Dict[str,np.array]]] + self._buffer_size = 40 + + def __len__(self) -> int: + return len(self._ds) + + def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: + return self._get_seq_by_idx(seq_idx) + + def _get_seq_by_idx(self, seq_idx: int) -> Dict[str, np.array]: + """ + Returns data for sequence index. + """ + if self._use_buffer: + assert ( + seq_idx in self._buffer + ), f"seq_idx {seq_idx} not in buffer. Available keys are {self._buffer.keys()}" + return self._buffer[seq_idx] else: - print( - "WARNING: reached final index of dataset, but iterator has more sequences. " - "Maybe the training was restarted from an epoch > 1?") - print(f"Current buffer indices: {self._buffer.keys()}") - self._ds_iterator = iter(self._ds) - for idx_ in range(self._buffer_size // 2): - if idx_ not in self._buffer and idx_ < len(self): - self._buffer[idx_] = next(self._ds_iterator) - print(f"After adding start of dataset to buffer indices: {self._buffer.keys()}") - - # remove sequences - if pop_seqs: - for idx in list(self._buffer): - if not (seq_idx - self._buffer_size // 2 <= idx <= seq_idx + self._buffer_size // 2): - if max(self._buffer.keys()) == len(self) - 1 and idx < self._buffer_size // 2: - # newly added sequences starting from 0 - continue - self._buffer.pop(idx) - - @staticmethod - def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): - """ - Caches and unzips a given archive with SMS-WSJ data which will then be used as data dir. - This is done because caching of the single files takes extremely long. - """ - print(f"Cache and unzip SMS-WSJ data from {zip_cache}") - - # cache and unzip - try: - zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") - assert zip_cache_cached != zip_cache, "cached and original file have the same path" - local_unzipped_dir = os.path.dirname(zip_cache_cached) - sp.check_call(["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir]) - except sp.CalledProcessError: - print(f"Cache manager: Error occurred when caching and unzipping {zip_cache}") - raise - - # modify json and check if all data is available - with open(json_path, "r") as f: - json_dict = json.loads(f.read()) - original_dir = next(iter(json_dict["datasets"][dataset_name].values()))["audio_path"]["original_source"][0] - while not original_dir.endswith(os.path.basename(local_unzipped_dir)) and len(original_dir) > 1: - original_dir = os.path.dirname(original_dir) - for seq in json_dict["datasets"][dataset_name]: - for audio_key in ["original_source", "rir"]: - for seq_idx in range(len(json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key])): - path = json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][seq_idx] - path = path.replace(original_dir, local_unzipped_dir) - json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][seq_idx] = path - assert path.startswith(local_unzipped_dir), ( - f"Audio file {path} was expected to start with {local_unzipped_dir}") - assert os.path.exists(path), f"Audio file {path} does not exist" - - json_path = os.path.join(local_unzipped_dir, "sms_wsj.json") - with open(json_path, "w", encoding="utf-8") as f: - json.dump(json_dict, f, ensure_ascii=False, indent=4) - - print(f"Finished preparation of zip cache data, use json in {json_path}") + return self._ds[seq_idx] + + def get_seq_tag(self, seq_idx: int) -> str: + """ + Returns tag for the sequence of the given index, default is 'seq-{seq_idx}'. + """ + if "seq_tag" in self._get_seq_by_idx(seq_idx): + return str(self._get_seq_by_idx(seq_idx)["seq_tag"]) + else: + return "seq-%i" % seq_idx + + def get_seq_len(self, seq_idx: int) -> int: + """ + Returns length of the sequence of the given index + """ + if "seq_len" in self._get_seq_by_idx(seq_idx): + return int(self._get_seq_by_idx(seq_idx)["seq_len"]) + else: + raise OptionalNotImplementedError + + def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: + """ + Returns sequence length for all data/target keys. + """ + data = self[seq_idx] + d = {k: v.size for k, v in data.items()} + for update_key in ["data", "target_signals"]: + if update_key in d.keys() and "seq_len" in data: + d[update_key] = int(data["seq_len"]) + return NumbersDict(d) + + def update_buffer(self, seq_idx: int, pop_seqs: bool = True): + """ + :param int seq_idx: + :param bool pop_seqs: if True, pop sequences from buffer that are outside buffer range + """ + if not self._use_buffer: + return + + # debugging information + keys = list(self._buffer.keys()) or [0] + if not (min(keys) <= seq_idx <= max(keys)): + print( + f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}" + ) + + # add sequences + for idx in range(seq_idx, min(seq_idx + self._buffer_size // 2, len(self))): + if idx not in self._buffer and idx < len(self): + try: + self._buffer[idx] = next(self._ds_iterator) + except StopIteration: + print(f"StopIteration for seq_idx {seq_idx}") + print(f"Dataset: {self} with SMS-WSJ {self._ds} of len {len(self)}") + print(f"Indices in buffer: {self._buffer.keys()}") + # raise + print(f"WARNING: Ignoring this, reset iterator and continue") + self._ds_iterator = iter(self._ds) + self._buffer[idx] = next(self._ds_iterator) + if idx == len(self) - 1 and 0 not in self._buffer: + print(f"Reached end of dataset, reset iterator") + try: + next(self._ds_iterator) + except StopIteration: + pass + else: + print( + "WARNING: reached final index of dataset, but iterator has more sequences. " + "Maybe the training was restarted from an epoch > 1?" + ) + print(f"Current buffer indices: {self._buffer.keys()}") + self._ds_iterator = iter(self._ds) + for idx_ in range(self._buffer_size // 2): + if idx_ not in self._buffer and idx_ < len(self): + self._buffer[idx_] = next(self._ds_iterator) + print( + f"After adding start of dataset to buffer indices: {self._buffer.keys()}" + ) + + # remove sequences + if pop_seqs: + for idx in list(self._buffer): + if not ( + seq_idx - self._buffer_size // 2 + <= idx + <= seq_idx + self._buffer_size // 2 + ): + if ( + max(self._buffer.keys()) == len(self) - 1 + and idx < self._buffer_size // 2 + ): + # newly added sequences starting from 0 + continue + self._buffer.pop(idx) + + @staticmethod + def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): + """ + Caches and unzips a given archive with SMS-WSJ data which will then be used as data dir. + This is done because caching of the single files takes extremely long. + """ + print(f"Cache and unzip SMS-WSJ data from {zip_cache}") + + # cache and unzip + try: + zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") + assert ( + zip_cache_cached != zip_cache + ), "cached and original file have the same path" + local_unzipped_dir = os.path.dirname(zip_cache_cached) + sp.check_call( + ["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir] + ) + except sp.CalledProcessError: + print( + f"Cache manager: Error occurred when caching and unzipping {zip_cache}" + ) + raise + + # modify json and check if all data is available + with open(json_path, "r") as f: + json_dict = json.loads(f.read()) + original_dir = next(iter(json_dict["datasets"][dataset_name].values()))[ + "audio_path" + ]["original_source"][0] + while ( + not original_dir.endswith(os.path.basename(local_unzipped_dir)) + and len(original_dir) > 1 + ): + original_dir = os.path.dirname(original_dir) + for seq in json_dict["datasets"][dataset_name]: + for audio_key in ["original_source", "rir"]: + for seq_idx in range( + len( + json_dict["datasets"][dataset_name][seq]["audio_path"][ + audio_key + ] + ) + ): + path = json_dict["datasets"][dataset_name][seq]["audio_path"][ + audio_key + ][seq_idx] + path = path.replace(original_dir, local_unzipped_dir) + json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][ + seq_idx + ] = path + assert path.startswith( + local_unzipped_dir + ), f"Audio file {path} was expected to start with {local_unzipped_dir}" + assert os.path.exists(path), f"Audio file {path} does not exist" + + json_path = os.path.join(local_unzipped_dir, "sms_wsj.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_dict, f, ensure_ascii=False, indent=4) + + print(f"Finished preparation of zip cache data, use json in {json_path}") class SmsWsjBaseWithRasrClasses(SmsWsjBase): - """ - Base class to wrap the SMS-WSJ dataset and combine it with RASR alignments in an hdf dataset. - """ - - def __init__( - self, rasr_classes_hdf=None, rasr_corpus=None, rasr_segment_prefix="", rasr_segment_postfix="", **kwargs - ): - """ - :param Optional[str] rasr_classes_hdf: hdf file with dumped RASR class labels - :param Optional[str] rasr_corpus: RASR corpus file for reading segment start and end times for padding - :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name - :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name - :param kwargs: """ - super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) - - # self.data_types = {"target_signals": {"dim": 2, "shape": (None, 2)}} - self._rasr_classes_hdf = None - if rasr_classes_hdf is not None: - self._rasr_classes_hdf = HDFDataset([rasr_classes_hdf], use_cache_manager=True) - # self.data_types["target_rasr"] = {"sparse": True, "dim": 9001, "shape": (None, 2)} - self._rasr_segment_start_end = {} # type: Dict[str, Tuple[float, float]] - if rasr_corpus is not None: - from i6_core.lib.corpus import Corpus - corpus = Corpus() - corpus.load(rasr_corpus) - for seg in corpus.segments(): - self._rasr_segment_start_end[seg.fullname()] = (seg.start, seg.end) - self.rasr_segment_prefix = rasr_segment_prefix - self.rasr_segment_postfix = rasr_segment_postfix - - def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: - d = self._get_seq_by_idx(seq_idx) - if self._rasr_classes_hdf is not None: - rasr_seq_tags = [ - f"{self.rasr_segment_prefix}{d['seq_tag']}_{speaker}{self.rasr_segment_postfix}" - for speaker in range(d["target_signals"].shape[1])] - rasr_targets = [] - for idx, rasr_seq_tag in enumerate(rasr_seq_tags): - rasr_targets.append(self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, "classes")) - start_end_times = [self._rasr_segment_start_end.get(rasr_seq_tag, (0, 0)) for rasr_seq_tag in rasr_seq_tags] - padded_len_sec = max(start_end[1] for start_end in start_end_times) - padded_len_frames = max(rasr_target.shape[0] for rasr_target in rasr_targets) - for speaker_idx in range(len(rasr_targets)): - pad_start = 0 - if padded_len_sec > 0: - pad_start = round(start_end_times[speaker_idx][0] / padded_len_sec * padded_len_frames) - pad_end = padded_len_frames - rasr_targets[speaker_idx].shape[0] - pad_start - if pad_end < 0: - pad_start += pad_end - assert pad_start >= 0 - pad_end = 0 - rasr_targets[speaker_idx] = np.concatenate([ - 9000 * np.ones(pad_start), rasr_targets[speaker_idx], 9000 * np.ones(pad_end)]) - d["target_rasr"] = np.stack(rasr_targets).T - d["target_rasr_len"] = np.array(padded_len_frames) - return d - - def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: - """ - Returns sequence length for all data/target keys. - """ - d = super(SmsWsjBaseWithRasrClasses, self).get_seq_length_for_keys(seq_idx) - data = self[seq_idx] - d["target_rasr"] = int(data["target_rasr_len"]) - return NumbersDict(d) + Base class to wrap the SMS-WSJ dataset and combine it with RASR alignments in an hdf dataset. + """ + + def __init__( + self, + rasr_classes_hdf=None, + rasr_corpus=None, + rasr_segment_prefix="", + rasr_segment_postfix="", + **kwargs, + ): + """ + :param Optional[str] rasr_classes_hdf: hdf file with dumped RASR class labels + :param Optional[str] rasr_corpus: RASR corpus file for reading segment start and end times for padding + :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name + :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param kwargs: + """ + super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) + + # self.data_types = {"target_signals": {"dim": 2, "shape": (None, 2)}} + self._rasr_classes_hdf = None + if rasr_classes_hdf is not None: + self._rasr_classes_hdf = HDFDataset( + [rasr_classes_hdf], use_cache_manager=True + ) + # self.data_types["target_rasr"] = {"sparse": True, "dim": 9001, "shape": (None, 2)} + self._rasr_segment_start_end = {} # type: Dict[str, Tuple[float, float]] + if rasr_corpus is not None: + from i6_core.lib.corpus import Corpus + + corpus = Corpus() + corpus.load(rasr_corpus) + for seg in corpus.segments(): + self._rasr_segment_start_end[seg.fullname()] = (seg.start, seg.end) + self.rasr_segment_prefix = rasr_segment_prefix + self.rasr_segment_postfix = rasr_segment_postfix + + def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: + d = self._get_seq_by_idx(seq_idx) + if self._rasr_classes_hdf is not None: + rasr_seq_tags = [ + f"{self.rasr_segment_prefix}{d['seq_tag']}_{speaker}{self.rasr_segment_postfix}" + for speaker in range(d["target_signals"].shape[1]) + ] + rasr_targets = [] + for idx, rasr_seq_tag in enumerate(rasr_seq_tags): + rasr_targets.append( + self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, "classes") + ) + start_end_times = [ + self._rasr_segment_start_end.get(rasr_seq_tag, (0, 0)) + for rasr_seq_tag in rasr_seq_tags + ] + padded_len_sec = max(start_end[1] for start_end in start_end_times) + padded_len_frames = max( + rasr_target.shape[0] for rasr_target in rasr_targets + ) + for speaker_idx in range(len(rasr_targets)): + pad_start = 0 + if padded_len_sec > 0: + pad_start = round( + start_end_times[speaker_idx][0] + / padded_len_sec + * padded_len_frames + ) + pad_end = ( + padded_len_frames - rasr_targets[speaker_idx].shape[0] - pad_start + ) + if pad_end < 0: + pad_start += pad_end + assert pad_start >= 0 + pad_end = 0 + rasr_targets[speaker_idx] = np.concatenate( + [ + 9000 * np.ones(pad_start), + rasr_targets[speaker_idx], + 9000 * np.ones(pad_end), + ] + ) + d["target_rasr"] = np.stack(rasr_targets).T + d["target_rasr_len"] = np.array(padded_len_frames) + return d + + def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: + """ + Returns sequence length for all data/target keys. + """ + d = super(SmsWsjBaseWithRasrClasses, self).get_seq_length_for_keys(seq_idx) + data = self[seq_idx] + d["target_rasr"] = int(data["target_rasr_len"]) + return NumbersDict(d) class SmsWsjWrapper(MapDatasetWrapper): - """ - Base class for datasets that can be used in RETURNN config. - """ - - def __init__( - self, sms_wsj_base, **kwargs - ): """ - :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this - """ - if "seq_ordering" not in kwargs: - print("Warning: no shuffling is enabled by default", file=log.v) - super(SmsWsjWrapper, self).__init__(sms_wsj_base, **kwargs) - # self.num_outputs = ... # needs to be set in derived classes - - def _get_seq_length(seq_idx: int) -> NumbersDict: - """ - Returns sequence length for all data/target keys. - """ - corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) - return sms_wsj_base.get_seq_length_for_keys(corpus_seq_idx) - - self.get_seq_length = _get_seq_length - - @staticmethod - def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: - """ - Used to process raw SMS-WSJ data - :param inputs: input as coming from SMS-WSJ - """ - return_dict = { - "seq_tag": np.array(inputs["example_id"], dtype=object), - "source_id": np.array(inputs["source_id"], dtype=object), - "seq_len": np.array(inputs["num_samples"]["observation"]), - } - return return_dict - - def _collect_single_seq(self, seq_idx: int) -> DatasetSeq: - """ - :param seq_idx: sorted seq idx - """ - corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) - self._dataset.update_buffer(corpus_seq_idx) - data = self._dataset[corpus_seq_idx] - assert "seq_tag" in data - return DatasetSeq(seq_idx, features=data, seq_tag=data["seq_tag"]) - - def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: - """ - Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, therefore - the buffer does not contain the initial indices when continuing the training from an epoch > 0. - """ - out = super(SmsWsjWrapper, self).init_seq_order(epoch=epoch, **kwargs) - buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) - self._dataset.update_buffer(buffer_index, pop_seqs=False) - return out + Base class for datasets that can be used in RETURNN config. + """ + + def __init__(self, sms_wsj_base, **kwargs): + """ + :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + """ + if "seq_ordering" not in kwargs: + print("Warning: no shuffling is enabled by default", file=log.v) + super(SmsWsjWrapper, self).__init__(sms_wsj_base, **kwargs) + # self.num_outputs = ... # needs to be set in derived classes + + def _get_seq_length(seq_idx: int) -> NumbersDict: + """ + Returns sequence length for all data/target keys. + """ + corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) + return sms_wsj_base.get_seq_length_for_keys(corpus_seq_idx) + + self.get_seq_length = _get_seq_length + + @staticmethod + def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: + """ + Used to process raw SMS-WSJ data + :param inputs: input as coming from SMS-WSJ + """ + return_dict = { + "seq_tag": np.array(inputs["example_id"], dtype=object), + "source_id": np.array(inputs["source_id"], dtype=object), + "seq_len": np.array(inputs["num_samples"]["observation"]), + } + return return_dict + + def _collect_single_seq(self, seq_idx: int) -> DatasetSeq: + """ + :param seq_idx: sorted seq idx + """ + corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) + self._dataset.update_buffer(corpus_seq_idx) + data = self._dataset[corpus_seq_idx] + assert "seq_tag" in data + return DatasetSeq(seq_idx, features=data, seq_tag=data["seq_tag"]) + + def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: + """ + Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, therefore + the buffer does not contain the initial indices when continuing the training from an epoch > 0. + """ + out = super(SmsWsjWrapper, self).init_seq_order(epoch=epoch, **kwargs) + buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) + self._dataset.update_buffer(buffer_index, pop_seqs=False) + return out class SmsWsjMixtureEarlyDataset(SmsWsjWrapper): - """ - Dataset with audio mixture and early signals as target. - """ - - def __init__( - self, dataset_name, json_path, num_outputs=None, zip_cache=None, sms_wsj_base=None, **kwargs - ): - """ - :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" - :param str json_path: path to SMS-WSJ json file - :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir - :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this - """ - if sms_wsj_base is None: - sms_wsj_base = SmsWsjBase( - dataset_name=dataset_name, - json_path=json_path, - pre_batch_transform=self._pre_batch_transform, - scenario_map_args={"add_speech_reverberation_early": True}, - zip_cache=zip_cache) - super(SmsWsjMixtureEarlyDataset, self).__init__(sms_wsj_base, **kwargs) - # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense - self.num_outputs = num_outputs or {"data": [1, 2], "target_signals": [2, 2]} - - @staticmethod - def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: """ - Used to process raw SMS-WSJ data - :param inputs: input as coming from SMS-WSJ - """ - return_dict = SmsWsjWrapper._pre_batch_transform(inputs) - return_dict.update({ - "data": inputs["audio_data"]["observation"][:1, :].T, # take first of 6 channels: (T, 1) - "target_signals": inputs["audio_data"]["speech_reverberation_early"][:, 0, :].T, # first of 6 channels: (T, S) - }) - return return_dict + Dataset with audio mixture and early signals as target. + """ + + def __init__( + self, + dataset_name, + json_path, + num_outputs=None, + zip_cache=None, + sms_wsj_base=None, + **kwargs, + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + """ + if sms_wsj_base is None: + sms_wsj_base = SmsWsjBase( + dataset_name=dataset_name, + json_path=json_path, + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + zip_cache=zip_cache, + ) + super(SmsWsjMixtureEarlyDataset, self).__init__(sms_wsj_base, **kwargs) + # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense + self.num_outputs = num_outputs or {"data": [1, 2], "target_signals": [2, 2]} + + @staticmethod + def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: + """ + Used to process raw SMS-WSJ data + :param inputs: input as coming from SMS-WSJ + """ + return_dict = SmsWsjWrapper._pre_batch_transform(inputs) + return_dict.update( + { + "data": inputs["audio_data"]["observation"][ + :1, : + ].T, # take first of 6 channels: (T, 1) + "target_signals": inputs["audio_data"]["speech_reverberation_early"][ + :, 0, : + ].T, # first of 6 channels: (T, S) + } + ) + return return_dict class SmsWsjMixtureEarlyAlignmentDataset(SmsWsjMixtureEarlyDataset): - """ - Dataset with audio mixture, target early signals and target RASR alignments. - """ - - def __init__( - self, dataset_name, json_path, num_outputs=None, rasr_num_outputs=None, rasr_segment_prefix="", - rasr_segment_postfix="", rasr_classes_hdf=None, rasr_corpus=None, zip_cache=None, **kwargs - ): - """ - :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" - :param str json_path: path to SMS-WSJ json file - :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset - :param Optional[int] rasr_num_outputs: number of output labels for RASR alignment, e.g. 9001 for that CART size - :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name - :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name - :param str rasr_classes_hdf: hdf file with dumped RASR class labels - :param str rasr_corpus: RASR corpus file for reading segment start and end times for padding - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir """ - sms_wsj_base = SmsWsjBaseWithRasrClasses( - dataset_name=dataset_name, - json_path=json_path, - pre_batch_transform=self._pre_batch_transform, - scenario_map_args={"add_speech_reverberation_early": True}, - rasr_classes_hdf=rasr_classes_hdf, - rasr_corpus=rasr_corpus, - rasr_segment_prefix=rasr_segment_prefix, - rasr_segment_postfix=rasr_segment_postfix, - zip_cache=zip_cache) - super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( - dataset_name, json_path, num_outputs=num_outputs, zip_cache=zip_cache, sms_wsj_base=sms_wsj_base, **kwargs) - if num_outputs is not None: - self.num_outputs = num_outputs - else: - assert rasr_num_outputs is not None, "either num_outputs or rasr_num_outputs has to be given" - self.num_outputs["target_rasr"] = [rasr_num_outputs, 1] # target alignments are sparse with the given dim + Dataset with audio mixture, target early signals and target RASR alignments. + """ + + def __init__( + self, + dataset_name, + json_path, + num_outputs=None, + rasr_num_outputs=None, + rasr_segment_prefix="", + rasr_segment_postfix="", + rasr_classes_hdf=None, + rasr_corpus=None, + zip_cache=None, + **kwargs, + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset + :param Optional[int] rasr_num_outputs: number of output labels for RASR alignment, e.g. 9001 for that CART size + :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name + :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param str rasr_classes_hdf: hdf file with dumped RASR class labels + :param str rasr_corpus: RASR corpus file for reading segment start and end times for padding + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + """ + sms_wsj_base = SmsWsjBaseWithRasrClasses( + dataset_name=dataset_name, + json_path=json_path, + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + rasr_classes_hdf=rasr_classes_hdf, + rasr_corpus=rasr_corpus, + rasr_segment_prefix=rasr_segment_prefix, + rasr_segment_postfix=rasr_segment_postfix, + zip_cache=zip_cache, + ) + super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( + dataset_name, + json_path, + num_outputs=num_outputs, + zip_cache=zip_cache, + sms_wsj_base=sms_wsj_base, + **kwargs, + ) + if num_outputs is not None: + self.num_outputs = num_outputs + else: + assert ( + rasr_num_outputs is not None + ), "either num_outputs or rasr_num_outputs has to be given" + self.num_outputs["target_rasr"] = [ + rasr_num_outputs, + 1, + ] # target alignments are sparse with the given dim From 2da7e59b3f933e9d23c36e7d6010d485ead673e9 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 2 Feb 2023 16:05:59 +0100 Subject: [PATCH 03/47] remove hard coded label for padding --- common/datasets/sms_wsj/returnn_datasets.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index b6701e443..3ee80823d 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -265,6 +265,7 @@ def __init__( rasr_corpus=None, rasr_segment_prefix="", rasr_segment_postfix="", + pad_label=None, **kwargs, ): """ @@ -272,13 +273,16 @@ def __init__( :param Optional[str] rasr_corpus: RASR corpus file for reading segment start and end times for padding :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param Optional[int] pad_label: target label assigned to padded areas :param kwargs: """ super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) # self.data_types = {"target_signals": {"dim": 2, "shape": (None, 2)}} self._rasr_classes_hdf = None + self.pad_label = pad_label if rasr_classes_hdf is not None: + assert pad_label is not None, "Label for padding is needed" self._rasr_classes_hdf = HDFDataset( [rasr_classes_hdf], use_cache_manager=True ) @@ -331,9 +335,9 @@ def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: pad_end = 0 rasr_targets[speaker_idx] = np.concatenate( [ - 9000 * np.ones(pad_start), + self.pad_label * np.ones(pad_start), rasr_targets[speaker_idx], - 9000 * np.ones(pad_end), + self.pad_label * np.ones(pad_end), ] ) d["target_rasr"] = np.stack(rasr_targets).T @@ -475,6 +479,7 @@ def __init__( rasr_segment_postfix="", rasr_classes_hdf=None, rasr_corpus=None, + pad_label=None, zip_cache=None, **kwargs, ): @@ -487,6 +492,7 @@ def __init__( :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name :param str rasr_classes_hdf: hdf file with dumped RASR class labels :param str rasr_corpus: RASR corpus file for reading segment start and end times for padding + :param Optional[int] pad_label: target label assigned to padded areas :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir """ sms_wsj_base = SmsWsjBaseWithRasrClasses( @@ -498,6 +504,7 @@ def __init__( rasr_corpus=rasr_corpus, rasr_segment_prefix=rasr_segment_prefix, rasr_segment_postfix=rasr_segment_postfix, + pad_label=pad_label, zip_cache=zip_cache, ) super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( From 7bfa8d4ab76beee21a6c3462058af0c4978288a9 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 2 Feb 2023 17:47:11 +0100 Subject: [PATCH 04/47] add SmsWsjMixtureEarlyBpeDataset --- common/datasets/sms_wsj/returnn_datasets.py | 64 +++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 3ee80823d..bdffdc96d 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -525,3 +525,67 @@ def __init__( rasr_num_outputs, 1, ] # target alignments are sparse with the given dim + + +class SmsWsjMixtureEarlyBpeDataset(SmsWsjMixtureEarlyDataset): + """ + Dataset with audio mixture, target early signals and target BPE labels. + """ + + def __init__( + self, + dataset_name, + json_path, + bpe, + text_proc=None, + num_outputs=None, + zip_cache=None, + **kwargs, + ): + """ + :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" + :param str json_path: path to SMS-WSJ json file + :param Dict[str] bpe: opts for :class:`BytePairEncoding` + :param Optional[Callable] text_proc: function to preprocess the transcriptions before applying BPE + :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + """ + sms_wsj_base = SmsWsjBase( + dataset_name=dataset_name, + json_path=json_path, + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + zip_cache=zip_cache, + ) + super(SmsWsjMixtureEarlyBpeDataset, self).__init__( + dataset_name, + json_path, + num_outputs=num_outputs, + zip_cache=zip_cache, + sms_wsj_base=sms_wsj_base, + **kwargs, + ) + from returnn.datasets.util.vocabulary import BytePairEncoding + self.bpe = BytePairEncoding(**bpe) + self.text_proc = text_proc or (lambda x: x) + if num_outputs is not None: + self.num_outputs = num_outputs + else: + assert ( + bpe is not None + ), "either num_outputs or bpe has to be given" + self.num_outputs["target_bpe"] = [ + self.bpe.num_labels, + 1, + ] # target BPE labels are sparse with the given dim + + def _pre_batch_transform(self, inputs: Dict[str, Any]) -> Dict[str, np.array]: + """ + Used to process raw SMS-WSJ data + :param inputs: input as coming from SMS-WSJ + """ + return_dict = SmsWsjMixtureEarlyDataset._pre_batch_transform(inputs) + for speaker, orth in enumerate(inputs["kaldi_transcription"]): + return_dict[f"target_bpe_{speaker}"] = np.array(self.bpe.get_seq(self.text_proc(orth)), dtype="int32") + return_dict[f"target_bpe_{speaker}_len"] = np.array(return_dict[f"target_bpe_{speaker}"].size) + return return_dict From 8c539ea3f26585d7b935e89b6497e07966bdf3bf Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 2 Feb 2023 17:49:21 +0100 Subject: [PATCH 05/47] black formatting --- common/datasets/sms_wsj/returnn_datasets.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index bdffdc96d..3a8fbe9fa 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -566,14 +566,13 @@ def __init__( **kwargs, ) from returnn.datasets.util.vocabulary import BytePairEncoding + self.bpe = BytePairEncoding(**bpe) self.text_proc = text_proc or (lambda x: x) if num_outputs is not None: self.num_outputs = num_outputs else: - assert ( - bpe is not None - ), "either num_outputs or bpe has to be given" + assert bpe is not None, "either num_outputs or bpe has to be given" self.num_outputs["target_bpe"] = [ self.bpe.num_labels, 1, @@ -586,6 +585,10 @@ def _pre_batch_transform(self, inputs: Dict[str, Any]) -> Dict[str, np.array]: """ return_dict = SmsWsjMixtureEarlyDataset._pre_batch_transform(inputs) for speaker, orth in enumerate(inputs["kaldi_transcription"]): - return_dict[f"target_bpe_{speaker}"] = np.array(self.bpe.get_seq(self.text_proc(orth)), dtype="int32") - return_dict[f"target_bpe_{speaker}_len"] = np.array(return_dict[f"target_bpe_{speaker}"].size) + return_dict[f"target_bpe_{speaker}"] = np.array( + self.bpe.get_seq(self.text_proc(orth)), dtype="int32" + ) + return_dict[f"target_bpe_{speaker}_len"] = np.array( + return_dict[f"target_bpe_{speaker}"].size + ) return return_dict From 9ad34aa752deb873bd1b615862ff60f6544e1e9d Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 3 Feb 2023 09:03:19 +0100 Subject: [PATCH 06/47] fix data_types --- common/datasets/sms_wsj/returnn_datasets.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 3a8fbe9fa..78d011510 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -39,6 +39,7 @@ def __init__( dataset_name, json_path, pre_batch_transform, + data_types, buffer=True, zip_cache=None, scenario_map_args=None, @@ -48,12 +49,15 @@ def __init__( :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" :param str json_path: path to SMS-WSJ json file :param function pre_batch_transform: function which processes raw SMS-WSJ data + :param Dict[str] data_types: data types for RETURNN, e.g. {"target_signals": {"dim": 2, "shape": (None, 2)}} :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn """ super(SmsWsjBase, self).__init__(**kwargs) + self.data_types = data_types + if zip_cache is not None: self._cache_zipped_audio(zip_cache, json_path, dataset_name) @@ -278,7 +282,6 @@ def __init__( """ super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) - # self.data_types = {"target_signals": {"dim": 2, "shape": (None, 2)}} self._rasr_classes_hdf = None self.pad_label = pad_label if rasr_classes_hdf is not None: @@ -286,7 +289,6 @@ def __init__( self._rasr_classes_hdf = HDFDataset( [rasr_classes_hdf], use_cache_manager=True ) - # self.data_types["target_rasr"] = {"sparse": True, "dim": 9001, "shape": (None, 2)} self._rasr_segment_start_end = {} # type: Dict[str, Tuple[float, float]] if rasr_corpus is not None: from i6_core.lib.corpus import Corpus @@ -439,6 +441,7 @@ def __init__( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, zip_cache=zip_cache, + data_types={"target_signals": {"dim": 2, "shape": (None, 2)}}, ) super(SmsWsjMixtureEarlyDataset, self).__init__(sms_wsj_base, **kwargs) # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense @@ -495,6 +498,14 @@ def __init__( :param Optional[int] pad_label: target label assigned to padded areas :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir """ + data_types = { + "target_signals": {"dim": 2, "shape": (None, 2)}, + "target_rasr": { + "sparse": True, + "dim": rasr_num_outputs, + "shape": (None, 2), + }, + } sms_wsj_base = SmsWsjBaseWithRasrClasses( dataset_name=dataset_name, json_path=json_path, @@ -506,6 +517,7 @@ def __init__( rasr_segment_postfix=rasr_segment_postfix, pad_label=pad_label, zip_cache=zip_cache, + data_types=data_types, ) super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( dataset_name, From ac0718c2095e1f6db4e1b962efaba4d335df3205 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 6 Feb 2023 10:48:39 +0100 Subject: [PATCH 07/47] bpe add data types --- common/datasets/sms_wsj/returnn_datasets.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 78d011510..eec83054c 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -18,6 +18,9 @@ # noinspection PyUnresolvedReferences from returnn.datasets.map import MapDatasetBase, MapDatasetWrapper +# noinspection PyUnresolvedReferences +from returnn.datasets.util.vocabulary import BytePairEncoding + # noinspection PyUnresolvedReferences from returnn.log import log @@ -404,8 +407,8 @@ def _collect_single_seq(self, seq_idx: int) -> DatasetSeq: def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: """ - Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, therefore - the buffer does not contain the initial indices when continuing the training from an epoch > 0. + Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, + therefore the buffer does not contain the initial indices when continuing the training from an epoch > 0. """ out = super(SmsWsjWrapper, self).init_seq_order(epoch=epoch, **kwargs) buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) @@ -562,12 +565,22 @@ def __init__( :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir """ + self.bpe = BytePairEncoding(**bpe) + data_types = { + "target_signals": {"dim": 2, "shape": (None, 2)}, + "target_bpe": { + "sparse": True, + "dim": self.bpe.num_labels, + "shape": (None, 2), + }, + } sms_wsj_base = SmsWsjBase( dataset_name=dataset_name, json_path=json_path, pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, zip_cache=zip_cache, + data_types=data_types, ) super(SmsWsjMixtureEarlyBpeDataset, self).__init__( dataset_name, @@ -577,9 +590,7 @@ def __init__( sms_wsj_base=sms_wsj_base, **kwargs, ) - from returnn.datasets.util.vocabulary import BytePairEncoding - self.bpe = BytePairEncoding(**bpe) self.text_proc = text_proc or (lambda x: x) if num_outputs is not None: self.num_outputs = num_outputs From 5c0ea9ee03fd69bd88d5ee45fd8b638116a41330 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 6 Feb 2023 17:33:05 +0100 Subject: [PATCH 08/47] allow original alignment --- common/datasets/sms_wsj/returnn_datasets.py | 141 +++++++++----------- 1 file changed, 62 insertions(+), 79 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index eec83054c..9d860ce9e 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -268,85 +268,60 @@ class SmsWsjBaseWithRasrClasses(SmsWsjBase): def __init__( self, - rasr_classes_hdf=None, - rasr_corpus=None, - rasr_segment_prefix="", - rasr_segment_postfix="", + rasr_classes_hdf, + segment_to_rasr, pad_label=None, + hdf_data_key="classes", **kwargs, ): """ - :param Optional[str] rasr_classes_hdf: hdf file with dumped RASR class labels - :param Optional[str] rasr_corpus: RASR corpus file for reading segment start and end times for padding - :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name - :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param str rasr_classes_hdf: hdf file with dumped RASR class labels + :param Callable segment_to_rasr: function that maps SMS-WSJ seg. name into list of corresponding RASR seg. names :param Optional[int] pad_label: target label assigned to padded areas + :param str hdf_data_key: data key under which the alignment is stored in the hdf, usually "classes" or "data" :param kwargs: """ super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) - self._rasr_classes_hdf = None - self.pad_label = pad_label - if rasr_classes_hdf is not None: - assert pad_label is not None, "Label for padding is needed" - self._rasr_classes_hdf = HDFDataset( - [rasr_classes_hdf], use_cache_manager=True - ) - self._rasr_segment_start_end = {} # type: Dict[str, Tuple[float, float]] - if rasr_corpus is not None: - from i6_core.lib.corpus import Corpus - - corpus = Corpus() - corpus.load(rasr_corpus) - for seg in corpus.segments(): - self._rasr_segment_start_end[seg.fullname()] = (seg.start, seg.end) - self.rasr_segment_prefix = rasr_segment_prefix - self.rasr_segment_postfix = rasr_segment_postfix + self._rasr_classes_hdf = HDFDataset( + [rasr_classes_hdf], use_cache_manager=True + ) + self._segment_to_rasr = segment_to_rasr + self._pad_label = pad_label + self._hdf_data_key = hdf_data_key def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: d = self._get_seq_by_idx(seq_idx) - if self._rasr_classes_hdf is not None: - rasr_seq_tags = [ - f"{self.rasr_segment_prefix}{d['seq_tag']}_{speaker}{self.rasr_segment_postfix}" - for speaker in range(d["target_signals"].shape[1]) - ] - rasr_targets = [] - for idx, rasr_seq_tag in enumerate(rasr_seq_tags): - rasr_targets.append( - self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, "classes") - ) - start_end_times = [ - self._rasr_segment_start_end.get(rasr_seq_tag, (0, 0)) - for rasr_seq_tag in rasr_seq_tags - ] - padded_len_sec = max(start_end[1] for start_end in start_end_times) - padded_len_frames = max( - rasr_target.shape[0] for rasr_target in rasr_targets + rasr_seq_tags = self._segment_to_rasr(d["seq_tag"]) + assert len(rasr_seq_tags) == d["target_signals"].shape[1], ( + f"got {len(rasr_seq_tags)} segment names, but there are {d['target_signals'].shape[1]} target signals") + rasr_targets = [ + self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, self._hdf_data_key) + for rasr_seq_tag in rasr_seq_tags + ] + padded_len = max( + rasr_target.shape[0] for rasr_target in rasr_targets + ) + for speaker_idx in range(len(rasr_targets)): + pad_start = int(round(d["offset"][speaker_idx] / d["seq_len"] * padded_len)) + pad_end = ( + padded_len - rasr_targets[speaker_idx].shape[0] - pad_start ) - for speaker_idx in range(len(rasr_targets)): - pad_start = 0 - if padded_len_sec > 0: - pad_start = round( - start_end_times[speaker_idx][0] - / padded_len_sec - * padded_len_frames - ) - pad_end = ( - padded_len_frames - rasr_targets[speaker_idx].shape[0] - pad_start - ) - if pad_end < 0: - pad_start += pad_end - assert pad_start >= 0 - pad_end = 0 - rasr_targets[speaker_idx] = np.concatenate( - [ - self.pad_label * np.ones(pad_start), - rasr_targets[speaker_idx], - self.pad_label * np.ones(pad_end), - ] - ) - d["target_rasr"] = np.stack(rasr_targets).T - d["target_rasr_len"] = np.array(padded_len_frames) + if pad_end < 0: + pad_start += pad_end + assert pad_start >= 0 + pad_end = 0 + if pad_start or pad_end: + assert self._pad_label is not None, "Label for padding is needed" + rasr_targets[speaker_idx] = np.concatenate( + [ + self._pad_label * np.ones(pad_start), + rasr_targets[speaker_idx], + self._pad_label * np.ones(pad_end), + ] + ) + d["target_rasr"] = np.stack(rasr_targets).T + d["target_rasr_len"] = np.array(padded_len) return d def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: @@ -481,12 +456,11 @@ def __init__( json_path, num_outputs=None, rasr_num_outputs=None, - rasr_segment_prefix="", - rasr_segment_postfix="", + zip_cache=None, rasr_classes_hdf=None, - rasr_corpus=None, + segment_to_rasr=None, pad_label=None, - zip_cache=None, + hdf_data_key="classes", **kwargs, ): """ @@ -494,12 +468,11 @@ def __init__( :param str json_path: path to SMS-WSJ json file :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset :param Optional[int] rasr_num_outputs: number of output labels for RASR alignment, e.g. 9001 for that CART size - :param str rasr_segment_prefix: prefix to map SMS-WSJ segment name to RASR segment name - :param str rasr_segment_postfix: postfix to map SMS-WSJ segment name to RASR segment name + :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param str rasr_classes_hdf: hdf file with dumped RASR class labels - :param str rasr_corpus: RASR corpus file for reading segment start and end times for padding + :param Callable segment_to_rasr: function that maps SMS-WSJ seg. name into list of corresponding RASR seg. names :param Optional[int] pad_label: target label assigned to padded areas - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + :param str hdf_data_key: data key under which the alignment is stored in the hdf, usually "classes" or "data" """ data_types = { "target_signals": {"dim": 2, "shape": (None, 2)}, @@ -514,13 +487,12 @@ def __init__( json_path=json_path, pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, + data_types=data_types, + zip_cache=zip_cache, rasr_classes_hdf=rasr_classes_hdf, - rasr_corpus=rasr_corpus, - rasr_segment_prefix=rasr_segment_prefix, - rasr_segment_postfix=rasr_segment_postfix, + segment_to_rasr=segment_to_rasr, pad_label=pad_label, - zip_cache=zip_cache, - data_types=data_types, + hdf_data_key=hdf_data_key, ) super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( dataset_name, @@ -541,6 +513,17 @@ def __init__( 1, ] # target alignments are sparse with the given dim + @staticmethod + def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: + """ + Used to process raw SMS-WSJ data + :param inputs: input as coming from SMS-WSJ + """ + return_dict = SmsWsjMixtureEarlyDataset._pre_batch_transform(inputs) + # we need the padding information here + return_dict["offset"] = np.array(inputs["offset"], dtype="int") + return return_dict + class SmsWsjMixtureEarlyBpeDataset(SmsWsjMixtureEarlyDataset): """ From ff2f0ea8354b1efc8121db3228d855021dc70412 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 6 Feb 2023 17:37:43 +0100 Subject: [PATCH 09/47] black --- common/datasets/sms_wsj/returnn_datasets.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 9d860ce9e..d3e757504 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -283,9 +283,7 @@ def __init__( """ super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) - self._rasr_classes_hdf = HDFDataset( - [rasr_classes_hdf], use_cache_manager=True - ) + self._rasr_classes_hdf = HDFDataset([rasr_classes_hdf], use_cache_manager=True) self._segment_to_rasr = segment_to_rasr self._pad_label = pad_label self._hdf_data_key = hdf_data_key @@ -293,20 +291,17 @@ def __init__( def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: d = self._get_seq_by_idx(seq_idx) rasr_seq_tags = self._segment_to_rasr(d["seq_tag"]) - assert len(rasr_seq_tags) == d["target_signals"].shape[1], ( - f"got {len(rasr_seq_tags)} segment names, but there are {d['target_signals'].shape[1]} target signals") + assert ( + len(rasr_seq_tags) == d["target_signals"].shape[1] + ), f"got {len(rasr_seq_tags)} segment names, but there are {d['target_signals'].shape[1]} target signals" rasr_targets = [ self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, self._hdf_data_key) for rasr_seq_tag in rasr_seq_tags ] - padded_len = max( - rasr_target.shape[0] for rasr_target in rasr_targets - ) + padded_len = max(rasr_target.shape[0] for rasr_target in rasr_targets) for speaker_idx in range(len(rasr_targets)): pad_start = int(round(d["offset"][speaker_idx] / d["seq_len"] * padded_len)) - pad_end = ( - padded_len - rasr_targets[speaker_idx].shape[0] - pad_start - ) + pad_end = padded_len - rasr_targets[speaker_idx].shape[0] - pad_start if pad_end < 0: pad_start += pad_end assert pad_start >= 0 From 44774a84a5b330b24eb0273713dac07c0c30a0ae Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:13:04 +0100 Subject: [PATCH 10/47] Update common/datasets/sms_wsj/returnn_datasets.py --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index d3e757504..025fcc49a 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -76,7 +76,7 @@ def __init__( **(scenario_map_args or {}), } ds = ds.map(functools.partial(scenario_map_fn, **scenario_map_args)) - ds = ds.map(functools.partial(pre_batch_transform)) + ds = ds.map(pre_batch_transform) self._ds = ds self._ds_iterator = iter(self._ds) From eed5171d1d42c88dd71436a7c0abb4d4c6095461 Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:15:29 +0100 Subject: [PATCH 11/47] Update common/datasets/sms_wsj/returnn_datasets.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- common/datasets/sms_wsj/returnn_datasets.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 025fcc49a..e83334e9b 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -109,10 +109,7 @@ def get_seq_tag(self, seq_idx: int) -> str: """ Returns tag for the sequence of the given index, default is 'seq-{seq_idx}'. """ - if "seq_tag" in self._get_seq_by_idx(seq_idx): - return str(self._get_seq_by_idx(seq_idx)["seq_tag"]) - else: - return "seq-%i" % seq_idx + return str(self._get_seq_by_idx(seq_idx).get("seq_tag", f"seq-{seq_idx}")) def get_seq_len(self, seq_idx: int) -> int: """ From 1db97719fdcea4eaff160893553fe8d5311bee6d Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:16:43 +0100 Subject: [PATCH 12/47] Update common/datasets/sms_wsj/returnn_datasets.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- common/datasets/sms_wsj/returnn_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index e83334e9b..7f780d2fc 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -115,10 +115,11 @@ def get_seq_len(self, seq_idx: int) -> int: """ Returns length of the sequence of the given index """ - if "seq_len" in self._get_seq_by_idx(seq_idx): + try: return int(self._get_seq_by_idx(seq_idx)["seq_len"]) - else: + except KeyError: raise OptionalNotImplementedError + def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: """ From 10a13cf0213f0b95970efab0ff3bad25062e72be Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:17:05 +0100 Subject: [PATCH 13/47] Update common/datasets/sms_wsj/returnn_datasets.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 7f780d2fc..ab423c7ed 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -128,7 +128,7 @@ def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: data = self[seq_idx] d = {k: v.size for k, v in data.items()} for update_key in ["data", "target_signals"]: - if update_key in d.keys() and "seq_len" in data: + if update_key in d and "seq_len" in data: d[update_key] = int(data["seq_len"]) return NumbersDict(d) From f7ef64a91ea9c02e3a583d3bd084d4884a267564 Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:19:49 +0100 Subject: [PATCH 14/47] Update common/datasets/sms_wsj/returnn_datasets.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index ab423c7ed..5e2f0e31f 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -414,7 +414,7 @@ def __init__( zip_cache=zip_cache, data_types={"target_signals": {"dim": 2, "shape": (None, 2)}}, ) - super(SmsWsjMixtureEarlyDataset, self).__init__(sms_wsj_base, **kwargs) + super().__init__(sms_wsj_base, **kwargs) # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense self.num_outputs = num_outputs or {"data": [1, 2], "target_signals": [2, 2]} From 0f4d36306e1c03d7772fc69b1eec8eb38f26dfdb Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:20:20 +0100 Subject: [PATCH 15/47] Update common/datasets/sms_wsj/returnn_datasets.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 5e2f0e31f..db2853df1 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -487,7 +487,7 @@ def __init__( pad_label=pad_label, hdf_data_key=hdf_data_key, ) - super(SmsWsjMixtureEarlyAlignmentDataset, self).__init__( + super().__init__( dataset_name, json_path, num_outputs=num_outputs, From c64c6fc91cf0b65bc29ec0143b85b0993a768abd Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:20:28 +0100 Subject: [PATCH 16/47] Update common/datasets/sms_wsj/returnn_datasets.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index db2853df1..fcef2adb6 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -558,7 +558,7 @@ def __init__( zip_cache=zip_cache, data_types=data_types, ) - super(SmsWsjMixtureEarlyBpeDataset, self).__init__( + super().__init__( dataset_name, json_path, num_outputs=num_outputs, From f39ebdad5c09d202d0c29cc57f6d7cd8f6227c05 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 7 Feb 2023 17:29:16 +0100 Subject: [PATCH 17/47] review comments --- common/datasets/sms_wsj/returnn_datasets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index fcef2adb6..b2b025286 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -46,6 +46,8 @@ def __init__( buffer=True, zip_cache=None, scenario_map_args=None, + prefetch_num_workers=4, + prefetch_buffer_size=40, **kwargs, ): """ @@ -83,9 +85,9 @@ def __init__( self._use_buffer = buffer if self._use_buffer: - self._ds = self._ds.prefetch(4, 8).copy(freeze=True) + self._ds = self._ds.prefetch(prefetch_num_workers, prefetch_buffer_size).copy(freeze=True) self._buffer = {} # type Dict[int,[Dict[str,np.array]]] - self._buffer_size = 40 + self._buffer_size = prefetch_buffer_size def __len__(self) -> int: return len(self._ds) @@ -571,7 +573,6 @@ def __init__( if num_outputs is not None: self.num_outputs = num_outputs else: - assert bpe is not None, "either num_outputs or bpe has to be given" self.num_outputs["target_bpe"] = [ self.bpe.num_labels, 1, From 8b7ad024316d3e07d428fc1a7752d5229db32c7e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 7 Feb 2023 17:29:59 +0100 Subject: [PATCH 18/47] black --- common/datasets/sms_wsj/returnn_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index b2b025286..4f9be092d 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -85,7 +85,9 @@ def __init__( self._use_buffer = buffer if self._use_buffer: - self._ds = self._ds.prefetch(prefetch_num_workers, prefetch_buffer_size).copy(freeze=True) + self._ds = self._ds.prefetch( + prefetch_num_workers, prefetch_buffer_size + ).copy(freeze=True) self._buffer = {} # type Dict[int,[Dict[str,np.array]]] self._buffer_size = prefetch_buffer_size @@ -121,7 +123,6 @@ def get_seq_len(self, seq_idx: int) -> int: return int(self._get_seq_by_idx(seq_idx)["seq_len"]) except KeyError: raise OptionalNotImplementedError - def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: """ From f736a507f0e1cd2f95dad7b32bc1812134b48521 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 12:42:46 +0100 Subject: [PATCH 19/47] _segment_to_rasr str input --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 4f9be092d..f927e63e9 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -291,7 +291,7 @@ def __init__( def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: d = self._get_seq_by_idx(seq_idx) - rasr_seq_tags = self._segment_to_rasr(d["seq_tag"]) + rasr_seq_tags = self._segment_to_rasr(str(d["seq_tag"])) assert ( len(rasr_seq_tags) == d["target_signals"].shape[1] ), f"got {len(rasr_seq_tags)} segment names, but there are {d['target_signals'].shape[1]} target signals" From 4445b29c1eee11ac116ff69bb5342edea468e31e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 11:33:32 +0100 Subject: [PATCH 20/47] sms wsj add init --- common/datasets/sms_wsj/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 common/datasets/sms_wsj/__init__.py diff --git a/common/datasets/sms_wsj/__init__.py b/common/datasets/sms_wsj/__init__.py new file mode 100644 index 000000000..e69de29bb From 19d033e55fab383656baa32045d6b1e6f03d77c2 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 17:55:39 +0100 Subject: [PATCH 21/47] caching: use local files if available --- common/datasets/sms_wsj/returnn_datasets.py | 102 ++++++++++++-------- 1 file changed, 61 insertions(+), 41 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index f927e63e9..5ee55659e 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -64,7 +64,7 @@ def __init__( self.data_types = data_types if zip_cache is not None: - self._cache_zipped_audio(zip_cache, json_path, dataset_name) + json_path = self._cache_zipped_audio(zip_cache, json_path, dataset_name) db = SmsWsj(json_path=json_path) ds = db.get_dataset(dataset_name) @@ -207,59 +207,79 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): """ print(f"Cache and unzip SMS-WSJ data from {zip_cache}") - # cache and unzip + # cache file try: zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") assert ( zip_cache_cached != zip_cache ), "cached and original file have the same path" local_unzipped_dir = os.path.dirname(zip_cache_cached) - sp.check_call( - ["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir] - ) + json_path_cached = sp.check_output(["cf", json_path]).strip().decode("utf8") + assert ( + json_path_cached != json_path + ), "cached and original file have the same path" except sp.CalledProcessError: print( f"Cache manager: Error occurred when caching and unzipping {zip_cache}" ) raise - # modify json and check if all data is available - with open(json_path, "r") as f: - json_dict = json.loads(f.read()) - original_dir = next(iter(json_dict["datasets"][dataset_name].values()))[ - "audio_path" - ]["original_source"][0] - while ( - not original_dir.endswith(os.path.basename(local_unzipped_dir)) - and len(original_dir) > 1 - ): - original_dir = os.path.dirname(original_dir) - for seq in json_dict["datasets"][dataset_name]: - for audio_key in ["original_source", "rir"]: - for seq_idx in range( - len( - json_dict["datasets"][dataset_name][seq]["audio_path"][ + # unzip if folder does not yet exist + if not os.path.exists(local_unzipped_dir): + sp.check_call( + ["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir] + ) + else: + print(f"Unzipped audio already exists in {local_unzipped_dir}") + + + json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") + original_dir = None + if not os.path.exists(json_path_cached_mod): + with open(json_path_cached, "r") as f: + json_dict = json.loads(f.read()) + # get original dir + original_dir = next(iter(json_dict["datasets"][dataset_name].values()))[ + "audio_path" + ]["original_source"][0] + while ( + not original_dir.endswith(os.path.basename(local_unzipped_dir)) + and len(original_dir) > 1 + ): + original_dir = os.path.dirname(original_dir) + else: + with open(json_path_cached_mod, "r") as f: + json_dict = json.loads(f.read()) + # check if all data is available and create modified json if it does not yet exist + for dataset_name in json_dict["datasets"]: + for seq in json_dict["datasets"][dataset_name]: + for audio_key in ["original_source", "rir"]: + for seq_idx in range( + len( + json_dict["datasets"][dataset_name][seq]["audio_path"][ + audio_key + ] + ) + ): + path = json_dict["datasets"][dataset_name][seq]["audio_path"][ audio_key - ] - ) - ): - path = json_dict["datasets"][dataset_name][seq]["audio_path"][ - audio_key - ][seq_idx] - path = path.replace(original_dir, local_unzipped_dir) - json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][ - seq_idx - ] = path - assert path.startswith( - local_unzipped_dir - ), f"Audio file {path} was expected to start with {local_unzipped_dir}" - assert os.path.exists(path), f"Audio file {path} does not exist" - - json_path = os.path.join(local_unzipped_dir, "sms_wsj.json") - with open(json_path, "w", encoding="utf-8") as f: - json.dump(json_dict, f, ensure_ascii=False, indent=4) - - print(f"Finished preparation of zip cache data, use json in {json_path}") + ][seq_idx] + if not os.path.exists(json_path_cached_mod): + path = path.replace(original_dir, local_unzipped_dir) + json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][ + seq_idx + ] = path + assert path.startswith( + local_unzipped_dir + ), f"Audio file {path} was expected to start with {local_unzipped_dir}" + assert os.path.exists(path), f"Audio file {path} does not exist" + + if not os.path.exists(json_path_cached_mod): + with open(json_path_cached_mod, "w", encoding="utf-8") as f: + json.dump(json_dict, f, ensure_ascii=False, indent=4) + + print(f"Finished preparation of zip cache data, use json in {json_path_cached_mod}") + return json_path_cached_mod class SmsWsjBaseWithRasrClasses(SmsWsjBase): From ac70e9dfd9fa2bad041d7077317e435117becde6 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 17:58:14 +0100 Subject: [PATCH 22/47] non standard imports inside classes --- common/datasets/sms_wsj/returnn_datasets.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 5ee55659e..276f3ed07 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -18,18 +18,12 @@ # noinspection PyUnresolvedReferences from returnn.datasets.map import MapDatasetBase, MapDatasetWrapper -# noinspection PyUnresolvedReferences -from returnn.datasets.util.vocabulary import BytePairEncoding - # noinspection PyUnresolvedReferences from returnn.log import log # noinspection PyUnresolvedReferences from returnn.util.basic import OptionalNotImplementedError, NumbersDict -# noinspection PyUnresolvedReferences -from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn - class SmsWsjBase(MapDatasetBase): """ @@ -59,6 +53,9 @@ def __init__( :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn """ + # noinspection PyUnresolvedReferences + from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn + super(SmsWsjBase, self).__init__(**kwargs) self.data_types = data_types @@ -564,6 +561,8 @@ def __init__( :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir """ + from returnn.datasets.util.vocabulary import BytePairEncoding + self.bpe = BytePairEncoding(**bpe) data_types = { "target_signals": {"dim": 2, "shape": (None, 2)}, From 30a90f765d46c6817599cc7adb908509358623eb Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 17:58:44 +0100 Subject: [PATCH 23/47] black --- common/datasets/sms_wsj/returnn_datasets.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 276f3ed07..3db68a2ac 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -229,7 +229,6 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): else: print(f"Unzipped audio already exists in {local_unzipped_dir}") - json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") original_dir = None if not os.path.exists(json_path_cached_mod): @@ -263,9 +262,9 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): ][seq_idx] if not os.path.exists(json_path_cached_mod): path = path.replace(original_dir, local_unzipped_dir) - json_dict["datasets"][dataset_name][seq]["audio_path"][audio_key][ - seq_idx - ] = path + json_dict["datasets"][dataset_name][seq]["audio_path"][ + audio_key + ][seq_idx] = path assert path.startswith( local_unzipped_dir ), f"Audio file {path} was expected to start with {local_unzipped_dir}" @@ -275,7 +274,9 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): with open(json_path_cached_mod, "w", encoding="utf-8") as f: json.dump(json_dict, f, ensure_ascii=False, indent=4) - print(f"Finished preparation of zip cache data, use json in {json_path_cached_mod}") + print( + f"Finished preparation of zip cache data, use json in {json_path_cached_mod}" + ) return json_path_cached_mod From dcce15a6e463ddedaed2e3843cc0615f0b91fbde Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 18:04:34 +0100 Subject: [PATCH 24/47] update of super() usage --- common/datasets/sms_wsj/returnn_datasets.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 3db68a2ac..2e936806a 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -56,7 +56,7 @@ def __init__( # noinspection PyUnresolvedReferences from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn - super(SmsWsjBase, self).__init__(**kwargs) + super().__init__(**kwargs) self.data_types = data_types @@ -300,7 +300,7 @@ def __init__( :param str hdf_data_key: data key under which the alignment is stored in the hdf, usually "classes" or "data" :param kwargs: """ - super(SmsWsjBaseWithRasrClasses, self).__init__(**kwargs) + super().__init__(**kwargs) self._rasr_classes_hdf = HDFDataset([rasr_classes_hdf], use_cache_manager=True) self._segment_to_rasr = segment_to_rasr @@ -342,7 +342,7 @@ def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: """ Returns sequence length for all data/target keys. """ - d = super(SmsWsjBaseWithRasrClasses, self).get_seq_length_for_keys(seq_idx) + d = super().get_seq_length_for_keys(seq_idx) data = self[seq_idx] d["target_rasr"] = int(data["target_rasr_len"]) return NumbersDict(d) @@ -359,7 +359,7 @@ def __init__(self, sms_wsj_base, **kwargs): """ if "seq_ordering" not in kwargs: print("Warning: no shuffling is enabled by default", file=log.v) - super(SmsWsjWrapper, self).__init__(sms_wsj_base, **kwargs) + super().__init__(sms_wsj_base, **kwargs) # self.num_outputs = ... # needs to be set in derived classes def _get_seq_length(seq_idx: int) -> NumbersDict: @@ -399,7 +399,7 @@ def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, therefore the buffer does not contain the initial indices when continuing the training from an epoch > 0. """ - out = super(SmsWsjWrapper, self).init_seq_order(epoch=epoch, **kwargs) + out = super().init_seq_order(epoch=epoch, **kwargs) buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) self._dataset.update_buffer(buffer_index, pop_seqs=False) return out From e80e210fdbd98cbc80a921157dcd29e1ffa612bd Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 18:17:44 +0100 Subject: [PATCH 25/47] simplify buffer logic --- common/datasets/sms_wsj/returnn_datasets.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 2e936806a..7677c2897 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -149,17 +149,8 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): # add sequences for idx in range(seq_idx, min(seq_idx + self._buffer_size // 2, len(self))): - if idx not in self._buffer and idx < len(self): - try: - self._buffer[idx] = next(self._ds_iterator) - except StopIteration: - print(f"StopIteration for seq_idx {seq_idx}") - print(f"Dataset: {self} with SMS-WSJ {self._ds} of len {len(self)}") - print(f"Indices in buffer: {self._buffer.keys()}") - # raise - print(f"WARNING: Ignoring this, reset iterator and continue") - self._ds_iterator = iter(self._ds) - self._buffer[idx] = next(self._ds_iterator) + if idx not in self._buffer: + self._buffer[idx] = next(self._ds_iterator) if idx == len(self) - 1 and 0 not in self._buffer: print(f"Reached end of dataset, reset iterator") try: @@ -173,8 +164,8 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): ) print(f"Current buffer indices: {self._buffer.keys()}") self._ds_iterator = iter(self._ds) - for idx_ in range(self._buffer_size // 2): - if idx_ not in self._buffer and idx_ < len(self): + for idx_ in range(min(self._buffer_size // 2, len(self))): + if idx_ not in self._buffer: self._buffer[idx_] = next(self._ds_iterator) print( f"After adding start of dataset to buffer indices: {self._buffer.keys()}" From f0a0fd8190708d98dbb0c90a06d6a895c5fbc85e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 18:17:56 +0100 Subject: [PATCH 26/47] print to log --- common/datasets/sms_wsj/returnn_datasets.py | 25 ++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 7677c2897..7a988fd54 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -144,7 +144,8 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): keys = list(self._buffer.keys()) or [0] if not (min(keys) <= seq_idx <= max(keys)): print( - f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}" + f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}", + file=log.v5 ) # add sequences @@ -152,7 +153,7 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): if idx not in self._buffer: self._buffer[idx] = next(self._ds_iterator) if idx == len(self) - 1 and 0 not in self._buffer: - print(f"Reached end of dataset, reset iterator") + print(f"Reached end of dataset, reset iterator", file=log.v4) try: next(self._ds_iterator) except StopIteration: @@ -160,15 +161,17 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): else: print( "WARNING: reached final index of dataset, but iterator has more sequences. " - "Maybe the training was restarted from an epoch > 1?" + "Maybe the training was restarted from an epoch > 1?", + file=log.v3 ) - print(f"Current buffer indices: {self._buffer.keys()}") + print(f"Current buffer indices: {self._buffer.keys()}", file=log.v5) self._ds_iterator = iter(self._ds) for idx_ in range(min(self._buffer_size // 2, len(self))): if idx_ not in self._buffer: self._buffer[idx_] = next(self._ds_iterator) print( - f"After adding start of dataset to buffer indices: {self._buffer.keys()}" + f"After adding start of dataset to buffer indices: {self._buffer.keys()}", + file=log.v5 ) # remove sequences @@ -193,7 +196,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): Caches and unzips a given archive with SMS-WSJ data which will then be used as data dir. This is done because caching of the single files takes extremely long. """ - print(f"Cache and unzip SMS-WSJ data from {zip_cache}") + print(f"Cache and unzip SMS-WSJ data from {zip_cache}", file=log.v4) # cache file try: @@ -208,7 +211,8 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): ), "cached and original file have the same path" except sp.CalledProcessError: print( - f"Cache manager: Error occurred when caching and unzipping {zip_cache}" + f"Cache manager: Error occurred when caching and unzipping {zip_cache}", + file=log.v2 ) raise @@ -218,7 +222,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): ["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir] ) else: - print(f"Unzipped audio already exists in {local_unzipped_dir}") + print(f"Unzipped audio already exists in {local_unzipped_dir}", file=log.v4) json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") original_dir = None @@ -266,7 +270,8 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): json.dump(json_dict, f, ensure_ascii=False, indent=4) print( - f"Finished preparation of zip cache data, use json in {json_path_cached_mod}" + f"Finished preparation of zip cache data, use json in {json_path_cached_mod}", + file=log.v4 ) return json_path_cached_mod @@ -349,7 +354,7 @@ def __init__(self, sms_wsj_base, **kwargs): :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this """ if "seq_ordering" not in kwargs: - print("Warning: no shuffling is enabled by default", file=log.v) + print("Warning: no shuffling is enabled by default", file=log.v2) super().__init__(sms_wsj_base, **kwargs) # self.num_outputs = ... # needs to be set in derived classes From f2aac11ae4029430f275b6012890a91da0077c79 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 18:29:24 +0100 Subject: [PATCH 27/47] avoid name RASR --- common/datasets/sms_wsj/returnn_datasets.py | 76 ++++++++++----------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 7a988fd54..c95a50a4e 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -276,62 +276,62 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): return json_path_cached_mod -class SmsWsjBaseWithRasrClasses(SmsWsjBase): +class SmsWsjBaseWithHdfClasses(SmsWsjBase): """ - Base class to wrap the SMS-WSJ dataset and combine it with RASR alignments in an hdf dataset. + Base class to wrap the SMS-WSJ dataset and combine it with alignments from an HDF dataset. """ def __init__( self, - rasr_classes_hdf, - segment_to_rasr, + hdf_file, + segment_mapping_fn, pad_label=None, hdf_data_key="classes", **kwargs, ): """ - :param str rasr_classes_hdf: hdf file with dumped RASR class labels - :param Callable segment_to_rasr: function that maps SMS-WSJ seg. name into list of corresponding RASR seg. names + :param str hdf_file: hdf file with dumped class labels (compatible with RETURNN HDFDataset) + :param Callable segment_mapping_fn: function that maps SMS-WSJ seg. name into list of corresp. seg. names in HDF :param Optional[int] pad_label: target label assigned to padded areas :param str hdf_data_key: data key under which the alignment is stored in the hdf, usually "classes" or "data" :param kwargs: """ super().__init__(**kwargs) - self._rasr_classes_hdf = HDFDataset([rasr_classes_hdf], use_cache_manager=True) - self._segment_to_rasr = segment_to_rasr + self._hdf_dataset = HDFDataset([hdf_file], use_cache_manager=True) + self._segment_mapping_fn = segment_mapping_fn self._pad_label = pad_label self._hdf_data_key = hdf_data_key def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: d = self._get_seq_by_idx(seq_idx) - rasr_seq_tags = self._segment_to_rasr(str(d["seq_tag"])) + hdf_seq_tags = self._segment_mapping_fn(str(d["seq_tag"])) assert ( - len(rasr_seq_tags) == d["target_signals"].shape[1] - ), f"got {len(rasr_seq_tags)} segment names, but there are {d['target_signals'].shape[1]} target signals" - rasr_targets = [ - self._rasr_classes_hdf.get_data_by_seq_tag(rasr_seq_tag, self._hdf_data_key) - for rasr_seq_tag in rasr_seq_tags + len(hdf_seq_tags) == d["target_signals"].shape[1] + ), f"got {len(hdf_seq_tags)} segment names, but there are {d['target_signals'].shape[1]} target signals" + hdf_classes = [ + self._hdf_dataset.get_data_by_seq_tag(hdf_seq_tag, self._hdf_data_key) + for hdf_seq_tag in hdf_seq_tags ] - padded_len = max(rasr_target.shape[0] for rasr_target in rasr_targets) - for speaker_idx in range(len(rasr_targets)): + padded_len = max(hdf_classes_.shape[0] for hdf_classes_ in hdf_classes) + for speaker_idx in range(len(hdf_classes)): pad_start = int(round(d["offset"][speaker_idx] / d["seq_len"] * padded_len)) - pad_end = padded_len - rasr_targets[speaker_idx].shape[0] - pad_start + pad_end = padded_len - hdf_classes[speaker_idx].shape[0] - pad_start if pad_end < 0: pad_start += pad_end assert pad_start >= 0 pad_end = 0 if pad_start or pad_end: assert self._pad_label is not None, "Label for padding is needed" - rasr_targets[speaker_idx] = np.concatenate( + hdf_classes[speaker_idx] = np.concatenate( [ self._pad_label * np.ones(pad_start), - rasr_targets[speaker_idx], + hdf_classes[speaker_idx], self._pad_label * np.ones(pad_end), ] ) - d["target_rasr"] = np.stack(rasr_targets).T - d["target_rasr_len"] = np.array(padded_len) + d["target_classes"] = np.stack(hdf_classes).T + d["target_classes_len"] = np.array(padded_len) return d def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: @@ -340,7 +340,7 @@ def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: """ d = super().get_seq_length_for_keys(seq_idx) data = self[seq_idx] - d["target_rasr"] = int(data["target_rasr_len"]) + d["target_classes"] = int(data["target_classes_len"]) return NumbersDict(d) @@ -457,7 +457,7 @@ def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: class SmsWsjMixtureEarlyAlignmentDataset(SmsWsjMixtureEarlyDataset): """ - Dataset with audio mixture, target early signals and target RASR alignments. + Dataset with audio mixture, target early signals and target alignments. """ def __init__( @@ -465,10 +465,10 @@ def __init__( dataset_name, json_path, num_outputs=None, - rasr_num_outputs=None, + classes_num_outputs=None, zip_cache=None, - rasr_classes_hdf=None, - segment_to_rasr=None, + hdf_file=None, + segment_mapping_fn=None, pad_label=None, hdf_data_key="classes", **kwargs, @@ -477,30 +477,30 @@ def __init__( :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" :param str json_path: path to SMS-WSJ json file :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset - :param Optional[int] rasr_num_outputs: number of output labels for RASR alignment, e.g. 9001 for that CART size + :param Optional[int] classes_num_outputs: number of output labels for alignment, e.g. 9001 for that CART size :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir - :param str rasr_classes_hdf: hdf file with dumped RASR class labels - :param Callable segment_to_rasr: function that maps SMS-WSJ seg. name into list of corresponding RASR seg. names + :param str hdf_file: hdf file with dumped class labels (compatible with RETURNN HDFDataset) + :param Callable segment_mapping_fn: function that maps SMS-WSJ seg. name into list of corresp. seg. names in HDF :param Optional[int] pad_label: target label assigned to padded areas :param str hdf_data_key: data key under which the alignment is stored in the hdf, usually "classes" or "data" """ data_types = { "target_signals": {"dim": 2, "shape": (None, 2)}, - "target_rasr": { + "target_classes": { "sparse": True, - "dim": rasr_num_outputs, + "dim": classes_num_outputs, "shape": (None, 2), }, } - sms_wsj_base = SmsWsjBaseWithRasrClasses( + sms_wsj_base = SmsWsjBaseWithHdfClasses( dataset_name=dataset_name, json_path=json_path, pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, data_types=data_types, zip_cache=zip_cache, - rasr_classes_hdf=rasr_classes_hdf, - segment_to_rasr=segment_to_rasr, + hdf_file=hdf_file, + segment_mapping_fn=segment_mapping_fn, pad_label=pad_label, hdf_data_key=hdf_data_key, ) @@ -516,10 +516,10 @@ def __init__( self.num_outputs = num_outputs else: assert ( - rasr_num_outputs is not None - ), "either num_outputs or rasr_num_outputs has to be given" - self.num_outputs["target_rasr"] = [ - rasr_num_outputs, + classes_num_outputs is not None + ), "either num_outputs or classes_num_outputs has to be given" + self.num_outputs["target_classes"] = [ + classes_num_outputs, 1, ] # target alignments are sparse with the given dim From d55bd5a3bf20beb93a1fb4ec103bae1ecab2cfce Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 9 Feb 2023 18:33:11 +0100 Subject: [PATCH 28/47] simons review comments --- common/datasets/sms_wsj/returnn_datasets.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index c95a50a4e..0209c06a4 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -314,13 +314,13 @@ def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: for hdf_seq_tag in hdf_seq_tags ] padded_len = max(hdf_classes_.shape[0] for hdf_classes_ in hdf_classes) - for speaker_idx in range(len(hdf_classes)): - pad_start = int(round(d["offset"][speaker_idx] / d["seq_len"] * padded_len)) - pad_end = padded_len - hdf_classes[speaker_idx].shape[0] - pad_start - if pad_end < 0: - pad_start += pad_end - assert pad_start >= 0 - pad_end = 0 + for speaker_idx, rasr_target in enumerate(hdf_classes): + total_pad_frames = padded_len - rasr_target.shape[0] + if total_pad_frames == 0: + continue + pad_start = round(d["offset"][speaker_idx] / d["seq_len"] * padded_len) + pad_start = min(pad_start, total_pad_frames) + pad_end = total_pad_frames - pad_start if pad_start or pad_end: assert self._pad_label is not None, "Label for padding is needed" hdf_classes[speaker_idx] = np.concatenate( From 2f0c4840f62a1dcfd2033b659c23172de0988700 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 13:37:15 +0100 Subject: [PATCH 29/47] update caching logic --- common/datasets/sms_wsj/returnn_datasets.py | 22 ++++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 0209c06a4..5f64e1bb0 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -204,7 +204,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): assert ( zip_cache_cached != zip_cache ), "cached and original file have the same path" - local_unzipped_dir = os.path.dirname(zip_cache_cached) + local_base_dir = os.path.dirname(zip_cache_cached) json_path_cached = sp.check_output(["cf", json_path]).strip().decode("utf8") assert ( json_path_cached != json_path @@ -216,13 +216,11 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): ) raise - # unzip if folder does not yet exist - if not os.path.exists(local_unzipped_dir): - sp.check_call( - ["unzip", "-q", "-n", zip_cache_cached, "-d", local_unzipped_dir] - ) - else: - print(f"Unzipped audio already exists in {local_unzipped_dir}", file=log.v4) + # unzip + sp.check_call( + ["unzip", "-q", "-n", zip_cache_cached, "-d", local_base_dir] + ) + sp.check_call(["chmod", "-R", "o+w", local_base_dir]) json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") original_dir = None @@ -234,7 +232,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): "audio_path" ]["original_source"][0] while ( - not original_dir.endswith(os.path.basename(local_unzipped_dir)) + not original_dir.endswith(os.path.basename(local_base_dir)) and len(original_dir) > 1 ): original_dir = os.path.dirname(original_dir) @@ -256,13 +254,13 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): audio_key ][seq_idx] if not os.path.exists(json_path_cached_mod): - path = path.replace(original_dir, local_unzipped_dir) + path = path.replace(original_dir, local_base_dir) json_dict["datasets"][dataset_name][seq]["audio_path"][ audio_key ][seq_idx] = path assert path.startswith( - local_unzipped_dir - ), f"Audio file {path} was expected to start with {local_unzipped_dir}" + local_base_dir + ), f"Audio file {path} was expected to start with {local_base_dir}" assert os.path.exists(path), f"Audio file {path} does not exist" if not os.path.exists(json_path_cached_mod): From 4074af4e719493774ff46bbfb107d1f606c18ebc Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 13:37:43 +0100 Subject: [PATCH 30/47] returnn log --- common/datasets/sms_wsj/returnn_datasets.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 5f64e1bb0..0caa7a857 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -19,7 +19,7 @@ from returnn.datasets.map import MapDatasetBase, MapDatasetWrapper # noinspection PyUnresolvedReferences -from returnn.log import log +from returnn.log import log as returnn_log # noinspection PyUnresolvedReferences from returnn.util.basic import OptionalNotImplementedError, NumbersDict @@ -145,7 +145,7 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): if not (min(keys) <= seq_idx <= max(keys)): print( f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}", - file=log.v5 + file=returnn_log.v5 ) # add sequences @@ -153,7 +153,7 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): if idx not in self._buffer: self._buffer[idx] = next(self._ds_iterator) if idx == len(self) - 1 and 0 not in self._buffer: - print(f"Reached end of dataset, reset iterator", file=log.v4) + print(f"Reached end of dataset, reset iterator", file=returnn_log.v4) try: next(self._ds_iterator) except StopIteration: @@ -162,16 +162,16 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): print( "WARNING: reached final index of dataset, but iterator has more sequences. " "Maybe the training was restarted from an epoch > 1?", - file=log.v3 + file=returnn_log.v3 ) - print(f"Current buffer indices: {self._buffer.keys()}", file=log.v5) + print(f"Current buffer indices: {self._buffer.keys()}", file=returnn_log.v5) self._ds_iterator = iter(self._ds) for idx_ in range(min(self._buffer_size // 2, len(self))): if idx_ not in self._buffer: self._buffer[idx_] = next(self._ds_iterator) print( f"After adding start of dataset to buffer indices: {self._buffer.keys()}", - file=log.v5 + file=returnn_log.v5 ) # remove sequences @@ -196,7 +196,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): Caches and unzips a given archive with SMS-WSJ data which will then be used as data dir. This is done because caching of the single files takes extremely long. """ - print(f"Cache and unzip SMS-WSJ data from {zip_cache}", file=log.v4) + print(f"Cache and unzip SMS-WSJ data from {zip_cache}", file=returnn_log.v4) # cache file try: @@ -212,7 +212,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): except sp.CalledProcessError: print( f"Cache manager: Error occurred when caching and unzipping {zip_cache}", - file=log.v2 + file=returnn_log.v2 ) raise @@ -269,7 +269,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): print( f"Finished preparation of zip cache data, use json in {json_path_cached_mod}", - file=log.v4 + file=returnn_log.v4 ) return json_path_cached_mod @@ -352,7 +352,7 @@ def __init__(self, sms_wsj_base, **kwargs): :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this """ if "seq_ordering" not in kwargs: - print("Warning: no shuffling is enabled by default", file=log.v2) + print("Warning: no shuffling is enabled by default", file=returnn_log.v2) super().__init__(sms_wsj_base, **kwargs) # self.num_outputs = ... # needs to be set in derived classes From c188b82772f92e52ab5a1c6afdaa076eb713880a Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 13:38:02 +0100 Subject: [PATCH 31/47] pad start int --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 0caa7a857..eb79f59f4 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -316,7 +316,7 @@ def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: total_pad_frames = padded_len - rasr_target.shape[0] if total_pad_frames == 0: continue - pad_start = round(d["offset"][speaker_idx] / d["seq_len"] * padded_len) + pad_start = int(round(d["offset"][speaker_idx] / d["seq_len"] * padded_len)) pad_start = min(pad_start, total_pad_frames) pad_end = total_pad_frames - pad_start if pad_start or pad_end: From 439da8ffcd8e4ead18ae8c5a7a09d9dd81407242 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 17:08:44 +0100 Subject: [PATCH 32/47] log unzipping command --- common/datasets/sms_wsj/returnn_datasets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index eb79f59f4..2bcb30455 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -217,9 +217,10 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): raise # unzip - sp.check_call( - ["unzip", "-q", "-n", zip_cache_cached, "-d", local_base_dir] - ) + unzip_cmd = ["unzip", "-q", "-n", zip_cache_cached, "-d", local_base_dir] + print(" ".join(unzip_cmd), file=returnn_log.v4) + sp.check_call(unzip_cmd) + print("Finished unzipping", file=returnn_log.v4) sp.check_call(["chmod", "-R", "o+w", local_base_dir]) json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") From f58844a02e1840a29cecf435878726ae2c8edf55 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 17:55:56 +0100 Subject: [PATCH 33/47] rename rasr variable --- common/datasets/sms_wsj/returnn_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 2bcb30455..495fa935f 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -313,8 +313,8 @@ def __getitem__(self, seq_idx: int) -> Dict[str, np.array]: for hdf_seq_tag in hdf_seq_tags ] padded_len = max(hdf_classes_.shape[0] for hdf_classes_ in hdf_classes) - for speaker_idx, rasr_target in enumerate(hdf_classes): - total_pad_frames = padded_len - rasr_target.shape[0] + for speaker_idx, hdf_classes_speaker in enumerate(hdf_classes): + total_pad_frames = padded_len - hdf_classes_speaker.shape[0] if total_pad_frames == 0: continue pad_start = int(round(d["offset"][speaker_idx] / d["seq_len"] * padded_len)) From caa18a3ac2b78ddf18bc4d8e50af5b042915fe55 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 17:56:09 +0100 Subject: [PATCH 34/47] use sequence buffer class --- common/datasets/sms_wsj/returnn_datasets.py | 51 +++++++++++---------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 495fa935f..433402957 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -25,6 +25,24 @@ from returnn.util.basic import OptionalNotImplementedError, NumbersDict +class SequenceBuffer(dict): + """ + Helper class to represent a buffer of sequences + """ + def __init__(self, max_size: int): + super().__init__() + self._max_size = max_size + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if len(self) > self._max_size: + self.pop(next(iter(self))) + + @property + def max_size(self): + return self._max_size + + class SmsWsjBase(MapDatasetBase): """ Base class to wrap the SMS-WSJ dataset. This is not the dataset that is used in the RETURNN config, see @@ -37,11 +55,11 @@ def __init__( json_path, pre_batch_transform, data_types, - buffer=True, zip_cache=None, scenario_map_args=None, + buffer=True, + buffer_size=40, prefetch_num_workers=4, - prefetch_buffer_size=40, **kwargs, ): """ @@ -49,9 +67,11 @@ def __init__( :param str json_path: path to SMS-WSJ json file :param function pre_batch_transform: function which processes raw SMS-WSJ data :param Dict[str] data_types: data types for RETURNN, e.g. {"target_signals": {"dim": 2, "shape": (None, 2)}} - :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn + :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer + :param int buffer_size: buffer size + :param int prefetch_num_workers: number of workers for prefetching """ # noinspection PyUnresolvedReferences from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn @@ -83,10 +103,9 @@ def __init__( self._use_buffer = buffer if self._use_buffer: self._ds = self._ds.prefetch( - prefetch_num_workers, prefetch_buffer_size + prefetch_num_workers, buffer_size ).copy(freeze=True) - self._buffer = {} # type Dict[int,[Dict[str,np.array]]] - self._buffer_size = prefetch_buffer_size + self._buffer = SequenceBuffer(buffer_size) def __len__(self) -> int: return len(self._ds) @@ -149,7 +168,7 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): ) # add sequences - for idx in range(seq_idx, min(seq_idx + self._buffer_size // 2, len(self))): + for idx in range(seq_idx, min(seq_idx + self._buffer.max_size // 2, len(self))): if idx not in self._buffer: self._buffer[idx] = next(self._ds_iterator) if idx == len(self) - 1 and 0 not in self._buffer: @@ -166,7 +185,7 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): ) print(f"Current buffer indices: {self._buffer.keys()}", file=returnn_log.v5) self._ds_iterator = iter(self._ds) - for idx_ in range(min(self._buffer_size // 2, len(self))): + for idx_ in range(min(self._buffer.max_size // 2, len(self))): if idx_ not in self._buffer: self._buffer[idx_] = next(self._ds_iterator) print( @@ -174,22 +193,6 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): file=returnn_log.v5 ) - # remove sequences - if pop_seqs: - for idx in list(self._buffer): - if not ( - seq_idx - self._buffer_size // 2 - <= idx - <= seq_idx + self._buffer_size // 2 - ): - if ( - max(self._buffer.keys()) == len(self) - 1 - and idx < self._buffer_size // 2 - ): - # newly added sequences starting from 0 - continue - self._buffer.pop(idx) - @staticmethod def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): """ From 55ff94307e5659d4a1d93f1fb37b28405c04a64e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 10 Feb 2023 18:00:12 +0100 Subject: [PATCH 35/47] black --- common/datasets/sms_wsj/returnn_datasets.py | 24 ++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 433402957..f5b1760cc 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -29,6 +29,7 @@ class SequenceBuffer(dict): """ Helper class to represent a buffer of sequences """ + def __init__(self, max_size: int): super().__init__() self._max_size = max_size @@ -102,9 +103,9 @@ def __init__( self._use_buffer = buffer if self._use_buffer: - self._ds = self._ds.prefetch( - prefetch_num_workers, buffer_size - ).copy(freeze=True) + self._ds = self._ds.prefetch(prefetch_num_workers, buffer_size).copy( + freeze=True + ) self._buffer = SequenceBuffer(buffer_size) def __len__(self) -> int: @@ -164,7 +165,7 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): if not (min(keys) <= seq_idx <= max(keys)): print( f"WARNING: seq_idx {seq_idx} outside range of keys: {self._buffer.keys()}", - file=returnn_log.v5 + file=returnn_log.v5, ) # add sequences @@ -181,16 +182,19 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): print( "WARNING: reached final index of dataset, but iterator has more sequences. " "Maybe the training was restarted from an epoch > 1?", - file=returnn_log.v3 + file=returnn_log.v3, ) - print(f"Current buffer indices: {self._buffer.keys()}", file=returnn_log.v5) + print( + f"Current buffer indices: {self._buffer.keys()}", + file=returnn_log.v5, + ) self._ds_iterator = iter(self._ds) for idx_ in range(min(self._buffer.max_size // 2, len(self))): if idx_ not in self._buffer: self._buffer[idx_] = next(self._ds_iterator) print( f"After adding start of dataset to buffer indices: {self._buffer.keys()}", - file=returnn_log.v5 + file=returnn_log.v5, ) @staticmethod @@ -215,7 +219,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): except sp.CalledProcessError: print( f"Cache manager: Error occurred when caching and unzipping {zip_cache}", - file=returnn_log.v2 + file=returnn_log.v2, ) raise @@ -273,7 +277,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): print( f"Finished preparation of zip cache data, use json in {json_path_cached_mod}", - file=returnn_log.v4 + file=returnn_log.v4, ) return json_path_cached_mod @@ -518,7 +522,7 @@ def __init__( self.num_outputs = num_outputs else: assert ( - classes_num_outputs is not None + classes_num_outputs is not None ), "either num_outputs or classes_num_outputs has to be given" self.num_outputs["target_classes"] = [ classes_num_outputs, From 8ecf218aa40065883a202df9cc14d7f59cd1eccc Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 13 Feb 2023 10:52:08 +0100 Subject: [PATCH 36/47] chmod -f --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index f5b1760cc..839bc86b9 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -228,7 +228,7 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): print(" ".join(unzip_cmd), file=returnn_log.v4) sp.check_call(unzip_cmd) print("Finished unzipping", file=returnn_log.v4) - sp.check_call(["chmod", "-R", "o+w", local_base_dir]) + sp.check_call(["chmod", "-R", "-f", "o+w", local_base_dir]) json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") original_dir = None From cab1c3c00e046eddfd5cfecb98b31f7564a3900e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 13 Feb 2023 12:01:40 +0100 Subject: [PATCH 37/47] chmod force exit 0 --- common/datasets/sms_wsj/returnn_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 839bc86b9..758e0f858 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -226,9 +226,10 @@ def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): # unzip unzip_cmd = ["unzip", "-q", "-n", zip_cache_cached, "-d", local_base_dir] print(" ".join(unzip_cmd), file=returnn_log.v4) - sp.check_call(unzip_cmd) + sp.check_output(unzip_cmd) print("Finished unzipping", file=returnn_log.v4) - sp.check_call(["chmod", "-R", "-f", "o+w", local_base_dir]) + # force exit code 0 for the case that the path does not belong to the user so permissions cannot be changed + sp.check_output(["chmod", "-R", "-f", "o+w", local_base_dir, "||", "true"]) json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") original_dir = None From 15ebbdc0921ae5672a6ace0bed9fb8989500a49a Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 13 Feb 2023 16:19:32 +0100 Subject: [PATCH 38/47] read data from zip instead of unzipping --- common/datasets/sms_wsj/returnn_datasets.py | 151 ++++++++------------ 1 file changed, 62 insertions(+), 89 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 758e0f858..b72368c8c 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -24,6 +24,9 @@ # noinspection PyUnresolvedReferences from returnn.util.basic import OptionalNotImplementedError, NumbersDict +# noinspection PyUnresolvedReferences +from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn + class SequenceBuffer(dict): """ @@ -44,6 +47,49 @@ def max_size(self): return self._max_size +class ZipAudioReader(AudioReader): + """ + Reads the audio data of an example from a zip file. + """ + def __init__( + self, + zip_path=None, + zip_prefix="", + **kwargs + ): + """ + :param Optional[str] zip_path: zip archive with SMS-WSJ data + :param str zip_prefix: prefix of filename that needs to be removed for the lookup in the zip archive + """ + import zipfile + + super().__init__(**kwargs) + self._zip = zipfile.ZipFile(zip_path, "r") if zip_path is not None else None + self._zip_prefix = zip_prefix + + def _rec_audio_read(self, file): + """ + Read audio from file + + :param Union[tuple, list, dict, str] file: filename + """ + import io + import soundfile + + if isinstance(file, (tuple, list)): + return np.array([self._rec_audio_read(f) for f in file]) + elif isinstance(file, dict): + return {k: self._rec_audio_read(v) for k, v in file.items()} + else: + if self._zip is not None: + assert file.startswith(self._zip_prefix) + file_zip = file[len(self._zip_prefix):] + data, sample_rate = soundfile.read(io.BytesIO(self._zip.read(file_zip))) + else: + data, sample_rate = soundfile.read(file) + return data.T + + class SmsWsjBase(MapDatasetBase): """ Base class to wrap the SMS-WSJ dataset. This is not the dataset that is used in the RETURNN config, see @@ -57,6 +103,7 @@ def __init__( pre_batch_transform, data_types, zip_cache=None, + zip_prefix="", scenario_map_args=None, buffer=True, buffer_size=40, @@ -69,24 +116,35 @@ def __init__( :param function pre_batch_transform: function which processes raw SMS-WSJ data :param Dict[str] data_types: data types for RETURNN, e.g. {"target_signals": {"dim": 2, "shape": (None, 2)}} :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir + :param str zip_prefix: prefix of filename that needs to be removed for the lookup in the zip archive :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer :param int buffer_size: buffer size :param int prefetch_num_workers: number of workers for prefetching """ - # noinspection PyUnresolvedReferences - from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn super().__init__(**kwargs) self.data_types = data_types if zip_cache is not None: - json_path = self._cache_zipped_audio(zip_cache, json_path, dataset_name) + zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") + assert ( + zip_cache_cached != zip_cache + ), "cached and original file have the same path" + json_path_cached = sp.check_output(["cf", json_path]).strip().decode("utf8") + assert ( + json_path_cached != json_path + ), "cached and original file have the same path" + json_path = json_path_cached + audio_reader = ZipAudioReader( + zip_path=zip_cache_cached, zip_prefix=zip_prefix, keys=("original_source", "rir")) + else: + audio_reader = AudioReader(keys=("original_source", "rir")) db = SmsWsj(json_path=json_path) ds = db.get_dataset(dataset_name) - ds = ds.map(AudioReader(("original_source", "rir"))) + ds = ds.map(audio_reader) scenario_map_args = { "add_speech_image": False, @@ -197,91 +255,6 @@ def update_buffer(self, seq_idx: int, pop_seqs: bool = True): file=returnn_log.v5, ) - @staticmethod - def _cache_zipped_audio(zip_cache: str, json_path: str, dataset_name: str): - """ - Caches and unzips a given archive with SMS-WSJ data which will then be used as data dir. - This is done because caching of the single files takes extremely long. - """ - print(f"Cache and unzip SMS-WSJ data from {zip_cache}", file=returnn_log.v4) - - # cache file - try: - zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") - assert ( - zip_cache_cached != zip_cache - ), "cached and original file have the same path" - local_base_dir = os.path.dirname(zip_cache_cached) - json_path_cached = sp.check_output(["cf", json_path]).strip().decode("utf8") - assert ( - json_path_cached != json_path - ), "cached and original file have the same path" - except sp.CalledProcessError: - print( - f"Cache manager: Error occurred when caching and unzipping {zip_cache}", - file=returnn_log.v2, - ) - raise - - # unzip - unzip_cmd = ["unzip", "-q", "-n", zip_cache_cached, "-d", local_base_dir] - print(" ".join(unzip_cmd), file=returnn_log.v4) - sp.check_output(unzip_cmd) - print("Finished unzipping", file=returnn_log.v4) - # force exit code 0 for the case that the path does not belong to the user so permissions cannot be changed - sp.check_output(["chmod", "-R", "-f", "o+w", local_base_dir, "||", "true"]) - - json_path_cached_mod = json_path_cached.replace(".json", ".mod.json") - original_dir = None - if not os.path.exists(json_path_cached_mod): - with open(json_path_cached, "r") as f: - json_dict = json.loads(f.read()) - # get original dir - original_dir = next(iter(json_dict["datasets"][dataset_name].values()))[ - "audio_path" - ]["original_source"][0] - while ( - not original_dir.endswith(os.path.basename(local_base_dir)) - and len(original_dir) > 1 - ): - original_dir = os.path.dirname(original_dir) - else: - with open(json_path_cached_mod, "r") as f: - json_dict = json.loads(f.read()) - # check if all data is available and create modified json if it does not yet exist - for dataset_name in json_dict["datasets"]: - for seq in json_dict["datasets"][dataset_name]: - for audio_key in ["original_source", "rir"]: - for seq_idx in range( - len( - json_dict["datasets"][dataset_name][seq]["audio_path"][ - audio_key - ] - ) - ): - path = json_dict["datasets"][dataset_name][seq]["audio_path"][ - audio_key - ][seq_idx] - if not os.path.exists(json_path_cached_mod): - path = path.replace(original_dir, local_base_dir) - json_dict["datasets"][dataset_name][seq]["audio_path"][ - audio_key - ][seq_idx] = path - assert path.startswith( - local_base_dir - ), f"Audio file {path} was expected to start with {local_base_dir}" - assert os.path.exists(path), f"Audio file {path} does not exist" - - if not os.path.exists(json_path_cached_mod): - with open(json_path_cached_mod, "w", encoding="utf-8") as f: - json.dump(json_dict, f, ensure_ascii=False, indent=4) - - print( - f"Finished preparation of zip cache data, use json in {json_path_cached_mod}", - file=returnn_log.v4, - ) - return json_path_cached_mod - class SmsWsjBaseWithHdfClasses(SmsWsjBase): """ From b0510452b6938d6f042efb70974ac21836d5fe49 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 13 Feb 2023 16:19:53 +0100 Subject: [PATCH 39/47] cleanup --- common/datasets/sms_wsj/returnn_datasets.py | 141 ++++++++------------ 1 file changed, 57 insertions(+), 84 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index b72368c8c..b27a0a7a6 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -3,11 +3,9 @@ """ import functools -import json import numpy as np -import os.path import subprocess as sp -from typing import Dict, Tuple, Any, Optional +from typing import Dict, Union, Any, Optional # noinspection PyUnresolvedReferences from returnn.datasets.basic import DatasetSeq @@ -34,6 +32,9 @@ class SequenceBuffer(dict): """ def __init__(self, max_size: int): + """ + :param int max_size: maximum number of sequences in buffer + """ super().__init__() self._max_size = max_size @@ -210,10 +211,9 @@ def get_seq_length_for_keys(self, seq_idx: int) -> NumbersDict: d[update_key] = int(data["seq_len"]) return NumbersDict(d) - def update_buffer(self, seq_idx: int, pop_seqs: bool = True): + def update_buffer(self, seq_idx: int): """ :param int seq_idx: - :param bool pop_seqs: if True, pop sequences from buffer that are outside buffer range """ if not self._use_buffer: return @@ -331,7 +331,7 @@ class SmsWsjWrapper(MapDatasetWrapper): def __init__(self, sms_wsj_base, **kwargs): """ - :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + :param SmsWsjBase sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this """ if "seq_ordering" not in kwargs: print("Warning: no shuffling is enabled by default", file=returnn_log.v2) @@ -351,7 +351,7 @@ def _get_seq_length(seq_idx: int) -> NumbersDict: def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: """ Used to process raw SMS-WSJ data - :param inputs: input as coming from SMS-WSJ + :param Dict[str, Any] inputs: input as coming from SMS-WSJ """ return_dict = { "seq_tag": np.array(inputs["example_id"], dtype=object), @@ -377,7 +377,7 @@ def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: """ out = super().init_seq_order(epoch=epoch, **kwargs) buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) - self._dataset.update_buffer(buffer_index, pop_seqs=False) + self._dataset.update_buffer(buffer_index) return out @@ -388,28 +388,23 @@ class SmsWsjMixtureEarlyDataset(SmsWsjWrapper): def __init__( self, - dataset_name, - json_path, - num_outputs=None, - zip_cache=None, sms_wsj_base=None, + sms_wsj_kwargs=None, + num_outputs=None, **kwargs, ): """ - :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" - :param str json_path: path to SMS-WSJ json file - :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + :param Optional[Dict[str, Any]] sms_wsj_args: kwargs to create SMS-WSJ base class if sms_wsj_base is not given + :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset """ if sms_wsj_base is None: + assert sms_wsj_kwargs is not None, "either sms_wsj_base or sms_wsj_kwargs need to be given" sms_wsj_base = SmsWsjBase( - dataset_name=dataset_name, - json_path=json_path, pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, - zip_cache=zip_cache, data_types={"target_signals": {"dim": 2, "shape": (None, 2)}}, + **sms_wsj_kwargs, ) super().__init__(sms_wsj_base, **kwargs) # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense @@ -419,7 +414,7 @@ def __init__( def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: """ Used to process raw SMS-WSJ data - :param inputs: input as coming from SMS-WSJ + :param Dict[str, Any] inputs: input as coming from SMS-WSJ """ return_dict = SmsWsjWrapper._pre_batch_transform(inputs) return_dict.update( @@ -442,53 +437,36 @@ class SmsWsjMixtureEarlyAlignmentDataset(SmsWsjMixtureEarlyDataset): def __init__( self, - dataset_name, - json_path, + sms_wsj_base=None, + sms_wsj_kwargs=None, num_outputs=None, classes_num_outputs=None, - zip_cache=None, - hdf_file=None, - segment_mapping_fn=None, - pad_label=None, - hdf_data_key="classes", **kwargs, ): """ - :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" - :param str json_path: path to SMS-WSJ json file + :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + :param Optional[Dict[str, Any]] sms_wsj_args: kwargs to create SMS-WSJ base class if sms_wsj_base is not given :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset :param Optional[int] classes_num_outputs: number of output labels for alignment, e.g. 9001 for that CART size - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir - :param str hdf_file: hdf file with dumped class labels (compatible with RETURNN HDFDataset) - :param Callable segment_mapping_fn: function that maps SMS-WSJ seg. name into list of corresp. seg. names in HDF - :param Optional[int] pad_label: target label assigned to padded areas - :param str hdf_data_key: data key under which the alignment is stored in the hdf, usually "classes" or "data" """ - data_types = { - "target_signals": {"dim": 2, "shape": (None, 2)}, - "target_classes": { - "sparse": True, - "dim": classes_num_outputs, - "shape": (None, 2), - }, - } - sms_wsj_base = SmsWsjBaseWithHdfClasses( - dataset_name=dataset_name, - json_path=json_path, - pre_batch_transform=self._pre_batch_transform, - scenario_map_args={"add_speech_reverberation_early": True}, - data_types=data_types, - zip_cache=zip_cache, - hdf_file=hdf_file, - segment_mapping_fn=segment_mapping_fn, - pad_label=pad_label, - hdf_data_key=hdf_data_key, - ) + if sms_wsj_base is None: + data_types = { + "target_signals": {"dim": 2, "shape": (None, 2)}, + "target_classes": { + "sparse": True, + "dim": classes_num_outputs, + "shape": (None, 2), + }, + } + assert sms_wsj_kwargs is not None, "either sms_wsj_base or sms_wsj_kwargs need to be given" + sms_wsj_base = SmsWsjBaseWithHdfClasses( + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + data_types=data_types, + **sms_wsj_kwargs, + ) super().__init__( - dataset_name, - json_path, num_outputs=num_outputs, - zip_cache=zip_cache, sms_wsj_base=sms_wsj_base, **kwargs, ) @@ -507,7 +485,7 @@ def __init__( def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: """ Used to process raw SMS-WSJ data - :param inputs: input as coming from SMS-WSJ + :param Dict[str, Any] inputs: input as coming from SMS-WSJ """ return_dict = SmsWsjMixtureEarlyDataset._pre_batch_transform(inputs) # we need the padding information here @@ -522,46 +500,41 @@ class SmsWsjMixtureEarlyBpeDataset(SmsWsjMixtureEarlyDataset): def __init__( self, - dataset_name, - json_path, bpe, + sms_wsj_base=None, + sms_wsj_kwargs=None, text_proc=None, num_outputs=None, - zip_cache=None, **kwargs, ): """ - :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" - :param str json_path: path to SMS-WSJ json file :param Dict[str] bpe: opts for :class:`BytePairEncoding` + :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this + :param Optional[Dict[str, Any]] sms_wsj_args: kwargs to create SMS-WSJ base class if sms_wsj_base is not given :param Optional[Callable] text_proc: function to preprocess the transcriptions before applying BPE :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset - :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir """ from returnn.datasets.util.vocabulary import BytePairEncoding self.bpe = BytePairEncoding(**bpe) - data_types = { - "target_signals": {"dim": 2, "shape": (None, 2)}, - "target_bpe": { - "sparse": True, - "dim": self.bpe.num_labels, - "shape": (None, 2), - }, - } - sms_wsj_base = SmsWsjBase( - dataset_name=dataset_name, - json_path=json_path, - pre_batch_transform=self._pre_batch_transform, - scenario_map_args={"add_speech_reverberation_early": True}, - zip_cache=zip_cache, - data_types=data_types, - ) + if sms_wsj_base is None: + data_types = { + "target_signals": {"dim": 2, "shape": (None, 2)}, + "target_bpe": { + "sparse": True, + "dim": self.bpe.num_labels, + "shape": (None, 2), + }, + } + assert sms_wsj_kwargs is not None, "either sms_wsj_base or sms_wsj_kwargs need to be given" + sms_wsj_base = SmsWsjBase( + pre_batch_transform=self._pre_batch_transform, + scenario_map_args={"add_speech_reverberation_early": True}, + data_types=data_types, + **sms_wsj_kwargs, + ) super().__init__( - dataset_name, - json_path, num_outputs=num_outputs, - zip_cache=zip_cache, sms_wsj_base=sms_wsj_base, **kwargs, ) @@ -578,7 +551,7 @@ def __init__( def _pre_batch_transform(self, inputs: Dict[str, Any]) -> Dict[str, np.array]: """ Used to process raw SMS-WSJ data - :param inputs: input as coming from SMS-WSJ + :param Dict[str, Any] inputs: input as coming from SMS-WSJ """ return_dict = SmsWsjMixtureEarlyDataset._pre_batch_transform(inputs) for speaker, orth in enumerate(inputs["kaldi_transcription"]): From 3823e2b3279e33c534eb924e1d61758b90740449 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 13 Feb 2023 16:20:43 +0100 Subject: [PATCH 40/47] black --- common/datasets/sms_wsj/returnn_datasets.py | 27 ++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index b27a0a7a6..8a6968970 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -52,12 +52,8 @@ class ZipAudioReader(AudioReader): """ Reads the audio data of an example from a zip file. """ - def __init__( - self, - zip_path=None, - zip_prefix="", - **kwargs - ): + + def __init__(self, zip_path=None, zip_prefix="", **kwargs): """ :param Optional[str] zip_path: zip archive with SMS-WSJ data :param str zip_prefix: prefix of filename that needs to be removed for the lookup in the zip archive @@ -84,7 +80,7 @@ def _rec_audio_read(self, file): else: if self._zip is not None: assert file.startswith(self._zip_prefix) - file_zip = file[len(self._zip_prefix):] + file_zip = file[len(self._zip_prefix) :] data, sample_rate = soundfile.read(io.BytesIO(self._zip.read(file_zip))) else: data, sample_rate = soundfile.read(file) @@ -139,7 +135,10 @@ def __init__( ), "cached and original file have the same path" json_path = json_path_cached audio_reader = ZipAudioReader( - zip_path=zip_cache_cached, zip_prefix=zip_prefix, keys=("original_source", "rir")) + zip_path=zip_cache_cached, + zip_prefix=zip_prefix, + keys=("original_source", "rir"), + ) else: audio_reader = AudioReader(keys=("original_source", "rir")) @@ -399,7 +398,9 @@ def __init__( :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset """ if sms_wsj_base is None: - assert sms_wsj_kwargs is not None, "either sms_wsj_base or sms_wsj_kwargs need to be given" + assert ( + sms_wsj_kwargs is not None + ), "either sms_wsj_base or sms_wsj_kwargs need to be given" sms_wsj_base = SmsWsjBase( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, @@ -458,7 +459,9 @@ def __init__( "shape": (None, 2), }, } - assert sms_wsj_kwargs is not None, "either sms_wsj_base or sms_wsj_kwargs need to be given" + assert ( + sms_wsj_kwargs is not None + ), "either sms_wsj_base or sms_wsj_kwargs need to be given" sms_wsj_base = SmsWsjBaseWithHdfClasses( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, @@ -526,7 +529,9 @@ def __init__( "shape": (None, 2), }, } - assert sms_wsj_kwargs is not None, "either sms_wsj_base or sms_wsj_kwargs need to be given" + assert ( + sms_wsj_kwargs is not None + ), "either sms_wsj_base or sms_wsj_kwargs need to be given" sms_wsj_base = SmsWsjBase( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, From b8c3811361cb0a7ad89fa922cbc65218c30109fb Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 14 Feb 2023 10:06:13 +0100 Subject: [PATCH 41/47] fix buffer logic --- common/datasets/sms_wsj/returnn_datasets.py | 32 +++++++-------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 8a6968970..8efc27b33 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -116,7 +116,7 @@ def __init__( :param str zip_prefix: prefix of filename that needs to be removed for the lookup in the zip archive :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer - :param int buffer_size: buffer size + :param int buffer_size: buffer size, should always be larger than 2 * number of sequences in a batch :param int prefetch_num_workers: number of workers for prefetching """ @@ -226,33 +226,21 @@ def update_buffer(self, seq_idx: int): ) # add sequences - for idx in range(seq_idx, min(seq_idx + self._buffer.max_size // 2, len(self))): - if idx not in self._buffer: - self._buffer[idx] = next(self._ds_iterator) + for idx in range(seq_idx, seq_idx + self._buffer.max_size // 2): + buffer_idx = idx % len(self) + if buffer_idx not in self._buffer: + self._buffer[buffer_idx] = next(self._ds_iterator) if idx == len(self) - 1 and 0 not in self._buffer: print(f"Reached end of dataset, reset iterator", file=returnn_log.v4) - try: - next(self._ds_iterator) - except StopIteration: - pass - else: + rest = list(self._ds_iterator) + if len(rest) > 0: print( - "WARNING: reached final index of dataset, but iterator has more sequences. " - "Maybe the training was restarted from an epoch > 1?", + f"WARNING: reached final index of dataset, but iterator has {len(rest)} more sequences. " + f"Maybe the training was restarted from an epoch > 1?", file=returnn_log.v3, ) - print( - f"Current buffer indices: {self._buffer.keys()}", - file=returnn_log.v5, - ) self._ds_iterator = iter(self._ds) - for idx_ in range(min(self._buffer.max_size // 2, len(self))): - if idx_ not in self._buffer: - self._buffer[idx_] = next(self._ds_iterator) - print( - f"After adding start of dataset to buffer indices: {self._buffer.keys()}", - file=returnn_log.v5, - ) + self._buffer[0] = next(self._ds_iterator) class SmsWsjBaseWithHdfClasses(SmsWsjBase): From fca6ea30dac831f7c1fb38bcb52359e74602d3a1 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 14 Feb 2023 15:34:29 +0100 Subject: [PATCH 42/47] explicit num_outputs required --- common/datasets/sms_wsj/returnn_datasets.py | 64 +++------------------ 1 file changed, 8 insertions(+), 56 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 8efc27b33..dcc2e50ef 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -98,20 +98,19 @@ def __init__( dataset_name, json_path, pre_batch_transform, - data_types, + num_outputs, zip_cache=None, zip_prefix="", scenario_map_args=None, buffer=True, buffer_size=40, prefetch_num_workers=4, - **kwargs, ): """ :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" :param str json_path: path to SMS-WSJ json file :param function pre_batch_transform: function which processes raw SMS-WSJ data - :param Dict[str] data_types: data types for RETURNN, e.g. {"target_signals": {"dim": 2, "shape": (None, 2)}} + :param Dict[str] num_outputs: data types for RETURNN, e.g. {"data": {"dim": 1, "shape": (None, 1)}} :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param str zip_prefix: prefix of filename that needs to be removed for the lookup in the zip archive :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn @@ -120,9 +119,7 @@ def __init__( :param int prefetch_num_workers: number of workers for prefetching """ - super().__init__(**kwargs) - - self.data_types = data_types + super().__init__(data_types=num_outputs) if zip_cache is not None: zip_cache_cached = sp.check_output(["cf", zip_cache]).strip().decode("utf8") @@ -316,14 +313,14 @@ class SmsWsjWrapper(MapDatasetWrapper): Base class for datasets that can be used in RETURNN config. """ - def __init__(self, sms_wsj_base, **kwargs): + def __init__(self, sms_wsj_base, num_outputs, **kwargs): """ :param SmsWsjBase sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this """ if "seq_ordering" not in kwargs: print("Warning: no shuffling is enabled by default", file=returnn_log.v2) super().__init__(sms_wsj_base, **kwargs) - # self.num_outputs = ... # needs to be set in derived classes + self.num_outputs = num_outputs def _get_seq_length(seq_idx: int) -> NumbersDict: """ @@ -377,13 +374,11 @@ def __init__( self, sms_wsj_base=None, sms_wsj_kwargs=None, - num_outputs=None, **kwargs, ): """ :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this :param Optional[Dict[str, Any]] sms_wsj_args: kwargs to create SMS-WSJ base class if sms_wsj_base is not given - :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset """ if sms_wsj_base is None: assert ( @@ -392,12 +387,10 @@ def __init__( sms_wsj_base = SmsWsjBase( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, - data_types={"target_signals": {"dim": 2, "shape": (None, 2)}}, + num_outputs=kwargs.get("num_outputs", None), **sms_wsj_kwargs, ) super().__init__(sms_wsj_base, **kwargs) - # typically data is raw waveform so 1-D and dense, target signals are 2-D (one for each speaker) and dense - self.num_outputs = num_outputs or {"data": [1, 2], "target_signals": [2, 2]} @staticmethod def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: @@ -428,49 +421,26 @@ def __init__( self, sms_wsj_base=None, sms_wsj_kwargs=None, - num_outputs=None, - classes_num_outputs=None, **kwargs, ): """ :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this :param Optional[Dict[str, Any]] sms_wsj_args: kwargs to create SMS-WSJ base class if sms_wsj_base is not given - :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset - :param Optional[int] classes_num_outputs: number of output labels for alignment, e.g. 9001 for that CART size """ if sms_wsj_base is None: - data_types = { - "target_signals": {"dim": 2, "shape": (None, 2)}, - "target_classes": { - "sparse": True, - "dim": classes_num_outputs, - "shape": (None, 2), - }, - } assert ( sms_wsj_kwargs is not None ), "either sms_wsj_base or sms_wsj_kwargs need to be given" sms_wsj_base = SmsWsjBaseWithHdfClasses( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, - data_types=data_types, + num_outputs=kwargs.get("num_outputs", None), **sms_wsj_kwargs, ) super().__init__( - num_outputs=num_outputs, sms_wsj_base=sms_wsj_base, **kwargs, ) - if num_outputs is not None: - self.num_outputs = num_outputs - else: - assert ( - classes_num_outputs is not None - ), "either num_outputs or classes_num_outputs has to be given" - self.num_outputs["target_classes"] = [ - classes_num_outputs, - 1, - ] # target alignments are sparse with the given dim @staticmethod def _pre_batch_transform(inputs: Dict[str, Any]) -> Dict[str, np.array]: @@ -495,7 +465,6 @@ def __init__( sms_wsj_base=None, sms_wsj_kwargs=None, text_proc=None, - num_outputs=None, **kwargs, ): """ @@ -503,43 +472,26 @@ def __init__( :param Optional[SmsWsjBase] sms_wsj_base: SMS-WSJ base class to allow inherited classes to modify this :param Optional[Dict[str, Any]] sms_wsj_args: kwargs to create SMS-WSJ base class if sms_wsj_base is not given :param Optional[Callable] text_proc: function to preprocess the transcriptions before applying BPE - :param Optional[Dict[str, List[int]]] num_outputs: num_outputs for RETURNN dataset """ from returnn.datasets.util.vocabulary import BytePairEncoding self.bpe = BytePairEncoding(**bpe) if sms_wsj_base is None: - data_types = { - "target_signals": {"dim": 2, "shape": (None, 2)}, - "target_bpe": { - "sparse": True, - "dim": self.bpe.num_labels, - "shape": (None, 2), - }, - } assert ( sms_wsj_kwargs is not None ), "either sms_wsj_base or sms_wsj_kwargs need to be given" sms_wsj_base = SmsWsjBase( pre_batch_transform=self._pre_batch_transform, scenario_map_args={"add_speech_reverberation_early": True}, - data_types=data_types, + num_outputs=kwargs.get("num_outputs", None), **sms_wsj_kwargs, ) super().__init__( - num_outputs=num_outputs, sms_wsj_base=sms_wsj_base, **kwargs, ) self.text_proc = text_proc or (lambda x: x) - if num_outputs is not None: - self.num_outputs = num_outputs - else: - self.num_outputs["target_bpe"] = [ - self.bpe.num_labels, - 1, - ] # target BPE labels are sparse with the given dim def _pre_batch_transform(self, inputs: Dict[str, Any]) -> Dict[str, np.array]: """ From 45fb9a0f4979defc608fd990070a9fdf6e3a57aa Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 14 Feb 2023 15:56:24 +0100 Subject: [PATCH 43/47] prefetch buffer size --- common/datasets/sms_wsj/returnn_datasets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index dcc2e50ef..ce22d1c37 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -105,6 +105,7 @@ def __init__( buffer=True, buffer_size=40, prefetch_num_workers=4, + prefetch_buffer_size=8, ): """ :param str dataset_name: "train_si284", "cv_dev93" or "test_eval92" @@ -117,6 +118,7 @@ def __init__( :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer :param int buffer_size: buffer size, should always be larger than 2 * number of sequences in a batch :param int prefetch_num_workers: number of workers for prefetching + :param int prefetch_buffer_size: buffer size for prefetching """ super().__init__(data_types=num_outputs) @@ -158,7 +160,7 @@ def __init__( self._use_buffer = buffer if self._use_buffer: - self._ds = self._ds.prefetch(prefetch_num_workers, buffer_size).copy( + self._ds = self._ds.prefetch(prefetch_num_workers, prefetch_buffer_size).copy( freeze=True ) self._buffer = SequenceBuffer(buffer_size) From c2cb0a033305730a9d17bb5c9f6124f34725491e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 14 Feb 2023 17:55:05 +0100 Subject: [PATCH 44/47] fix init seq order --- common/datasets/sms_wsj/returnn_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index ce22d1c37..6c0df8318 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -362,7 +362,7 @@ def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: therefore the buffer does not contain the initial indices when continuing the training from an epoch > 0. """ out = super().init_seq_order(epoch=epoch, **kwargs) - buffer_index = ((epoch or 1) - 1) * self.num_seqs % len(self._dataset) + buffer_index = self.get_corpus_seq_idx(0) self._dataset.update_buffer(buffer_index) return out From a7fc5192f11fe96913f3146cad9a15b267f6929a Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 14 Feb 2023 17:55:17 +0100 Subject: [PATCH 45/47] allow shuffling --- common/datasets/sms_wsj/returnn_datasets.py | 28 ++++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 6c0df8318..12d0fa1ad 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -102,6 +102,7 @@ def __init__( zip_cache=None, zip_prefix="", scenario_map_args=None, + shuffle=False, buffer=True, buffer_size=40, prefetch_num_workers=4, @@ -115,6 +116,7 @@ def __init__( :param Optional[str] zip_cache: zip archive with SMS-WSJ data which can be cached, unzipped and used as data dir :param str zip_prefix: prefix of filename that needs to be removed for the lookup in the zip archive :param Optional[Dict] scenario_map_args: optional kwargs for sms_wsj scenario_map_fn + :param bool shuffle: shuffle data in SMS-WSJ dataset :param bool buffer: if True, use SMS-WSJ dataset prefetching and store sequences in buffer :param int buffer_size: buffer size, should always be larger than 2 * number of sequences in a batch :param int prefetch_num_workers: number of workers for prefetching @@ -154,17 +156,17 @@ def __init__( } ds = ds.map(functools.partial(scenario_map_fn, **scenario_map_args)) ds = ds.map(pre_batch_transform) - - self._ds = ds - self._ds_iterator = iter(self._ds) - + if shuffle: + ds = ds.shuffle(reshuffle=True) self._use_buffer = buffer if self._use_buffer: - self._ds = self._ds.prefetch(prefetch_num_workers, prefetch_buffer_size).copy( - freeze=True - ) + ds = ds.prefetch(prefetch_num_workers, prefetch_buffer_size) self._buffer = SequenceBuffer(buffer_size) + self._ds = ds + self._ds_copy = ds.copy(freeze=True) + self._ds_iterator = iter(self._ds_copy) + def __len__(self) -> int: return len(self._ds) @@ -181,7 +183,7 @@ def _get_seq_by_idx(self, seq_idx: int) -> Dict[str, np.array]: ), f"seq_idx {seq_idx} not in buffer. Available keys are {self._buffer.keys()}" return self._buffer[seq_idx] else: - return self._ds[seq_idx] + return self._ds_copy[seq_idx] def get_seq_tag(self, seq_idx: int) -> str: """ @@ -238,9 +240,16 @@ def update_buffer(self, seq_idx: int): f"Maybe the training was restarted from an epoch > 1?", file=returnn_log.v3, ) - self._ds_iterator = iter(self._ds) + self._ds_iterator = iter(self._ds.copy(freeze=True)) self._buffer[0] = next(self._ds_iterator) + def update_dataset_copy(self): + """ + Update the copy of the internal SMS-WSJ dataset. The copy is used because it can be indexed. It is updated in + order to obtain different shuffling for different epochs. + """ + self._ds_copy = self._ds.copy(freeze=True) + class SmsWsjBaseWithHdfClasses(SmsWsjBase): """ @@ -362,6 +371,7 @@ def init_seq_order(self, epoch: Optional[int] = None, **kwargs) -> bool: therefore the buffer does not contain the initial indices when continuing the training from an epoch > 0. """ out = super().init_seq_order(epoch=epoch, **kwargs) + self._dataset.update_dataset_copy() buffer_index = self.get_corpus_seq_idx(0) self._dataset.update_buffer(buffer_index) return out From ac508342a208c4a88c302a8fdb7fb87e0c8f4133 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Wed, 15 Feb 2023 12:35:30 +0100 Subject: [PATCH 46/47] try reading from zip 5 times --- common/datasets/sms_wsj/returnn_datasets.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 12d0fa1ad..9bce50938 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -81,7 +81,18 @@ def _rec_audio_read(self, file): if self._zip is not None: assert file.startswith(self._zip_prefix) file_zip = file[len(self._zip_prefix) :] - data, sample_rate = soundfile.read(io.BytesIO(self._zip.read(file_zip))) + data = None + for _ in range(5): + try: + data, sample_rate = soundfile.read(io.BytesIO(self._zip.read(file_zip))) + except: + print( + f"data could not be read: {file_zip} from {self._zip.filename}, retry...", + file=returnn_log.v4, + ) + else: + break + assert data is not None, f"data could not be read: {file_zip} from {self._zip.filename}, abort now" else: data, sample_rate = soundfile.read(file) return data.T From 48ad8b5de6428ecc406633ac4df63f861ee9240b Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Wed, 15 Feb 2023 18:14:04 +0100 Subject: [PATCH 47/47] black --- common/datasets/sms_wsj/returnn_datasets.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/common/datasets/sms_wsj/returnn_datasets.py b/common/datasets/sms_wsj/returnn_datasets.py index 9bce50938..6c8c23972 100644 --- a/common/datasets/sms_wsj/returnn_datasets.py +++ b/common/datasets/sms_wsj/returnn_datasets.py @@ -84,7 +84,9 @@ def _rec_audio_read(self, file): data = None for _ in range(5): try: - data, sample_rate = soundfile.read(io.BytesIO(self._zip.read(file_zip))) + data, sample_rate = soundfile.read( + io.BytesIO(self._zip.read(file_zip)) + ) except: print( f"data could not be read: {file_zip} from {self._zip.filename}, retry...", @@ -92,7 +94,9 @@ def _rec_audio_read(self, file): ) else: break - assert data is not None, f"data could not be read: {file_zip} from {self._zip.filename}, abort now" + assert ( + data is not None + ), f"data could not be read: {file_zip} from {self._zip.filename}, abort now" else: data, sample_rate = soundfile.read(file) return data.T