diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 530b89a3e..8eaad6c4d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -396,7 +396,14 @@ jobs: done # Needed for some tests. - pip install --user --progress-bar=off transformers + # transformers 4.50 requires PyTorch >2.0, so stick to transformers 4.49 for now. + # (https://github.com/rwth-i6/returnn/issues/1706) + if [[ "${{matrix.python-version}}" == 3.8 ]]; then + # Need older version for Python 3.8. Install whatever is available. + pip install --user --progress-bar=off transformers + else + pip install --user --progress-bar=off transformers==4.49.0 + fi pip install --user --progress-bar=off espnet # TorchAudio needed by ESPnet. # https://pytorch.org/audio/stable/installation.html#compatibility-matrix diff --git a/.github/workflows/black.yml b/.github/workflows/ruff.yml similarity index 60% rename from .github/workflows/black.yml rename to .github/workflows/ruff.yml index c0e16a6c9..12ea8183b 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/ruff.yml @@ -1,4 +1,4 @@ -name: black +name: ruff on: push: @@ -11,7 +11,7 @@ on: - master jobs: - check-black-formatting: + check-ruff-formatting: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -19,7 +19,6 @@ jobs: with: python-version: 3.8 cache: 'pip' - cache-dependency-path: '.github/workflows/black.yml' - - run: pip install black==22.3.0 - - run: black --diff . - - run: black --check . + cache-dependency-path: '.github/workflows/ruff.yml' + - run: pip install ruff==0.11.8 + - run: ruff format --diff . diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f9936e237..70b2ff777 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,7 +18,7 @@ General rules when contributing to the code of RETURNN: Our code style uses most common Python conventions. If you are not an expert in Python, use PyCharm, and follow [our PyCharm configuration guide](https://github.com/rwth-i6/returnn/wiki/PyCharm-Configuration). - Apply [black](https://black.readthedocs.io/). + Apply [ruff](https://github.com/astral-sh/ruff). * Make sure all [tests](https://returnn.readthedocs.io/en/latest/advanced/test_suite.html) pass. * At the time being, we want to support earlier versions of TF 1 (consider at least TF 1.8, but maybe even TF 1.4) diff --git a/__init__.py b/__init__.py index 3acf98ab4..72c6d976e 100644 --- a/__init__.py +++ b/__init__.py @@ -7,7 +7,6 @@ We want to support the same code. """ - from __future__ import annotations import os import sys diff --git a/demos/demo-tf-search-compiled-graph.py b/demos/demo-tf-search-compiled-graph.py index 85e629bc9..0274c3c9a 100644 --- a/demos/demo-tf-search-compiled-graph.py +++ b/demos/demo-tf-search-compiled-graph.py @@ -8,12 +8,12 @@ # No RETURNN dependency needed for the basic search. Just TF itself. -import typing import os import json import argparse import tensorflow as tf import numpy +from typing import List, Optional, Tuple class Hyp: @@ -26,7 +26,7 @@ def __init__(self, idx): :param int idx: hyp idx (to identify it in a beam) """ self.idx = idx - self.source_idx = None # type: typing.Optional[int] # source hyp idx + self.source_idx: Optional[int] = None # source hyp idx self.score = 0.0 self.seq = [] # label seq @@ -91,7 +91,6 @@ def make_initial_feed_dict(): # Now loop over decoder steps. max_dec_len = 100 # TODO better default... depending on input len. or configurable... for i in range(max_dec_len): - # Loop over all stochastic variables. for stochastic_var in info["stochastic_var_order"]: assert isinstance(stochastic_var, str) @@ -108,9 +107,7 @@ def make_initial_feed_dict(): # TODO: length norm here? # Select new hypotheses. - best_possibilities = sorted(all_possibilities)[ - : args.beam_size - ] # type: typing.List[typing.Tuple[float,int,Hyp]] + best_possibilities: List[Tuple[float, int, Hyp]] = sorted(all_possibilities)[: args.beam_size] assert len(best_possibilities) == args.beam_size hyps = [ hyp.expand(idx=i, label=label, score=score) @@ -121,8 +118,9 @@ def make_initial_feed_dict(): session.run( info["state_vars"]["stochastic_var_scores_%s" % stochastic_var] + "/Assign...?", # TODO... feed_dict={ - info["state_vars"]["stochastic_var_scores_%s" % stochastic_var] - + "/Initial...?": [[hyp.seq[-1] for hyp in hyps]] # TODO... + info["state_vars"]["stochastic_var_scores_%s" % stochastic_var] + "/Initial...?": [ + [hyp.seq[-1] for hyp in hyps] + ] # TODO... }, ) diff --git a/docs/conf.py b/docs/conf.py index df8b1eb46..0fa3f688c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,7 +77,7 @@ ] # See https://github.com/rtfd/readthedocs.org/issues/283 -mathjax_path = "https://cdn.mathjax.org/mathjax/latest/MathJax.js?" "config=TeX-AMS-MML_HTMLorMML" +mathjax_path = "https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" # see https://stackoverflow.com/q/12206334/562769 numpydoc_show_class_members = False @@ -151,6 +151,7 @@ # If true, keep warnings as "system message" paragraphs in the built documents. # keep_warnings = False + # Resolve function for the linkcode extension. def linkcode_resolve(domain, info): def find_source(): diff --git a/pyproject.toml b/pyproject.toml index cea912c91..e652c5e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,5 +13,9 @@ extend-exclude = ''' )/ ''' +[tool.ruff] +line-length = 120 +target-version = "py38" # https://github.com/rwth-i6/returnn/issues/1326 + [build-system] requires = ["setuptools", "numpy"] diff --git a/requirements-dev b/requirements-dev index fbbe5ed50..9c01631f7 100644 --- a/requirements-dev +++ b/requirements-dev @@ -1,2 +1,2 @@ -black==22.3.0 pytest +ruff==0.11.8 diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index 798ef00d6..15d05b631 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -20,7 +20,7 @@ import numpy import functools import typing -from typing import TYPE_CHECKING, Optional, Any, Union, Type, Dict, Sequence, List, Callable +from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable from returnn.log import log from returnn.engine.batch import Batch, BatchSetGenerator @@ -141,12 +141,10 @@ def __init__( :param int _shard_index: local shard index, when sharding is enabled """ self.name = name or ("dataset_id%s" % id(self)) - self.lock = None # type: Optional[RLock] # Used when manipulating our data potentially from multiple threads. - self.rnd_seq_drop = None # type: typing.Optional[Random] + self.lock: Optional[RLock] = None # Used when manipulating our data potentially from multiple threads. + self.rnd_seq_drop: Optional[Random] = None self.num_inputs = 0 # usually not used, but num_outputs instead, which is more generic - self.num_outputs = ( - None - ) # type: typing.Optional[typing.Dict[str,typing.Tuple[int,int]]] # tuple is num-classes, len(shape). # nopep8 + self.num_outputs: Optional[Dict[str, Tuple[int, int]]] = None # tuple is num-classes, len(shape). self.window = window self.seq_ordering = seq_ordering # "default", "sorted" or "random". See self.get_seq_order_for_epoch(). self.fixed_random_seed = fixed_random_seed @@ -159,10 +157,10 @@ def __init__( self._seq_order_seq_lens_file = seq_order_seq_lens_file self._seq_order_seq_lens_by_idx = None # There is probably no use case for combining the two, so avoid potential misconfiguration. - assert ( - self.partition_epoch == 1 or self.repeat_epoch == 1 - ), "Combining partition_epoch and repeat_epoch is prohibited." - self.labels = {} # type: typing.Dict[str,typing.List[str]] + assert self.partition_epoch == 1 or self.repeat_epoch == 1, ( + "Combining partition_epoch and repeat_epoch is prohibited." + ) + self.labels: Dict[str, List[str]] = {} self.weights = {} self._num_timesteps = 0 self._num_seqs = 0 @@ -213,8 +211,8 @@ def __repr__(self): getattr(self, "epoch", ""), ) - _getnewargs_exclude_attrs = set() # type: typing.Set[str] - _getnewargs_remap = {} # type: typing.Dict[str,str] + _getnewargs_exclude_attrs: Set[str] = set() + _getnewargs_remap: Dict[str, str] = {} @staticmethod def _create_from_reduce(cls, kwargs, state) -> Dataset: @@ -660,12 +658,13 @@ def get_seq_order_for_epoch( ) old_seq_index = seq_index seq_index = [i for i in seq_index if all_seq_tags[i] in self.seq_tags_filter] - assert ( - seq_index - ), "%s: empty after applying seq_list_filter_file. Example filter tags: %r, used tags: %r" % ( - self, - sorted(self.seq_tags_filter)[:3], - [all_seq_tags[i] for i in old_seq_index[:3]], + assert seq_index, ( + "%s: empty after applying seq_list_filter_file. Example filter tags: %r, used tags: %r" + % ( + self, + sorted(self.seq_tags_filter)[:3], + [all_seq_tags[i] for i in old_seq_index[:3]], + ) ) return seq_index @@ -736,9 +735,9 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ self.epoch = epoch self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch)) - assert ( - self._num_shards == 1 or self.supports_sharding() - ), f"{self}: does not support sharding, but got num_shards == {self._num_shards}" + assert self._num_shards == 1 or self.supports_sharding(), ( + f"{self}: does not support sharding, but got num_shards == {self._num_shards}" + ) return False def finish_epoch(self, *, free_resources: bool = False): @@ -864,18 +863,16 @@ def _get_data_slice_sparse(self, seq_idx, key, start_frame, end_frame): data = self.get_data(seq_idx, key) return data[s0_start:s0_end] - def get_tag(self, sorted_seq_idx): + def get_tag(self, sorted_seq_idx: int) -> str: """ - :param int sorted_seq_idx: - :rtype: str + :param sorted_seq_idx: """ return "seq-%i" % sorted_seq_idx - def get_all_tags(self): + def get_all_tags(self) -> List[str]: """ :return: list of all seq tags, of the whole dataset, without partition epoch. Note that this is not possible with all datasets. - :rtype: list[str] """ raise OptionalNotImplementedError(f"{self} get_all_tags not implemented") @@ -972,16 +969,16 @@ def get_complete_frac(self, sorted_seq_idx: int, *, allow_only_lr_suitable: bool except Exception: # also not always available num_seqs = None # ignore - if math.isinf(num_seqs): + if num_seqs is not None and math.isinf(num_seqs): if allow_only_lr_suitable: # cannot compute meaningful complete_frac for infinite num_seqs return None else: num_seqs = None - assert ( - num_seqs is None or 0 <= sorted_seq_idx < num_seqs - ), f"{self}: invalid seq indices: 0 <= seq_idx ({sorted_seq_idx}) < num_seqs ({num_seqs}) violated" + assert num_seqs is None or 0 <= sorted_seq_idx < num_seqs, ( + f"{self}: invalid seq indices: 0 <= seq_idx ({sorted_seq_idx}) < num_seqs ({num_seqs}) violated" + ) return self.generic_complete_frac(sorted_seq_idx, num_seqs) @property @@ -1118,7 +1115,9 @@ def can_serialize_data(self, key: str) -> bool: def serialize_data(self, key: str, data: numpy.ndarray) -> str: """ - In case you have a :class:`Vocabulary`, just use :func:`Vocabulary.get_seq_labels`. + This is deprecated, as this is slow! + In case you have a :class:`Vocabulary`, just use :func:`Vocabulary.get_seq_labels` + or :func:`Vocabulary.serialize_labels`. :param key: e.g. "classes". self.labels[key] should be set :param numpy.ndarray data: 0D or 1D diff --git a/returnn/datasets/cached.py b/returnn/datasets/cached.py index f0955259e..11c9623be 100644 --- a/returnn/datasets/cached.py +++ b/returnn/datasets/cached.py @@ -46,9 +46,10 @@ def __init__(self, cache_byte_size=0, **kwargs): self._index_map = range(len(self._seq_index)) # sorted seq idx -> seq_index idx self._tag_idx = {} # type: typing.Dict[str,int] # map of tag -> real-seq-idx. call _update_tag_idx self.targets = {} - self.target_keys = ( - [] - ) # the keys for which we provide data; we may have labels for additional keys in self.labels + # the keys for which we provide data; + # we may have labels for additional keys in self.labels + self.target_keys = [] + self.timestamps = None def initialize(self): diff --git a/returnn/datasets/distrib_files.py b/returnn/datasets/distrib_files.py index 9f6c2d5d3..62cf5f1d9 100644 --- a/returnn/datasets/distrib_files.py +++ b/returnn/datasets/distrib_files.py @@ -13,6 +13,7 @@ from returnn.log import log from returnn.util import better_exchook from returnn.util.basic import override_env_var, try_run +from returnn.util.literal_py_to_pickle import literal_eval from returnn.util.multi_proc_non_daemonic_spawn import NonDaemonicSpawnContext from returnn.config import SubProcCopyGlobalConfigPreInitFunc from .basic import init_dataset, extend_dataset_dict_from_parent_dataset, DatasetSeq, RANDOM_SEED_OFFSET_ENV_VAR @@ -133,7 +134,7 @@ def get_sub_epoch_dataset(files_subepoch: List[Tuple[str, str]]) -> Dict[str, An def __init__( self, *, - files: List[FileTree], + files: Union[List[FileTree], os.PathLike], get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]], preload_next_n_sub_epochs: int = 1, buffer_size: int = 1, @@ -144,7 +145,11 @@ def __init__( ): """ :param files: the files to shuffle over, can also be a list of arbitrarily nested python objects - to keep associated heterogeneous data together + to keep associated heterogeneous data together. + When the list grows too large to be serialized into a RETURNN config, the list of files + can also be specified as a path to a .txt file containing one file per line, + or a python file containing the repr of a list of arbitrarily nested python objects, + or a JSON file containing a list of arbitarily nested (JSON) objects. :param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files :param preload_next_n_sub_epochs: how many sub epoch datasets to preload :param buffer_size: buffer size for each worker, amount of seqs to prefetch @@ -163,6 +168,7 @@ def __init__( self._data_keys: Optional[List[str]] = None self._num_seqs: Optional[int] = None + self._files: Optional[List[FileTree]] = None # files to use for this dataset self._workers: Dict[int, _WorkerProcParent] = {} # epoch -> worker self._files_order_cache: Dict[int, List[List[FileTree]]] = {} # full epoch (0-indexed) -> files order @@ -191,9 +197,7 @@ def __init__( self.labels = _meta_info_cache["labels"] self._data_keys = _meta_info_cache["data_keys"] self._file_sizes = _meta_info_cache["file_sizes"] - - if len(files) < self.partition_epoch: - raise ValueError(f"{self}: len(files) {len(files)} < partition_epoch {self.partition_epoch}") + self._files = _meta_info_cache["files"] def initialize(self): """init""" @@ -218,17 +222,53 @@ def _meta_info_cache(self): "labels": self.labels, "data_keys": self._data_keys, "file_sizes": self._file_sizes, + "files": self._files, } def _uses_custom_distributed_sharding(self) -> bool: return self._num_shards > 1 + def _lazy_init_file_list(self): + """ + The list of data files can either be provided as python list, or, if that grows + too large, as path to a file containing the list. + + This function initializes the list of files from whatever was given as input. + """ + if self._files is not None: + return + if isinstance(self.files, list): + self._files = self.files + elif isinstance(self.files, (str, os.PathLike)): + _, ext = os.path.splitext(self.files) + assert ext, f"{self}: no file extension on file list file {self.files}" + if ext == ".txt": + with open(self.files, "rt") as f: + stripped_lines = (line.strip() for line in f.readlines()) + self._files = [line for line in stripped_lines if line and not line.startswith("#")] + elif ext == ".json": + import json + + with open(self.files, "rt") as f: + self._files = json.load(f) + elif ext == ".py": + with open(self.files, "rb") as f: + self._files = literal_eval(f.read()) + else: + raise ValueError(f"{self}: type {ext} not supported as file list file") + assert isinstance(self._files, list) + else: + raise ValueError(f"{self}: unsupported file list ({type(self.files)}: {self.files})") + if len(self._files) < self.partition_epoch: + raise ValueError(f"{self}: len(files) {len(self._files)} < partition_epoch {self.partition_epoch}") + def _lazy_init_num_outputs(self): if self.num_outputs: return + self._lazy_init_file_list() # First, we need to know the num_inputs, num_outputs, total_num_seqs, labels. # Init the dataset with the first file. - dataset_dict = self._get_sub_dataset_dict(files=[self.files[0]]) + dataset_dict = self._get_sub_dataset_dict(files=[self._files[0]]) dataset = init_dataset(dataset_dict, extra_kwargs={"seq_ordering": "default"}, parent_dataset=self) self.num_inputs = dataset.num_inputs self.num_outputs = dataset.num_outputs @@ -240,8 +280,9 @@ def _lazy_init_file_sizes(self): if self._file_sizes: return + self._lazy_init_file_list() self._file_sizes = { - _get_key_for_file_tree(t): sum((os.path.getsize(fn) for fn in tree.flatten(t)), 0) for t in self.files + _get_key_for_file_tree(t): sum((os.path.getsize(fn) for fn in tree.flatten(t)), 0) for t in self._files } def __del__(self): @@ -266,6 +307,7 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N self._num_seqs = 0 return True + self._lazy_init_file_list() self._lazy_init_file_sizes() full_epoch_0idx = (epoch - 1) // self.partition_epoch @@ -279,13 +321,13 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N if full_epoch_0idx_ in self._files_order_cache: continue if self.seq_ordering == "default": - files_order_flat = self.files + files_order_flat = self._files elif self.seq_ordering == "random": # when sharding, _get_random_seed_for_epoch makes sure to use a fixed # random_seed_offset rnd_seed = self._get_random_seed_for_epoch(full_epoch_0idx_ * self.partition_epoch + 1) random_generator = numpy.random.RandomState(rnd_seed) - files_order_flat = list(self.files) + files_order_flat = list(self._files) random_generator.shuffle(files_order_flat) else: raise ValueError(f"{self}: seq_ordering {self.seq_ordering!r} not supported") @@ -409,8 +451,7 @@ def _distribute_evenly_by_size( # We need to decide where to add this file, to the current or the next sub epoch. if not files_per_bin[bin_idx] or ( # Better to add this file to the current sub epoch? - abs((size_taken + size) - avg_size_per_sub_epoch) - <= abs(size_taken - avg_size_per_sub_epoch) + abs((size_taken + size) - avg_size_per_sub_epoch) <= abs(size_taken - avg_size_per_sub_epoch) ): files_per_bin[bin_idx].append(f_tree) size_taken = 0 @@ -431,7 +472,8 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]: def have_seqs(self) -> bool: """have seqs""" - return bool(self.files) + self._lazy_init_file_list() + return bool(self._files) def finish_epoch(self, *, free_resources: bool = False): """finish epoch""" diff --git a/returnn/datasets/generating.py b/returnn/datasets/generating.py index 3d08804d6..96de40058 100644 --- a/returnn/datasets/generating.py +++ b/returnn/datasets/generating.py @@ -46,12 +46,12 @@ def __init__(self, input_dim, output_dim, num_seqs=float("inf"), **kwargs): output_dim["data"] = (input_dim * self.window, 2) # not sparse self.num_outputs = output_dim self.expected_load_seq_start = 0 - self._seq_order = None # type: Optional[Sequence[int]] + self._seq_order: Optional[Sequence[int]] = None self._num_seqs = num_seqs self._total_num_seqs = num_seqs self.random = numpy.random.RandomState(1) self.reached_final_seq = False - self.added_data = [] # type: typing.List[DatasetSeq] + self.added_data: List[DatasetSeq] = [] if self.seq_ordering in ("sorted", "sorted_reverse"): # For the dev/eval dataset, RETURNN automatically tries to sort them. # As this is not supported, just ignore it and reset it to the default order. @@ -792,10 +792,10 @@ def generate_seq(self, seq_idx): i1 = seq_idx i2 = i1 + seq_len * self.num_inputs features = numpy.array( - [((i % self.input_max_value) + self.input_shift) * self.input_scale for i in range(i1, i2)] + [((i % self.input_max_value) + self.input_shift) * self.input_scale for i in range(i1, i2)], dtype="float32" ).reshape((seq_len, self.num_inputs)) i1, i2 = i2, i2 + seq_len - targets = numpy.array([i % self.num_outputs["classes"][0] for i in range(i1, i2)]) + targets = numpy.array([i % self.num_outputs["classes"][0] for i in range(i1, i2)], dtype="int32") return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets) @@ -904,22 +904,24 @@ def __init__( seq_len = {} for key in self.data_keys: seq_len[key] = _seq_len - assert set(data_keys) == set( - seq_len.keys() - ), "%s: the keys of seq_len (%s) must match the keys in data_keys=%s." % ( - self, - str(seq_len.keys()), - str(data_keys), + assert set(data_keys) == set(seq_len.keys()), ( + "%s: the keys of seq_len (%s) must match the keys in data_keys=%s." + % ( + self, + str(seq_len.keys()), + str(data_keys), + ) ) - assert isinstance( - output_dim, dict - ), "%s: output_dim %r must be a dict containing a definition for each key in data_keys." % (self, output_dim) - assert set(data_keys) == set( - output_dim.keys() - ), "%s: the keys of output_dim (%s) must match the keys in data_keys=%s." % ( - self, - str(output_dim.keys()), - str(data_keys), + assert isinstance(output_dim, dict), ( + "%s: output_dim %r must be a dict containing a definition for each key in data_keys." % (self, output_dim) + ) + assert set(data_keys) == set(output_dim.keys()), ( + "%s: the keys of output_dim (%s) must match the keys in data_keys=%s." + % ( + self, + str(output_dim.keys()), + str(data_keys), + ) ) super(DummyDatasetMultipleDataKeys, self).__init__( @@ -2037,7 +2039,7 @@ def __init__( """ :param str path: dir, should contain "train-*/*/*/{*.flac,*.trans.txt}", or "train-*.zip" :param str prefix: "train", "dev", "test", "dev-clean", "dev-other", ... - :param str|list[str]|None orth_post_process: :func:`get_post_processor_function`, applied on orth + :param str|list[str]|function|None orth_post_process: :func:`get_post_processor_function`, applied on orth :param str|dict[str]|None targets: "bpe" or "chars" or None or dict for :func:`Vocabulary.create_vocab` :param dict[str]|None audio: options for :class:`ExtractAudioFeatures` :param dict[str]|None bpe: options for :class:`BytePairEncoding` @@ -2134,9 +2136,7 @@ def _collect_trans(self): import os import zipfile - transs = ( - {} - ) # type: typing.Dict[typing.Tuple[str,int,int,int],str] # (subdir, speaker-id, chapter-id, seq-id) -> transcription # nopep8 + transs: Dict[Tuple[str, int, int, int], str] = {} # (subdir, speaker-id, chapter-id, seq-id) -> transcription if self.use_zip: for name, zip_file in self._zip_files.items(): assert isinstance(zip_file, zipfile.ZipFile) diff --git a/returnn/datasets/hdf.py b/returnn/datasets/hdf.py index ad13b4979..f0130d8e5 100644 --- a/returnn/datasets/hdf.py +++ b/returnn/datasets/hdf.py @@ -37,9 +37,9 @@ def __init__(self, files=None, use_cache_manager=False, **kwargs): :param bool use_cache_manager: uses :func:`Util.cf` for files """ super(HDFDataset, self).__init__(**kwargs) - assert ( - self.partition_epoch == 1 or self.cache_byte_size_total_limit == 0 - ), "To use partition_epoch in HDFDatasets, disable caching by setting cache_byte_size=0" + assert self.partition_epoch == 1 or self.cache_byte_size_total_limit == 0, ( + "To use partition_epoch in HDFDatasets, disable caching by setting cache_byte_size=0" + ) self._use_cache_manager = use_cache_manager self.files = [] # type: typing.List[str] # file names self.h5_files = [] # type: typing.List[h5py.File] @@ -1073,6 +1073,8 @@ class SimpleHDFWriter: which can be read later by :class:`HDFDataset`. Note that we dump to a temp file first, and only at :func:`close` we move it over to the real destination. + + Can be used as a context manager, i.e. with the `with` statement. """ def __init__( @@ -1201,6 +1203,7 @@ def _prepare_extra(self, extra_type, extra_labels): shape = [None] * ndim # type: typing.List[typing.Optional[int]] if ndim >= 2: shape[-1] = dim + assert all(shape[1:]), f"{self} extra {data_key!r} supports only dyn dim in first axis, got shape {shape!r}" if dtype == "string": # noinspection PyUnresolvedReferences dtype = h5py.special_dtype(vlen=str) @@ -1237,10 +1240,15 @@ def _insert_h5_inputs(self, raw_data): self._datasets[name] = self._file.create_dataset( name, raw_data.shape, raw_data.dtype, maxshape=tuple(None for _ in raw_data.shape) ) + expected_shape = (raw_data.shape[0],) + self._datasets[name].shape[1:] else: old_shape = self._datasets[name].shape self._datasets[name].resize(old_shape[0] + raw_data.shape[0], axis=0) + expected_shape = (raw_data.shape[0],) + old_shape[1:] # append raw data to dataset + assert expected_shape == raw_data.shape, ( + f"{self} insert: shape mismatch: expected {expected_shape}, got {raw_data.shape}" + ) self._datasets[name][self._file.attrs["numTimesteps"] :] = raw_data self._file.attrs["numTimesteps"] += raw_data.shape[0] self._file.attrs["numSeqs"] += 1 @@ -1286,13 +1294,17 @@ def _insert_h5_other(self, data_key, raw_data, dtype=None, add_time_dim=False, d self._seq_lengths[seq_idx, data_key_idx_0 + 1] = self._extra_num_time_steps[data_key_] self._extra_num_time_steps[data_key] += raw_data.shape[0] - self._datasets[data_key].resize(self._extra_num_time_steps[data_key], axis=0) + hdf_data = self._datasets[data_key] + hdf_data.resize(self._extra_num_time_steps[data_key], axis=0) data_key_idx = sorted(self._prepared_extra).index(data_key) + 1 self._seq_lengths[seq_idx, data_key_idx] = raw_data.shape[0] offset = self._extra_num_time_steps[data_key] - raw_data.shape[0] - hdf_data = self._datasets[data_key] + expected_shape = (raw_data.shape[0],) + hdf_data.shape[1:] + assert expected_shape == raw_data.shape, ( + f"{self} insert other {data_key!r}: shape mismatch: expected {expected_shape}, got {raw_data.shape}" + ) hdf_data[offset:] = raw_data def insert_batch(self, inputs, seq_len, seq_tag, extra=None): @@ -1403,6 +1415,12 @@ def close(self): os.remove(self.tmp_filename) self.tmp_filename = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + class HDFDatasetWriter: """ diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index 74f8d1c7f..0178c237f 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -7,7 +7,22 @@ from __future__ import annotations -from typing import Optional, Union, Any, Callable, Iterator, List, Tuple, Set, BinaryIO, Dict, cast, Generator +from typing import ( + Iterable, + Optional, + Sequence, + Union, + Any, + Callable, + Iterator, + List, + Tuple, + Set, + BinaryIO, + Dict, + cast, + Generator, +) import typing import os from io import IOBase @@ -55,6 +70,7 @@ def __init__( orth_symbols_file=None, orth_symbols_map_file=None, orth_replace_map_file=None, + orth_post_process=None, word_based=False, word_end_symbol=None, seq_end_symbol="[END]", @@ -101,6 +117,7 @@ def __init__( a python dict with {"": , ...} or a pickled dictionary :param str|()->str|None orth_replace_map_file: JSON file with replacement dict for orth symbols. + :param str|list[str]|function|None orth_post_process: :func:`get_post_processor_function`, applied on orth :param bool word_based: whether to parse single words, or otherwise will be character based. :param str|None word_end_symbol: If provided and if word_based is False (character based modeling), token to be used to represent word ends. @@ -247,6 +264,10 @@ def __init__( else: assert not orth_replace_map_file + self.orth_post_process = None + if orth_post_process: + self.orth_post_process = get_post_processor_function(orth_post_process) + num_labels = len(self.labels["data"]) if dtype: self.dtype = dtype @@ -578,6 +599,9 @@ def _collect_single_seq(self, seq_idx): seq_tag = self._seq_list[true_idx] self.next_orth_idx += 1 + if self.orth_post_process: + orth = self.orth_post_process(orth) + if self.orth_vocab is not None: data = numpy.array(self.orth_vocab.get_seq(orth), dtype=self.dtype) @@ -1463,8 +1487,8 @@ def __init__( } self._data_keys = self._source_data_keys + self._target_data_keys - self._data = {data_key: [] for data_key in self._data_keys} # type: typing.Dict[str,typing.List[numpy.ndarray]] - self._data_len = None # type: typing.Optional[int] + self._data: Dict[str, List[numpy.ndarray]] = {data_key: [] for data_key in self._data_keys} + self._data_len: Optional[int] = None self._vocabs = self._get_vocabs() self.num_outputs = {k: [max(self._vocabs[k].values()) + 1, 1] for k in self._vocabs.keys()} # all sparse @@ -1480,7 +1504,7 @@ def __init__( unknown_label.setdefault(data_key, None) self._unknown_label = unknown_label - self._seq_order = None # type: typing.Optional[typing.Sequence[int]] # seq_idx -> line_nr + self._seq_order: Optional[Sequence[int]] = None # seq_idx -> line_nr self._tag_prefix = "line-" # sequence tag is "line-n", where n is the line number self._thread = Thread(name="%r reader" % self, target=self._thread_main) self._thread.daemon = True @@ -1869,14 +1893,11 @@ def _extend_data(self, file_prefix, data_strs): assert file_prefix == self.target_file_prefix data_keys = self._target_data_keys - data = [ + data: List[List[numpy.ndarray]] = [ self._factored_words_to_numpy(data_keys, s.decode("utf8").strip().split(), self._add_postfix[file_prefix]) for s in data_strs - ] # type: typing.List[typing.List[numpy.ndarray]] # shape: (len(data_strs), len(data_keys)) - - data = zip( - *data - ) # type: typing.Iterable[typing.Tuple[numpy.ndarray]] # shape: (len(data_keys), len(data_strs)) + ] # shape: (len(data_strs), len(data_keys)) + data: Iterable[Tuple[numpy.ndarray]] = zip(*data) # shape: (len(data_keys), len(data_strs)) with self._lock: for i, data_ in enumerate(data): @@ -1899,9 +1920,9 @@ def _factored_words_to_numpy(self, data_keys, words, postfix): words_per_factor = [[]] * len(data_keys) elif len(data_keys) > 1: factored_words = [word.split(self._factor_separator) for word in words] - assert all( - len(factors) == len(data_keys) for factors in factored_words - ), "All words must have all factors. Expected: " + self._factor_separator.join(data_keys) + assert all(len(factors) == len(data_keys) for factors in factored_words), ( + "All words must have all factors. Expected: " + self._factor_separator.join(data_keys) + ) words_per_factor = zip(*factored_words) words_per_factor = [list(w) for w in words_per_factor] else: @@ -2421,7 +2442,7 @@ def get_post_processor_function(opts): for some normalization / cleanup. This function can be used to get such functions. - :param str|list[str] opts: e.g. "english_cleaners", or "get_remove_chars(',/')" + :param str|list[str]|function opts: e.g. "english_cleaners", or "get_remove_chars(',/')" :return: function :rtype: (str)->str """ diff --git a/returnn/datasets/meta.py b/returnn/datasets/meta.py index 0fe21905d..34c6c4ea8 100644 --- a/returnn/datasets/meta.py +++ b/returnn/datasets/meta.py @@ -247,10 +247,10 @@ def __init__( self.seq_order_control_dataset = seq_order_control_dataset # This will only initialize datasets needed for features occuring in data_map - self.datasets = { + self.datasets: Dict[str, Dataset] = { key: init_dataset(datasets[key], extra_kwargs={"name": "%s_%s" % (self.name, key)}, parent_dataset=self) for key in self.dataset_keys - } # type: typing.Dict[str,Dataset] + } self._seq_list_file = seq_list_file self.seq_list_original = self._load_seq_list(seq_list_file) @@ -260,8 +260,8 @@ def __init__( self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])} - self._seq_lens = None # type: typing.Optional[typing.Dict[str,NumbersDict]] - self._num_timesteps = None # type: typing.Optional[NumbersDict] + self._seq_lens: Optional[Dict[str, NumbersDict]] = None + self._num_timesteps: Optional[NumbersDict] = None self._seq_lens_file = seq_lens_file if seq_lens_file: seq_lens = load_json(filename=seq_lens_file) @@ -290,7 +290,7 @@ def __init__( self.num_outputs = self.data_dims self.orig_seq_order_is_initialized = False - self.seq_list_ordered = None # type: typing.Optional[typing.Dict[str,typing.List[str]]] + self.seq_list_ordered: Optional[Dict[str, List[str]]] = None def _load_seq_list(self, seq_list_file: Optional[Union[str, Dict[str, str]]] = None) -> Dict[str, List[str]]: """ @@ -312,10 +312,7 @@ def _load_seq_list(self, seq_list_file: Optional[Union[str, Dict[str, str]]] = N try: seq_list = default_dataset.get_all_tags() except NotImplementedError: - raise NotImplementedError( - "Unsupported %s used as default in MetaDataset." - " Only datasets with known and tagged sequences can be used." % type(default_dataset) - ) + raise NotImplementedError(f"{default_dataset}.get_all_tags() required by {self}, but not implemented.") # Catch index out of bounds errors. # Whether the tags are actually valid will be checked in _check_dataset_seq(). @@ -774,7 +771,7 @@ def __init__(self, datasets: Sequence[Dict[str, Any]], **kwargs): for ds in self.datasets[1:]: assert ds.num_inputs == self.num_inputs assert ds.num_outputs == self.num_outputs - self.dataset_seq_idx_offsets = None # type: typing.Optional[typing.List[int]] + self.dataset_seq_idx_offsets: Optional[List[int]] = None def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ @@ -1020,9 +1017,9 @@ def __init__( for (dset_key, dset_data_key), data_key in data_map.items() } - self.dataset_seq_idx_boundaries = None # type: typing.Optional[typing.List[int]] - self.dataset_sorted_seq_idx_list = None # type: typing.Optional[typing.List[typing.Tuple[int,int]]] - self.used_num_seqs_per_subset = None # type: typing.Optional[typing.List[int]] + self.dataset_seq_idx_boundaries: Optional[List[int]] = None + self.dataset_sorted_seq_idx_list: Optional[List[Tuple[int, int]]] = None + self.used_num_seqs_per_subset: Optional[List[int]] = None def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ @@ -1183,9 +1180,9 @@ def _get_sampling_seq_order(self): :rtype: list[int] """ assert self.partition_epoch in [None, 1], "partition_epoch not supported in combination with sampling_sizes." - assert ( - self._seq_order_seq_lens_file is None - ), "seq_order_seq_lens_file not supported in combination with sampling_sizes." + assert self._seq_order_seq_lens_file is None, ( + "seq_order_seq_lens_file not supported in combination with sampling_sizes." + ) assert not self.unique_seq_tags, "unique_seq_tags not supported in combination with sampling_sizes." assert self.seq_tags_filter is None, "seq_order_seq_lens_file in combination with sampling_sizes." @@ -1448,7 +1445,7 @@ def __init__( self.repeat_in_between_last_frame_up_to_multiple_of = repeat_in_between_last_frame_up_to_multiple_of or {} self.pad_narrow_data_to_multiple_of_target_len = pad_narrow_data_to_multiple_of_target_len or {} if epoch_wise_filter is None: - self.epoch_wise_filter = None # type: Optional[EpochWiseFilter] + self.epoch_wise_filter: Optional[EpochWiseFilter] = None elif isinstance(epoch_wise_filter, dict): self.epoch_wise_filter = EpochWiseFilter(epoch_wise_filter) else: @@ -1474,10 +1471,8 @@ def __init__( self.seq_lens = eval(open(seq_len_file).read()) assert isinstance(self.seq_lens, dict) self.full_seq_len_list = self._get_full_seq_lens_list() - self.cur_seq_list = None # type: typing.Optional[typing.List[str]] # list of seq tags - self.cur_sub_seq_idxs = ( - None - ) # type: typing.Optional[typing.List[typing.List[int]]] # list of list of sub seq idxs + self.cur_seq_list: typing.Optional[typing.List[str]] = None # list of seq tags + self.cur_sub_seq_idxs: typing.Optional[typing.List[typing.List[int]]] = None # list of list of sub seq idxs def _get_full_seq_lens_list(self): """ @@ -1567,20 +1562,22 @@ def _collect_single_seq(self, seq_idx): if seq_idx == 0: # some extra check, but enough to do for first seq only sub_dataset_keys = self.dataset.get_data_keys() for key in self.remove_in_between_postfix: - assert ( - key in sub_dataset_keys - ), "%s: remove_in_between_postfix key %r not in sub dataset data-keys %r" % ( - self, - key, - sub_dataset_keys, + assert key in sub_dataset_keys, ( + "%s: remove_in_between_postfix key %r not in sub dataset data-keys %r" + % ( + self, + key, + sub_dataset_keys, + ) ) for key in self.repeat_in_between_last_frame_up_to_multiple_of: - assert ( - key in sub_dataset_keys - ), "%s: repeat_in_between_last_frame_up_to_multiple_of key %r not in sub dataset data-keys %r" % ( - self, - key, - sub_dataset_keys, + assert key in sub_dataset_keys, ( + "%s: repeat_in_between_last_frame_up_to_multiple_of key %r not in sub dataset data-keys %r" + % ( + self, + key, + sub_dataset_keys, + ) ) for key in self.pad_narrow_data_to_multiple_of_target_len: assert key in sub_dataset_keys, ( @@ -1590,15 +1587,16 @@ def _collect_single_seq(self, seq_idx): for sub_seq_idx, sub_seq_tag in zip(sub_seq_idxs, sub_seq_tags): self.dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1) sub_dataset_tag = self.dataset.get_tag(sub_seq_idx) - assert ( - sub_dataset_tag == sub_seq_tag - ), "%s: expected tag %r for sub seq idx %i but got %r, part of seq %i %r" % ( - self, - sub_seq_tag, - sub_seq_idx, - sub_dataset_tag, - seq_idx, - seq_tag, + assert sub_dataset_tag == sub_seq_tag, ( + "%s: expected tag %r for sub seq idx %i but got %r, part of seq %i %r" + % ( + self, + sub_seq_tag, + sub_seq_idx, + sub_dataset_tag, + seq_idx, + seq_tag, + ) ) for key in self.get_data_keys(): data = self.dataset.get_data(sub_seq_idx, key) @@ -1858,12 +1856,14 @@ class VariableDataset(Dataset): based on a user-provided function. """ - def __init__(self, *, get_dataset, dataset_lru_cache_size: int = 1, **kwargs): + def __init__(self, *, get_dataset, dataset_lru_cache_size: int = 1, always_same_tags: bool = False, **kwargs): """ :param get_dataset: function (*, epoch: int, **_) -> Dict[str,Any], will be called for every sub-epoch. It will cache the dataset(s) from the prev call (dataset_lru_cache_size), and if the dict is the same of those, it will not recreate the dataset. - :param dataset_lru_cache_size + :param dataset_lru_cache_size: + :param always_same_tags: whether all the datasets returned by ``get_dataset`` will have the same tags + (same :func:`get_all_tags`). """ from functools import lru_cache @@ -1872,6 +1872,7 @@ def __init__(self, *, get_dataset, dataset_lru_cache_size: int = 1, **kwargs): self._dataset_dict: Optional[Dict[str, Any]] = None self._dataset: Optional[Dataset] = None self._dataset_lru_cache_size = dataset_lru_cache_size + self._always_same_tags = always_same_tags self._make_dataset = lru_cache(maxsize=self._dataset_lru_cache_size)( lambda dataset_dict: init_dataset(dataset_dict, parent_dataset=self) ) @@ -1979,6 +1980,12 @@ def is_data_sparse(self, key: str) -> bool: """is data sparse""" return self._dataset.is_data_sparse(key) + def get_all_tags(self) -> List[str]: + """all tags""" + if self._always_same_tags: + return self._dataset.get_all_tags() + raise OptionalNotImplementedError(f"{self}.get_all_tags(): always_same_tags=False, thus could be inconsistent") + class MultiEpochDataset(CachedDataset2): """ diff --git a/returnn/datasets/normalization_data.py b/returnn/datasets/normalization_data.py index 58b017ec7..43bb4b452 100644 --- a/returnn/datasets/normalization_data.py +++ b/returnn/datasets/normalization_data.py @@ -169,7 +169,7 @@ def _updateTotalSum(totalSum, intermediateSum): sumErr = np.sum(np.abs(newSum - oldSum - intermediateSum)) if sumErr > NormalizationData.SUMMATION_PRECISION: raise FloatingPointError( - "sums have very different orders of magnitude." " summation error = {}".format(sumErr) + "sums have very different orders of magnitude. summation error = {}".format(sumErr) ) return newSum diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index e477d7abe..abb43b095 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -141,13 +141,14 @@ def __init__( self._map_seq_stream = map_seq_stream if map_seq_stream_preserves_num_seqs is None and map_seq_stream is not None: map_seq_stream_preserves_num_seqs = getattr(map_seq_stream, "preserves_num_seqs", None) - self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs or False + assert map_seq_stream_preserves_num_seqs is None or isinstance(map_seq_stream_preserves_num_seqs, bool) + self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs self._map_outputs = map_outputs self._rng = RandomState(self._get_random_seed_for_epoch(0)) self._seq_list_for_validation: Optional[List[str]] = None self._dataset = init_dataset(self._dataset_def, parent_dataset=self) - if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs: + if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True: # if the stream mapper is set, the num_seqs may change and the estimation is less accurate self._estimated_num_seqs = self._dataset.estimated_num_seqs self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None @@ -210,7 +211,7 @@ def init_seq_order( self._data_iter = enumerate(self._build_mapping_iter()) self._data_iter_produced_num_seqs = 0 self._seq_list_for_validation = seq_list - if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs: + if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True: # If we don't have an iterable mapper (or the user explicitly specifies this), # we know the number of segments exactly equals the number of segments in the wrapped dataset try: @@ -243,6 +244,13 @@ def get_total_num_seqs(self, *, fast=False): assert self._dataset is not None return self._dataset.get_total_num_seqs(fast=fast) + def get_all_tags(self) -> List[str]: + """:return: all tags""" + if self._map_seq_stream is not None: + raise util.OptionalNotImplementedError(f"{self}: get_all_tags not allowed when map_seq_stream is set.") + assert self._dataset is not None + return self._dataset.get_all_tags() + def supports_sharding(self) -> bool: """:return: whether this dataset supports sharding""" assert self._dataset is not None @@ -310,9 +318,9 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi data_iter = self._iterate_dataset() if self._map_seq_stream is not None: data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs()) - assert isinstance( - data_iter, Iterator - ), f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}" + assert isinstance(data_iter, Iterator), ( + f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}" + ) return _validate_tensor_dict_iter(data_iter) def _iterate_dataset(self) -> Iterator[TensorDict]: @@ -341,9 +349,9 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: tensor_dict = self._map_seq( tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs() ) - assert isinstance( - tensor_dict, TensorDict - ), f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" + assert isinstance(tensor_dict, TensorDict), ( + f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" + ) # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped # since we don't add/drop any segments w/ the non-iterator postprocessing function. @@ -359,9 +367,9 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: if self._seq_list_for_validation is not None: seq_tag = self._seq_list_for_validation[seq_index] tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() - assert ( - tag_of_seq == seq_tag - ), f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" + assert tag_of_seq == seq_tag, ( + f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" + ) yield tensor_dict seq_index += 1 diff --git a/returnn/datasets/sprint.py b/returnn/datasets/sprint.py index 346a800e0..359453517 100644 --- a/returnn/datasets/sprint.py +++ b/returnn/datasets/sprint.py @@ -393,13 +393,14 @@ def add_new_data(self, features, targets=None, segment_name=None): targets = {"classes": targets} if "classes" in targets: # 'classes' is always the alignment - assert targets["classes"].shape == ( - reduce_num_frames, - ), "Number of targets %s does not match number of features %s (reduce factor %d)" % ( - # is in format (time,) - targets["classes"].shape, - (num_frames,), - self.reduce_target_factor, + assert targets["classes"].shape == (reduce_num_frames,), ( + "Number of targets %s does not match number of features %s (reduce factor %d)" + % ( + # is in format (time,) + targets["classes"].shape, + (num_frames,), + self.reduce_target_factor, + ) ) if "speaker_name" in targets: targets["speaker_name"] = targets["speaker_name"].strip() diff --git a/returnn/datasets/util/strings.py b/returnn/datasets/util/strings.py index 45896d958..bb1a52c93 100644 --- a/returnn/datasets/util/strings.py +++ b/returnn/datasets/util/strings.py @@ -2,7 +2,6 @@ Operations on strings. """ - from __future__ import annotations import numpy diff --git a/returnn/datasets/util/vocabulary.py b/returnn/datasets/util/vocabulary.py index 18d336446..8ed21c111 100644 --- a/returnn/datasets/util/vocabulary.py +++ b/returnn/datasets/util/vocabulary.py @@ -17,7 +17,6 @@ import sys import numpy -from returnn.log import log from returnn.util.basic import NotSpecified @@ -157,7 +156,7 @@ def set_random_seed(self, seed: int): def _parse_vocab(self): """ - Sets self.vocab, self.labels, self.num_labels. + Sets self._vocab, self._labels, self.num_labels. """ filename = self.vocab_file if self._labels is not None: @@ -167,34 +166,41 @@ def _parse_vocab(self): self._vocab, self._labels = self._cache[filename] self.num_labels = len(self._labels) else: + labels_from_idx = None if filename.endswith(".pkl"): import pickle - d = pickle.load(open(filename, "rb")) + labels_to_idx = pickle.load(open(filename, "rb")) else: if filename.endswith(".gz"): import gzip - file_content = gzip.open(filename, "rt").read() + file_content = gzip.open(filename, "rt", encoding="utf8").read() else: - file_content = open(filename, "r").read() + file_content = open(filename, "r", encoding="utf8").read() if file_content.startswith("{"): - d = eval(file_content) + labels_to_idx = eval(file_content) else: # Do line-based parsing. - lines = file_content.splitlines() - d = {line: i for (i, line) in enumerate(lines)} - assert isinstance(d, dict), f"{self}: expected dict, got {type(d).__name__} in {filename}" - labels = {idx: label for (label, idx) in sorted(d.items())} - min_label, max_label, num_labels = min(labels), max(labels), len(labels) - assert 0 == min_label - if num_labels - 1 < max_label: - print("Vocab error: not all indices used? max label: %i" % max_label, file=log.v1) - print("unused labels: %r" % ([i for i in range(max_label + 1) if i not in labels],), file=log.v2) - assert num_labels - 1 == max_label - self.num_labels = len(labels) - self._vocab = d - self._labels = [label for (idx, label) in sorted(labels.items())] + labels = file_content.splitlines() + labels_from_idx = {i: line for (i, line) in enumerate(labels)} + labels_to_idx = {line: i for (i, line) in enumerate(labels)} + assert isinstance(labels_to_idx, dict), ( + f"{self}: expected dict, got {type(labels_to_idx).__name__} in {filename}" + ) + if labels_from_idx is None: + labels_from_idx = {idx: label for (label, idx) in sorted(labels_to_idx.items())} + min_label, max_label, num_labels = min(labels_from_idx), max(labels_from_idx), len(labels_from_idx) + if 0 != min_label or num_labels - 1 != max_label: + raise Exception( + f"Vocab error: not all indices used? min label idx {min_label}, max label idx {max_label}," + f" num labels {num_labels}, " + f" unused labels: {[i for i in range(max_label + 1) if i not in labels_from_idx]}." + "There are duplicates in the vocab." + ) + self.num_labels = len(labels_from_idx) + self._vocab = labels_to_idx + self._labels = [label for (idx, label) in sorted(labels_from_idx.items())] self._cache[filename] = (self._vocab, self._labels) @classmethod @@ -333,6 +339,26 @@ def get_seq_labels(self, seq: Union[List[int], numpy.ndarray]) -> str: labels = self.labels return " ".join(map(labels.__getitem__, seq)) + def serialize_labels(self, data: numpy.ndarray) -> str: + """ + Like :func:`get_seq_labels` but a bit more generic, to not just work on sequences, + but any shape. + + Also like :func:`Dataset.serialize_data` but even slightly more generic. + """ + if data.ndim == 0: + return self.id_to_label(data.item()) + if data.ndim == 1: + return self.get_seq_labels(data) + + def _s(d_: numpy.ndarray) -> str: + assert d_.ndim >= 1 + if d_.ndim == 1: + return ",".join(self._labels[i] for i in d_) + return ",".join(f"[{_s(d_[t])}]" for t in range(d_.shape[0])) + + return _s(data) + class BytePairEncoding(Vocabulary): """ @@ -479,7 +505,13 @@ def __init__(self, **opts): """ import sentencepiece as spm # noqa + opts = opts.copy() + for k in ["model_file", "model_proto"]: + if k in opts: + # Make sure it is a string. (Could be e.g. Sis Path.) + opts[k] = str(opts[k]) self._opts = opts + opts = opts.copy() self._cache_key = opts.get("model_file", None) control_symbols = opts.pop("control_symbols", None) user_defined_symbols = opts.pop("user_defined_symbols", None) diff --git a/returnn/extern/graph_editor/subgraph.py b/returnn/extern/graph_editor/subgraph.py index 17acf6dbb..1781ad5e8 100644 --- a/returnn/extern/graph_editor/subgraph.py +++ b/returnn/extern/graph_editor/subgraph.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SubGraphView: a subgraph view on an existing tf.Graph. -""" +"""SubGraphView: a subgraph view on an existing tf.Graph.""" from __future__ import annotations diff --git a/returnn/extern/graph_editor/transform.py b/returnn/extern/graph_editor/transform.py index d0aebfd62..751e765ec 100644 --- a/returnn/extern/graph_editor/transform.py +++ b/returnn/extern/graph_editor/transform.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Class to transform an subgraph into another. -""" +"""Class to transform an subgraph into another.""" from __future__ import annotations diff --git a/returnn/extern/graph_editor/util.py b/returnn/extern/graph_editor/util.py index b3f3cd1e6..ad1646401 100644 --- a/returnn/extern/graph_editor/util.py +++ b/returnn/extern/graph_editor/util.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utility functions for the graph_editor. -""" +"""Utility functions for the graph_editor.""" from __future__ import annotations diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index 8161d76fc..17fa976f6 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -1509,9 +1509,10 @@ def get_backend_by_raw_tensor_type(tensor_type: Type[T]) -> Union[Type[Backend[T else: continue - assert any( - issubclass(base_type, type_) for type_ in tensor_types - ), f"tensor type {tensor_type} base_type {base_type} not in {tensor_types}, expected for backend {backend_type}" + assert any(issubclass(base_type, type_) for type_ in tensor_types), ( + f"tensor type {tensor_type} base_type {base_type} not in {tensor_types}, " + f"expected for backend {backend_type}" + ) for base_type_ in tensor_types: register_backend_by_tensor_type(base_type_, backend_type) return backend_type diff --git a/returnn/frontend/_utils.py b/returnn/frontend/_utils.py index cff3a410e..816c37f93 100644 --- a/returnn/frontend/_utils.py +++ b/returnn/frontend/_utils.py @@ -263,7 +263,7 @@ def _slice_find_sparse_dim(v: Union[Tensor, slice, Any]) -> Optional[Dim]: def _map_slice_value_raw( - v: Union[None, slice, int, numpy.number, numpy.ndarray, Tensor[T]] + v: Union[None, slice, int, numpy.number, numpy.ndarray, Tensor[T]], ) -> Union[None, slice, int, numpy.number, T]: if v is None: return None diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index 32f0a6894..15b4ed944 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -39,6 +39,7 @@ "pad_packed", "gather", "scatter", + "scatter_mean", "scatter_argmax", "scatter_logsumexp", "scatter_logmeanexp", @@ -807,8 +808,8 @@ def scatter( :param source: [batch_dims..., indices_dim(s)..., feature_dims...] :param indices: [batch_dims..., indices_dim(s)...] -> out_dim :param indices_dim: - :param mode: "sum", "max", "min", "logsumexp", "logmeanexp", "argmax". - (Note: If you ever need mean, argmin, etc, please open an issue/PR.) + :param mode: "sum", "max", "min", "mean", "logsumexp", "logmeanexp", "argmax". + (Note: If you ever need another mode, please open an issue/PR.) :param fill_value: :param out_dim: The indices target dim. If not given, will be automatically determined as the sparse_dim from indices. @@ -817,6 +818,8 @@ def scatter( :param use_mask: :return: [batch_dims..., out_dim(s)..., feature_dims...] """ + if mode == "mean": + return scatter_mean(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim) if mode == "logsumexp": return scatter_logsumexp( source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim @@ -863,6 +866,35 @@ def scatter( return out +def scatter_mean( + source: Tensor, + *, + indices: Tensor, + indices_dim: Union[Dim, Sequence[Dim]], + fill_value: Optional[Union[int, float]] = None, + out_dim: Optional[Union[Dim, Sequence[Dim]]] = None, +) -> Tensor: + """ + Scatters into new zero-tensor. + If entries in indices are duplicated, the corresponding values in source will be mean'ed together. + This is like :func:`scatter` with ``mode="mean"``. + + :param source: [batch_dims..., indices_dim(s)..., feature_dims...] + :param indices: [batch_dims..., indices_dim(s)...] -> out_dim + :param indices_dim: + :param fill_value: + :param out_dim: The indices target dim. + If not given, will be automatically determined as the sparse_dim from indices. + If multiple out dims, use indices into the merged out dims, + and then we use :func:`rf.split_dims` afterwards. + :return: [batch_dims..., out_dim(s)..., feature_dims...] + """ + ones = rf.ones(dims=indices.dims, dtype=source.dtype, device=source.device) + counts = rf.scatter(ones, indices=indices, indices_dim=indices_dim, fill_value=1, out_dim=out_dim) + y = scatter(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim) + return y / counts + + def scatter_argmax( source: Tensor, *, diff --git a/returnn/frontend/audio/mel.py b/returnn/frontend/audio/mel.py index e924e40fd..a54bcdc85 100644 --- a/returnn/frontend/audio/mel.py +++ b/returnn/frontend/audio/mel.py @@ -2,7 +2,6 @@ Mel filterbank. """ - from __future__ import annotations from typing import Optional, Union, Tuple import functools diff --git a/returnn/frontend/audio/specaugment.py b/returnn/frontend/audio/specaugment.py index e66f217cd..9882c7e11 100644 --- a/returnn/frontend/audio/specaugment.py +++ b/returnn/frontend/audio/specaugment.py @@ -67,7 +67,9 @@ def _mask_branch(): ) return x_masked - return rf.cond(rf.get_run_ctx().train_flag | (not only_on_train), _mask_branch, lambda: x) + return rf.cond( + rf.get_run_ctx().is_train_flag_enabled(func=specaugment) | (not only_on_train), _mask_branch, lambda: x + ) def random_mask( diff --git a/returnn/frontend/const.py b/returnn/frontend/const.py index 4b5cab35f..0c1a25927 100644 --- a/returnn/frontend/const.py +++ b/returnn/frontend/const.py @@ -54,9 +54,9 @@ def full( "Use rf.convert_to_tensor to convert an arbitrary array to a tensor." ) if isinstance(fill_value, Tensor): - assert ( - fill_value.dims == () - ), f"full/fill/constant: expect scalar fill_value, got tensor with shape {fill_value.dims}." + assert fill_value.dims == (), ( + f"full/fill/constant: expect scalar fill_value, got tensor with shape {fill_value.dims}." + ) return global_backend.full( dims, fill_value, dtype=dtype, device=device, sparse_dim=sparse_dim, feature_dim=feature_dim ) diff --git a/returnn/frontend/device.py b/returnn/frontend/device.py index 7664fa7da..5f111ea46 100644 --- a/returnn/frontend/device.py +++ b/returnn/frontend/device.py @@ -2,7 +2,6 @@ Device handling. """ - from __future__ import annotations from typing import Optional from contextlib import contextmanager diff --git a/returnn/frontend/dropout.py b/returnn/frontend/dropout.py index 1d0315cf5..8b645f044 100644 --- a/returnn/frontend/dropout.py +++ b/returnn/frontend/dropout.py @@ -50,7 +50,7 @@ def dropout( raise ValueError(f"dropout axis {axis} not in source {source}") if isinstance(keep_prob, (float, int)) and not 0 < keep_prob <= 1: - raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) + raise ValueError("keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob) # Do nothing if we know keep_prob == 1 if isinstance(keep_prob, (float, int)) and keep_prob == 1: @@ -60,7 +60,7 @@ def dropout( return _dropout(source, keep_prob, noise_dims=noise_dims) return rf.cond( - pred=rf.get_run_ctx().train_flag, + pred=rf.get_run_ctx().is_train_flag_enabled(func=dropout), true_fn=lambda: _dropout(source, keep_prob, noise_dims=noise_dims), false_fn=lambda: source, ) diff --git a/returnn/frontend/encoder/e_branchformer.py b/returnn/frontend/encoder/e_branchformer.py index a2613c3cd..6686667ff 100644 --- a/returnn/frontend/encoder/e_branchformer.py +++ b/returnn/frontend/encoder/e_branchformer.py @@ -268,7 +268,7 @@ def __call__(self, x1: Tensor, x2: Tensor, *, spatial_dim: Dim) -> Tensor: def _make_activation( - activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] + activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module], ) -> Union[Callable[[Tensor], Tensor], rf.Module]: if isinstance(activation, dict): activation = rf.build_from_dict(activation) diff --git a/returnn/frontend/loop.py b/returnn/frontend/loop.py index d290f85c6..9f88b85a8 100644 --- a/returnn/frontend/loop.py +++ b/returnn/frontend/loop.py @@ -273,9 +273,9 @@ def _check(path, template, x): x._push_back_delayed_check() else: # other cases: just check same type - assert type(template) is type( - x - ), f"loop var {path} template type {type(template)} does not match var type {type(x)}" + assert type(template) is type(x), ( + f"loop var {path} template type {type(template)} does not match var type {type(x)}" + ) assert not isinstance(x, Tensor), f"loop var {path} is a Tensor but should not be" tree.map_structure_with_path(_check, loop_var_templates, loop_vars) diff --git a/returnn/frontend/loss.py b/returnn/frontend/loss.py index 0098bbd64..b5c90f88a 100644 --- a/returnn/frontend/loss.py +++ b/returnn/frontend/loss.py @@ -137,7 +137,6 @@ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, # We are going diagonal over (Ta+1) and (Tb+1). (Similar as RETURNN native EditDistanceOp.) # You need to draw the grid on paper to understand all the index math... for u in range(1, n_a_max_len + n_b_max_len + 1): - prev2_dist, _ = rf.slice( buffer, axis=buffer_dim, start=buffer_offsets[u % 3], size=b_spatial_dim1, out_dim=b_spatial_dim1 ) # [Tb+1,B] diff --git a/returnn/frontend/matmul.py b/returnn/frontend/matmul.py index 396ff5d6e..222a37e64 100644 --- a/returnn/frontend/matmul.py +++ b/returnn/frontend/matmul.py @@ -2,7 +2,6 @@ Dot / matmul """ - from __future__ import annotations from typing import Sequence, Union, TypeVar from returnn.tensor import Tensor, Dim diff --git a/returnn/frontend/normalization.py b/returnn/frontend/normalization.py index 16bb70afc..ec09dbc2a 100644 --- a/returnn/frontend/normalization.py +++ b/returnn/frontend/normalization.py @@ -226,8 +226,9 @@ def __call__(self, source: Tensor) -> Tensor: if use_mask: # Generic implementation which supports masking. - use_current_batch_stats = self.running_mean is None or rf.get_run_ctx().train_flag - update_running_stats = self.running_mean is not None and rf.get_run_ctx().train_flag + train_flag = rf.get_run_ctx().is_train_flag_enabled(func=BatchNorm.__call__) + use_current_batch_stats = self.running_mean is None or train_flag + update_running_stats = self.running_mean is not None and train_flag need_current_batch_stats = rf.opt_logical_or(use_current_batch_stats, update_running_stats) mean_cur_batch, variance_cur_batch = rf.cond( diff --git a/returnn/frontend/parametrizations.py b/returnn/frontend/parametrizations.py index 466f685a7..32751d6e7 100644 --- a/returnn/frontend/parametrizations.py +++ b/returnn/frontend/parametrizations.py @@ -48,7 +48,7 @@ def _on_train() -> Tensor: # on_forward=True because we already checked for train_flag return rf.dropout(param, drop_prob=self.drop_prob, on_forward=True) - return rf.cond(rf.get_run_ctx().train_flag, _on_train, lambda: param) + return rf.cond(rf.get_run_ctx().is_train_flag_enabled(func=WeightDropout.__call__), _on_train, lambda: param) def weight_noise(module: rf.Module, param_name: str, *, std: float) -> rf.Module: @@ -84,4 +84,4 @@ def _on_train() -> Tensor: noise = rf.random_normal(param.dims, dtype=param.dtype, stddev=self.std) return param + noise - return rf.cond(rf.get_run_ctx().train_flag, _on_train, lambda: param) + return rf.cond(rf.get_run_ctx().is_train_flag_enabled(func=WeightNoise.__call__), _on_train, lambda: param) diff --git a/returnn/frontend/rec.py b/returnn/frontend/rec.py index 910c9e036..2d8e30f3d 100644 --- a/returnn/frontend/rec.py +++ b/returnn/frontend/rec.py @@ -218,7 +218,7 @@ def _zoneout(*, prev: Tensor, cur: Tensor, factor: float, out_dim: Dim, dropout_ if factor == 0.0: return cur return rf.cond( - rf.get_run_ctx().train_flag, + rf.get_run_ctx().is_train_flag_enabled(func=ZoneoutLSTM.__call__), lambda: (1 - factor) * rf.dropout(cur - prev, factor, axis=dropout_broadcast and out_dim) + prev, lambda: (1 - factor) * cur + factor * prev, ) diff --git a/returnn/frontend/reduce.py b/returnn/frontend/reduce.py index 94fdca5bc..4d78151c3 100644 --- a/returnn/frontend/reduce.py +++ b/returnn/frontend/reduce.py @@ -251,7 +251,11 @@ def _update_running_stats(): x_ = rf.reduce_mean(x, axis=[d for d in x.dims if d not in self.shape]) self.mean.assign_add(self.alpha * (x_ - self.mean)) - rf.cond((not self.update_only_in_train) or rf.get_run_ctx().train_flag, _update_running_stats, lambda: None) + rf.cond( + (not self.update_only_in_train) or rf.get_run_ctx().is_train_flag_enabled(func=RunningMean.__call__), + _update_running_stats, + lambda: None, + ) return self.mean diff --git a/returnn/frontend/run_ctx.py b/returnn/frontend/run_ctx.py index 53f56aa96..4fea67770 100644 --- a/returnn/frontend/run_ctx.py +++ b/returnn/frontend/run_ctx.py @@ -7,7 +7,8 @@ """ from __future__ import annotations -from typing import Optional, Union, Any, Sequence, Dict +from typing import Optional, Union, Any, Callable, Sequence, Dict, List +from types import FunctionType from dataclasses import dataclass from contextlib import contextmanager from returnn.tensor import Tensor, Dim, TensorDict, batch_dim @@ -101,7 +102,7 @@ def __init__( - "forward_step", for mark_as_output """ self._stage = stage - self._train_flag = train_flag + self._train_flags_stack: List[Dict[Optional[FunctionType], Union[Tensor, bool]]] = [{None: train_flag}] self._step = step self._epoch = epoch self.losses = {} # type: Dict[str, Loss] @@ -121,14 +122,17 @@ def stage(self) -> str: @property def train_flag(self) -> Union[bool, Tensor]: """ - :return: whether we are in training mode, i.e. the model is updated, - and we are supposed to use dropout and similar mechanisms. - In a graph-based backend, this can be dynamic. + :return: ``is_train_flag_enabled(func=None)``. See :func:`is_train_flag_enabled`. """ - return self._train_flag + return self.is_train_flag_enabled(func=None) @contextmanager - def train_flag_ctx(self, train_flag: Union[bool, Tensor]): + def train_flag_ctx( + self, + train_flag: Union[bool, Tensor], + *, + func: Optional[Union[Sequence[Union[FunctionType, Callable]], FunctionType, Callable]] = None, + ): """ Context manager to temporarily set the train_flag. @@ -137,14 +141,51 @@ def train_flag_ctx(self, train_flag: Union[bool, Tensor]): with rf.get_run_ctx().train_flag_ctx(False): ... - :param train_flag: whether we are in training mode - """ - old_train_flag = self.train_flag - self._train_flag = train_flag + :param train_flag: whether we are in training mode. + In a graph-based backend, this can be dynamic (scalar Tensor, not just bool). + :param func: if given, the train flag is only enabled/disabled for this specific function(s) + (e.g. ``rf.dropout`` or ``rf.BatchNorm.__call__``). + (See https://github.com/rwth-i6/returnn/issues/1712 for some discussion.) + (Note: We expect a Python function, not just any general Callable. But typing seems to get this wrong.) + """ + old_train_flags = self._train_flags_stack[-1] + new_train_flags = old_train_flags.copy() + if func is None: + new_train_flags[None] = train_flag + elif isinstance(func, FunctionType): + new_train_flags[func] = train_flag + elif isinstance(func, (list, tuple)): + for f in func: + if not isinstance(f, FunctionType): + raise TypeError(f"Expected function, got {type(f)}") + new_train_flags[f] = train_flag + else: + raise TypeError(f"Expected function or sequence of functions, got {type(func)}") + self._train_flags_stack.append(new_train_flags) try: yield finally: - self._train_flag = old_train_flag + last = self._train_flags_stack.pop(-1) + assert last is new_train_flags + assert len(self._train_flags_stack) >= 1 + + def is_train_flag_enabled(self, *, func: Optional[Union[FunctionType, Callable]]) -> Union[bool, Tensor]: + """ + :param func: function for which we want to check the train flag + (e.g. ``rf.dropout`` or ``rf.BatchNorm.__call__``), + or None for the global fallback. + (See https://github.com/rwth-i6/returnn/issues/1712 for some discussion.) + :return: Whether the train flag is enabled, either for the specific function, or globally. + Training is usually when the model is updated, + and we are supposed to use dropout and similar mechanisms. + This is either for the specified function, or globally. + In a graph-based backend, this can also be dynamic (scalar Tensor, not just bool). + """ + train_flags = self._train_flags_stack[-1] + if func in train_flags: + return train_flags[func] + assert isinstance(func, FunctionType) + return train_flags[None] # global fallback. this should always be defined, see __init__ @property def step(self) -> Union[int, Tensor]: @@ -265,19 +306,19 @@ def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *, dims: Optiona assert self.stage == "forward_step" if self.expected_outputs is not None: - assert ( - name in self.expected_outputs.data - ), f"mark_as_output: unexpected output {name!r}, we expect outputs: {self.expected_outputs}" + assert name in self.expected_outputs.data, ( + f"mark_as_output: unexpected output {name!r}, we expect outputs: {self.expected_outputs}" + ) expected_output = self.expected_outputs.data[name] if self.expected_outputs else None - assert dims is None or ( - isinstance(dims, (list, tuple)) and all(isinstance(dim, Dim) for dim in dims) - ), f"dims should be a tuple of Dims, got {dims}" + assert dims is None or (isinstance(dims, (list, tuple)) and all(isinstance(dim, Dim) for dim in dims)), ( + f"dims should be a tuple of Dims, got {dims}" + ) if dims is None and expected_output is not None: dims = expected_output.dims if dims is not None and expected_output is not None: - assert expected_output.dims == tuple( - dims - ), f"mark_as_output: {name!r} dims mismatch from expected output, given {dims}, expected {expected_output}" + assert expected_output.dims == tuple(dims), ( + f"mark_as_output: {name!r} dims mismatch from expected output, given {dims}, expected {expected_output}" + ) if not isinstance(tensor, Tensor): assert isinstance(tensor, _backend.global_backend.RawTensorType) diff --git a/returnn/frontend/signal.py b/returnn/frontend/signal.py index a4f4bb758..f9199b5f0 100644 --- a/returnn/frontend/signal.py +++ b/returnn/frontend/signal.py @@ -2,7 +2,6 @@ stft etc """ - from __future__ import annotations from typing import Optional, Tuple from returnn.tensor import Tensor, Dim diff --git a/returnn/frontend/types.py b/returnn/frontend/types.py index b0406dde5..6de8cd2c0 100644 --- a/returnn/frontend/types.py +++ b/returnn/frontend/types.py @@ -19,15 +19,13 @@ class GetModelFunc(Protocol): """get model func""" - def __call__(self, *, epoch: int, step: int) -> rf.Module: - ... + def __call__(self, *, epoch: int, step: int) -> rf.Module: ... class StepFunc(Protocol): """step func""" - def __call__(self, *, model: rf.Module, extern_data: TensorDict) -> None: - ... + def __call__(self, *, model: rf.Module, extern_data: TensorDict) -> None: ... def get_raw_tensor_type() -> Type: diff --git a/returnn/native_op.py b/returnn/native_op.py index 0bb384b92..4ee7e1608 100644 --- a/returnn/native_op.py +++ b/returnn/native_op.py @@ -291,6 +291,7 @@ class LstmGenericBase(NativeOpGenBase): :param H: gates and cell state. 3d (time,batch,dim*4) :param d: final cell state. 2d (batch,dim) """ + in_info = ( { "name": "Z", @@ -542,6 +543,7 @@ class LstmLowMem(NativeOpGenBase): :param C: cell states. 3d (time,batch,dim). gradient ignored! :param d: final cell state. 2d (batch,dim) """ + in_info = ( {"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True}, {"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True}, @@ -994,6 +996,7 @@ class NativeLstm2(NativeOpGenBase): :param H: cell-in + gates. 3d (time,batch,dim*4). gradient ignored! :param d: final cell state. 2d (batch,dim) """ + in_info = ( {"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True}, {"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True}, @@ -1423,6 +1426,7 @@ class TwoDLSTM(NativeOpGenBase): :param H: gates and cell state. 4d (timeS,timeT,batch,dim*5) ? :param d: final cell state. 3d (timeT,batch,dim) """ + in_info = ( { "name": "X", @@ -3198,6 +3202,7 @@ class FastBaumWelchOp(NativeOpGenBase): outputs: :param output: Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores """ + in_info = ( { "name": "am_scores", @@ -3620,6 +3625,7 @@ class MultiEndFastBaumWelchOp(NativeOpGenBase): outputs: :param output: Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores """ + in_info = ( { "name": "am_scores", @@ -4497,6 +4503,7 @@ class FastViterbiOp(NativeOpGenBase): :param output: Viterbi (hard) alignment, scores in +log space. 2d (time,batch) :param scores: (batch,) """ + in_info = ( { "name": "am_scores", @@ -4865,6 +4872,7 @@ class GetCtcFsaFastBwOp(NativeOpGenBase): `num_edges` should be `n_batch * (5 * (n_time - 1) + 10)` (see construction in kernel why that number). """ + in_info = ( { "name": "targets", @@ -5229,6 +5237,7 @@ class EditDistanceOp(NativeOpGenBase): outputs: :param output: 1d (batch,), int32, unnormalized edit distance """ + in_info = ( { "name": "a", @@ -5414,6 +5423,7 @@ class OptimalCompletionEditDistanceOp(NativeOpGenBase): outputs: :param output: 1d (batch,), int32, unnormalized edit distance """ + in_info = ( { "name": "a", @@ -5610,6 +5620,7 @@ class OptimalCompletionEditDistancePerSuccessorOp(NativeOpGenBase): outputs: :param output: 2d (batch,num_labels), int32, unnormalized edit distance """ + in_info = ( { "name": "a", @@ -5880,6 +5891,7 @@ class NextEditDistanceRowOp(NativeOpGenBase): outputs: :param output: 2d (batch,b_time + 1), int32, next (unnormalized) edit distance row """ + in_info = ( { "name": "last_row", @@ -6039,6 +6051,7 @@ class NextEditDistanceReduceOp(NativeOpGenBase): outputs: :param output: 2d (batch,n_labels), int32, next (unnormalized) (maybe optional) edit distance """ + in_info = ( { "name": "last_row", diff --git a/returnn/sprint/cache.py b/returnn/sprint/cache.py index 103e4db05..d9780eba3 100644 --- a/returnn/sprint/cache.py +++ b/returnn/sprint/cache.py @@ -7,7 +7,7 @@ """ from __future__ import annotations -from typing import List +from typing import List, Optional, Tuple import sys import os import typing @@ -904,9 +904,7 @@ def __init__(self, filename): self.densities[n, 1] = cov_idx self.num_mixtures = self.read_u32() - self.mixtures = [ - None - ] * self.num_mixtures # type: typing.List[typing.Optional[typing.Tuple[typing.List[int],typing.List[float]]]] # nopep8 + self.mixtures: List[Optional[Tuple[List[int], List[float]]]] = [None] * self.num_mixtures for n in range(self.num_mixtures): num_densities = self.read_u32() dns_idx = [] diff --git a/returnn/sprint/interface.py b/returnn/sprint/interface.py index 55c3443ba..69211a2ad 100644 --- a/returnn/sprint/interface.py +++ b/returnn/sprint/interface.py @@ -820,9 +820,9 @@ def _prepare_forwarding(): assert engine assert config # Should already be set via setTargetMode(). - assert config.list("extract") == [ - "posteriors" - ], "You need to have extract = posteriors in your RETURNN config. You have: %s" % config.list("extract") + assert config.list("extract") == ["posteriors"], ( + "You need to have extract = posteriors in your RETURNN config. You have: %s" % config.list("extract") + ) # Load network. engine.init_network_from_config(config) @@ -870,7 +870,6 @@ def _train(segment_name, features, targets=None): # The CRNN train thread started via start() will do the actual training. if TargetMode == "criterion-by-sprint": - # TODO... make_criterion_class() diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index 793e87545..9f31c7778 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -1067,13 +1067,14 @@ def set_tag_on_size_tensor(self: Dim, x, batch=None, same_as_before=False) -> Di ) ) if batch and getattr(x, "_RETURNN_dyn_size_beam", None): - assert batch.beam == getattr( - x, "_RETURNN_dyn_size_beam" - ), "%s: dyn size %s has unexpected batch %s, expected %s" % ( - self, - x, - batch, - getattr(x, "_RETURNN_dyn_size_beam"), + assert batch.beam == getattr(x, "_RETURNN_dyn_size_beam"), ( + "%s: dyn size %s has unexpected batch %s, expected %s" + % ( + self, + x, + batch, + getattr(x, "_RETURNN_dyn_size_beam"), + ) ) if self.batch and batch: assert self.batch == batch @@ -1359,8 +1360,7 @@ def is_equal( # Only auto-generated dim tags are allowed to be treated as broadcastable. # This was another suggestion from here: https://github.com/rwth-i6/returnn/issues/666 # It was not implemented like this because the auto_generated flag was only introduced later. - (self.dimension == 1 and self.auto_generated) - or (other.dimension == 1 and other.auto_generated) + (self.dimension == 1 and self.auto_generated) or (other.dimension == 1 and other.auto_generated) ): pass # pass on else: diff --git a/returnn/tensor/_tensor_extra.py b/returnn/tensor/_tensor_extra.py index 66aaf3389..5151af3c3 100644 --- a/returnn/tensor/_tensor_extra.py +++ b/returnn/tensor/_tensor_extra.py @@ -335,9 +335,9 @@ def sanity_check(self, ignore_placeholder=False, assume_complete=True): if tag.dyn_size_ext.placeholder is None: tag.complete_dyn_size() if self.placeholder is not None: - assert ( - tag.dyn_size_ext.placeholder is not None - ), "%s sanity_check: dynamic dim %s value unknown" % (self, tag) + assert tag.dyn_size_ext.placeholder is not None, ( + "%s sanity_check: dynamic dim %s value unknown" % (self, tag) + ) assert tag.is_dim_known() def get_runtime_sanity_check_op(self: Tensor): @@ -2494,8 +2494,7 @@ def get_axis_from_description(self, axis, allow_int=NotSpecified): if res_tag.match_priority > tag.match_priority: continue raise Exception( - f"{self}: get_axis_from_description({axis}) not unique." - f" use match_priority to resolve ambiguity" + f"{self}: get_axis_from_description({axis}) not unique. use match_priority to resolve ambiguity" ) if res_idx is None: raise Exception(f"{self}: get_axis_from_description({axis}) not found") @@ -2646,12 +2645,13 @@ def is_time_axis_dynamic(self): return self.batch_shape[self.time_dim_axis_excluding_batch] is None if self.time_dim_axis_excluding_batch in self.size_placeholder: return True - assert isinstance( - self.shape[self.time_dim_axis_excluding_batch], int - ), "%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" % ( - self, - self.time_dim_axis, - self.size_placeholder, + assert isinstance(self.shape[self.time_dim_axis_excluding_batch], int), ( + "%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" + % ( + self, + self.time_dim_axis, + self.size_placeholder, + ) ) return False @@ -3307,14 +3307,15 @@ def map_other_axis_to_self(other_axis: int, taken_self_axes: Set[int]) -> int: if self_axis not in taken_self_axes ] if opt == "unknown_spatial_matches": - assert ( - len(matching) <= 1 - ), "cannot match axes %s from %s to %s, failed at other %s, not unique after %s" % ( - other_axes, - other, - self, - other_axis, - opt, + assert len(matching) <= 1, ( + "cannot match axes %s from %s to %s, failed at other %s, not unique after %s" + % ( + other_axes, + other, + self, + other_axis, + opt, + ) ) if matching: break diff --git a/returnn/tensor/_tensor_op_overloads.py b/returnn/tensor/_tensor_op_overloads.py index 2ed98dd87..65f7be7fc 100644 --- a/returnn/tensor/_tensor_op_overloads.py +++ b/returnn/tensor/_tensor_op_overloads.py @@ -13,7 +13,6 @@ class _TensorOpOverloadsMixin(_TensorMixinBase): - # Note that all those ops have native implementations as well, # so keep the logic in sync. diff --git a/returnn/tensor/tensor.py b/returnn/tensor/tensor.py index c4a821a2a..e17eba6bb 100644 --- a/returnn/tensor/tensor.py +++ b/returnn/tensor/tensor.py @@ -187,7 +187,7 @@ def raw_tensor(self, value: Optional[RawTensorType]): if not backend.executing_eagerly(): backend.set_known_shape_raw(value, self.batch_shape) assert backend.get_dtype_name_raw(value) == self.dtype, ( - f"{self} dtype {self.dtype} does not match " f"raw tensor dtype {backend.get_dtype_name_raw(value)}" + f"{self} dtype {self.dtype} does not match raw tensor dtype {backend.get_dtype_name_raw(value)}" ) self._raw_tensor = value diff --git a/returnn/tensor/tensor_dict.py b/returnn/tensor/tensor_dict.py index 41ea8bf14..76d0e1aa5 100644 --- a/returnn/tensor/tensor_dict.py +++ b/returnn/tensor/tensor_dict.py @@ -91,9 +91,9 @@ def as_raw_tensor_dict( out = {} for key, value in self.data.items(): assert key not in out - assert isinstance( - value.raw_tensor, expected_value_type - ), f"key {key} {value}: unexpected {type(value.raw_tensor)}, expected {expected_value_type}" + assert isinstance(value.raw_tensor, expected_value_type), ( + f"key {key} {value}: unexpected {type(value.raw_tensor)}, expected {expected_value_type}" + ) out[key] = value.raw_tensor for i, dim in enumerate(value.dims): if exclude_duplicate_dims and dim in visited_dims: @@ -103,9 +103,9 @@ def as_raw_tensor_dict( if dim.is_batch_dim() and (dim.dyn_size_ext is None or dim.dyn_size_ext.raw_tensor is None): if include_scalar_dyn_sizes: dim_value = dim.get_dim_value() - assert isinstance( - dim_value, expected_value_type - ), f"key {key_} {dim}: unexpected {type(dim_value)}, expected {expected_value_type}" + assert isinstance(dim_value, expected_value_type), ( + f"key {key_} {dim}: unexpected {type(dim_value)}, expected {expected_value_type}" + ) out[key_] = dim_value elif dim.dyn_size_ext is not None: if include_scalar_dyn_sizes or dim.dyn_size_ext.dims: @@ -116,9 +116,9 @@ def as_raw_tensor_dict( out[key_] = dim.dyn_size_ext.raw_tensor elif dim.size is not None: if include_scalar_dyn_sizes and include_const_sizes: - assert isinstance( - dim.size, expected_value_type - ), f"key {key_} {dim}: unexpected {type(dim.size)}, expected {expected_value_type}" + assert isinstance(dim.size, expected_value_type), ( + f"key {key_} {dim}: unexpected {type(dim.size)}, expected {expected_value_type}" + ) out[key_] = dim.size else: raise Exception(f"cannot handle dim: {dim}") diff --git a/returnn/tf/engine.py b/returnn/tf/engine.py index 59b514409..a85c4adab 100644 --- a/returnn/tf/engine.py +++ b/returnn/tf/engine.py @@ -12,7 +12,7 @@ from __future__ import annotations -from typing import Optional +from typing import Callable, Dict, List, Optional, Union import typing import os import sys @@ -101,31 +101,29 @@ def __init__( self.store_tf_profile = engine.config.bool("store_tf_profile", False) self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0) self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0) - assert not ( - self.store_tf_profile and self.store_metadata_mod_step - ), "Cannot use store_tf_profile and store_metadata_mod_step at the same time" + assert not (self.store_tf_profile and self.store_metadata_mod_step), ( + "Cannot use store_tf_profile and store_metadata_mod_step at the same time" + ) self.finalized = False self.cancel_flag = False self.run_exception = None self.num_steps = None - self.device_crash_batch = None # type: typing.Optional[int] + self.device_crash_batch: Optional[int] = None self.start_time = None self.elapsed = None - self.report_prefix = None # type: typing.Optional[str] + self.report_prefix: Optional[str] = None self._results_accumulated = NumbersDict() # entries like "cost:output" or "loss" self._inv_norm_accumulated = NumbersDict() # entries like "output" self.num_frames_accumulated = NumbersDict() # for each data key (eg. "classes"), corresponding number of frames - self.results = {} # type: typing.Dict[str,float] # entries like "cost:output" or "loss" - self.score = {} # type: typing.Dict[str,float] # entries like "cost:output" - self.error = {} # type: typing.Dict[str,float] # entries like "error:output" - self.stats = ( - {} - ) # type: typing.Dict[str,typing.Union[float,numpy.ndarray,'Util.Stats']] # entries like "stats:..." + self.results: Dict[str, float] = {} # entries like "cost:output" or "loss" + self.score: Dict[str, float] = {} # entries like "cost:output" + self.error: Dict[str, float] = {} # entries like "error:output" + self.stats: Dict[str, Union[float, numpy.ndarray, "util.Stats"]] = {} # entries like "stats:..." self.extra_fetches = extra_fetches if extra_fetches is not None: assert extra_fetches_callback self.extra_fetches_callback = extra_fetches_callback - self._step_start_time = None # type: typing.Optional[float] + self._step_start_time: Optional[float] = None self._horovod_last_param_sync_time = time.time() # we assume it is synced right now self._horovod_stopped_runner = False self._horovod_finish_all = False @@ -133,9 +131,7 @@ def __init__( self._horovod_finish_all = True # With Horovod, during the main session.run, if reduce_type != grad or not training, # the following tensors are enough to ensure that we are in sync. - self._horovod_collected_reduce_inputs = ( - {} - ) # type: typing.Dict[str,(tf.Tensor,tf.Tensor)] # name -> (input,output) + self._horovod_collected_reduce_inputs: Dict[str, (tf.Tensor, tf.Tensor)] = {} # name -> (input,output) from returnn.util.basic import terminal_size @@ -196,9 +192,9 @@ def callback_on_new(): d["extra:%s" % k] = v continue assert isinstance(v, Data) - d[ - "extra:%s" % k - ] = v.placeholder # see _maybe_handle_extra_fetches, it will transform to batch-major there + d["extra:%s" % k] = ( + v.placeholder + ) # see _maybe_handle_extra_fetches, it will transform to batch-major there for i, s in v.size_placeholder.items(): d["extra:%s:size_%i" % (k, i)] = s @@ -732,9 +728,9 @@ def run(self, report_prefix: str, *, raise_exception: bool = False): run_options_.MergeFrom(run_options) # We could use tfdbg.add_debug_tensor_watch here. session_run_start_time = time.time() - fetches_results = sess.run( + fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run( fetches_dict, feed_dict=feed_dict, options=run_options_, run_metadata=run_metadata - ) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]] + ) elapsed_time_tf += time.time() - session_run_start_time writer.add_summary(fetches_results["summary"], step + step_offset) writer.add_run_metadata(run_metadata, "step_{:04d}".format(step + step_offset)) @@ -746,13 +742,13 @@ def run(self, report_prefix: str, *, raise_exception: bool = False): session_run_start_time = time.time() if self.store_tf_profile: with tf.profiler.experimental.Trace(name=report_prefix, step_num=step + step_offset): - fetches_results = sess.run( + fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run( fetches_dict, feed_dict=feed_dict, options=run_options - ) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]] + ) else: - fetches_results = sess.run( + fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run( fetches_dict, feed_dict=feed_dict, options=run_options - ) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]] + ) elapsed_time_tf += time.time() - session_run_start_time if writer and "summary" in fetches_results: writer.add_summary(fetches_results["summary"], step + step_offset) @@ -891,27 +887,27 @@ def __init__(self, config=None): BackendEngine.select_engine(default_fallback_engine=default_fallback_engine, config=self.config) assert BackendEngine.is_tensorflow_selected() self.orig_config = {} # see _maybe_update_config - self.custom_get_net_dict = None # type: typing.Optional[typing.Callable] + self.custom_get_net_dict: Optional[Callable] = None self._have_rf_get_model_func = False self._check_devices() - self.tf_session = None # type: typing.Optional[tf.compat.v1.Session] - self.network = None # type: typing.Optional[TFNetwork] - self.updater = None # type: typing.Optional[Updater] + self.tf_session: Optional[tf.compat.v1.Session] = None + self.network: Optional[TFNetwork] = None + self.updater: Optional[Updater] = None self._checked_uninitialized_vars = False self._merge_all_summaries = None - self.dataset_batches = {} # type: typing.Dict[str,BatchSetGenerator] - self.dataset_provider = None # type: typing.Optional[DatasetDataProvider] - self.train_data = None # type: typing.Optional[Dataset] - self.eval_datasets = {} # type: typing.Dict[str,Dataset] - self.start_epoch = None # type: typing.Optional[int] - self._num_trained_epochs = 0 # type: int # just a counter - self._num_net_reinit = 0 # type: int + self.dataset_batches: Dict[str, BatchSetGenerator] = {} + self.dataset_provider: Optional[DatasetDataProvider] = None + self.train_data: Optional[Dataset] = None + self.eval_datasets: Dict[str, Dataset] = {} + self.start_epoch: Optional[int] = None + self._num_trained_epochs: int = 0 # just a counter + self._num_net_reinit: int = 0 self.use_dynamic_train_flag = False self.use_search_flag = self.config.value("task", None) == "search" self.use_eval_flag = self.config.value("task", None) != "forward" - self._const_cache = {} # type: typing.Dict[str,tf.Tensor] - self.preload_from_files = None # type: typing.Optional[typing.Dict[str,typing.Dict[str]]] - self.max_seqs = None # type: typing.Optional[int] + self._const_cache: Dict[str, tf.Tensor] = {} + self.preload_from_files: Optional[Dict[str, Dict[str]]] = None + self.max_seqs: Optional[int] = None def finalize(self, error_occurred=False): """ @@ -1140,7 +1136,7 @@ def init_train_from_config(self, config=None, train_data=None, dev_data=None, ev self.min_seq_length = config.typed_value("min_seq_length", None) or config.float("min_seq_length", 0) self.inc_seq_length = config.float("inc_seq_length", 0) if not self.max_seq_length: - self.max_seq_length = sys.maxsize # type: typing.Union[int,float,typing.Dict[str,int],NumbersDict] + self.max_seq_length: Union[int, float, Dict[str, int], NumbersDict] = sys.maxsize if isinstance(self.max_seq_length, dict): self.max_seq_length = NumbersDict(self.max_seq_length) assert isinstance(self.max_seq_length, (int, float, NumbersDict)) @@ -1630,7 +1626,7 @@ def train(self): assert isinstance(self.start_epoch, int) epoch = self.start_epoch # Epochs start at 1. while epoch <= final_epoch: - self.epoch = epoch # type: int + self.epoch: int = epoch if isinstance(self.max_seq_length, int) and self.max_seq_length != sys.maxsize: if int(self.max_seq_length + self.inc_seq_length) != int(self.max_seq_length): print("increasing sequence lengths to", int(self.max_seq_length + self.inc_seq_length), file=log.v3) @@ -1878,9 +1874,9 @@ def _maybe_prepare_train_in_eval(self, targets_via_search=False): # We update the model params in-place. # In training, we don't want that, because it should not use the validation data. # We could reset it later when continuing the training, but it's not implemented. - assert ( - self.config.value("task", "train") != "train" - ), "task %r should be just 'eval' or so. training will break." % self.config.value("task", None) + assert self.config.value("task", "train") != "train", ( + "task %r should be just 'eval' or so. training will break." % self.config.value("task", None) + ) if not self.updater: self.updater = Updater( config=self.config, network=self.network, initial_learning_rate=self.initial_learning_rate @@ -1928,11 +1924,12 @@ def eval_model( allowed_outputs = {"seq_tag", "seq_len", "score", "error", "pos_score", "pos_error"} assert isinstance(output_per_seq_format, (tuple, list)), "provide output_per_seq_format" - assert ( - set(output_per_seq_format) - allowed_outputs == set() - ), "Only %r are allowed in function eval_model as output_per_seq_format, but got: %r " % ( - allowed_outputs, - output_per_seq_format, + assert set(output_per_seq_format) - allowed_outputs == set(), ( + "Only %r are allowed in function eval_model as output_per_seq_format, but got: %r " + % ( + allowed_outputs, + output_per_seq_format, + ) ) # always fetch seq_tag to map loss values to the corresponding line @@ -1968,12 +1965,10 @@ def eval_model( if "pos_error" in output_per_seq_format: extra_fetches["pos_error"] = loss_holder.get_error_value_per_pos() - seq_idx_to_tag = ( - {} - ) # type: typing.Dict[int,str] # we need this in order to write the results in the correct order later # nopep8 - results_per_seq = ( - {} - ) # type: typing.Dict[str,typing.Dict[str,typing.Union[float,str,int]]] # seq_tag -> dict. Results of fetches will be written in this dict # nopep8 + seq_idx_to_tag: Dict[int, str] = {} # we need this in order to write the results in the correct order later + results_per_seq: Dict[ + str, Dict[str, Union[float, str, int]] + ] = {} # seq_tag -> dict. Results of fetches will be written in this dict # function to save the return values of each callback to the dict `results_per_seq` # noinspection PyShadowingNames @@ -2012,7 +2007,7 @@ def extra_fetches_callback(seq_idx, seq_tags, **extra_fetches_out): if output_per_seq_file: assert len(self.get_eval_datasets()) == 1, ( - "output per sequence is only supported for one dataset (dev or eval)," "provided datasets are %r" + "output per sequence is only supported for one dataset (dev or eval),provided datasets are %r" ) % list(self.get_eval_datasets().keys()) # try to sort dataset to minimize zero-padding dataset = list(self.get_eval_datasets().values())[0] @@ -2453,9 +2448,9 @@ def search(self, dataset, do_eval=True, output_layer_names="output", output_file ) max_seq_length = self.config.typed_value("max_seq_length", None) or self.config.float("max_seq_length", 0) - assert ( - not max_seq_length - ), "Set max_seq_length = 0 for search (i.e. no maximal length). We want to keep all source sentences." + assert not max_seq_length, ( + "Set max_seq_length = 0 for search (i.e. no maximal length). We want to keep all source sentences." + ) dataset.init_seq_order(epoch=self.epoch) batches = dataset.generate_batches( @@ -2552,8 +2547,8 @@ def extra_fetches_callback(seq_idx, seq_tag, **kwargs): outputs[output_layer_idx] = bytearray(outputs[output_layer_idx]).decode("utf8") # Create lists with serialized data. All of length num_output_layers. - serialized_outputs = [] # type: typing.List[typing.Optional[typing.Union[str,numpy.ndarray]]] - serialized_targets = [] # type: typing.List[typing.Optional[typing.Union[str,numpy.ndarray]]] + serialized_outputs: List[Optional[Union[str, numpy.ndarray]]] = [] + serialized_targets: List[Optional[Union[str, numpy.ndarray]]] = [] # noinspection PyShadowingNames for output_layer_idx in range(num_output_layers): if output_layers[output_layer_idx].output.sparse: @@ -2572,8 +2567,8 @@ def extra_fetches_callback(seq_idx, seq_tag, **kwargs): ] else: serialized_output = None - assert not output_file, "Unable to serialize sparse output of layer '%s'." % ( - output_layer_names[output_layer_idx] + assert not output_file, ( + "Unable to serialize sparse output of layer '%s'." % (output_layer_names[output_layer_idx]) ) else: # Output dense layers as-is @@ -2594,8 +2589,8 @@ def extra_fetches_callback(seq_idx, seq_tag, **kwargs): ] else: serialized_target = None - assert not output_file, "Unable to serialize sparse target '%s'." % ( - target_keys[output_layer_idx] + assert not output_file, ( + "Unable to serialize sparse target '%s'." % (target_keys[output_layer_idx]) ) else: serialized_target = targets[output_layer_idx] diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index a9974e5bf..3dded483f 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -510,9 +510,9 @@ def set_parameter_initial_value(param: rf.Parameter[Layer], value: Union[None, T # We could also maybe move out all the dependencies. # However, it's not clear whether this is always safe. for dep in value.raw_tensor.get_tensor_dependencies(): - assert ( - dep.parent.can_access_children_from_root - ), f"dep {dep} of moved value {value} is not accessible" + assert dep.parent.can_access_children_from_root, ( + f"dep {dep} of moved value {value} is not accessible" + ) param.raw_tensor.layer_dict["init_by_layer"] = value else: param.raw_tensor.layer_dict.pop("init_by_layer", None) diff --git a/returnn/tf/frontend_layers/cond.py b/returnn/tf/frontend_layers/cond.py index bc9c5ee09..5c947d9f6 100644 --- a/returnn/tf/frontend_layers/cond.py +++ b/returnn/tf/frontend_layers/cond.py @@ -181,9 +181,9 @@ def false(self, false_value: T): After this, self.result is available. """ assert self._entered, f"{self} you need to be in the context scope" - assert ( - self._entered_state is False - ), f"{self} you need to be in the False branch, have assigned :func:`true` before" + assert self._entered_state is False, ( + f"{self} you need to be in the False branch, have assigned :func:`true` before" + ) assert not self._false_value_set nest.assert_same_structure(self._true_value, false_value) # This needs to match the true() setter logic. @@ -198,9 +198,9 @@ def false(self, false_value: T): if false_v is None: # see above false_v = rf.zeros((), dtype="int32") # dummy value else: - assert isinstance( - false_v, Tensor - ), f"unexpected {false_value!r}, only expects tensors, got {type(false_v)}" + assert isinstance(false_v, Tensor), ( + f"unexpected {false_value!r}, only expects tensors, got {type(false_v)}" + ) assert true_v.raw_tensor.parent is self.true_branch_name_ctx name = true_v.raw_tensor.name assert name not in self.false_branch_name_ctx.children diff --git a/returnn/tf/frontend_layers/debug_eager_mode.py b/returnn/tf/frontend_layers/debug_eager_mode.py index 6eff1e5e0..11633f7de 100644 --- a/returnn/tf/frontend_layers/debug_eager_mode.py +++ b/returnn/tf/frontend_layers/debug_eager_mode.py @@ -2,7 +2,6 @@ Debug eager mode """ - _debug_eager_mode_enabled = False diff --git a/returnn/tf/frontend_layers/layer.py b/returnn/tf/frontend_layers/layer.py index d41d729be..dea367a25 100644 --- a/returnn/tf/frontend_layers/layer.py +++ b/returnn/tf/frontend_layers/layer.py @@ -1104,13 +1104,13 @@ def make_net_dict_raw(self, net: Net, *, _stack: Optional[_StackInfo] = None) -> # If dyn_size_ext is not set yet, try to complete it. if dim.dyn_size_ext is None: dim.complete_dyn_size() - assert ( - dim.dyn_size_ext is not None - ), f"{sub_name_ctx}: need {dim} to be defined to be able to know about implicit dims" + assert dim.dyn_size_ext is not None, ( + f"{sub_name_ctx}: need {dim} to be defined to be able to know about implicit dims" + ) dim_tags.extend(data_template.dim_tags_set_implicit_only_wrapped) - assert len(dim_tags) == len( - set((d, d.match_priority if isinstance(d, Dim) else 0) for d in dim_tags) - ), f"duplicate dims in {sub_name_ctx} {sub_name_ctx.tensor}" + assert len(dim_tags) == len(set((d, d.match_priority if isinstance(d, Dim) else 0) for d in dim_tags)), ( + f"duplicate dims in {sub_name_ctx} {sub_name_ctx.tensor}" + ) if len(dim_tags) == len(set(dim_tags)): # might not be unique without match_priority # For some layer classes, the out_shape would be redundant. if layer_dict["class"] not in {"constant", "variable", "random", "subnetwork", "transpose"}: @@ -1135,9 +1135,9 @@ def make_net_dict_raw(self, net: Net, *, _stack: Optional[_StackInfo] = None) -> sub_layer_abs_name_scope = self._expected_layer_abs_name_scope(sub_name_ctx) if sub_name_ctx.layer_dict["class"] == "variable": - assert ( - sub_layer_abs_name_scope - ), f"VariableLayer {sub_name_ctx} must have a unique name in {self.root_module}" + assert sub_layer_abs_name_scope, ( + f"VariableLayer {sub_name_ctx} must have a unique name in {self.root_module}" + ) if sub_layer_abs_name_scope is not None: if ( layer_abs_name_scope_default != sub_layer_abs_name_scope @@ -1153,9 +1153,9 @@ def make_net_dict_raw(self, net: Net, *, _stack: Optional[_StackInfo] = None) -> def _map_elem_resolve(obj: Any) -> Any: if isinstance(obj, Tensor): - assert isinstance( - obj.raw_tensor, rfl.Layer - ), f"unexpected tensor {obj} with raw tensor type {type(obj.raw_tensor)}, expected rfl.Layer" + assert isinstance(obj.raw_tensor, rfl.Layer), ( + f"unexpected tensor {obj} with raw tensor type {type(obj.raw_tensor)}, expected rfl.Layer" + ) obj: Tensor[rfl.Layer] assert obj.raw_tensor.parent or net.name_ctx == obj.raw_tensor return obj.raw_tensor.get_name_in_ctx(ctx=net.name_ctx) diff --git a/returnn/tf/frontend_layers/loop.py b/returnn/tf/frontend_layers/loop.py index 6d15d43d4..aa4cbeee5 100644 --- a/returnn/tf/frontend_layers/loop.py +++ b/returnn/tf/frontend_layers/loop.py @@ -415,9 +415,9 @@ def _map_ref_to_name_ctx(tensor: Tensor, name_ctx: rfl.Layer, initial: Tensor): tensor.raw_tensor.make_all_sub_networks_and_optimize() layer_ctx_list = tensor.raw_tensor.get_abs_name_ctx_list() - assert ( - self.loop.name_ctx in layer_ctx_list - ), f"Loop state {name_ctx} should get a value inside the loop but got {tensor}" + assert self.loop.name_ctx in layer_ctx_list, ( + f"Loop state {name_ctx} should get a value inside the loop but got {tensor}" + ) # We need some special logic for MaskedComputation but maybe also for others later. # This is currently not nice, but I'm not sure about better solutions. for i in range(layer_ctx_list.index(self.loop.name_ctx) + 1, len(layer_ctx_list) - 1): diff --git a/returnn/tf/frontend_layers/make_layer.py b/returnn/tf/frontend_layers/make_layer.py index 799bbd3cf..4025a0623 100644 --- a/returnn/tf/frontend_layers/make_layer.py +++ b/returnn/tf/frontend_layers/make_layer.py @@ -74,7 +74,6 @@ def make_layer( raise TypeError(f"{layer}: unexpected type {type(value)} in layer_dict: {layer_dict}") try: - if out is not None: assert isinstance(out, Tensor) elif predefined_out_data is not None: diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index 19960755c..dd4fdc936 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -4,8 +4,9 @@ from __future__ import annotations -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Union import typing +from typing import TYPE_CHECKING import contextlib import numpy import tensorflow as tf @@ -17,6 +18,9 @@ from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope from returnn.log import log +if TYPE_CHECKING: + from tensorflow.python.training.saver import BaseSaverBuilder + class LayerBase: """ @@ -188,7 +192,7 @@ def __init__( self.name = name self.network = network self._register_layer() - self.kwargs = None # type: typing.Optional[typing.Dict[str]] # set via self.post_init + self.kwargs: Optional[Dict[str]] = None # set via self.post_init self.target = None self.targets = None if target: @@ -219,12 +223,12 @@ def __init__( "%s: out_dim handling not implemented correctly for this layer" % self ) out_shape # noqa # not used here but in fixup_out_data - self.output_before_activation = None # type: typing.Optional[OutputWithActivation] - self.output_loss = None # type: typing.Optional[tf.Tensor] + self.output_before_activation: Optional[OutputWithActivation] = None + self.output_loss: Optional[tf.Tensor] = None if copy_output_loss_from_source_idx is not None: self.output_loss = sources[copy_output_loss_from_source_idx].output_loss - self.rec_vars_outputs = {} # type: typing.Dict[str,tf.Tensor] - self.search_choices = None # type: typing.Optional[SearchChoices] + self.rec_vars_outputs: Dict[str, tf.Tensor] = {} + self.search_choices: Optional[SearchChoices] = None self._src_common_search_choices = _src_common_search_choices self._initial_output = initial_output self.need_last = need_last @@ -237,14 +241,14 @@ def __init__( # Note that this check is somewhat incomplete # (does not check multiple sources, see _ConcatInputLayer) # and there is no guarantee that a specific layer really uses this correctly. - assert sources[0].output.have_dim_tag( - in_dim, unique=True - ), "%s: in_dim %s not found or unique in input %s" % (self, in_dim, sources[0]) + assert sources[0].output.have_dim_tag(in_dim, unique=True), ( + "%s: in_dim %s not found or unique in input %s" % (self, in_dim, sources[0]) + ) self.have_params = False - self.params = {} # type: typing.Dict[str,tf.Variable] - self.saveable_param_replace = ( - {} - ) # type: typing.Dict[tf.Variable,typing.Union['tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject',None]] # see get_saveable_params_dict() # nopep8 + self.params: Dict[str, tf.Variable] = {} + self.saveable_param_replace: Dict[ + tf.Variable, Union["BaseSaverBuilder.SaveableObject", None] + ] = {} # see get_saveable_params_dict() self.reuse_params = reuse_params self.name_scope = name_scope self.param_device = param_device @@ -264,7 +268,7 @@ def __init__( self.control_dependencies_on_output = control_dependencies_on_output self.register_as_extern_data = register_as_extern_data # Stats will be collected by the engine. - self.stats = {} # type: typing.Dict[str,tf.Tensor] + self.stats: Dict[str, tf.Tensor] = {} self._set_prev_state(state) def _set_prev_state(self, state): @@ -516,9 +520,9 @@ def _base_get_out_data_from_opts( # Special case: Input feature or sparse dim looks the same, so overtake it. out_dim = sources_data.feature_dim_or_sparse_dim if out_dim: - assert ( - out_dim.dimension == output.dim - ), f"Layer {name!r} out_dim {out_dim} does not match Data {output} via out_type {out_type}" + assert out_dim.dimension == output.dim, ( + f"Layer {name!r} out_dim {out_dim} does not match Data {output} via out_type {out_type}" + ) if output.sparse: output.sparse_dim = out_dim else: @@ -850,9 +854,9 @@ def transform_config_dict(cls, d, network, get_layer): loss_scale = d.pop("loss_scale", 1.0) if loss_scale != 1.0: if "scale" in loss_opts: - assert ( - loss_opts["scale"] == loss_scale - ), "do not use loss_scale and loss with 'scale' option together" + assert loss_opts["scale"] == loss_scale, ( + "do not use loss_scale and loss with 'scale' option together" + ) loss_opts["scale"] = loss_scale d["loss"] = cls._make_loss( class_name=d.pop("loss", None), opts=loss_opts, network=network, get_layer=get_layer @@ -2099,9 +2103,9 @@ def get_rec_initial_output(cls, batch_dim, name, output, rec_layer, initial_outp src_output = src.output.copy() if src_output.placeholder is not None: zeroed_src_shape = tf_util.get_shape(src_output.placeholder) - zeroed_src_shape = [ + zeroed_src_shape: List[Union[tf.Tensor, int]] = [ zeroed_src_shape[i] for i in range(src_output.batch_ndim) - ] # type: typing.List[typing.Union[tf.Tensor,int]] + ] else: zeroed_src_shape = [] for i, d in enumerate(src_output.batch_shape): @@ -2550,9 +2554,9 @@ def custom_getter(getter, name, *args, **kwargs): :rtype: tf.Variable|tf.Tensor """ if self.shape is not None: - assert tuple(shape) == tuple( - d.dimension for d in self.shape - ), "%s: unexpected shape %r for param %r, expected %r" % (self, shape, name, self.shape) + assert tuple(shape) == tuple(d.dimension for d in self.shape), ( + "%s: unexpected shape %r for param %r, expected %r" % (self, shape, name, self.shape) + ) abs_scope_prefix = base_layer.get_absolute_name_scope_prefix() assert not abs_scope_prefix or abs_scope_prefix.endswith("/") assert name.startswith(abs_scope_prefix) @@ -2609,10 +2613,10 @@ def __init__(self, owner, beam_size, is_decided=False, keep_raw=False): assert beam_size is not None self.owner = owner self._done_src_layer = False - self._src_layer = None # type: typing.Optional[LayerBase] - self.src_beams = None # type: typing.Optional[tf.Tensor] # src beam index, (batch, beam) + self._src_layer: Optional[LayerBase] = None + self.src_beams: Optional[tf.Tensor] = None # src beam index, (batch, beam) self.beam_size = beam_size - self.beam_scores = None # type: typing.Optional[tf.Tensor] # (batch, beam) + self.beam_scores: Optional[tf.Tensor] = None # (batch, beam) self.is_decided = is_decided self.keep_raw = keep_raw if not owner.output.beam: @@ -2872,22 +2876,22 @@ def __init__( """ self.base_network = base_network self.use_flatten_frames = use_flatten_frames - self.layer = None # type: typing.Optional[LayerBase] + self.layer: Optional[LayerBase] = None # All are initialized in self.init(). - self.output = None # type: typing.Optional[Data] - self.output_with_activation = None # type: typing.Optional[OutputWithActivation] - self.output_seq_lens = None # type: typing.Optional[tf.Tensor] - self.target = None # type: typing.Optional[Data] - self.target_seq_lens = None # type: typing.Optional[tf.Tensor] - self.output_flat = None # type: typing.Optional[tf.Tensor] - self.output_before_softmax_flat = None # type: typing.Optional[tf.Tensor] + self.output: Optional[Data] = None + self.output_with_activation: Optional[OutputWithActivation] = None + self.output_seq_lens: Optional[tf.Tensor] = None + self.target: Optional[Data] = None + self.target_seq_lens: Optional[tf.Tensor] = None + self.output_flat: Optional[tf.Tensor] = None + self.output_before_softmax_flat: Optional[tf.Tensor] = None if _check_output_before_softmax is not None: self._check_output_before_softmax = _check_output_before_softmax - self.target_flat = None # type: typing.Optional[tf.Tensor] + self.target_flat: Optional[tf.Tensor] = None # Maybe make configurable. For now, same as in our Theano behavior. # The loss_norm_factor is used by Runner._normalize_loss both for normalization per epoch and per batch. # It is e.g. set to 1/sum(target_seq_len), and logic of accumulation is handled in the Runner. - self.loss_norm_factor = None # type: typing.Optional[tf.Tensor] + self.loss_norm_factor: Optional[tf.Tensor] = None self.use_normalized_loss = use_normalized_loss # for the optimizer, per batch self.custom_norm_factor = custom_norm_factor self.custom_inv_norm_factor = custom_inv_norm_factor @@ -3132,18 +3136,21 @@ def _check_init(self): self.output, self.target, ) - assert ( - self.target.ndim_dense == self.output.ndim_dense - ), "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output) + assert self.target.ndim_dense == self.output.ndim_dense, ( + "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output) + ) expected_output_dim = self.get_auto_output_layer_dim(self.target.feature_dim_or_sparse_dim) - assert ( - expected_output_dim.dimension == self.output.dim - ), "Expected output dim is %r but the output has dim %r. " % ( - expected_output_dim, - self.output.feature_dim_or_sparse_dim, - ) + "Target: %s, output: %s" % ( - self.target, - self.output, + assert expected_output_dim.dimension == self.output.dim, ( + "Expected output dim is %r but the output has dim %r. " + % ( + expected_output_dim, + self.output.feature_dim_or_sparse_dim, + ) + + "Target: %s, output: %s" + % ( + self.target, + self.output, + ) ) if self.base_network.get_config().bool("debug_runtime_sanity_checks", False): with tf.name_scope("Loss_debug_runtime_sanity_checks"): diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 5067abd72..60a368a3b 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Optional, Union, Sequence, List, Tuple, Dict +from typing import Callable, Optional, Union, Sequence, List, Tuple, Dict import typing import tensorflow as tf import contextlib @@ -126,7 +126,7 @@ def concat_sources(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpec data.placeholder = tf.concat( axis=data.feature_dim_axis, values=[layer_data.placeholder for layer_data in layers_data] ) - axes_split_info = [None] * data.batch_ndim # type: typing.List[typing.Optional[typing.List[int]]] + axes_split_info: List[Optional[List[int]]] = [None] * data.batch_ndim axes_split_info[data.feature_dim_axis] = [layer_data.dim for layer_data in layers_data] tf_util.set_param_axes_split_info(data.placeholder, axes_split_info) # Note: We will loose this info for any further op (e.g. dropout, activation, etc). Should be better... @@ -294,7 +294,7 @@ def __init__( elif mask == "dropout": assert dropout > 0 self.dropout = dropout - self.input_data = None # type: typing.Optional[Data] + self.input_data: Optional[Data] = None if self.sources: self.input_data = concat_sources_with_opt_dropout( self.sources, @@ -509,9 +509,7 @@ def get_out_data_from_opts(cls, name, sources, out_dim=None, **kwargs): assert sources sources, axes = zip(*sources) # unzip axes_int = [layer.output.get_axis_from_description(axis) for (layer, axis) in zip(sources, axes)] - concat_dim_tags = [ - layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int) - ] # type: typing.List[Dim] + concat_dim_tags: List[Dim] = [layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int)] if any(tag.dimension is None for tag in concat_dim_tags): dimension = None else: @@ -707,8 +705,8 @@ def __init__(self, search_choices_layer, sources, **kwargs): self.output = src.output.copy_as_batch_major() self.rec_vars_outputs = src.rec_vars_outputs.copy() src_search_choices = src.get_search_choices() - self.transform_func = None # type: typing.Optional[typing.Callable[[tf.Tensor],tf.Tensor]] - self.search_choices_seq = None # type: typing.Optional[typing.List[SearchChoices]] + self.transform_func: Optional[Callable[[tf.Tensor], tf.Tensor]] = None + self.search_choices_seq: Optional[List[SearchChoices]] = None if not search_choices: assert not src_search_choices assert not self.output.beam @@ -726,13 +724,7 @@ def __init__(self, search_choices_layer, sources, **kwargs): assert src_search_choices in search_choices_seq, self.network.debug_search_choices( self.search_choices_layer ) or ( - ( - "%s: No common search base:\n" - "from layer %s\n" - "search choices %s,\n" - "to layer %s\n" - "search choices\n%s." - ) + "%s: No common search base:\nfrom layer %s\nsearch choices %s,\nto layer %s\nsearch choices\n%s." % (self, src, src_search_choices, self.search_choices_layer, pformat(search_choices_seq)) ) search_choices_seq = search_choices_seq[: search_choices_seq.index(src_search_choices)] @@ -4436,12 +4428,13 @@ def _get_merge_axes(cls, axes, keep_order, input_data, name): :rtype: list[int] """ if keep_order: - assert isinstance(axes, (tuple, list, typing.Sequence)) and not isinstance( - axes, str - ), "%s: axes %r must be a list or tuple, to have a well defined order in input %s" % ( - name, - axes, - input_data, + assert isinstance(axes, (tuple, list, typing.Sequence)) and not isinstance(axes, str), ( + "%s: axes %r must be a list or tuple, to have a well defined order in input %s" + % ( + name, + axes, + input_data, + ) ) axes_ = [] for axis in axes: @@ -5562,11 +5555,12 @@ def __init__(self, repetitions, axis="T", out_dim=None, **kwargs): repetitions_data = repetitions_data.copy_add_dim_by_tag(axis_dim_tag, unbroadcast=True) repetitions_axis = repetitions_data.get_axis_from_description(axis, allow_int=False) assert repetitions_data.ndim == 1, "Repetitions %r must only have at most one non-batch axis" % repetitions - assert ( - repetitions_data.batch_shape[repetitions_axis] == self.input_data.batch_shape[input_axis] - ), "Axis mismatch between input (%i) and repetitions (%i)" % ( - self.input_data.batch_shape[input_axis], - repetitions_data.batch_shape[repetitions_axis], + assert repetitions_data.batch_shape[repetitions_axis] == self.input_data.batch_shape[input_axis], ( + "Axis mismatch between input (%i) and repetitions (%i)" + % ( + self.input_data.batch_shape[input_axis], + repetitions_data.batch_shape[repetitions_axis], + ) ) assert self.output.have_batch_axis() == ( @@ -6267,9 +6261,9 @@ def __init__( from returnn.util import BehaviorVersion padding = padding.upper() if isinstance(padding, str) else padding - assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance( - padding, (int, tuple, list) - ), f"{self}: got unsupported padding {padding}" + assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance(padding, (int, tuple, list)), ( + f"{self}: got unsupported padding {padding}" + ) assert "out_type" not in kwargs, "don't set out_type explicitly for this layer" assert len(filter_size) in (1, 2, 3), "only 1D conv, 2D conv or 3D conv supported" super(ConvLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs) @@ -6285,9 +6279,9 @@ def __init__( assert len(dilation_rate) == len(filter_size) assert not self.input_data.sparse assert self.input_data.have_batch_axis() - assert ( - self.input_data.have_feature_axis() - ), "this should be our single input feature dim now. otherwise use input_add_feature_dim" + assert self.input_data.have_feature_axis(), ( + "this should be our single input feature dim now. otherwise use input_add_feature_dim" + ) input_data, num_batch_dims = self.transform_input( self.input_data, network=self.network, @@ -7117,9 +7111,9 @@ def __init__( super(PoolLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs) assert not self.input_data.sparse assert self.input_data.have_batch_axis() - assert ( - self.input_data.have_feature_axis() - ), "this should be our single input feature dim now. otherwise use input_add_feature_dim" + assert self.input_data.have_feature_axis(), ( + "this should be our single input feature dim now. otherwise use input_add_feature_dim" + ) if in_dim and out_dim: assert in_dim == out_dim elif in_dim: @@ -7381,9 +7375,9 @@ def __init__( out_dim # noqa # via get_out_data_from_opts assert not self.input_data.sparse assert self.input_data.have_batch_axis() - assert ( - self.input_data.have_feature_axis() - ), "this should be our single input feature dim now. otherwise use input_add_feature_dim" + assert self.input_data.have_feature_axis(), ( + "this should be our single input feature dim now. otherwise use input_add_feature_dim" + ) input_data, num_batch_dims = ConvLayer.transform_input( self.input_data, network=self.network, @@ -7404,14 +7398,15 @@ def __init__( remove_padding = [remove_padding] * len(spatial_axes) if not isinstance(output_padding, (list, tuple)): output_padding = [output_padding] * len(spatial_axes) - assert ( - len(spatial_axes) == len(filter_size) == len(strides) == len(remove_padding) == len(output_padding) - ), "%s: expected %i-D transposed-conv for input %r but got filter %r and strides %r" % ( - self, - len(spatial_axes), - input_data, - filter_size, - strides, + assert len(spatial_axes) == len(filter_size) == len(strides) == len(remove_padding) == len(output_padding), ( + "%s: expected %i-D transposed-conv for input %r but got filter %r and strides %r" + % ( + self, + len(spatial_axes), + input_data, + filter_size, + strides, + ) ) assert len(spatial_axes) in [1, 2], "%s: %i-D not yet implemented..." % (self, len(spatial_axes)) x = input_data.placeholder @@ -8775,9 +8770,9 @@ def __init__( red1, red2, ) - assert len(a_reduce_axes) == len( - b_reduce_axes - ), "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (self, self.sources, red1, red2) + assert len(a_reduce_axes) == len(b_reduce_axes), ( + "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (self, self.sources, red1, red2) + ) if ( (BehaviorVersion.get() >= 3 and (var1 is NotSpecified or var2 is NotSpecified)) or var1 == "auto" @@ -9150,9 +9145,9 @@ def get_out_data_from_opts( raise Exception( "%s %r: " % (cls.__name__, name) + "%s not found in sources %r" % (red_axis_desc, sources) ) - assert len(a_reduce_axes) == len( - b_reduce_axes - ), "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (name, sources, red1, red2) + assert len(a_reduce_axes) == len(b_reduce_axes), ( + "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (name, sources, red1, red2) + ) if ( (BehaviorVersion.get() >= 3 and (var1 is NotSpecified or var2 is NotSpecified)) or var1 == "auto" @@ -10178,9 +10173,9 @@ def __init__( self.condition_desc = condition self.condition_layer = self._make_layer("condition", self.condition_desc) self.true_layer_desc = true_layer - self.true_layer = None # type: typing.Optional[LayerBase] + self.true_layer: Optional[LayerBase] = None self.false_layer_desc = false_layer - self.false_layer = None # type: typing.Optional[LayerBase] + self.false_layer: Optional[LayerBase] = None assert self.condition_layer.output.batch_ndim == 0 and self.condition_layer.output.dtype == "bool" self._extra_out_templates = {k: v[0] for k, v in _extra_out.items()} x, extra_out, sizes = tf_util.cond( @@ -12070,7 +12065,7 @@ def __init__( for (key, output) in extra.items() } extra = {key: output.copy_as_batch_spatial_major() for (key, output) in extra.items()} - self.extra = extra # type: typing.Dict[str,Data] + self.extra: Dict[str, Data] = extra self.dump_whole_batches = dump_whole_batches self.num_seqs_written = 0 ndim = data.ndim @@ -12454,9 +12449,9 @@ def __init__(self, pos_weight=None, **kwargs): def _check_init(self): assert self.target is not None - assert ( - self.target.batch_ndim == self.output.batch_ndim - ), "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output) + assert self.target.batch_ndim == self.output.batch_ndim, ( + "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output) + ) def get_value(self): """ @@ -13020,7 +13015,7 @@ def __init__( self.divide_beam_size = divide_beam_size self.subtract_average_loss = subtract_average_loss self.loss_correction_grad_only = loss_correction_grad_only - self.search_choices = None # type: typing.Optional[SearchChoices] + self.search_choices: Optional[SearchChoices] = None @classmethod def transform_config_dict(cls, d, network, get_layer): @@ -13120,9 +13115,9 @@ def _check_init(self): Does some checks on self.target and self.output, e.g. if the dense shapes matches. You can overwrite this if those checks don't make sense for your derived loss class. """ - assert ( - self.target.ndim_dense == self.output.ndim_dense - ), "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output) + assert self.target.ndim_dense == self.output.ndim_dense, ( + "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output) + ) expected_output_dim = self._embedding_dimension * (self.target.shape[1] // self._nr_of_sources) assert expected_output_dim == self.output.dim, "Expected output dim is %i but the output has dim %r. " % ( expected_output_dim, @@ -13822,9 +13817,9 @@ def sampled_loss_fn(): else: loss_fn = tf.nn.sampled_softmax_loss - assert ( - self.layer.params["W"].shape[0] == self.target.dim - ), "Expect weight matrix of shape [num_classes, dim]" + assert self.layer.params["W"].shape[0] == self.target.dim, ( + "Expect weight matrix of shape [num_classes, dim]" + ) out = loss_fn( weights=self.layer.params["W"].read_value(), # (num_classes,D). biases=self.layer.params["b"].read_value(), # (num_classes). diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 8285dd61f..45da91942 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -6,6 +6,7 @@ import contextlib import typing +from typing import Dict, Optional, Tuple, Union import tensorflow as tf import returnn.tf.compat as tf_compat @@ -1037,7 +1038,8 @@ def _get_output_cell(self, cell): scope=tf_compat.v1.get_variable_scope(), ) elif rnn_contrib and isinstance( - cell, (rnn_contrib.FusedRNNCell, rnn_contrib.LSTMBlockWrapper) # noqa # e.g. LSTMBlockFusedCell + cell, + (rnn_contrib.FusedRNNCell, rnn_contrib.LSTMBlockWrapper), # noqa # e.g. LSTMBlockFusedCell ): # Will get (time,batch,ydim). assert self._max_seq_len is None @@ -1280,9 +1282,9 @@ def get_last_hidden_state(self, key): :param str|int|None key: :rtype: tf.Tensor """ - assert ( - self._last_hidden_state is not None - ), "last-hidden-state not implemented/supported for this layer-type. try another unit. see the code." + assert self._last_hidden_state is not None, ( + "last-hidden-state not implemented/supported for this layer-type. try another unit. see the code." + ) return RnnCellLayer.get_state_by_key(self._last_hidden_state, key=key) @classmethod @@ -1431,9 +1433,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n ) self._last_frames = {} # type: typing.Dict[str,Data] self._initial_outputs = None # type: typing.Optional[typing.Dict[str,tf.Tensor]] - self._initial_extra_outputs = ( - None - ) # type: typing.Optional[typing.Dict[str,typing.Dict[str,typing.Union[tf.Tensor,typing.Tuple[tf.Tensor,...]]]]] # nopep8 + self._initial_extra_outputs: Optional[Dict[str, Dict[str, Union[tf.Tensor, Tuple[tf.Tensor, ...]]]]] = None # input_layers_moved_out, output_layers_moved_out and layers_in_loop include (used) sub-layers as separate # entries, this way in- and outputting them to the loop via TensorArrays will be handled just as for normal @@ -1608,14 +1608,9 @@ def __repr__(lself): while parent and parent.parent: parent_names.insert(0, parent.parent_name or "?") parent = parent.parent - return ( - "(" - "allow_uninitialized_template %r, " - "parents %r)" - % ( - lself.allow_uninitialized_template, - " <- ".join(parent_names) or None, - ) + return "(allow_uninitialized_template %r, parents %r)" % ( + lself.allow_uninitialized_template, + " <- ".join(parent_names) or None, ) def _add_uninitialized_count(self): @@ -2141,16 +2136,17 @@ def get_input_moved_out(name): layer = self.input_layers_net.layers[layer_name] assert isinstance(layer, LayerBase) if layer_name not in inputs_moved_out_tas: - assert not layer.output.mark_same_time( - self._time_dim_tags - ), "%s does not expect to have matching time dim to %s" % (layer, self.parent_rec_layer) - assert ( - name != "output" and not prev - ), "Time dim does not match: RecLayer %s (%r) vs sub layer %s (%r)." % ( - self.parent_rec_layer, - self.parent_rec_layer.output.get_time_dim_tag(), - layer, - layer.output.get_time_dim_tag(), + assert not layer.output.mark_same_time(self._time_dim_tags), ( + "%s does not expect to have matching time dim to %s" % (layer, self.parent_rec_layer) + ) + assert name != "output" and not prev, ( + "Time dim does not match: RecLayer %s (%r) vs sub layer %s (%r)." + % ( + self.parent_rec_layer, + self.parent_rec_layer.output.get_time_dim_tag(), + layer, + layer.output.get_time_dim_tag(), + ) ) return layer output = layer.output.copy_template_excluding_time_dim().copy_template_set_ctx(self.net.control_flow_ctx) @@ -2376,9 +2372,9 @@ def _check_output_template_shape(self): assert output_template.output.dim == self.parent_rec_layer.output.dim assert self.parent_rec_layer.output.time_dim_axis == 0 assert not output_template.output.has_axis(self.time_dim_tag) - assert ( - output_template.output.batch_shape == self.parent_rec_layer.output.batch_shape[1:] - ), "see RecLayer.get_out_data_from_opts()" + assert output_template.output.batch_shape == self.parent_rec_layer.output.batch_shape[1:], ( + "see RecLayer.get_out_data_from_opts()" + ) def get_init_loop_vars(self): """ @@ -3014,9 +3010,9 @@ def get_choice_source_batches(): needed_outputs.add("end") assert tf.as_dtype(end_template.output.dtype) is tf.bool else: - assert ( - have_known_seq_len - ), "You need to have an 'end' layer in your rec subnet if the generated seq len is unknown." + assert have_known_seq_len, ( + "You need to have an 'end' layer in your rec subnet if the generated seq len is unknown." + ) # noinspection PyProtectedMember if self.parent_rec_layer._optimize_move_layers_out: @@ -3358,11 +3354,12 @@ def maybe_transform(layer): from .basic import SelectSearchSourcesLayer prev_end_layer = choices.translate_to_this_search_beam(prev_end_layer) - assert isinstance( - prev_end_layer, SelectSearchSourcesLayer - ), "unexpected search choices: cur end %r, prev end %r" % ( - choices, - prev_end_layer.get_search_choices(), + assert isinstance(prev_end_layer, SelectSearchSourcesLayer), ( + "unexpected search choices: cur end %r, prev end %r" + % ( + choices, + prev_end_layer.get_search_choices(), + ) ) prev_end_flag = prev_end_layer.output.placeholder with tf.name_scope("dyn_seq_len"): @@ -3475,14 +3472,15 @@ def cond(i, net_vars, acc_tas, seq_len_info=None, allow_inf_max_len=False): assert fixed_seq_len is not None seq_len = fixed_seq_len if output_beam: - assert ( - not input_beam or input_beam == output_beam - ), "%s: input beam %r, output beam %r, sources %r, target %r" % ( - self.parent_rec_layer, - input_beam, - output_beam, - self.parent_rec_layer.sources, - self.parent_rec_layer.target, + assert not input_beam or input_beam == output_beam, ( + "%s: input beam %r, output beam %r, sources %r, target %r" + % ( + self.parent_rec_layer, + input_beam, + output_beam, + self.parent_rec_layer.sources, + self.parent_rec_layer.target, + ) ) assert output_template.output.batch.beam == output_beam time_dim_tag = time_dim_tag.get_for_batch_ctx( @@ -3791,9 +3789,9 @@ def _opt_search_resolve(self, layer_name, acc_ta, final_net_vars, seq_len, searc if end_layer_choice.name.startswith("prev:"): # Logic from maybe_transform. It would be translated to the current beam. end_layer_choice = self.net.layers[end_layer_choice.name[len("prev:") :]] - assert ( - end_layer_choice in choice_seq_in_frame - ), "End layer must not have a beam independent from output layer '{}'.".format(layer_name) + assert end_layer_choice in choice_seq_in_frame, ( + "End layer must not have a beam independent from output layer '{}'.".format(layer_name) + ) end_layer_choice_index = choice_seq_in_frame.index(end_layer_choice) choices_seq_until_end_layer = choice_seq_in_frame[:end_layer_choice_index] @@ -5856,12 +5854,13 @@ def get_out_data_from_opts(cls, name, sources, network, axis=None, declare_rec_t if out_dim.is_dim_known(): # usually the case except at template construction assert out_dim != rec_time_dim # rec_time_dim is unknown, so it cannot be the same if out_dim != rec_time_dim: - assert ( - declare_rec_time - ), "%s %r: must either set known axis on rec %s or enable declare_rec_time" % ( - cls.__name__, - name, - rec_time_dim, + assert declare_rec_time, ( + "%s %r: must either set known axis on rec %s or enable declare_rec_time" + % ( + cls.__name__, + name, + rec_time_dim, + ) ) rec_time_dim.declare_same_as(out_dim) out.mark_same_time(out_dim, must_match=True) @@ -6132,12 +6131,13 @@ def __init__( base_beam_in = tf.shape(scores_base)[1] # 1 in first frame, then beam_in scores_beam_in = tf.shape(scores_in)[0] // net_batch_dim beam_in = self.sources[0].output.beam.beam_size - assert ( - beam_in == base_search_choices.beam_size - ), "%r: source %r beam-size unexpected from base choice %r" % ( - self, - self.sources[0], - base_search_choices, + assert beam_in == base_search_choices.beam_size, ( + "%r: source %r beam-size unexpected from base choice %r" + % ( + self, + self.sources[0], + base_search_choices, + ) ) # About incoming beam size: # base_beam_in - 1 in first frame, then beam_in @@ -7510,9 +7510,9 @@ def _weights_remaining_axes(cls, base, weights, auto_squeeze, exception_prefix): base_rem_axes = base.get_axes(exclude_batch=True, exclude_time=True) base_rem_axes.remove(base.feature_dim_axis) weights_rem_axes = weights.get_axes(exclude_batch=True) - assert ( - weights.time_dim_axis is not None - ), f"{exception_prefix}: base {base}, weights {weights}, need time_dim_axis in weights" + assert weights.time_dim_axis is not None, ( + f"{exception_prefix}: base {base}, weights {weights}, need time_dim_axis in weights" + ) weights_axis_to_reduce = cls._weights_time_axis_to_reduce(weights=weights, base=base) assert weights.batch_shape[weights_axis_to_reduce] == base.batch_shape[base.time_dim_axis] weights_rem_axes.remove(weights_axis_to_reduce) @@ -9088,13 +9088,13 @@ def __init__( new_size, new_time, idxs = None, None, None if mask: if self.network.is_inside_rec_layer(): - assert ( - mask.output.shape == () and mask.output.dtype == "bool" - ), "%s: invalid mask %s (inside rec loop)" % (self, mask) + assert mask.output.shape == () and mask.output.dtype == "bool", ( + "%s: invalid mask %s (inside rec loop)" % (self, mask) + ) else: - assert ( - mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool" - ), "%s: invalid mask %s (outside rec loop)" % (self, mask) + assert mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool", ( + "%s: invalid mask %s (outside rec loop)" % (self, mask) + ) assert in_spatial_dim and out_spatial_dim mask_data = mask.output.copy_as_time_major() mask_t = where_bc(mask_data.placeholder, mask_data.get_sequence_mask(), tf.convert_to_tensor(False)) @@ -9785,9 +9785,9 @@ def __init__(self, mask, **kwargs): with same_control_flow_ctx(src_layer.output.placeholder): src = src_layer.output.copy_as_bt_or_tb_major() mask_out = self.mask.output - assert ( - mask_out.shape == () and mask_out.batch_shape == (None,) and mask_out.dtype == "bool" - ), "%s: invalid mask %s (inside rec loop)" % (self, self.mask) + assert mask_out.shape == () and mask_out.batch_shape == (None,) and mask_out.dtype == "bool", ( + "%s: invalid mask %s (inside rec loop)" % (self, self.mask) + ) prev_t = self._rec_previous_layer.rec_vars_outputs["t"] # [B] t = prev_t + tf.cast(mask_out.placeholder, tf.int32) # [B] self.rec_vars_outputs["t"] = t @@ -11192,9 +11192,9 @@ def __init__( and is_axis_from_description_recurrent(key_value_spatial_dim, network=self.network, data=self.input_data) ): length = self.network.get_rec_step_index() + 1 - assert ( - key_value_spatial_dim_.dimension is None - ), f"{self}: unexpected kv spatial dim {key_value_spatial_dim_}" + assert key_value_spatial_dim_.dimension is None, ( + f"{self}: unexpected kv spatial dim {key_value_spatial_dim_}" + ) assert key_value_spatial_dim_.dyn_size_ext is not None # See CumConcatLayer for similar logic if key_value_spatial_dim_.dyn_size_ext.placeholder is None: diff --git a/returnn/tf/native_op.py b/returnn/tf/native_op.py index c85c30af6..754d560b7 100644 --- a/returnn/tf/native_op.py +++ b/returnn/tf/native_op.py @@ -283,9 +283,7 @@ def make_dim_str(c): // otherwise it will trigger an assertion. if (IsRefType(context->input_dtype({in_idx}))) context->forward_ref_input_to_ref_output({in_idx}, {out_idx}); - """.format( - in_idx=in_idx, out_idx=out_idx - ) + """.format(in_idx=in_idx, out_idx=out_idx) code_set_io = "" for in_idx, v in enumerate(in_info): ndim = len(v["shape"]) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index d47ab4f9f..aabf8098e 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -3,7 +3,9 @@ """ from __future__ import annotations -from typing import Optional, Any, Protocol, List, Tuple, Dict + +from typing import Callable, List, Optional, Any, Protocol, Tuple, Dict, TYPE_CHECKING, Union + import tensorflow as tf import sys import re @@ -19,6 +21,11 @@ from returnn.tf.util.data import Data from returnn.util import basic as util +if TYPE_CHECKING: + from returnn.config import Config + from returnn.tf.layers.base import SearchChoices + from returnn.tf.util.data import BatchInfo + class DataNotFound(Exception): """ @@ -39,8 +46,8 @@ def __init__(self, data=None, default_input="data", default_target="classes"): :param None|dict[str,dict[str]] data: optional init kwargs for Data """ super().__init__() - self._config = None # type: typing.Optional["returnn.config.Config"] - self._batch_info = None # type: typing.Optional["returnn.tf.util.data.BatchInfo"] + self._config: typing.Optional[Config] = None + self._batch_info: typing.Optional[BatchInfo] = None self.default_input = default_input self.default_target = default_target self.extra_added_keys = set() # set[str] @@ -369,8 +376,7 @@ def _extern_data_types_from_config(config): print("Warning: Using extern_data and will ignore num_inputs/num_outputs in config.", file=log.v2) else: log.print_deprecation_warning( - "Using num_inputs/num_outputs instead of extern_data is deprecated" - " and might be removed in future versions" + "Using num_inputs/num_outputs instead of extern_data is deprecated and might be removed in future versions" ) num_inputs, num_outputs = _num_inputs_outputs_from_config(config) data_dims = num_outputs.copy() @@ -502,7 +508,7 @@ class _NetworkConstructionStack: """ def __init__(self): - self.layers = [] # type: typing.List[str] + self.layers: typing.List[str] = [] self.in_flat_construct_count = 0 def append(self, layer_name): @@ -645,33 +651,31 @@ def __init__( self.extra_deps_in_extra = False self.extra_only_template = False self.is_root_in_ctx = not parent_net # default. might be overwritten - self.extra_nets = {} # type: typing.Dict[str,TFNetwork] - self.subnets = {} # type: typing.Dict[str,Subnetwork] + self.extra_nets: Dict[str, TFNetwork] = {} + self.subnets: Dict[str, Subnetwork] = {} self._selected_train_layers = None self._construction_stack = _NetworkConstructionStack() self.layers_desc: Dict[str, Dict[str, Any]] = {} self.layers: Dict[str, LayerBase] = {} - self.losses_dict = {} # type: typing.Dict[str,LossHolder] - self.total_loss = None # type: typing.Optional[tf.Tensor] - self.total_constraints = None # type: typing.Optional[tf.Tensor] - self.total_objective = None # type: typing.Optional[tf.Tensor] - self._global_train_step = None # type: typing.Optional[tf.Tensor] - self._global_train_step_var = None # type: typing.Optional[tf.Variable] + self.losses_dict: Dict[str, LossHolder] = {} + self.total_loss: Optional[tf.Tensor] = None + self.total_constraints: Optional[tf.Tensor] = None + self.total_objective: Optional[tf.Tensor] = None + self._global_train_step: Optional[tf.Tensor] = None + self._global_train_step_var: Optional[tf.Variable] = None self.epoch_step = None - self.saver = None # type: typing.Optional[tf.compat.v1.train.Saver] - self.extra_vars_to_save = [] # type: typing.List[tf.Variable] + self.saver: Optional[tf.compat.v1.train.Saver] = None + self.extra_vars_to_save: List[tf.Variable] = [] self.recurrent = False - self._assigner_cache = {} # type: typing.Dict[tf.Variable,tf_util.VariableAssigner] + self._assigner_cache: Dict[tf.Variable, tf_util.VariableAssigner] = {} self.concat_sources_dropout_cache: Dict[ Tuple[Tuple[LayerBase, ...], Dim, float, Optional[Tuple[Optional[int], ...]]], Data ] = {} - self._merge_all_summaries = None # type: typing.Optional[tf.Tensor] - self._graph_reset_callbacks = [] # type: typing.List[typing.Callable] - self._run_opts = {} # type: typing.Dict[str, typing.Any] - self._run_finished_callbacks = [] # type: typing.List[typing.Callable] - self._map_search_beam_to_search_choices = ( - {} - ) # type: typing.Dict[tf_util.SearchBeam,"returnn.tf.layers.base.SearchChoices"] + self._merge_all_summaries: Optional[tf.Tensor] = None + self._graph_reset_callbacks: List[Callable] = [] + self._run_opts: Dict[str, Any] = {} + self._run_finished_callbacks: List[Callable] = [] + self._map_search_beam_to_search_choices: Dict[tf_util.SearchBeam, SearchChoices] = {} def __repr__(self): s = "TFNetwork %r" % self.name @@ -1308,15 +1312,16 @@ def _create_layer(self, name, layer_class, **layer_desc): layer.output.sanity_check() # The axes should not have moved now. output_special_axes = layer.output.get_special_axes_dict() - assert ( - output_template_special_axes == output_special_axes - ), "%s %r: not equal: %r == %r, from data %r -> %r" % ( - layer_class.__name__, - name, - output_template_special_axes, - output_special_axes, - output_template, - layer.output, + assert output_template_special_axes == output_special_axes, ( + "%s %r: not equal: %r == %r, from data %r -> %r" + % ( + layer_class.__name__, + name, + output_template_special_axes, + output_special_axes, + output_template, + layer.output, + ) ) except TypeError: help_on_type_error_wrong_args(cls=layer_class, kwargs=list(layer_desc.keys())) @@ -1486,7 +1491,7 @@ def get_losses_initialized(self, reduce_func=None, with_total=False): else: total_loss = None total_constraints = None - losses_multi_dict = {} # type: typing.Dict[str,typing.List[typing.Tuple[typing.Optional[str],LossHolder]]] + losses_multi_dict: Dict[str, List[Tuple[Optional[str], LossHolder]]] = {} # self.layers also include extra net layers and sub layers, see add_layer. for name, layer in sorted(self.layers.items()): assert isinstance(layer, LayerBase) @@ -1869,14 +1874,15 @@ def _resolve_layer(layer_): # All end points must be mapped now. for layer in end_points: - assert ( - layer in mapped_layers - ), "end point %r not mapped.\n end points:\n%s\n mapped:\n%s\n blacklist:\n%s\n starting points:\n%s" % ( - layer, - pformat(end_points), - pformat(mapped_layers), - pformat(blacklist), - pformat(starting_points), + assert layer in mapped_layers, ( + "end point %r not mapped.\n end points:\n%s\n mapped:\n%s\n blacklist:\n%s\n starting points:\n%s" + % ( + layer, + pformat(end_points), + pformat(mapped_layers), + pformat(blacklist), + pformat(starting_points), + ) ) # Assign flatten_with_seq_len_mask cache to mapped layers. for layer, new_layer in mapped_layers.items(): @@ -2402,9 +2408,7 @@ def set_param_values_by_dict(self, values_dict, ignore_non_existing=False, **kwa Note that this excludes auxiliary params. """ - layers = { - layer.get_absolute_name(): layer for layer in self.get_all_layers_deep() - } # type: typing.Dict[str,LayerBase] + layers: Dict[str, LayerBase] = {layer.get_absolute_name(): layer for layer in self.get_all_layers_deep()} for layer_name, layer_values_dict in values_dict.items(): if layer_values_dict: if ignore_non_existing and layer_name not in layers: @@ -4091,9 +4095,9 @@ def _prepare(self): self._error_value = self._layer._cond_only_on_eval_opt(self.loss.get_error, default_value=0.0) else: self._error_value = self.loss.get_error() - assert ( - self._loss_value is not None or self._error_value is not None - ), "layer %r loss %r return None for loss and error" % (self._layer, self.loss) + assert self._loss_value is not None or self._error_value is not None, ( + "layer %r loss %r return None for loss and error" % (self._layer, self.loss) + ) if self._norm_factor is None: self._norm_factor = self.loss.get_normalization_factor() loss_value = self._loss_value @@ -4515,12 +4519,12 @@ def __init__( # All variables in the checkpoint: self.var_ckpt_names = set(self.reader.get_variable_to_shape_map()) # type: typing.Set[str] # All variables of the model to be loaded: - self.var_net_names = { + self.var_net_names: Dict[str, Union[tf.Variable, Any]] = { self._get_param_name(v): v for v in self.saveable_params - } # type: typing.Dict[str,typing.Union[tf.Variable,typing.Any]] + } # Model variables missing in the checkpoint: - self.missing_var_names = [] # type: typing.List[str] - self.missing_non_critical_var_names = [] # type: typing.List[str] + self.missing_var_names: List[str] = [] + self.missing_non_critical_var_names: List[str] = [] for name, v in sorted(self.var_net_names.items()): if name in self.var_ckpt_names: continue @@ -4702,10 +4706,10 @@ def get_variable_value_map(self): "rnn/lstm_cell/bias": "lstm_cell/bias", "rnn/lstm_cell/kernel": "lstm_cell/kernel", ( - "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/" "cudnn_compatible_lstm_cell/bias" + "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias" ): "lstm_fused_cell/bias", ( - "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/" "cudnn_compatible_lstm_cell/kernel" + "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel" ): "lstm_fused_cell/kernel", } @@ -4877,7 +4881,7 @@ def __init__(self, prefix, target="lstm_block_wrapper/"): self.target = target self.keys = [target + "bias", target + "kernel"] self.prefix = prefix - self.data = None # type: typing.Optional[typing.Dict[str,numpy.ndarray]] + self.data: typing.Optional[typing.Dict[str, numpy.ndarray]] = None # noinspection PyMethodParameters def _load(sself): @@ -5140,8 +5144,7 @@ class CustomLoadParamFunc(Protocol): def __call__( self, *, name: str, shape: Tuple[int], reader: tf.compat.v1.train.NewCheckpointReader - ) -> Optional[numpy.ndarray]: - ... + ) -> Optional[numpy.ndarray]: ... def set_custom_post_init(var, func): diff --git a/returnn/tf/updater.py b/returnn/tf/updater.py index 3ac65a69e..353056e8b 100644 --- a/returnn/tf/updater.py +++ b/returnn/tf/updater.py @@ -219,9 +219,9 @@ def get_current_step_learning_rate(self): learning_rate_function = self.config.typed_dict.get("dynamic_learning_rate") signature = inspect.signature(learning_rate_function) - assert any( - [arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()] - ), "please specify **kwargs in dynamic_learning_rate for future compatibility" + assert any([arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]), ( + "please specify **kwargs in dynamic_learning_rate for future compatibility" + ) if "epoch" in signature.parameters: raise NotImplementedError("TF updater: dynamic_learning_rate with epoch not supported currently") lr = learning_rate_function( diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 766ec8856..ffa913f42 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -1799,7 +1799,7 @@ def dropout( x = tf.convert_to_tensor(x, name="x") assert isinstance(x, tf.Tensor) if isinstance(keep_prob, (float, int)) and not 0 < keep_prob <= 1: - raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) + raise ValueError("keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob) # Do nothing if we know keep_prob == 1 if isinstance(keep_prob, (float, int)) and keep_prob == 1: return x @@ -2492,9 +2492,9 @@ def get_common_shape(values, ignore_axes=(), allow_broadcast_all_sources=NotSpec import numpy assert len(values) > 0 - assert all( - [isinstance(value, (tf.Tensor, tf.Variable, float, int, numpy.number)) for value in values] - ), "types %r" % ([type(v) for v in values]) + assert all([isinstance(value, (tf.Tensor, tf.Variable, float, int, numpy.number)) for value in values]), ( + "types %r" % ([type(v) for v in values]) + ) # Filter out scalars. values = [value for value in values if isinstance(value, (tf.Tensor, tf.Variable))] assert all([value.shape.ndims is not None for value in values]), "some unknown ndim" @@ -2523,14 +2523,15 @@ def get_common_shape(values, ignore_axes=(), allow_broadcast_all_sources=NotSpec common_shape[axis] = static_dim else: # common_shape is int assert isinstance(common_shape[axis], int) - assert ( - common_shape[axis] == static_dim - ), "non matching dim %r vs %r in axis %i, value %r of values %r" % ( - common_shape[axis], - static_dim, - axis, - value, - values, + assert common_shape[axis] == static_dim, ( + "non matching dim %r vs %r in axis %i, value %r of values %r" + % ( + common_shape[axis], + static_dim, + axis, + value, + values, + ) ) # Check validate_broadcast_all_sources need_broadcast = {id(value): False for value in values} @@ -2576,9 +2577,9 @@ def unbroadcast_to_common_shape(value, common_shape, ignore_axes=(), allow_only_ for axis in ignore_axes: assert 0 <= axis < ndim tile_multiples[axis] = 1 - assert all( - [m is not None for m in tile_multiples] - ), "ignore_axes %r probably missing some axis for common shape %r" % (ignore_axes, common_shape) + assert all([m is not None for m in tile_multiples]), ( + "ignore_axes %r probably missing some axis for common shape %r" % (ignore_axes, common_shape) + ) if all([isinstance(m, int) and m == 1 for m in tile_multiples]): # We have a no-op. return value @@ -6611,7 +6612,6 @@ def find_unsupported_devices_in_graph(graph, dev_name, ignore=None): class _DeviceAttrMod: - _tf_mod = None @classmethod @@ -7680,13 +7680,14 @@ def copy_graph(cls, fetches, target_op, fetch_helper_tensors, stop_at_ts=(), ver _, info = copier(sgv, dst_graph=sgv.graph, dst_scope="", reuse_dst_scope=True) assert isinstance(info, graph_editor.TransformerInfo) target_op_transformed = info.transformed(target_op) - assert isinstance( - target_op_transformed, tf.Operation - ), "\ntarget_op\n%r,\nfetches\n%r,\nstop_at_ts\n%s,\nops\n%s" % ( - target_op, - fetches, - pformat(stop_at_ts), - pformat(ops), + assert isinstance(target_op_transformed, tf.Operation), ( + "\ntarget_op\n%r,\nfetches\n%r,\nstop_at_ts\n%s,\nops\n%s" + % ( + target_op, + fetches, + pformat(stop_at_ts), + pformat(ops), + ) ) fetch_helpers = [] for x in fetch_helper_tensors: diff --git a/returnn/torch/data/extern_data.py b/returnn/torch/data/extern_data.py index c512fc0e4..15df9bbf2 100644 --- a/returnn/torch/data/extern_data.py +++ b/returnn/torch/data/extern_data.py @@ -28,12 +28,14 @@ def raw_dict_to_extern_data( extern_data_template: TensorDict, device: Union[str, torch.device], float_dtype: Optional[Union[str, torch.dtype]] = None, + with_eval_targets: bool = False, ) -> TensorDict: """ :param extern_data_raw: This comes out of the DataLoader, via our collate_batch. :param extern_data_template: Specified via `extern_data` in the config. :param device: E.g. the GPU. :param float_dtype: + :param with_eval_targets: if False, we skip all tensors with ``available_for_inference=False``. :return: tensor dict, like extern_data_template, but with raw tensors set to Torch tensors, on the right device. """ if isinstance(float_dtype, str): @@ -47,14 +49,16 @@ def raw_dict_to_extern_data( batch_dim.dyn_size_ext = Tensor(batch_dim.name or "batch", dims=[], dtype="int32") extern_data = TensorDict() for k, data in extern_data_template.data.items(): + if not with_eval_targets and not data.available_for_inference: + continue data = data.copy_template() raw_tensor = extern_data_raw[k] assert len(raw_tensor.shape) == data.batch_ndim, f"ndim mismatch for {k}: {raw_tensor.shape} vs {data}" for i, dim in enumerate(data.dims): if dim.dimension is not None: - assert ( - dim.dimension == raw_tensor.shape[i] - ), f"shape mismatch for {k}: {raw_tensor.shape} vs {data.batch_shape}" + assert dim.dimension == raw_tensor.shape[i], ( + f"shape mismatch for {k}: {raw_tensor.shape} vs {data.batch_shape}" + ) if isinstance(raw_tensor, torch.Tensor): if raw_tensor.dtype.is_floating_point and float_dtype: raw_tensor = raw_tensor.to(dtype=float_dtype) @@ -77,8 +81,7 @@ def raw_dict_to_extern_data( and (data.dims[1].dyn_size_ext is None or data.dims[1].dyn_size_ext.raw_tensor is None) ): assert k + ":seq_len" in extern_data_raw, ( - f"extern_data {data}, dyn spatial dim, missing {k}:seq_len in raw dict, " - f"check dataset or collate_batch" + f"extern_data {data}, dyn spatial dim, missing {k}:seq_len in raw dict, check dataset or collate_batch" ) size = extern_data_raw[k + ":seq_len"] # Sequence lengths have to be on CPU for the later call to rnn.pack_padded_sequence diff --git a/returnn/torch/data/pipeline.py b/returnn/torch/data/pipeline.py index fa6237dde..8609ba89b 100644 --- a/returnn/torch/data/pipeline.py +++ b/returnn/torch/data/pipeline.py @@ -123,7 +123,6 @@ def __iter__(self): chunking_data_keys = list(self._chunk_size.keys()) for data_dict in self._dataset: - if not chunking_data_keys: chunking_data_keys = list(data_dict.keys()) # use all if not configured separately chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs", "epoch", "complete_frac"] @@ -150,9 +149,9 @@ def __iter__(self): if num_chunks is None: num_chunks = len(chunks) else: - assert num_chunks == len( - chunks - ), "Chunking resulted in different number of chunks for different data keys." + assert num_chunks == len(chunks), ( + "Chunking resulted in different number of chunks for different data keys." + ) data_chunks[data_key] = chunks diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index b0b19400b..445dfab05 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -66,42 +66,44 @@ def __init__(self, config: Config): self.model_filename = self.config.value("model", None) self._mp_manager = torch.multiprocessing.Manager() self._epoch_mp_shared = self._mp_manager.Value("i", 0) - self.train_dataset = None # type: Optional[Dataset] + self.train_dataset: Optional[Dataset] = None self.eval_datasets = {} - self.extern_data = None # type: Optional[TensorDict] - self._train_dataloader = None # type: Optional[DataLoader] - self._eval_dataloaders = {} # type: Dict[str, DataLoader] + self.extern_data: Optional[TensorDict] = None + self._train_dataloader: Optional[DataLoader] = None + self._eval_dataloaders: Dict[str, DataLoader] = {} - self._start_epoch = None # type: Optional[int] - self._final_epoch = None # type: Optional[int] - self._min_seq_length = config.typed_value("min_seq_length", None) or config.int( + self._start_epoch: Optional[int] = None + self._final_epoch: Optional[int] = None + self._min_seq_length: Union[int, float, Dict[str, int], NumbersDict] = config.typed_value( "min_seq_length", None - ) # type: Union[int,float,Dict[str,int],NumbersDict] - self._max_seq_length = config.typed_value("max_seq_length", None) or config.int( + ) or config.int("min_seq_length", None) + self._max_seq_length: Union[int, float, Dict[str, int], NumbersDict] = config.typed_value( "max_seq_length", None - ) # type: Union[int,float,Dict[str,int],NumbersDict] - self._orig_model = None # type: Optional[Union[rf.Module, torch.nn.Module]] - self._pt_model = None # type: Optional[torch.nn.Module] - self._train_step_func = None # type: Optional[Callable] - self._forward_step_func = self.config.typed_value("forward_step") # type: Optional[Callable] - self._forward_step_expected_outputs = None # type: Optional[TensorDict] + ) or config.int("max_seq_length", None) + self._orig_model: Optional[Union[rf.Module, torch.nn.Module]] = None + self._pt_model: Optional[torch.nn.Module] = None + self._epoch_start_func: Optional[Callable] = self.config.typed_value("epoch_start") + self._epoch_end_func: Optional[Callable] = self.config.typed_value("epoch_end") + self._train_step_func: Optional[Callable] = None + self._forward_step_func: Optional[Callable] = self.config.typed_value("forward_step") + self._forward_step_expected_outputs: Optional[TensorDict] = None if self.config.typed_value("model_outputs") is not None: self._forward_step_expected_outputs = TensorDict() self._forward_step_expected_outputs.update(self.config.typed_value("model_outputs"), auto_convert=True) self._save_model_epoch_interval = 1 self._ignore_param_set: Set[str] = set() # for the updater and for saving the model checkpoint - self._updater = None # type: Optional[Updater] + self._updater: Optional[Updater] = None self._use_autocast = False - self._autocast_dtype = None # type: Optional[str] - self._grad_scaler = None # type: Optional[amp.GradScaler] + self._autocast_dtype: Optional[str] = None + self._grad_scaler: Optional[amp.GradScaler] = None dev_ = get_device_from_config_opt(config.value("device", None)) self._device = dev_.result print("Using device:", self._device, f"({dev_.reason or '?'})", file=log.v2) - self._torch_distributed_ctx = None # type: Optional[DistributedContext] - self._ddp_pt_model = None # type: Optional[DistributedDataParallel] + self._torch_distributed_ctx: Optional[DistributedContext] = None + self._ddp_pt_model: Optional[DistributedDataParallel] = None if config.typed_value("torch_distributed") is not None: self._torch_distributed_ctx = dist_get_ctx(config=config) @@ -132,6 +134,13 @@ def __init__(self, config: Config): self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False) self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True) + if config.bool("use_tensorboard", False): + from torch.utils.tensorboard import SummaryWriter + + self._tensorboard_writer = SummaryWriter() + else: + self._tensorboard_writer = None + default_float_dtype = config.value("default_float_dtype", None) if default_float_dtype is not None: assert isinstance(default_float_dtype, str) @@ -255,6 +264,9 @@ def train(self): self.init_train_epoch() self.train_epoch() + if self._tensorboard_writer: + self._tensorboard_writer.close() + print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3) def init_train_epoch(self): @@ -319,6 +331,26 @@ def _maybe_report_dev_memory_stats(self): ] print(f"Memory usage ({self._device}):", " ".join(stats), file=log.v1) + def _on_epoch_start(self, *, dataset_name: str): + if self._epoch_start_func: + self._epoch_start_func( + epoch=self.epoch, + step=self.global_train_step, + model=self._orig_model, + dataset_name=dataset_name, + **util.get_fwd_compat_kwargs(), + ) + + def _on_epoch_end(self, *, dataset_name: str): + if self._epoch_end_func: + self._epoch_end_func( + epoch=self.epoch, + step=self.global_train_step, + model=self._orig_model, + dataset_name=dataset_name, + **util.get_fwd_compat_kwargs(), + ) + def train_epoch(self): """ train one (sub)epoch @@ -346,6 +378,8 @@ def train_epoch(self): self._maybe_reset_dev_memory_caches() self._reset_dev_memory_stats() + self._on_epoch_start(dataset_name="train") + if self.config.bool("debug_shell_before_train_loop", False): print("debug_shell_before_train_loop", file=log.v1) debug_shell(user_ns=locals(), user_global_ns=globals(), exit_afterwards=False) @@ -411,6 +445,7 @@ def train_epoch(self): extern_data_template=self.extern_data, device=self._device, float_dtype=self._default_float_dtype, + with_eval_targets=True, ) self._run_step(extern_data, train_flag=True, train_func=True) @@ -481,6 +516,10 @@ def train_epoch(self): batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None, log_memory_usage_device=self._device if self._log_memory_usage else None, ) + if self._tensorboard_writer: + # write losses/errors to tensorboard + for key, val in eval_info.items(): + self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step) if self._stop_on_nonfinite_train_score: if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()): @@ -564,6 +603,8 @@ def _debug_func() -> torch.Tensor: self._maybe_report_dev_memory_stats() + self._on_epoch_end(dataset_name="train") + if self.epoch % self._save_model_epoch_interval == 0 or self.epoch == self._final_epoch: if self.model_filename: self._save_model() @@ -612,6 +653,8 @@ def eval_model(self, *, skip_already_evaluated: bool = False): print(f"Evaluating dataset {dataset_name!r}", file=log.v3) + self._on_epoch_start(dataset_name=dataset_name) + accumulated_losses_dict = NumbersDict() accumulated_inv_norm_factors_dict = NumbersDict() step_idx = 0 @@ -632,6 +675,7 @@ def eval_model(self, *, skip_already_evaluated: bool = False): extern_data_template=self.extern_data, device=self._device, float_dtype=self._default_float_dtype, + with_eval_targets=True, ) self._run_step(extern_data, train_func=True) @@ -665,12 +709,18 @@ def eval_model(self, *, skip_already_evaluated: bool = False): start_elapsed=step_end_time - eval_start_time, log_memory_usage_device=self._device if self._log_memory_usage else None, ) + step_idx += 1 assert step_idx > 0, f"No data in dataset {dataset_name!r}." accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict accumulated_losses_dict = self._maybe_extend_losses_info(accumulated_losses_dict) + if self._tensorboard_writer: + # write losses/errors to tensorboard + for key, val in accumulated_losses_dict.items(): + self._tensorboard_writer.add_scalar(f"{dataset_name}/{key}", val, global_step=self.epoch) + self.learning_rate_control.set_epoch_error( self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()} ) @@ -685,6 +735,8 @@ def eval_model(self, *, skip_already_evaluated: bool = False): _has_data = torch.tensor([False], device="cpu", dtype=torch.int8) torch.distributed.broadcast(_has_data, src=0) + self._on_epoch_end(dataset_name=dataset_name) + if not self._torch_distributed_ctx or self._torch_distributed_ctx.rank() == 0: print( f"Epoch {self.epoch} evaluation:", @@ -942,8 +994,13 @@ def _load_model(self): continue if opts["filename"] is None: print(f"Pre-load (initialize) weights for key '{preload_key}'", file=log.v3) - pattern = opts["pattern"] - match = re.compile(fnmatch.translate(pattern)).match + if opts.get("pattern", None) is not None: + pattern = opts["pattern"] + match = re.compile(fnmatch.translate(pattern)).match + elif opts.get("prefix", None) is not None: + match = re.compile(re.escape(opts["prefix"]) + ".*").fullmatch + else: + raise ValueError(f"preload key {preload_key} without file {opts}: no pattern or prefix given") remove = [] for name in self._pt_model.state_dict().keys(): if match(name) and name in missing_keys: @@ -1467,7 +1524,7 @@ def _print_process( info += ["%.3f sec/step" % step_duration] if start_elapsed is not None: info += ["elapsed %s" % hms(start_elapsed)] - if complete_frac is not None: + if complete_frac not in (-1, None): assert 1 >= complete_frac > 0, f"{step} step, {complete_frac} complete_frac" assert start_elapsed is not None total_time_estimated = start_elapsed / complete_frac diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index 82f210810..3c92a8333 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -421,9 +421,9 @@ def concat( else: # not allow_broadcast for source, dim in sources: templ_dims = other_dims[:axis] + [dim] + other_dims[axis:] - assert set(templ_dims) == set( - source.dims - ), f"concat {source} {dim} not allowed with allow_broadcast=False" + assert set(templ_dims) == set(source.dims), ( + f"concat {source} {dim} not allowed with allow_broadcast=False" + ) source_ = source.copy_transpose(templ_dims) sources_raw.append(source_.raw_tensor) out = Tensor( @@ -612,9 +612,9 @@ def softmax_cross_entropy_with_logits(*, logits: Tensor, targets: Tensor, axis: assert axis in logits.dims, "Specified axis not present in logits." if axis == targets.sparse_dim: - assert ( - logits.dims_set - {axis} == targets.dims_set - ), "logits Dims and target Dims have to match (except for implicit sparse_dim)." + assert logits.dims_set - {axis} == targets.dims_set, ( + "logits Dims and target Dims have to match (except for implicit sparse_dim)." + ) logits_dim_order = list(targets.dims) if len(logits_dim_order) > 0: @@ -629,9 +629,9 @@ def softmax_cross_entropy_with_logits(*, logits: Tensor, targets: Tensor, axis: targets.raw_tensor = targets.raw_tensor.long() else: - assert ( - not targets.sparse_dim - ), "We expect that cross entropy would always be calculated along the sparse dim, if there is one." + assert not targets.sparse_dim, ( + "We expect that cross entropy would always be calculated along the sparse dim, if there is one." + ) assert logits.dims_set == targets.dims_set, "logits Dims and target Dims have to match." assert axis in targets.dims, "Specified axis not present in targets." @@ -1348,12 +1348,12 @@ def matmul(a: _TT, b: _TT, *, reduce: Union[Dim, Sequence[Dim]], use_mask: bool a_dims = a.dims b_dims = b.dims - assert all( - dim in a_dims for dim in reduce - ), f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})" - assert all( - dim in b_dims for dim in reduce - ), f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})" + assert all(dim in a_dims for dim in reduce), ( + f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})" + ) + assert all(dim in b_dims for dim in reduce), ( + f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})" + ) if len(reduce) > 1: reduce = list(reduce) @@ -1859,7 +1859,7 @@ def batch_norm( bias=beta.raw_tensor if affine else None, # training: means whether we should use the current batch statistics # + update the running statistics (if given) - training=rf.get_run_ctx().train_flag or (running_mean is None), + training=rf.get_run_ctx().is_train_flag_enabled(func=rf.BatchNorm.__call__) or (running_mean is None), momentum=momentum, eps=epsilon, ) @@ -2236,7 +2236,7 @@ def lstm( has_biases=has_biases, num_layers=1, dropout=0.0, - train=rf.get_run_ctx().train_flag, + train=rf.get_run_ctx().is_train_flag_enabled(func=rf.LSTM.__call__), bidirectional=False, ) diff --git a/returnn/torch/frontend/bridge.py b/returnn/torch/frontend/bridge.py index ab422aa3b..a44f3ef5a 100644 --- a/returnn/torch/frontend/bridge.py +++ b/returnn/torch/frontend/bridge.py @@ -178,9 +178,9 @@ def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> rf_param = getattr(self._rf_module, name, None) if not isinstance(rf_param, rf.Parameter): return # just ignore - assert isinstance( - param, torch.nn.Parameter - ), f"{self} register_parameter {name}: did not get a Parameter but {type(param).__name__}" + assert isinstance(param, torch.nn.Parameter), ( + f"{self} register_parameter {name}: did not get a Parameter but {type(param).__name__}" + ) rf_param.raw_tensor = param def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: diff --git a/returnn/torch/updater.py b/returnn/torch/updater.py index 791d158d9..682ffc648 100644 --- a/returnn/torch/updater.py +++ b/returnn/torch/updater.py @@ -39,7 +39,7 @@ def _init_optimizer_classes_dict(): def get_optimizer_class( - class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]] + class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]], ) -> Type[torch.optim.Optimizer]: """ :param class_name: Optimizer class, either as str (e.g. "adam"), as type (torch.optim.Adam) or callable. @@ -121,9 +121,9 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0): import inspect signature = inspect.signature(self.learning_rate_function) - assert any( - [arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()] - ), "please specify **kwargs in dynamic_learning_rate for future compatibility" + assert any([arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]), ( + "please specify **kwargs in dynamic_learning_rate for future compatibility" + ) if "network" in signature.parameters: raise ValueError("Torch updater: dynamic_learning_rate network is TF specific") else: @@ -497,10 +497,9 @@ def _get_optimizer_param_groups( # Split in parameter groups only if decouple_constraints is set and the optimizer accepts weight_decay. cls_init_kwargs = _get_class_init_kwargs(optim_class) if "weight_decay" not in cls_init_kwargs: - assert ( - "weight_decay" not in optimizer_opts - ), "weight_decay not accepted by the chosen optimizer. Accepted values: %s" % ", ".join( - "%s" % optim_name for optim_name in cls_init_kwargs + assert "weight_decay" not in optimizer_opts, ( + "weight_decay not accepted by the chosen optimizer. Accepted values: %s" + % ", ".join("%s" % optim_name for optim_name in cls_init_kwargs) ) return network_params @@ -564,7 +563,7 @@ def _get_optimizer_param_groups( def _wrap_user_blacklist_wd_modules( - mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]] + mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]], ) -> Tuple[type, ...]: assert isinstance(mods, (list, tuple)), f"invalid blacklist_weight_decay_modules {mods!r}" res = [] diff --git a/returnn/torch/util/debug_inf_nan.py b/returnn/torch/util/debug_inf_nan.py index 0caa46dd5..de2855790 100644 --- a/returnn/torch/util/debug_inf_nan.py +++ b/returnn/torch/util/debug_inf_nan.py @@ -30,7 +30,6 @@ So we don't stop on the first occurrence but just report all of them. """ - from __future__ import annotations import sys @@ -90,7 +89,6 @@ def debug_inf_nan( print(f"Caught RuntimeError in backward: {exc}", file=file) else: # without grad - with trace_ops: func() diff --git a/returnn/torch/util/exception_helper.py b/returnn/torch/util/exception_helper.py index d3d6ace89..92b0de75b 100644 --- a/returnn/torch/util/exception_helper.py +++ b/returnn/torch/util/exception_helper.py @@ -79,7 +79,7 @@ def help_on_torch_exception( def _help_data_or_array( - value: Union[torch.Tensor, np.ndarray, bool, object] + value: Union[torch.Tensor, np.ndarray, bool, object], ) -> Tuple[str, Tuple[Union[int, float], Union[int, float]]]: """ :param value: diff --git a/returnn/torch/util/scaled_gradient.py b/returnn/torch/util/scaled_gradient.py index 27ff0115a..d28e9ddc7 100644 --- a/returnn/torch/util/scaled_gradient.py +++ b/returnn/torch/util/scaled_gradient.py @@ -14,7 +14,6 @@ https://github.com/tadeephuy/GradientReversal/blob/5d9857d63/gradient_reversal/functional.py """ - from __future__ import annotations from typing import Optional, Union import torch diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 5c2e54078..4ad9d4949 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -705,7 +705,7 @@ def _convert(mo: re.Match) -> str: return delim if mo.group("invalid") is not None: i = mo.start("invalid") - raise ValueError(f"Invalid placeholder in string: {s[i:i+2]!r}...") + raise ValueError(f"Invalid placeholder in string: {s[i : i + 2]!r}...") raise ValueError(f"Unrecognized named group in pattern {pattern}") return pattern_.sub(_convert, s) @@ -1811,7 +1811,6 @@ def json_remove_comments(string, strip_space=True): index = 0 for match in re.finditer(tokenizer, string): - if not (in_multi or in_single): tmp = string[index : match.start()] if not in_string and strip_space: @@ -3153,42 +3152,53 @@ def lock(self): Acquires the lock. """ import time - import errno wait_count = 0 while True: - # Try to create directory if it does not exist. - try: - os.makedirs(self.directory) - except OSError as exc: - # Possible errors: - # ENOENT (No such file or directory), e.g. if some parent directory was deleted. - # EEXIST (File exists), if the dir already exists. - if exc.errno not in [errno.ENOENT, errno.EEXIST]: - # Other error, so reraise. - # Common ones are e.g.: - # ENOSPC (No space left on device) - # EACCES (Permission denied) - raise - # Ignore those errors. + if self.try_lock(): + break + # We did not get the lock. Wait a bit, and then retry. + time.sleep(min(wait_count * 0.1, 1.0)) + wait_count += 1 + if wait_count == 10: + print("Waiting for lock-file: %s" % self.lockfile) + + def try_lock(self) -> bool: + """ + Tries to acquire the lock. + + :return: whether the lock was acquired + """ + + import errno + + # Try to create directory if it does not exist. + try: + os.makedirs(self.directory) + except OSError as exc: + # Possible errors: + # ENOENT (No such file or directory), e.g. if some parent directory was deleted. + # EEXIST (File exists), if the dir already exists. + if exc.errno not in [errno.ENOENT, errno.EEXIST]: + # Other error, so reraise. + # Common ones are e.g.: + # ENOSPC (No space left on device) + # EACCES (Permission denied) + raise + # Ignore those errors. + + for _ in range(2): # Now try to create the lock. try: self.fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR) - return + return True except OSError as exc: - # Possible errors: - # ENOENT (No such file or directory), e.g. if the directory was deleted. - # EEXIST (File exists), if the lock already exists. if exc.errno not in [errno.ENOENT, errno.EEXIST]: - raise # Other error, so reraise. - # We did not get the lock. - # Check if it is a really old one. - self.maybe_remove_old_lockfile() - # Wait a bit, and then retry. - time.sleep(1) - wait_count += 1 - if wait_count == 10: - print("Waiting for lock-file: %s" % self.lockfile) + raise # raise any other error + # We did not get the lock. + # Remove potential stale lockfile before retrying. + self.maybe_remove_old_lockfile() + return False def unlock(self): """ @@ -3673,10 +3683,14 @@ def get_hostname(): def is_running_on_cluster(): """ - :return: i6 specific. Whether we run on some of the cluster nodes. + :return: i6 / Slurm specific. Whether we run on some of the cluster nodes. :rtype: bool """ - return get_hostname().startswith("cluster-cn-") or get_hostname().startswith("cn-") + return ( + get_hostname().startswith("cluster-cn-") + or get_hostname().startswith("cn-") + or os.environ.get("SLURM_JOB_ID", None) + ) start_time = time.time() @@ -4285,7 +4299,7 @@ def cf(filename): return filename # for debugging try: cached_fn = check_output(["cf", filename]).strip().decode("utf8") - except CalledProcessError: + except (CalledProcessError, OSError): if not _cf_msg_printed: print("Cache manager: Error occurred, using local file") _cf_msg_printed = True diff --git a/returnn/util/better_exchook.py b/returnn/util/better_exchook.py index 7d8d74246..2d8821c36 100644 --- a/returnn/util/better_exchook.py +++ b/returnn/util/better_exchook.py @@ -58,6 +58,7 @@ import keyword import inspect import contextlib +from weakref import WeakKeyDictionary try: import typing @@ -1564,6 +1565,9 @@ def get_func_str_from_code_object(co, frame=None): return co.co_name +_func_from_code_object_cache = WeakKeyDictionary() # code object -> function + + def get_func_from_code_object(co, frame=None): """ :param types.CodeType co: @@ -1580,6 +1584,11 @@ def get_func_from_code_object(co, frame=None): import types assert isinstance(co, (types.CodeType, DummyFrame)) + co_is_code_object = isinstance(co, types.CodeType) + if co_is_code_object: + candidate = _func_from_code_object_cache.get(co) + if candidate: + return candidate _attr_name = "__code__" if PY3 else "func_code" if frame and frame.f_code.co_nlocals > 0: func_name = frame.f_code.co_name @@ -1587,18 +1596,23 @@ def get_func_from_code_object(co, frame=None): if frame_self is not None: candidate = getattr(frame_self.__class__, func_name, None) if candidate and (getattr(candidate, _attr_name, None) is co or isinstance(co, DummyFrame)): + if co_is_code_object: + _func_from_code_object_cache[co] = candidate return candidate try: candidate = getattr(_get_loaded_module_from_filename(co.co_filename), co.co_name, None) except ImportError: # some modules have lazy loaders, but those might fail here candidate = None if candidate and (getattr(candidate, _attr_name, None) is co or isinstance(co, DummyFrame)): + if co_is_code_object: + _func_from_code_object_cache[co] = candidate return candidate if isinstance(co, DummyFrame): return None candidates = gc.get_referrers(co) candidates = [f for f in candidates if getattr(f, _attr_name, None) is co] if candidates: + _func_from_code_object_cache[co] = candidates[0] return candidates[0] return None diff --git a/returnn/util/file_cache.py b/returnn/util/file_cache.py index 1002912ad..43af6bee3 100644 --- a/returnn/util/file_cache.py +++ b/returnn/util/file_cache.py @@ -126,11 +126,8 @@ def get_file(self, src_filename: str) -> str: raise e if last_error is not None: raise last_error - info_file = self._get_info_filename(dst_filename) - os.utime(dst_filename, None) - os.utime(info_file, None) # protect info file from tempreaper, which looks at the mtime - self._touch_files_thread.files_extend([dst_filename, info_file]) + self._touch_files_thread.files_extend([dst_filename, self._get_info_filename(dst_filename)]) return dst_filename def release_files(self, filenames: Union[str, Iterable[str]]): @@ -181,6 +178,9 @@ def cleanup(self, *, need_at_least_free_space_size: int = 0): elif self._is_info_filename(fn): # skip keepalive files, they are processed together with the file they guard continue + elif self._is_lock_filename(fn): + # skip lock files, removing them would accidentally release locks + continue try: f_stat = os.stat(fn) except Exception as exc: @@ -203,45 +203,85 @@ def cleanup(self, *, need_at_least_free_space_size: int = 0): for mtime, neg_size, fn in all_files: size = -neg_size delete_reason = None - if cur_time - mtime > self._cleanup_files_always_older_than_days * 60 * 60 * 24: - delete_reason = f"File is {(cur_time - mtime) / 60 / 60 / 24:.1f} days old" - else: - reached_more_recent_files = True - if not delete_reason and need_at_least_free_space_size > cur_expected_free: - # Still must delete some files. - if cur_time - mtime > cur_used_time_threshold: - delete_reason = f"Still need more space, file is {(cur_time - mtime) / 60 / 60:.1f} hours old" - else: - raise Exception( - f"We cannot free enough space on {self.cache_directory}.\n" - f"Needed: {human_bytes_size(need_at_least_free_space_size)},\n" - f"currently available: {human_bytes_size(cur_expected_free)},\n" - f"oldest file is still too recent: {fn}.\n" - f"{report_size_str}" - ) - if not delete_reason and want_free_space_size > cur_expected_free: - if cur_time - mtime > self._cleanup_files_wanted_older_than_days * 60 * 60 * 24: - delete_reason = f"Still want more space, file is {(cur_time - mtime) / 60 / 60:.1f} hours old" - else: - # All further files are even more recent, so we would neither cleanup them, - # so we can also just stop now. - break - if delete_reason: - cur_expected_free += size - print( - f"FileCache: Delete file {fn}, size {human_bytes_size(size)}. {delete_reason}." - f" After deletion, have {human_bytes_size(cur_expected_free)} free space." - ) - try: - os.remove(fn) - except Exception as exc: - print(f"FileCache: Error while removing {fn}: {type(exc).__name__}: {exc}") - cur_expected_free -= size + lock_dir, lock_file_name = self._get_lock_filename(fn) + lock_file = LockFile(directory=lock_dir, name=lock_file_name, lock_timeout=self._lock_timeout) + if not lock_file.try_lock(): + print(f"FileCache: lock for {fn} is currently held, skipping.") + continue + try: + # Re-check mtime with lock, could have been updated by another + # process in the meantime. + # We do not update the `mtime` variable here, because the code + # assumes that the list of files is sorted by mtime to abort + # early when enough space has been made. + # Instead, we treat the case where the mtime was updated during + # cleanup as an outlier and continue as if no other mtimes had + # changed. + # See for discussion: + # - https://github.com/rwth-i6/returnn/issues/1675 + # - https://github.com/rwth-i6/returnn/pull/1709 try: - os.remove(self._get_info_filename(fn)) + cur_mtime = os.stat(fn).st_mtime + except FileNotFoundError: + # File was deleted while waiting for the lock, or because it was + # a temporary copy file and was renamed to its final location. + # Since we don't know whether it was actually deleted or just + # renamed, we leave cur_expected_free unchanged. + continue except Exception as exc: - print(f"FileCache: Ignoring error file removing info file of {fn}: {type(exc).__name__}: {exc}") + print(f"FileCache: Error refreshing mtime of {fn}: {type(exc).__name__}: {exc}") + continue + if cur_mtime > mtime and (time.time() - cur_mtime) <= cur_used_time_threshold: + print(f"FileCache: {fn} has been updated during cleanup, skipping.") + continue + if cur_time - mtime > self._cleanup_files_always_older_than_days * 60 * 60 * 24: + delete_reason = f"File is {(cur_time - mtime) / 60 / 60 / 24:.1f} days old" + else: + reached_more_recent_files = True + if not delete_reason and need_at_least_free_space_size > cur_expected_free: + # Still must delete some files. + if cur_time - mtime > cur_used_time_threshold: + delete_reason = f"Still need more space, file is {(cur_time - mtime) / 60 / 60:.1f} hours old" + else: + raise Exception( + f"We cannot free enough space on {self.cache_directory}.\n" + f"Needed: {human_bytes_size(need_at_least_free_space_size)},\n" + f"currently available: {human_bytes_size(cur_expected_free)},\n" + f"oldest file is still too recent: {fn}.\n" + f"{report_size_str}" + ) + if not delete_reason and want_free_space_size > cur_expected_free: + if cur_time - mtime > self._cleanup_files_wanted_older_than_days * 60 * 60 * 24: + delete_reason = f"Still want more space, file is {(cur_time - mtime) / 60 / 60:.1f} hours old" + else: + # All further files are even more recent, so we would neither cleanup them, + # so we can also just stop now. + break + + if delete_reason: + cur_expected_free += size + print( + f"FileCache: Delete file {fn}, size {human_bytes_size(size)}. {delete_reason}." + f" After deletion, have {human_bytes_size(cur_expected_free)} free space." + ) + try: + os.remove(fn) + except Exception as exc: + if not isinstance(exc, FileNotFoundError): + print(f"FileCache: Error while removing {fn}: {type(exc).__name__}: {exc}") + + # We don't know whether the file was just renamed or actually deleted, so + # we do as if its space has not been freed. + cur_expected_free -= size + try: + os.remove(self._get_info_filename(fn)) + except FileNotFoundError: + pass + except Exception as exc: + print(f"FileCache: Ignoring error file removing info file of {fn}: {type(exc).__name__}: {exc}") + finally: + lock_file.unlock() if reached_more_recent_files and want_free_space_size <= cur_expected_free: # Have enough free space now. @@ -311,6 +351,16 @@ def _get_info_filename(filename: str) -> str: """:return: the name of the corresponding info file to `filename`.""" return f"{filename}.returnn-info" + @staticmethod + def _get_lock_filename(filename: str) -> Tuple[str, str]: + """:return: lock file target directory and lock file name""" + return os.path.dirname(filename), os.path.basename(filename) + ".returnn-lock" + + @staticmethod + def _is_lock_filename(filename: str) -> bool: + """:return: whether `filename` points to a lock file.""" + return filename.endswith(".returnn-lock") + @staticmethod def _is_info_filename(filename: str) -> bool: """:return: whether `filename` points to a info file.""" @@ -324,13 +374,22 @@ def _copy_file_if_needed(self, src_filename: str, dst_filename: str): dst_dir = os.path.dirname(dst_filename) os.makedirs(dst_dir, exist_ok=True) + lock_dir, lock_file = self._get_lock_filename(dst_filename) + info_file_name = self._get_info_filename(dst_filename) + # Copy the file, while holding a lock. See comment on lock_timeout above. with LockFile( - directory=dst_dir, name=os.path.basename(dst_filename) + ".lock", lock_timeout=self._lock_timeout + directory=lock_dir, name=lock_file, lock_timeout=self._lock_timeout ) as lock, self._touch_files_thread.files_added_context(lock.lockfile): # Maybe it was copied in the meantime, while waiting for the lock. if self._check_existing_copied_file_maybe_cleanup(src_filename, dst_filename): print(f"FileCache: using existing file {dst_filename}") + # Update mtime while holding lock, to synchronize with any concurrent cleanup. + # See for discussion: + # - https://github.com/rwth-i6/returnn/issues/1675 + # - https://github.com/rwth-i6/returnn/pull/1709 + os.utime(dst_filename, None) + os.utime(info_file_name, None) return print(f"FileCache: Copy file {src_filename} to cache") @@ -352,7 +411,7 @@ def _copy_file_if_needed(self, src_filename: str, dst_filename: str): with self._touch_files_thread.files_added_context(dst_dir): # save mtime before the copy process to have it pessimistic orig_mtime_ns = os.stat(src_filename).st_mtime_ns - FileInfo(mtime_ns=orig_mtime_ns).save(self._get_info_filename(dst_filename)) + FileInfo(mtime_ns=orig_mtime_ns).save(info_file_name) _copy_with_prealloc(src_filename, dst_tmp_filename) os.rename(dst_tmp_filename, dst_filename) diff --git a/tests/_setup_test_env.py b/tests/_setup_test_env.py index e3d95c77a..57e609992 100644 --- a/tests/_setup_test_env.py +++ b/tests/_setup_test_env.py @@ -8,7 +8,6 @@ See :func:`setup` below for details. """ - from __future__ import annotations diff --git a/tests/rf_utils.py b/tests/rf_utils.py index 2f1e73661..55e7feb4d 100644 --- a/tests/rf_utils.py +++ b/tests/rf_utils.py @@ -119,9 +119,9 @@ def run_model( assert random_journal.reached_end() print("Output PT/TF:", out_pt, out_tf) - assert set(out_pt.data.keys()) == set( - out_tf.data.keys() - ), f"PT output {list(out_pt.data.keys())} vs TF output {list(out_tf.data.keys())}" + assert set(out_pt.data.keys()) == set(out_tf.data.keys()), ( + f"PT output {list(out_pt.data.keys())} vs TF output {list(out_tf.data.keys())}" + ) for k, v_pt in out_pt.data.items(): v_tf = out_tf[k] # We cannot really check the dims directly for equality, diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 306e3c20b..2b6469b88 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Iterator, List, Dict, Optional +from typing import Any, Iterator, TYPE_CHECKING, List, Dict, Optional import os import sys import _setup_test_env # noqa @@ -16,10 +16,14 @@ from returnn.util.basic import NumbersDict from returnn.util import better_exchook +if TYPE_CHECKING: + from returnn.tensor import TensorDict + def dummy_iter_dataset(dataset: Dataset, *, epoch: int = 1) -> List[DatasetSeq]: """ - :param Dataset dataset: + :param dataset: + :param epoch: :return: seqs """ dataset.init_seq_order(epoch=epoch) @@ -457,7 +461,6 @@ def get_seq_len(i): "sort_bin_shuffle:3", "sort_bin_shuffle_x2:.10", ]: - dataset.seq_ordering = seq_ordering # test full epoch @@ -953,7 +956,7 @@ def iter_identity(x, **kwargs): def test_DistributeFilesDataset(): from returnn.datasets.distrib_files import DistributeFilesDataset - from test_HDFDataset import generate_hdf_from_other + from test_HDFDataset import generate_hdf_from_other, get_test_tmp_file # Create a few HDF files such that we can easily verify the data later. hdf_files = [] @@ -1065,6 +1068,57 @@ def _get_subepoch_dset_w_unknown_num_seqs(files_subepoch: List[str]) -> Dict[str concat_dataset.init_seq_order(epoch=sub_epoch) concat_dataset.load_seqs(0, 1) + # Test DFD loading data from a text file + + files_name = get_test_tmp_file(suffix=".txt") + with open(files_name, "wt") as f: + for hdf_file in hdf_files: + f.write(hdf_file + "\n") + with open(files_name, "rt") as f: + print(f.read()) + concat_dataset = init_dataset( + { + "class": "DistributeFilesDataset", + "files": files_name, + "get_sub_epoch_dataset": _get_sub_epoch_dataset, + "partition_epoch": partition_epoch, + } + ) + assert isinstance(concat_dataset, DistributeFilesDataset) + assert concat_dataset.get_data_keys() == ["classes"] + num_hdfs_per_part = num_hdf_files // partition_epoch + global_seq_idx = 0 + for sub_epoch in range(1, partition_epoch + 1): + print(f"Sub-epoch {sub_epoch}...") + concat_dataset.init_seq_order(sub_epoch) + if sub_epoch == 1: + assert concat_dataset._files_order_cache == { + 0: [hdf_files[ep * num_hdfs_per_part : (ep + 1) * num_hdfs_per_part] for ep in range(partition_epoch)] + } + assert ( + concat_dataset._workers[sub_epoch].dataset_dict["files"] + == hdf_files[(sub_epoch - 1) * num_hdfs_per_part : sub_epoch * num_hdfs_per_part] + ) + # We preload one sub-epoch. + assert set(concat_dataset._workers.keys()) == {sub_epoch, sub_epoch + 1} # cur sub-epoch + next sub-epoch + for ep, worker in concat_dataset._workers.items(): + assert worker.worker_proc.is_alive() + next_part_idx = sub_epoch % partition_epoch # wrap around at the end + assert ( + concat_dataset._workers[sub_epoch + 1].dataset_dict["files"] + == hdf_files[next_part_idx * num_hdfs_per_part : (next_part_idx + 1) * num_hdfs_per_part] + ) + local_seq_idx = 0 + while concat_dataset.is_less_than_num_seqs(local_seq_idx): + print(f"Sub-epoch {sub_epoch}, seq {local_seq_idx} (global seq {global_seq_idx})...") + concat_dataset.load_seqs(local_seq_idx, local_seq_idx + 1) + data = concat_dataset.get_data(local_seq_idx, "classes") + assert data.shape == (seq_len,) + assert data.tolist() == list(range(global_seq_idx * seq_len, (global_seq_idx + 1) * seq_len)) + local_seq_idx += 1 + global_seq_idx += 1 + assert global_seq_idx == num_hdf_files * num_seqs + def test_PostprocessingDataset(): from returnn.tensor.tensor_dict import TensorDict @@ -1159,6 +1213,36 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict] assert func(2) == 21 +def _post_process_map_seq_no_op(tdict: TensorDict, **_other) -> TensorDict: + return tdict + + +def test_PostprocessingDataset_pickle(): + from returnn.datasets.postprocessing import PostprocessingDataset + import pickle + + ds0_opts = {"class": "DummyDataset", "input_dim": 2, "output_dim": 3, "num_seqs": 20} + ds0 = init_dataset(ds0_opts) + ds0_seqs = dummy_iter_dataset(ds0) + + ds_opts = { + "class": "PostprocessingDataset", + "dataset": ds0_opts, + "map_seq": _post_process_map_seq_no_op, + } + ds = init_dataset(ds_opts) + assert isinstance(ds, PostprocessingDataset) + ds_seqs = dummy_iter_dataset(ds) + compare_dataset_seqs(ds0_seqs, ds_seqs) + + s = pickle.dumps(ds) + ds2 = pickle.loads(s) + assert isinstance(ds2, PostprocessingDataset) + assert ds2 is not ds + ds2_seqs = dummy_iter_dataset(ds2) + compare_dataset_seqs(ds0_seqs, ds2_seqs) + + def test_MultiEpochDataset(): from returnn.datasets.meta import MultiEpochDataset from returnn.datasets.cached2 import CachedDataset2 diff --git a/tests/test_HDFDataset.py b/tests/test_HDFDataset.py index fb30ff21f..e3a3e75b9 100644 --- a/tests/test_HDFDataset.py +++ b/tests/test_HDFDataset.py @@ -576,6 +576,81 @@ def test_SimpleHDFWriter_labels(): assert reader.seq_lens[i]["data"] == seq_len +def test_SimpleHDFWriter_extra_with_feat(): + from returnn.tensor import TensorDict, Dim, batch_dim + from returnn.tensor.utils import tensor_dict_fill_random_numpy_ + + fn = get_test_tmp_file(suffix=".hdf") + os.remove(fn) # SimpleHDFWriter expects that the file does not exist + + # See i6_experiments.users.zeyer.forward_to_hdf._returnn_get_forward_callback. + # (Maybe also somewhat similar: HDFDumpLayer) + spatial_dim = Dim(None, name="spatial") + k_dim = Dim(5, name="k") + vocab_dim = Dim(11, name="vocab") + expected_outputs = TensorDict() + expected_outputs.update( + { + "output": {"dims": [batch_dim, spatial_dim, k_dim], "dtype": "int32", "sparse_dim": vocab_dim}, + "k_log_probs": {"dims": [batch_dim, spatial_dim, k_dim], "dtype": "float32"}, + }, + auto_convert=True, + ) + output = expected_outputs["output"] + + writer = SimpleHDFWriter( + filename=fn, + dim=output.dim, + ndim=output.ndim, + labels=output.vocab and output.vocab.labels, + # Note: in HDFDumpLayer, we use ndim=min(v.ndim - len(v.size_placeholder) + 1, 2) instead of just v.ndim... + extra_type={k: (v.shape[-1], v.ndim, v.dtype) for k, v in expected_outputs.data.items() if k != "output"}, + extra_labels={k: v.vocab.labels for k, v in expected_outputs.data.items() if k != "output" and v.vocab}, + ) + + expected_outputs.reset_content() + tensor_dict_fill_random_numpy_(expected_outputs) + + # Note: HDFDumpLayer does a lot of extra stuff for the extra data... + # We are following i6_experiments.users.zeyer.forward_to_hdf here. + # We iterate over the batch dim, to insert each seq separately, + # which is important that the extra data seq lens are correct + # (insert_batch would not know this otherwise). + batch_size = output.raw_tensor.shape[0] + for b in range(batch_size): + writer.insert_batch( + inputs=output.raw_tensor[b : b + 1, : spatial_dim.dyn_size[b]], + # We are following i6_experiments.users.zeyer.forward_to_hdf here. + # This logic then in turn will flatten the dims internally. + seq_len={0: spatial_dim.dyn_size[b : b + 1], 1: [k_dim.dimension]}, + seq_tag=[f"seq-{b}"], + extra={ + k: v.raw_tensor[b : b + 1, : spatial_dim.dyn_size[b]] + for k, v in expected_outputs.data.items() + if k != "output" + }, + ) + + writer.close() + + dataset = HDFDataset(files=[fn]) + reader = DatasetTestReader(dataset=dataset) + reader.read_all() + assert reader.num_seqs > 0 + assert set(reader.data.keys()) == {"data", "k_log_probs", "sizes"} + for i in range(reader.num_seqs): + # data dims will be flattened. + data0 = reader.data["data"][i] + print(f"** seq_idx {i} data key 'data' shape {data0.shape}") + assert data0.ndim == 1 and data0.shape[0] % k_dim.dimension == 0 + time0 = data0.shape[0] // k_dim.dimension + sizes0 = reader.data["sizes"][i] + assert sizes0.shape == (2,) + assert sizes0.tolist() == [time0, k_dim.dimension] + probs0 = reader.data["k_log_probs"][i] + assert probs0.shape == (time0, k_dim.dimension) + + class Old2018HDFDataset(CachedDataset): """ Copied and adapted from an early RETURNN version: diff --git a/tests/test_TFNativeOp.py b/tests/test_TFNativeOp.py index d820976e1..077c4f9c5 100644 --- a/tests/test_TFNativeOp.py +++ b/tests/test_TFNativeOp.py @@ -2144,17 +2144,13 @@ def search(): :rtype: list[dict[int,(float,int)]] """ start_idx, _ = start_end_states[:, sequence_idx] - states = defaultdict( - lambda: (zero_score, -1) - ) # type: typing.Dict[int,typing.Tuple[float,int]] # state-idx -> score/edge # nopep8 + states = defaultdict(lambda: (zero_score, -1)) # type: typing.Dict[int,typing.Tuple[float,int]] # state-idx -> score/edge # nopep8 states[start_idx] = (0.0, -1) res = [] # type: typing.List[typing.Dict[int,typing.Tuple[float,int]]] for t in range(n_time): if t >= am_seq_len[sequence_idx]: break - scores = defaultdict( - list - ) # type: typing.Dict[int,typing.List[typing.Tuple[float,int]]] # state-idx -> list[score/edge] # nopep8 + scores = defaultdict(list) # type: typing.Dict[int,typing.List[typing.Tuple[float,int]]] # state-idx -> list[score/edge] # nopep8 for edge_idx in range(n_edges): from_idx, to_idx, emission_idx, sequence_idx_ = edges[:, edge_idx] if sequence_idx_ != sequence_idx: diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index c3d82bf20..7c5f55e25 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -1519,9 +1519,9 @@ def test_cnn_building_block(): seq_len = 10 seq_lens = numpy.array([10, 10, 10, 10, 10], dtype=numpy.int32) feed = { - network.extern_data.get_default_input_data() - .placeholder: numpy.random.rand(n_batch, seq_len, num_inputs) - .astype("f"), + network.extern_data.get_default_input_data().placeholder: numpy.random.rand( + n_batch, seq_len, num_inputs + ).astype("f"), network.extern_data.get_default_input_data().size_placeholder[0]: seq_lens, } v = session.run(out, feed_dict=feed) diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 2f6d99302..ff6d2ad80 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -13524,8 +13524,7 @@ def test_rel_pos_self_attention_left_ctx_explicit_vs_layer(): "conformer_block_01_self_att_ln_rel_pos_enc": { "class": "relative_positional_encoding", "clipping": 16, - "forward_weights_init": "variance_scaling_initializer(mode='fan_avg', distribution='uniform', " - "scale=1.0)", + "forward_weights_init": "variance_scaling_initializer(mode='fan_avg', distribution='uniform', scale=1.0)", "from": "conformer_block_01_self_att_ln_concat", # [B*C, 2*W, D] "n_out": n_total_dim // n_head, }, @@ -13621,8 +13620,7 @@ def test_rel_pos_self_attention_left_ctx_explicit_vs_layer(): "conformer_block_01_self_att_ln_rel_pos_enc": { "class": "relative_positional_encoding", "clipping": 16, - "forward_weights_init": "variance_scaling_initializer(mode='fan_avg', distribution='uniform', " - "scale=1.0)", + "forward_weights_init": "variance_scaling_initializer(mode='fan_avg', distribution='uniform', scale=1.0)", "from": "conformer_block_01_self_att_ln_concat", "key_value_spatial_dim": concat_window_dim, "out_dim": enc_dim_per_head_dim, @@ -13903,9 +13901,9 @@ def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer(): value_dim=value_dim, ) net_dict["output"]["unit"]["multi_layer_att"]["is_output_layer"] = True - net_dict["output"]["unit"]["multi_layer_qkv0"][ - "is_output_layer" - ] = True # we need to set the matrix here + net_dict["output"]["unit"]["multi_layer_qkv0"]["is_output_layer"] = ( + True # we need to set the matrix here + ) else: net_dict = { "single_layer_att": { diff --git a/tests/test_TranslationDataset.py b/tests/test_TranslationDataset.py index 3c1616d3e..b3fb75acc 100644 --- a/tests/test_TranslationDataset.py +++ b/tests/test_TranslationDataset.py @@ -15,9 +15,7 @@ dummy_source_text = ( - "This is some example text.\n" - "It is used to test the translation dataset.\n" - "We will write it into a temporary file.\n" + "This is some example text.\nIt is used to test the translation dataset.\nWe will write it into a temporary file.\n" ) dummy_target_text = ( @@ -67,7 +65,6 @@ def test_translation_dataset(): target_file.write(dummy_target_text.encode("utf-8")) for postfix in ["", " "]: # test with and without postfix - # Replace one word by . # This way it will not appear in the vocabulary (and is added to the vocabulary). # We will test below whether this word is assigned the unknown id by checking whether the reconstruction also @@ -114,18 +111,18 @@ def test_translation_dataset(): num_source_factors = 2 -dummy_source_text_factor_0 = "This is some example text.\n" "The factors here have no meaning\n" -dummy_source_text_factor_1 = "a b c d e\n" "a b c d e f\n" +dummy_source_text_factor_0 = "This is some example text.\nThe factors here have no meaning\n" +dummy_source_text_factor_1 = "a b c d e\na b c d e f\n" dummy_source_text_factored_format = ( - "This|a is|b some|c example|d text.|e\n" "The|a factors|b here|c have|d no|e meaning|f\n" + "This|a is|b some|c example|d text.|e\nThe|a factors|b here|c have|d no|e meaning|f\n" ) num_target_factors = 3 -dummy_target_text_factor_0 = "Das ist ein Beispieltext.\n" "Die Factors hier haben keinen Sinn.\n" -dummy_target_text_factor_1 = "a b c d\n" "a b c d e f\n" -dummy_target_text_factor_2 = "1 2 3 4\n" "1 2 3 4 5 6\n" +dummy_target_text_factor_0 = "Das ist ein Beispieltext.\nDie Factors hier haben keinen Sinn.\n" +dummy_target_text_factor_1 = "a b c d\na b c d e f\n" +dummy_target_text_factor_2 = "1 2 3 4\n1 2 3 4 5 6\n" dummy_target_text_factored_format = ( - "Das|a|1 ist|b|2 ein|c|3 Beispieltext.|d|4\n" "Die|a|1 Factors|b|2 hier|c|3 haben|d|4 keinen|e|5 Sinn.|f|6\n" + "Das|a|1 ist|b|2 ein|c|3 Beispieltext.|d|4\nDie|a|1 Factors|b|2 hier|c|3 haben|d|4 keinen|e|5 Sinn.|f|6\n" ) diff --git a/tests/test_rf_base.py b/tests/test_rf_base.py index 22e168eee..8a32cf1d5 100644 --- a/tests/test_rf_base.py +++ b/tests/test_rf_base.py @@ -68,6 +68,21 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): # Now come some tests for some base functionality. +def test_train_flag(): + rf.init_train_step_run_ctx(train_flag=False) + assert rf.get_run_ctx().train_flag is False + assert rf.get_run_ctx().is_train_flag_enabled(func=rf.dropout) is False + with rf.get_run_ctx().train_flag_ctx(True): + assert rf.get_run_ctx().train_flag is True + assert rf.get_run_ctx().is_train_flag_enabled(func=rf.dropout) is True + with rf.get_run_ctx().train_flag_ctx(False): + assert rf.get_run_ctx().train_flag is False + assert rf.get_run_ctx().is_train_flag_enabled(func=rf.dropout) is False + with rf.get_run_ctx().train_flag_ctx(True, func=rf.dropout): + assert rf.get_run_ctx().train_flag is False + assert rf.get_run_ctx().is_train_flag_enabled(func=rf.dropout) is True + + def test_state(): # https://github.com/rwth-i6/returnn/issues/1329 import tree diff --git a/tests/test_torch_frontend.py b/tests/test_torch_frontend.py index fdcedd3af..b285f09f3 100644 --- a/tests/test_torch_frontend.py +++ b/tests/test_torch_frontend.py @@ -400,7 +400,6 @@ def _loss_rf_padded(logits: Tensor, targets: Tensor) -> torch.Tensor: prev_bias_grad = None for loss_fn in [_loss_pt_packed, _loss_rf_padded, _loss_rf_packed]: - torch.manual_seed(42) batch_dim = Dim(dimension=3, name="batch") @@ -564,8 +563,7 @@ def _filter_rf_allocs_list(allocs: List[Dict[str, Any]]) -> List[Dict[str, Any]] print("Naive:", allocs_naive) assert len(allocs_rf) == len(allocs_naive) == 1 assert ( - list(allocs_rf.values())[0]["size"] - == list(allocs_naive.values())[0]["size"] + list(allocs_rf.values())[0]["size"] == list(allocs_naive.values())[0]["size"] # On CPU, it should match, but on GPU, it will allocate more. # == rf_pack_padded_res.numel() * sizeof_float ) diff --git a/tools/dump-dataset.py b/tools/dump-dataset.py index 38f3da192..944a98dcb 100755 --- a/tools/dump-dataset.py +++ b/tools/dump-dataset.py @@ -18,6 +18,7 @@ import argparse import numpy from returnn.datasets import init_dataset, Dataset +from returnn.datasets.util.vocabulary import Vocabulary from returnn.util.basic import Stats, hms, hms_fraction, pretty_print, NumbersDict from returnn.util import basic as util @@ -53,9 +54,9 @@ def dump_dataset(options): ", ".join(f"{k!r}: {v[:3]}... len {len(v)}" for k, v in dataset.labels.items()) or "None", file=log.v3, ) - assert ( - options.key in dataset.get_data_keys() - ), f"key {options.key!r} not in {dataset.get_data_keys()} (targets {dataset.get_target_list()})" + assert options.key in dataset.get_data_keys(), ( + f"key {options.key!r} not in {dataset.get_data_keys()} (targets {dataset.get_target_list()})" + ) max_seq_length = NumbersDict(options.max_seq_length) min_seq_length = NumbersDict(options.min_seq_length) @@ -87,6 +88,13 @@ def dump_dataset(options): print("Done.") return + # Inlined and cached can_serialize_data / serialize_data. + vocabs = {} + for key in dataset.get_data_keys(): + labels = dataset.labels.get(key) + if labels and len(labels) > 1: + vocabs[key] = Vocabulary.create_vocab_from_labels(labels) + dump_file = None if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) @@ -192,9 +200,8 @@ def dump_dataset(options): elif options.type == "stdout": print("seq %s tag:" % progress, dataset.get_tag(seq_idx)) extra = "" - if "data" in dataset.labels and len(dataset.labels["data"]) > 1: - assert dataset.can_serialize_data("data") - extra += " (%r)" % dataset.serialize_data(key="data", data=data) + if "data" in vocabs: + extra += " (%r)" % vocabs["data"].serialize_labels(data) print("seq %s data: %s%s" % (progress, pretty_print(data), extra)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) @@ -210,9 +217,8 @@ def dump_dataset(options): ) elif options.type == "stdout": extra = "" - if target in dataset.labels and len(dataset.labels[target]) > 1: - assert dataset.can_serialize_data(target) - extra += " (%r)" % dataset.serialize_data(key=target, data=targets) + if target in vocabs: + extra += " (%r)" % vocabs[target].serialize_labels(targets) print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) diff --git a/tools/hdf_dump_translation_dataset.py b/tools/hdf_dump_translation_dataset.py index 8aaf233e8..372ee97c8 100755 --- a/tools/hdf_dump_translation_dataset.py +++ b/tools/hdf_dump_translation_dataset.py @@ -264,9 +264,7 @@ def _write_lengths(self, source_lengths, target_lengths): target_sequence_lengths[data_key] = target_lengths # Now sort by key. - key_lengths_tuples_sorted = sorted( - target_sequence_lengths.items(), key=lambda x: x[0] - ) # type: typing.List[typing.Tuple[str,typing.List[int]]] # nopep8 + key_lengths_tuples_sorted = sorted(target_sequence_lengths.items(), key=lambda x: x[0]) # type: typing.List[typing.Tuple[str,typing.List[int]]] # nopep8 target_lengths_sorted = [key_length_tuple[1] for key_length_tuple in key_lengths_tuples_sorted] # Finally, add one time the source lengths for the input ("data") and convert to numpy. @@ -320,10 +318,10 @@ def _finalize_data(self): """ # Make sure the number of lines given by the user was correct. # Otherwise lengths and labels would have trailing zeros. - assert ( - self.number_of_lines == self._number_of_processed_lines - ), "Fewer lines ({}) in the corpus files " "than specified ({}).".format( - self._number_of_processed_lines, self.number_of_lines + assert self.number_of_lines == self._number_of_processed_lines, ( + "Fewer lines ({}) in the corpus files than specified ({}).".format( + self._number_of_processed_lines, self.number_of_lines + ) ) # Trim datasets to actually occupied length, i.e. remove unused reserved space. @@ -349,9 +347,9 @@ def _line_to_indices(self, line, side): else: if words: words_split_into_factors = [word.split(self.factor_separator) for word in words] - assert all( - len(factors) == len(data_keys) for factors in words_split_into_factors - ), "All words must have all factors. Expected: " + self.factor_separator.join(data_keys) + assert all(len(factors) == len(data_keys) for factors in words_split_into_factors), ( + "All words must have all factors. Expected: " + self.factor_separator.join(data_keys) + ) word_list_per_factor = zip(*words_split_into_factors) else: word_list_per_factor = [[]] * len(data_keys) @@ -431,7 +429,7 @@ def parse_args(): "-d", "--data_buffer_size", type=int, - help="How much space to reserve in the HDF dataset " "at once (in number of integers).", + help="How much space to reserve in the HDF dataset at once (in number of integers).", default=5000000, ) diff --git a/tools/tf_avg_checkpoints.py b/tools/tf_avg_checkpoints.py index 05e43e066..83624ab3b 100755 --- a/tools/tf_avg_checkpoints.py +++ b/tools/tf_avg_checkpoints.py @@ -95,7 +95,7 @@ def main(_): tf_compat.v1.logging.info("%s ", c) var_list = tf.train.list_variables(checkpoints[0]) var_values, var_dtypes = {}, {} - for (name, shape) in var_list: + for name, shape in var_list: var_values[name] = numpy.zeros(shape) for checkpoint in checkpoints: reader = tf.train.load_checkpoint(checkpoint) diff --git a/tools/tf_inspect_checkpoint.py b/tools/tf_inspect_checkpoint.py index 503a6c3c0..900561d2a 100755 --- a/tools/tf_inspect_checkpoint.py +++ b/tools/tf_inspect_checkpoint.py @@ -135,7 +135,7 @@ def main(unused_argv): Main entry: """ if not FLAGS.file_name: - print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " "[--tensor_name=tensor_to_print]") + print("Usage: inspect_checkpoint --file_name=checkpoint_file_name [--tensor_name=tensor_to_print]") sys.exit(1) else: print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, FLAGS.all_tensors) diff --git a/tools/torch_export_to_onnx.py b/tools/torch_export_to_onnx.py index 2973f771d..df3fb5eb7 100644 --- a/tools/torch_export_to_onnx.py +++ b/tools/torch_export_to_onnx.py @@ -187,9 +187,9 @@ def main(): init(config_filename=args.config, checkpoint=args.checkpoint, log_verbosity=args.verbosity, device=args.device) model_outputs_dict = config.typed_value("model_outputs") - assert ( - model_outputs_dict is not None - ), "The specified config needs to have explicit model outputs. Please define `model_outputs` in your config." + assert model_outputs_dict is not None, ( + "The specified config needs to have explicit model outputs. Please define `model_outputs` in your config." + ) model_outputs = TensorDict() model_outputs.update(model_outputs_dict, auto_convert=True) @@ -207,9 +207,9 @@ def main(): is_rf_module = isinstance(model, rf.Module) is_pt_module = isinstance(model, torch.nn.Module) - assert ( - is_rf_module or is_pt_module - ), "The module returned by get_model() isn't a returnn.frontend.Module or a torch.nn.Module." + assert is_rf_module or is_pt_module, ( + "The module returned by get_model() isn't a returnn.frontend.Module or a torch.nn.Module." + ) export_func = config.typed_value("export") or torch.onnx.export forward_step_func = config.typed_value("forward_step") @@ -252,16 +252,16 @@ def main(): if args.input_names: input_names = args.input_names.split(",") - assert set(extern_data_raw.keys()) == set( - input_names - ), f"missmatch between input_names {input_names} and extern_data keys {list(extern_data_raw.keys())}" + assert set(extern_data_raw.keys()) == set(input_names), ( + f"missmatch between input_names {input_names} and extern_data keys {list(extern_data_raw.keys())}" + ) else: input_names = list(extern_data_raw.keys()) if args.output_names: output_names = args.output_names.split(",") - assert set(model_outputs_raw_keys) == set( - output_names - ), f"missmatch between output_names {output_names} and model_outputs keys {model_outputs_raw_keys}" + assert set(model_outputs_raw_keys) == set(output_names), ( + f"missmatch between output_names {output_names} and model_outputs keys {model_outputs_raw_keys}" + ) else: output_names = model_outputs_raw_keys diff --git a/tools/torch_scale_tuning.py b/tools/torch_scale_tuning.py index 0cfb5533b..add3f6bd9 100755 --- a/tools/torch_scale_tuning.py +++ b/tools/torch_scale_tuning.py @@ -176,7 +176,7 @@ def main(): for i in range(0, len(seq_list_ordered_by_len), args.batch_size): print( f"Batch {len(batches)}, seqs {i} - {min(i + args.batch_size, len(seq_list_ordered_by_len))}" - f" / {len(seq_list_ordered_by_len)}, {i/len(seq_list_ordered_by_len)*100:.1f}%" + f" / {len(seq_list_ordered_by_len)}, {i / len(seq_list_ordered_by_len) * 100:.1f}%" ) batch_seq_tag_list: List[str] = seq_list_ordered_by_len[i : i + args.batch_size]