Skip to content

Commit 355c9bb

Browse files
committed
Vocabulary, allow callable labels
1 parent f38f5ae commit 355c9bb

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

returnn/datasets/util/vocabulary.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,16 @@ def __init__(self, vocab_file, seq_postfix=None,
6262
https://github.com/google/sentencepiece/blob/master/doc/special_symbols.md
6363
:param int num_labels: just for verification
6464
:param list[int]|None seq_postfix: labels will be added to the seq in self.get_seq
65-
:param list[str]|None labels:
65+
:param list[str]|(()->list[str])|None labels:
6666
"""
6767
self.vocab_file = vocab_file
6868
self.unknown_label = unknown_label
6969
self.num_labels = None # type: typing.Optional[int] # will be set by _parse_vocab
7070
self._vocab = None # type: typing.Optional[typing.Dict[str,int]] # label->idx
71+
if labels is not None and callable(labels):
72+
labels = labels()
73+
if labels is not None:
74+
assert isinstance(labels, (list, tuple))
7175
self._labels = labels
7276

7377
self._parse_vocab()

0 commit comments

Comments
 (0)