Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions todo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ https://github.com/JasonSWFu/VQscore
https://github.com/soumimaiti/speechlmscore_tool
Other metrics with pre-trained models
https://github.com/Ashvala/AQUA-Tk?tab=readme-ov-file

https://github.com/FireRedTeam/FireRedASR/tree/main
50 changes: 47 additions & 3 deletions versa/corpus_metrics/whisper_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
whisper = None

from espnet2.text.cleaner import TextCleaner
from espnet2.text.phoneme_tokenizer import PhonemeTokenizer

TARGET_FS = 16000
CHUNK_SIZE = 30 # seconds


def whisper_wer_setup(
model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True
model_tag="default", beam_size=5, text_cleaner="whisper_basic", calc_per=False, use_gpu=True
):
if model_tag == "default":
model_tag = "large"
Expand All @@ -37,6 +38,14 @@ def whisper_wer_setup(
model = whisper.load_model(model_tag, device=device)
textcleaner = TextCleaner(text_cleaner)
wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size}
if calc_per is True:
g2p = {
"zh": PhonemeTokenizer("pypinyin_g2p_phone_without_prosody"),
"ja": PhonemeTokenizer("pyopenjtalk"),
"en": PhonemeTokenizer("g2p_en"),
# To support additional languages, add corresponding g2p modules here.
}
wer_utils.update(g2p=g2p)
return wer_utils


Expand All @@ -63,9 +72,11 @@ def whisper_levenshtein_metric(
pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS)
fs = TARGET_FS
with torch.no_grad():
inf_text = wer_utils["model"].transcribe(
results = wer_utils["model"].transcribe(
torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"]
)["text"]
)
inf_text = results["text"]
lang = results["language"]

ref_text = wer_utils["cleaner"](ref_text).strip()
pred_text = wer_utils["cleaner"](inf_text).strip()
Expand Down Expand Up @@ -126,6 +137,39 @@ def whisper_levenshtein_metric(
)
assert total == len(pred_words), (total, len(pred_words))

# process per
if "g2p" in wer_utils.keys():
assert lang in wer_utils["g2p"].keys(), f"Not support g2p for {lang} language"
ref_words = wer_utils["g2p"][lang].text2tokens(ref_text)
pred_words = wer_utils["g2p"][lang].text2tokens(pred_text)

ref_words = [p for phn in ref_words for p in phn.strip().split('_')]
pred_words = [p for phn in pred_words for p in phn.strip().split('_')]

ret.update(
whisper_per_delete=0,
whisper_per_insert=0,
whisper_per_replace=0,
whisper_per_equal=0,
)
for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words):
if op == "insert":
ret["whisper_per_" + op] = ret["whisper_per_" + op] + inf_et - inf_st
else:
ret["whisper_per_" + op] = ret["whisper_per_" + op] + ref_et - ref_st
total = (
ret["whisper_per_delete"]
+ ret["whisper_per_replace"]
+ ret["whisper_per_equal"]
)
assert total == len(ref_words), (total, len(ref_words))
total = (
ret["whisper_per_insert"]
+ ret["whisper_per_replace"]
+ ret["whisper_per_equal"]
)
assert total == len(pred_words), (total, len(pred_words))

return ret


Expand Down
3 changes: 2 additions & 1 deletion versa/scorer_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def load_score_modules(score_config, use_gt=True, use_gt_text=False, use_gpu=Fal
model_tag=config.get("model_tag", "default"),
beam_size=config.get("beam_size", 1),
text_cleaner=config.get("text_cleaner", "whisper_basic"),
calc_per=config.get("calc_per", False),
use_gpu=use_gpu,
)

Expand Down Expand Up @@ -1392,7 +1393,7 @@ def load_summary(score_info):
# NOTE(jiatong): skip text cases
continue
summary[key] = sum([score[key] for score in score_info])
if "_wer" not in key and "_cer" not in key:
if "_wer" not in key and "_cer" not in key and "_per" not in key:
# Average for non-WER/CER metrics
summary[key] /= len(score_info)
return summary
Expand Down
Loading