Skip to content

Commit e750c2d

Browse files
update run_evaluate script for cased itn (#164)
* update run_evaluate script for cased itn Signed-off-by: Mariana Graterol Fuenmayor <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Mariana Graterol Fuenmayor <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cb47029 commit e750c2d

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

nemo_text_processing/inverse_text_normalization/run_evaluate.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,14 @@ def parse_args():
3333
parser = ArgumentParser()
3434
parser.add_argument("--input", help="input file path", type=str)
3535
parser.add_argument(
36-
"--lang", help="language", choices=['en', 'de', 'es', 'pt', 'ru', 'fr', 'vi', 'hy'], default="en", type=str
36+
"--lang",
37+
help="language",
38+
choices=["ar", "de", "en", "es", "es_en", "fr", "hy", "mr", "pt", "ru", "sv", "vi", "zh"],
39+
default="en",
40+
type=str,
3741
)
42+
parser.add_argument("--input_case", choices=["lower_cased", "cased"])
43+
parser.add_argument("--output_case", choices=["lower_cased", "cased"])
3844
parser.add_argument(
3945
"--cat",
4046
dest="category",
@@ -54,10 +60,15 @@ def parse_args():
5460
if args.lang == 'en':
5561
from nemo_text_processing.inverse_text_normalization.en.clean_eval_data import filter_loaded_data
5662
file_path = args.input
57-
inverse_normalizer = InverseNormalizer(lang=args.lang)
63+
inverse_normalizer = InverseNormalizer(lang=args.lang, input_case=args.input_case)
5864

5965
print("Loading training data: " + file_path)
60-
training_data = load_files([file_path])
66+
if args.output_case == "lower_cased":
67+
to_lower = True
68+
elif args.output_case == "cased":
69+
to_lower = False
70+
71+
training_data = load_files([file_path], to_lower=to_lower)
6172

6273
if args.filter:
6374
training_data = filter_loaded_data(training_data)

nemo_text_processing/text_normalization/data_loader_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
]
4747

4848

49-
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
49+
def _load_kaggle_text_norm_file(file_path: str, to_lower: bool) -> List[Instance]:
5050
"""
5151
https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
5252
Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
@@ -76,8 +76,9 @@ def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
7676
res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
7777
else:
7878
l_type, l_token, l_normalized = parts
79-
l_token = l_token.lower()
80-
l_normalized = l_normalized.lower()
79+
if to_lower:
80+
l_token = l_token.lower()
81+
l_normalized = l_normalized.lower()
8182

8283
if l_type == PLAIN_TYPE:
8384
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
@@ -86,7 +87,7 @@ def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
8687
return res
8788

8889

89-
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
90+
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file, to_lower: bool = True) -> List[Instance]:
9091
"""
9192
Load given list of text files using the `load_func` function.
9293
@@ -98,7 +99,7 @@ def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) ->
9899
"""
99100
res = []
100101
for file_path in file_paths:
101-
res.extend(load_func(file_path=file_path))
102+
res.extend(load_func(file_path=file_path, to_lower=to_lower))
102103
return res
103104

104105

0 commit comments

Comments
 (0)