diff --git a/todo.txt b/todo.txt index 47a96c5..6a8bbf0 100644 --- a/todo.txt +++ b/todo.txt @@ -5,3 +5,5 @@ https://github.com/JasonSWFu/VQscore https://github.com/soumimaiti/speechlmscore_tool Other metrics with pre-trained models https://github.com/Ashvala/AQUA-Tk?tab=readme-ov-file + +https://github.com/FireRedTeam/FireRedASR/tree/main \ No newline at end of file diff --git a/versa/corpus_metrics/whisper_wer.py b/versa/corpus_metrics/whisper_wer.py index 2a4b218..78e5afb 100644 --- a/versa/corpus_metrics/whisper_wer.py +++ b/versa/corpus_metrics/whisper_wer.py @@ -19,13 +19,14 @@ whisper = None from espnet2.text.cleaner import TextCleaner +from espnet2.text.phoneme_tokenizer import PhonemeTokenizer TARGET_FS = 16000 CHUNK_SIZE = 30 # seconds def whisper_wer_setup( - model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True + model_tag="default", beam_size=5, text_cleaner="whisper_basic", calc_per=False, use_gpu=True ): if model_tag == "default": model_tag = "large" @@ -37,6 +38,14 @@ def whisper_wer_setup( model = whisper.load_model(model_tag, device=device) textcleaner = TextCleaner(text_cleaner) wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} + if calc_per is True: + g2p = { + "zh": PhonemeTokenizer("pypinyin_g2p_phone_without_prosody"), + "ja": PhonemeTokenizer("pyopenjtalk"), + "en": PhonemeTokenizer("g2p_en"), + # To support additional languages, add corresponding g2p modules here. + } + wer_utils.update(g2p=g2p) return wer_utils @@ -63,9 +72,11 @@ def whisper_levenshtein_metric( pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) fs = TARGET_FS with torch.no_grad(): - inf_text = wer_utils["model"].transcribe( + results = wer_utils["model"].transcribe( torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] - )["text"] + ) + inf_text = results["text"] + lang = results["language"] ref_text = wer_utils["cleaner"](ref_text).strip() pred_text = wer_utils["cleaner"](inf_text).strip() @@ -126,6 +137,39 @@ def whisper_levenshtein_metric( ) assert total == len(pred_words), (total, len(pred_words)) + # process per + if "g2p" in wer_utils.keys(): + assert lang in wer_utils["g2p"].keys(), f"Not support g2p for {lang} language" + ref_words = wer_utils["g2p"][lang].text2tokens(ref_text) + pred_words = wer_utils["g2p"][lang].text2tokens(pred_text) + + ref_words = [p for phn in ref_words for p in phn.strip().split('_')] + pred_words = [p for phn in pred_words for p in phn.strip().split('_')] + + ret.update( + whisper_per_delete=0, + whisper_per_insert=0, + whisper_per_replace=0, + whisper_per_equal=0, + ) + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["whisper_per_" + op] = ret["whisper_per_" + op] + inf_et - inf_st + else: + ret["whisper_per_" + op] = ret["whisper_per_" + op] + ref_et - ref_st + total = ( + ret["whisper_per_delete"] + + ret["whisper_per_replace"] + + ret["whisper_per_equal"] + ) + assert total == len(ref_words), (total, len(ref_words)) + total = ( + ret["whisper_per_insert"] + + ret["whisper_per_replace"] + + ret["whisper_per_equal"] + ) + assert total == len(pred_words), (total, len(pred_words)) + return ret diff --git a/versa/scorer_shared.py b/versa/scorer_shared.py index cecd525..da13438 100644 --- a/versa/scorer_shared.py +++ b/versa/scorer_shared.py @@ -602,6 +602,7 @@ def load_score_modules(score_config, use_gt=True, use_gt_text=False, use_gpu=Fal model_tag=config.get("model_tag", "default"), beam_size=config.get("beam_size", 1), text_cleaner=config.get("text_cleaner", "whisper_basic"), + calc_per=config.get("calc_per", False), use_gpu=use_gpu, ) @@ -1392,7 +1393,7 @@ def load_summary(score_info): # NOTE(jiatong): skip text cases continue summary[key] = sum([score[key] for score in score_info]) - if "_wer" not in key and "_cer" not in key: + if "_wer" not in key and "_cer" not in key and "_per" not in key: # Average for non-WER/CER metrics summary[key] /= len(score_info) return summary