diff --git a/.gitignore b/.gitignore index 58a6eeb..7f15191 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,14 @@ fadtk/ scoreq/ fairseq/ UTMOSv2/ + +# Versa optional metric installer output and model caches +versa_cache/ +tools/NISQA/ +tools/Noresqa/ +tools/SRMRpy/ +tools/audiobox-aesthetics/ +tools/emotion2vec/ +ssl-singer-identity/ +pretrained_models/ wvmos/ diff --git a/README.md b/README.md index b10d0ce..38b81bb 100644 --- a/README.md +++ b/README.md @@ -55,10 +55,10 @@ For metrics marked without "x" in the "Auto-Install" column of our metrics table ```bash # Test core functionality -python versa/test/test_pipeline/test_general.py +python -m pytest test/test_general.py # Test specific metrics that require additional installation -python versa/test/test_pipeline/test_{metric}.py +python -m pytest test/test_metrics/test_{metric}.py ``` @@ -69,7 +69,7 @@ python versa/test/test_pipeline/test_{metric}.py ```bash # Direct usage with file paths python versa/bin/scorer.py \ - --score_config egs/speech.yaml \ + --score_config egs/speech_cpu.yaml \ --gt test/test_samples/test1 \ --pred test/test_samples/test2 \ --output_file test_result \ @@ -77,7 +77,7 @@ python versa/bin/scorer.py \ # With SCP-style input python versa/bin/scorer.py \ - --score_config egs/speech.yaml \ + --score_config egs/speech_cpu.yaml \ --gt test/test_samples/test1.scp \ --pred test/test_samples/test2.scp \ --output_file test_result \ @@ -85,7 +85,7 @@ python versa/bin/scorer.py \ # With Kaldi-ARK style input (compatible with ESPnet) python versa/bin/scorer.py \ - --score_config egs/speech.yaml \ + --score_config egs/speech_cpu.yaml \ --gt test/test_samples/test1.scp \ --pred test/test_samples/test2.scp \ --output_file test_result \ @@ -93,7 +93,7 @@ python versa/bin/scorer.py \ # Including text transcription information python versa/bin/scorer.py \ - --score_config egs/separate_metrics/wer.yaml \ + --score_config egs/separate_metrics/wer_tiny.yaml \ --gt test/test_samples/test1.scp \ --pred test/test_samples/test2.scp \ --output_file test_result \ diff --git a/docs/metric_migration.md b/docs/metric_migration.md new file mode 100644 index 0000000..3212b0a --- /dev/null +++ b/docs/metric_migration.md @@ -0,0 +1,151 @@ +# Metric Migration Guide + +This guide summarizes the preferred process for migrating existing Versa metrics +to the new object-oriented metric interface. + +## Migration Goal + +Use `versa.definition.BaseMetric` as the source of truth for metric +implementations. Preserve user-facing behavior, but do not preserve legacy +internal helper APIs unless they are still needed by public callers. + +Preserve: + +- YAML metric names +- CLI/scorer behavior +- output score keys +- documented config defaults +- optional dependency behavior + +Clean up: + +- old function-style metric internals +- duplicated setup code +- eager optional dependency imports +- tests that only exercise legacy helper functions + +## Required Metric Shape + +Each migrated metric should provide: + +- a `BaseMetric` subclass +- `_setup(self)` for config defaults, dependency checks, and model setup +- `compute(self, predictions, references=None, metadata=None)` for scoring +- `get_metadata(self)` returning `MetricMetadata` +- `register__metric(registry)` as the registry integration point + +`compute` should: + +- validate required inputs +- read sample rate from `metadata.get("sample_rate", 16000)` when needed +- return the same output keys users already receive +- avoid changing user-visible numeric conventions unless the migration requires it + +## Metadata Checklist + +Every metric registration should define: + +- canonical metric name +- `MetricCategory`: `INDEPENDENT`, `DEPENDENT`, `NON_MATCH`, or `DISTRIBUTIONAL` +- `MetricType`: usually `FLOAT` for one score or `DICT` for grouped scores +- `requires_reference` +- `requires_text` +- `gpu_compatible` +- `auto_install` +- dependency import names +- short description +- paper reference and implementation source when known +- useful aliases for existing YAML or common names + +## Optional Dependencies + +Optional dependencies must not break `import versa`. + +Use guarded imports inside metric modules, and raise a clear `ImportError` from +`_setup` when a required optional package is missing. Register optional metrics +from `versa/__init__.py` through `_optional_metric_import(...)`. + +## Tests + +Prefer tests for the new public path: + +- metric class behavior +- registry registration and aliases +- `VersaScorer` pipeline behavior with existing sample audio when lightweight +- missing optional dependency behavior +- unchanged user-facing output keys + +Do not add tests solely to preserve old internal helper APIs unless those APIs +remain part of the public interface. + +Base-install focused tests currently live in: + +- `test/test_metrics/test_base_metrics.py` +- `test/test_pipeline/test_base_metrics_pipeline.py` + +## Migration Candidates + +The following modules still appear to use the old interface because they do not +define or import `BaseMetric`. This list is based on a repository scan and should +be updated as each metric is migrated. + +### Corpus and Distributional Metrics + +- `versa/corpus_metrics/fad.py` +- `versa/corpus_metrics/individual_fad.py` +- `versa/corpus_metrics/kid.py` +- `versa/corpus_metrics/clap_score.py` + +### Already Migrated Examples + +Use these as local references when migrating the remaining metrics: + +- `versa/sequence_metrics/mcd_f0.py` +- `versa/sequence_metrics/signal_metric.py` +- `versa/sequence_metrics/warpq.py` +- `versa/corpus_metrics/espnet_wer.py` +- `versa/corpus_metrics/owsm_wer.py` +- `versa/corpus_metrics/whisper_wer.py` +- `versa/utterance_metrics/log_wmse.py` +- `versa/utterance_metrics/pseudo_mos.py` +- `versa/utterance_metrics/qwen2_audio.py` +- `versa/utterance_metrics/qwen_omni.py` +- `versa/utterance_metrics/speaking_rate.py` +- `versa/utterance_metrics/scoreq.py` +- `versa/utterance_metrics/se_snr.py` +- `versa/utterance_metrics/sheet_ssqa.py` +- `versa/utterance_metrics/singer.py` +- `versa/utterance_metrics/speaker.py` +- `versa/utterance_metrics/stoi.py` +- `versa/utterance_metrics/pesq_score.py` +- `versa/utterance_metrics/squim.py` +- `versa/utterance_metrics/universa.py` +- `versa/utterance_metrics/vad.py` +- `versa/utterance_metrics/visqol_score.py` +- `versa/utterance_metrics/vqscore.py` + +## Verification + +Run focused checks before broader validation: + +```bash +/opt/homebrew/bin/mamba run -n versa-dev python -m pytest -q +/opt/homebrew/bin/mamba run -n versa-dev python -m black --check +/opt/homebrew/bin/mamba run -n versa-dev python -m flake8 +``` + +The base migration tests use mocks for heavy model-backed metrics. They validate +registry integration, pipeline wiring, input handling, and output keys, but do +not prove checkpoint download or real inference. + +Run optional real model checks locally after installing the metric dependencies: + +```bash +tools/install_scoreq.sh +VERSA_RUN_REAL_MODEL_TESTS=1 \ + /opt/homebrew/bin/mamba run -n versa-dev python -m pytest \ + test/test_pipeline/test_scoreq.py -q -s +``` + +These tests are marked `real_model` and are skipped unless +`VERSA_RUN_REAL_MODEL_TESTS=1` is set. diff --git a/docs/supported_metrics.md b/docs/supported_metrics.md index 7c4ef0f..9fde2d5 100644 --- a/docs/supported_metrics.md +++ b/docs/supported_metrics.md @@ -13,7 +13,7 @@ We include x mark if the metric is auto-installed in versa. | 6 | x | PESQ in TorchAudio-Squim | squim_no_ref | torch_squim_pesq | [torch_squim](https://pytorch.org/audio/main/tutorials/squim_tutorial.html) | [paper](https://arxiv.org/abs/2304.01448) | | 7 | x | STOI in TorchAudio-Squim | squim_no_ref | torch_squim_stoi | [torch_squim](https://pytorch.org/audio/main/tutorials/squim_tutorial.html) | [paper](https://arxiv.org/abs/2304.01448) | | 8 | x | SI-SDR in TorchAudio-Squim | squim_no_ref | torch_squim_si_sdr | [torch_squim](https://pytorch.org/audio/main/tutorials/squim_tutorial.html) | [paper](https://arxiv.org/abs/2304.01448) | -| 9 | x | Singing voice MOS | pseudo_mos | singmos_v1 |[singmos](https://github.com/South-Twilight/SingMOS) | [paper](https://arxiv.org/abs/2406.10911) | +| 9 | x | Singing voice MOS | singmos | singmos |[singmos](https://github.com/South-Twilight/SingMOS/tree/main) | [paper](https://arxiv.org/abs/2406.10911) | | 10 | x | Sheet SSQA MOS Models | sheet_ssqa | sheet_ssqa |[Sheet](https://github.com/unilight/sheet/tree/main) | [paper](https://arxiv.org/abs/2411.03715) | | 11 | | UTMOSv2: UTokyo-SaruLab MOS Prediction System | utmosv2 | utmosv2 |[UTMOSv2](https://github.com/sarulab-speech/UTMOSv2) | [paper](https://arxiv.org/abs/2409.09305) | | 12 | | Speech Contrastive Regression for Quality Assessment without reference (ScoreQ) | scoreq_nr | scoreq_nr |[ScoreQ](https://github.com/ftshijt/scoreq/tree/main) | [paper](https://arxiv.org/pdf/2410.06675) | @@ -50,9 +50,9 @@ We include x mark if the metric is auto-installed in versa. | 43 | x | Qwen2 Recording Environment - Background | qwen2_speech_background_environment_metric | qwen2_speech_background_environment_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | | 44 | x | Qwen2 Recording Environment - Quality | qwen2_recording_quality_metric | qwen2_recording_quality_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | | 45 | x | Qwen2 Recording Environment - Channel Type | qwen2_channel_type_metric | qwen2_channel_type_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | -| 46 | x | Dimensional Emotion | w2v2_dimensional_emotion | w2v2_dimensional_emotion | [w2v2-how-to](https://github.com/audeering/w2v2-how-to) | [paper](https://arxiv.org/pdf/2203.07378) | -| 47 | | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) - No Reference | universa_noref | universa_score | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | -| 48 | | ARECHO (Audio Reference Echo Cancellation and Codec Quality Assessment) - No Reference | arecho_noref | arecho_score | [ARECHO](https://huggingface.co/espnet/arecho_base_v0) | [paper](https://arxiv.org/abs/2505.20741) | +| 46 | x | Dimensional Emotion | emo_vad | arousal_emo_vad, valence_emo_vad, dominance_emo_vad | [w2v2-how-to](https://github.com/audeering/w2v2-how-to) | [paper](https://arxiv.org/pdf/2203.07378) | +| 47 | x | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) | universa, universa_noref, universa_audioref, universa_textref, universa_fullref | universa_{sub_metrics} | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | +| 48 | | ARECHO (Audio Reference Echo Cancellation and Codec Quality Assessment) - No Reference | arecho, arecho_noref | arecho_{sub_metrics} | [ARECHO](https://huggingface.co/espnet/arecho_base_v0) | [paper](https://arxiv.org/abs/2505.20741) | | 49 | x | DNSMOS Pro: A Reduced-Size DNN for Probabilistic MOS of Speech | pseudo_mos | dnsmos_pro_bvcc | [DNSMOSPro](https://github.com/fcumlin/DNSMOSPro/tree/main) | [paper](https://www.isca-archive.org/interspeech_2024/cumlin24_interspeech.html) | | 50 | x | DNSMOS Pro: A Reduced-Size DNN for Probabilistic MOS of Speech | pseudo_mos | dnsmos_pro_nisqa | [DNSMOSPro](https://github.com/fcumlin/DNSMOSPro/tree/main) | [paper](https://www.isca-archive.org/interspeech_2024/cumlin24_interspeech.html) | | 51 | x | DNSMOS Pro: A Reduced-Size DNN for Probabilistic MOS of Speech | pseudo_mos | dnsmos_pro_vcc2018 | [DNSMOSPro](https://github.com/fcumlin/DNSMOSPro/tree/main) | [paper](https://www.isca-archive.org/interspeech_2024/cumlin24_interspeech.html) | @@ -61,6 +61,7 @@ We include x mark if the metric is auto-installed in versa. | 54 | x | VQScore (Self-Supervised Speech Quality Estimation and Enhancement Using Only Clean Speech) | vqscore | vqscore | [VQScore](https://github.com/JasonSWFu/VQscore) | [paper](https://arxiv.org/abs/2402.16321) | | 55 | x | Singing voice MOS | pseudo_mos | singmos_pro |[singmos](https://github.com/South-Twilight/SingMOS) | [paper](https://arxiv.org/abs/2510.01812) | + ### Dependent Metrics |Number| Auto-Install | Metric Name (Auto-Install) | Key in config | Key in report | Code Source | References | |---|---|------------------|---------------|---------------|-----------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------| @@ -70,7 +71,7 @@ We include x mark if the metric is auto-installed in versa. | 4 | x | Signal-to-interference Ratio (SIR) | signal_metric | sir | [espnet](https://github.com/espnet/espnet) | - | | 5 | x | Signal-to-artifact Ratio (SAR) | signal_metric | sar | [espnet](https://github.com/espnet/espnet) | - | | 6 | x | Signal-to-distortion Ratio (SDR) | signal_metric | sdr | [espnet](https://github.com/espnet/espnet) | - | -| 7 | x | Convolutional scale-invariant signal-to-distortion ratio (CI-SDR) | signal_metric | ci-sdr | [ci_sdr](https://github.com/fgnt/ci_sdr) | [paper](https://arxiv.(org/abs/2011.15003) | +| 7 | x | Convolutional scale-invariant signal-to-distortion ratio (CI-SDR) | signal_metric | ci-sdr | [ci_sdr](https://github.com/fgnt/ci_sdr) | [paper](https://arxiv.org/abs/2011.15003) | | 8 | x | Scale-invariant signal-to-noise ratio (SI-SNR) | signal_metric | si-snr | [espnet](https://github.com/espnet/espnet) | [paper](https://arxiv.org/abs/1711.00541) | | 9 | x | Perceptual Evaluation of Speech Quality (PESQ) | pesq | pesq | [pesq](https://pypi.org/project/pesq/) | [paper](https://ieeexplore.ieee.org/document/941023) | | 10 | x | Short-Time Objective Intelligibility (STOI) | stoi | stoi | [pystoi](https://github.com/mpariente/pystoi) | [paper](https://ieeexplore.ieee.org/document/5495701) | @@ -89,11 +90,10 @@ We include x mark if the metric is auto-installed in versa. | 23 | | Composite Objective Speech Quality (composite) | pysepm | pysepm_Csig, pysepm_Cbak, pysepm_Covl | [pysepm](https://github.com/shimhz/pysepm.git) | [Paper](https://ecs.utdallas.edu/loizou/speech/obj_paper_jan08.pdf)| | 24 | | Coherence and speech intelligibility index (CSII) | pysepm | pysepm_csii_high, pysepm_csii_mid, pysepm_csii_low | [pysepm](https://github.com/shimhz/pysepm.git) | [Paper](https://www.researchgate.net/profile/James-Kates-2/publication/7842209_Coherence_and_the_speech_intelligibility_index/links/546f5dab0cf2d67fc0310f88/Coherence-and-the-speech-intelligibility-index.pdf)| | 25 | | Normalized-covariance measure (NCM) | pysepm | pysepm_ncm | [pysepm](https://github.com/shimhz/pysepm.git) | [Paper](https://pmc.ncbi.nlm.nih.gov/articles/PMC3037773/pdf/JASMAN-000128-003715_1.pdf)| -| 26 | | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) with Audio Reference | universa_audioref | universa_score | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | -| 27 | | ARECHO (Audio Reference Echo Cancellation and Codec Quality Assessment) with Audio Reference | arecho_audioref | arecho_score | [ARECHO](https://huggingface.co/espnet/arecho_base_v0) | [paper](https://arxiv.org/abs/2505.20741) | -| 28 | x | Chroma-related Alignment | chroma_alignment | chroma_{stft,cqt,cens}_{cosine, euclidean}_dtw{"", _log, _raw} | - | - | -| 29 | x | Deep Perceptual Audio Metric (DPAM) | dpam | dpam_distance | [PerceptualAudio_Pytorch](https://github.com/adrienchaton/PerceptualAudio_pytorch) | [paper](https://arxiv.org/abs/2001.04460) | -| 30 | x | Contrastive learning-based Deep Perceptual Audio Metric (CDPAM) | cdpam | cdpam_distance | [PerceptualAudio](https://github.com/pranaymanocha/PerceptualAudio/cdpam) | [paper](https://arxiv.org/abs/2102.05109) | +| 26 | x | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) with Paired Reference | universa | universa_{sub_metrics} | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | +| 27 | x | Chroma-related Alignment | chroma_alignment | chroma_{stft,cqt,cens}_{cosine, euclidean}_dtw{"", _log, _raw} | - | - | +| 28 | x | Deep Perceptual Audio Metric (DPAM) | dpam | dpam_distance | [PerceptualAudio_Pytorch](https://github.com/adrienchaton/PerceptualAudio_pytorch) | [paper](https://arxiv.org/abs/2001.04460) | +| 29 | x | Contrastive learning-based Deep Perceptual Audio Metric (CDPAM) | cdpam | cdpam_distance | [PerceptualAudio](https://github.com/pranaymanocha/PerceptualAudio/cdpam) | [paper](https://arxiv.org/abs/2102.05109) | ### Non-match Metrics @@ -111,11 +111,8 @@ We include x mark if the metric is auto-installed in versa. | 9 | | Contrastive Language-Audio Pretraining Score (CLAP Score) | clap_score | clap_score | [fadtk](https://github.com/gudgud96/frechet-audio-distance) | [paper](https://arxiv.org/abs/2301.12661) | | 10 | | Accompaniment Prompt Adherence (APA) | apa | apa | [Sony-audio-metrics](https://github.com/SonyCSLParis/audio-metrics) | [paper](https://arxiv.org/abs/2404.00775) | | 11 | | Log Likelihood Ratio (LLR) | pysepm | pysepm_llr | [pysepm](https://github.com/shimhz/pysepm.git) | [Paper](https://ecs.utdallas.edu/loizou/speech/obj_paper_jan08.pdf)| -| 12 | | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) with Text Reference | universa_textref | universa_score | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | -| 13 | | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) with Full Reference | universa_fullref | universa_score | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | -| 14 | | ARECHO (Audio Reference Echo Cancellation and Codec Quality Assessment) with Text Reference | arecho_textref | arecho_score | [ARECHO](https://huggingface.co/espnet/arecho_base_v0) | [paper](https://arxiv.org/abs/2505.20741) | -| 15 | | ARECHO (Audio Reference Echo Cancellation and Codec Quality Assessment) with Full Reference | arecho_fullref | arecho_score | [ARECHO](https://huggingface.co/espnet/arecho_base_v0) | [paper](https://arxiv.org/abs/2505.20741) | -| 16 | | Singer Embedding Similarity | singer | singer_similarity | [SSL-Singer-Identity](https://github.com/SonyCSLParis/ssl-singer-identity) | [paper](https://hal.science/hal-04186048v1) | +| 12 | x | Uni-VERSA (Versatile Speech Assessment with a Unified Framework) with Paired Text | universa | universa_{sub_metrics} | [Uni-VERSA](https://huggingface.co/collections/espnet/universa-6834e7c0a28225bffb6e2526) | [paper](https://arxiv.org/abs/2505.20741) | +| 13 | | Singer Embedding Similarity | singer | singer_similarity | [SSL-Singer-Identity](https://github.com/SonyCSLParis/ssl-singer-identity) | [paper](https://hal.science/hal-04186048v1) | ### Distributional Metrics (in verifying) diff --git a/egs/demo/se.yaml b/egs/demo/se.yaml index e7500bf..9d54afe 100644 --- a/egs/demo/se.yaml +++ b/egs/demo/se.yaml @@ -90,4 +90,4 @@ # --nisqa_loud_pred: NISQA loudness prediction # NOTE(jiatong): pretrain model can be downloaded with `./tools/setup_nisqa.sh` - name: nisqa - nisqa_model_path: ./tools/NISQA/weights/nisqa.tar + nisqa_model_path: versa_cache/nisqa/nisqa.tar diff --git a/egs/separate_metrics/cdpam_distance.yaml b/egs/separate_metrics/cdpam_distance.yaml new file mode 100644 index 0000000..a73d5fd --- /dev/null +++ b/egs/separate_metrics/cdpam_distance.yaml @@ -0,0 +1,5 @@ +# CDPAM distance metrics +# CDPAM distance between audio samples +# More info in https://github.com/facebookresearch/audiocraft +# -- cdpam_distance: the CDPAM distance between audio samples +- name: cdpam_distance \ No newline at end of file diff --git a/egs/separate_metrics/chroma_alignment.yaml b/egs/separate_metrics/chroma_alignment.yaml new file mode 100644 index 0000000..858b08b --- /dev/null +++ b/egs/separate_metrics/chroma_alignment.yaml @@ -0,0 +1,20 @@ +# Chroma Alignment related metrics +# Chroma-based distance estimation with dynamic programming alignment +# Uses librosa chroma features (STFT, CQT, CENS) with DTW alignment +# -- chroma_stft_cosine_dtw: STFT chroma features with cosine distance and DTW +# -- chroma_stft_euclidean_dtw: STFT chroma features with euclidean distance and DTW +# -- chroma_cqt_cosine_dtw: CQT chroma features with cosine distance and DTW +# -- chroma_cqt_euclidean_dtw: CQT chroma features with euclidean distance and DTW +# -- chroma_cens_cosine_dtw: CENS chroma features with cosine distance and DTW +# -- chroma_cens_euclidean_dtw: CENS chroma features with euclidean distance and DTW +# -- chroma_stft_cosine_dtw_raw: Raw DTW distance with higher scaling +# -- chroma_stft_cosine_dtw_log: Log-scaled DTW distance +- name: chroma_alignment + sample_rate: 22050 + feature_types: ["stft", "cqt", "cens"] + distance_metrics: ["cosine", "euclidean"] + scale_factor: 100.0 + normalize: True + normalize_by_path: True + return_alignment: False + chroma_kwargs: {} \ No newline at end of file diff --git a/egs/separate_metrics/dpam_distance.yaml b/egs/separate_metrics/dpam_distance.yaml new file mode 100644 index 0000000..ba195df --- /dev/null +++ b/egs/separate_metrics/dpam_distance.yaml @@ -0,0 +1,5 @@ +# DPAM distance metrics +# DPAM distance between audio samples +# More info in https://github.com/adrienchaton/PerceptualAudio_Pytorch +# -- dpam_distance: the DPAM distance between audio samples +- name: dpam_distance \ No newline at end of file diff --git a/egs/separate_metrics/emo_vad.yaml b/egs/separate_metrics/emo_vad.yaml new file mode 100644 index 0000000..579043d --- /dev/null +++ b/egs/separate_metrics/emo_vad.yaml @@ -0,0 +1,7 @@ +# EmoVad related metrics +# Dimensional emotion prediction (arousal, valence, dominance) using w2v2-how-to +# More info in https://github.com/audeering/w2v2-how-to +# -- arousal_emo_vad: the dimensional emotion prediction with w2v2 +# -- valence_emo_vad: the dimensional emotion prediction with w2v2 +# -- dominance_emo_vad: the dimensional emotion prediction with w2v2 +- name: emo_vad \ No newline at end of file diff --git a/egs/separate_metrics/lid.yaml b/egs/separate_metrics/lid.yaml index 750e07a..00c18da 100644 --- a/egs/separate_metrics/lid.yaml +++ b/egs/separate_metrics/lid.yaml @@ -1,10 +1,10 @@ - -# Word error rate with ESPnet-OWSM model +# Language Identification with ESPnet-OWSM model # More model_tag can be from the ESPnet huggingface https://huggingface.co/espnet . # The default model is `espnet/owsm_v3.1_ebf`. -# --lid: the nbest language tag +# --language: the nbest language tag - name: lid model_tag: default nbest: 5 + use_gpu: false diff --git a/egs/separate_metrics/nisqa.yaml b/egs/separate_metrics/nisqa.yaml index 67ee222..8c97a17 100644 --- a/egs/separate_metrics/nisqa.yaml +++ b/egs/separate_metrics/nisqa.yaml @@ -3,8 +3,9 @@ # -- nisqa_noi_pred: NISQA noise prediction # -- nisqa_dis_pred: NISQA distortion prediction # -- nisqa_col_pred: NISQA color prediction -# --nisqa_loud_pred: NISQA loudness prediction +# -- nisqa_loud_pred: NISQA loudness prediction # NOTE(jiatong): pretrain model can be downloaded with `./tools/setup_nisqa.sh` -- name: nisqa - nisqa_model_path: ./tools/NISQA/weights/nisqa.tar +- name: nisqa + nisqa_model_path: versa_cache/nisqa/nisqa.tar + use_gpu: false diff --git a/egs/separate_metrics/nomad.yaml b/egs/separate_metrics/nomad.yaml index 49cc4cb..167fdf3 100644 --- a/egs/separate_metrics/nomad.yaml +++ b/egs/separate_metrics/nomad.yaml @@ -2,3 +2,4 @@ # -- nomad: nomad reference-based model - name: nomad model_cache: versa_cache/nomad_pt-models + use_gpu: false diff --git a/egs/separate_metrics/noresqa.yaml b/egs/separate_metrics/noresqa.yaml index 07db66f..e61da95 100644 --- a/egs/separate_metrics/noresqa.yaml +++ b/egs/separate_metrics/noresqa.yaml @@ -1,4 +1,8 @@ # noresqa related metrics -# -- noresqa: non-matching reference based speech quality assessment -- name: noresqa - metric_type: 1 #0: NORESQA-score, 1: NORESQA-MOS \ No newline at end of file +# -- noresqa_mos: NORESQA-MOS (metric_type=1) +# -- noresqa_score: NORESQA-score (metric_type=0) +- name: noresqa_mos + metric_type: 1 # 0: NORESQA-score, 1: NORESQA-MOS + model_tag: default + cache_dir: versa_cache/noresqa_model + use_gpu: false \ No newline at end of file diff --git a/egs/separate_metrics/pesq.yaml b/egs/separate_metrics/pesq.yaml new file mode 100644 index 0000000..a94a401 --- /dev/null +++ b/egs/separate_metrics/pesq.yaml @@ -0,0 +1,11 @@ +# PESQ: Perceptual Evaluation of Speech Quality +# https://www.itu.int/rec/T-REC-P.862 +# +# PESQ is a reference-based metric that measures speech quality +# by comparing a degraded signal to a reference signal. +# +# Supported sample rates: +# - 8kHz: narrowband (nb) mode +# - 16kHz: wideband (wb) mode +# - Other rates: automatically resampled to nearest supported rate +- name: pesq \ No newline at end of file diff --git a/egs/separate_metrics/sigmos.yaml b/egs/separate_metrics/sigmos.yaml index add5900..290e8a6 100644 --- a/egs/separate_metrics/sigmos.yaml +++ b/egs/separate_metrics/sigmos.yaml @@ -1,3 +1,3 @@ -# sigmos (independent) metric - -- name: sigmos +# sigmos (independent) metric + +- name: sigmos diff --git a/egs/separate_metrics/w2v2_dimensional_emotion.yaml b/egs/separate_metrics/w2v2_dimensional_emotion.yaml deleted file mode 100644 index ec12464..0000000 --- a/egs/separate_metrics/w2v2_dimensional_emotion.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# Dimensional emotion prediction calculated based on w2v2 -# More info in https://github.com/audeering/w2v2-how-to - -# --w2v2_dimensional_emotion: the dimensional emotion prediction with w2v2 -- name: w2v2_dimensional_emotion diff --git a/egs/separate_metrics/wer_tiny.yaml b/egs/separate_metrics/wer_tiny.yaml new file mode 100644 index 0000000..9239d63 --- /dev/null +++ b/egs/separate_metrics/wer_tiny.yaml @@ -0,0 +1,5 @@ +# Lightweight WER/CER example for README smoke tests. +- name: whisper_wer + model_tag: tiny + beam_size: 1 + text_cleaner: whisper_basic diff --git a/egs/separate_metrics/wvmos.yaml b/egs/separate_metrics/wvmos.yaml index 1ce48d3..dd1163a 100644 --- a/egs/separate_metrics/wvmos.yaml +++ b/egs/separate_metrics/wvmos.yaml @@ -1,3 +1,3 @@ -# wvmos (independent) metric - -- name: wvmos +# wvmos (independent) metric + +- name: wvmos diff --git a/egs/universa_prepare/gpu_subset.yaml b/egs/universa_prepare/gpu_subset.yaml index e753d93..3528e0e 100644 --- a/egs/universa_prepare/gpu_subset.yaml +++ b/egs/universa_prepare/gpu_subset.yaml @@ -31,7 +31,7 @@ # --nisqa_loud_pred: NISQA loudness prediction # NOTE(jiatong): pretrain model can be downloaded with `./tools/setup_nisqa.sh` - name: nisqa - nisqa_model_path: ./tools/NISQA/weights/nisqa.tar + nisqa_model_path: versa_cache/nisqa/nisqa.tar # discrete speech metrics # -- speech_bert: speech bert score diff --git a/egs/universa_prepare/universa_prepare.yaml b/egs/universa_prepare/universa_prepare.yaml index 019a44d..738d1ff 100644 --- a/egs/universa_prepare/universa_prepare.yaml +++ b/egs/universa_prepare/universa_prepare.yaml @@ -56,7 +56,7 @@ # --nisqa_loud_pred: NISQA loudness prediction # NOTE(jiatong): pretrain model can be downloaded with `./tools/setup_nisqa.sh` - name: nisqa - nisqa_model_path: ./tools/NISQA/weights/nisqa.tar + nisqa_model_path: versa_cache/nisqa/nisqa.tar # discrete speech metrics # -- speech_bert: speech bert score diff --git a/scripts/show_result.py b/scripts/show_result.py index 7a20ba1..da8806b 100644 --- a/scripts/show_result.py +++ b/scripts/show_result.py @@ -4,11 +4,7 @@ import statistics import os import glob -import matplotlib.pyplot as plt -import seaborn as sns -import pandas as pd import numpy as np -from pathlib import Path from typing import Dict, List, Any, Optional @@ -449,6 +445,10 @@ def estimate_metric_quality(metric_name: str, values: List[float]) -> Dict[str, def create_visualizations(metrics_stats: Dict, discovery: Dict, output_dir: str = None): """Create visualizations organized by metric categories""" + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) @@ -647,6 +647,8 @@ def print_metric_analysis(metrics_stats: Dict, discovery: Dict): def export_results_to_csv(metrics_stats: Dict, discovery: Dict, output_file: str): """Export comprehensive results to CSV files""" + import pandas as pd + # Main results file results_data = [] diff --git a/setup.cfg b/setup.cfg index eae086c..8723050 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,3 +13,5 @@ testpaths = test python_files = test_*.py python_functions = test_* python_classes = Test* +markers = + real_model: optional tests that download/load real metric models and are skipped unless explicitly enabled diff --git a/setup.py b/setup.py index 92a2c68..2a8bd69 100644 --- a/setup.py +++ b/setup.py @@ -1,62 +1,108 @@ from setuptools import setup, find_packages +import os + + +# Read README for long description +def read_readme(): + readme_path = os.path.join(os.path.dirname(__file__), "README.md") + if os.path.exists(readme_path): + with open(readme_path, "r", encoding="utf-8") as f: + return f.read() + return "A package for versatile evaluation of speech and audio" + setup( name="versa-speech-audio-toolkit", version="1.0.0", + author="Jiatong Shi", + author_email="ftshijt@gmail.com", + description="A package for versatile evaluation of speech and audio", + long_description=read_readme(), + long_description_content_type="text/markdown", + url="https://github.com/wavlab-speech/versa.git", packages=find_packages(), + python_requires=">=3.8", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Multimedia :: Sound/Audio :: Analysis", + ], + keywords=["speech", "audio", "metrics", "evaluation", "machine learning"], install_requires=[ + # Core ML and Deep Learning + "torch", + "torchaudio", + "transformers>=4.36.2", "accelerate", - "audioread", - "ci-sdr", - "Cython", - "Distance", - "editdistance", - "einops", - "espnet @ git+https://github.com/ftshijt/espnet.git@espnet_inference#egg=espnet", - "espnet-tts-frontend", - "fast-bss-eval", - "fastdtw", "huggingface-hub", - "hydra-core", - "idna", - "importlib-metadata", - "kaggle", - "kaldiio", - "lazy_loader", - "Levenshtein", - "librosa", - "mir-eval", - "omegaconf", - "onnxruntime", - # NOTE(jiatong): use the latest commit for python 3.13 - "openai-whisper @ git+https://github.com/openai/whisper.git", + "safetensors", + "tokenizers", + "einops", "opt-einsum", - "pesq", - "protobuf", + # Audio Processing + "librosa", + "soundfile", + "audioread", + "resampy", + "torchlibrosa", + "pyworld", "pysptk", + # Speech and Audio Evaluation Metrics + "pesq", "pystoi", - "python-dateutil", - "pyworld", - "pyyaml", + "mir-eval", + "fast-bss-eval", + "ci-sdr", + "speechmos", + # Text Processing and Distance Metrics + "Levenshtein", + "editdistance", + "Distance", "rapidfuzz", - "resampy", - "safetensors", - "scikit-learn", "sentencepiece", - "setuptools", - "soundfile", - "speechmos", + # Scientific Computing + "scikit-learn", "sympy", "threadpoolctl", - "tokenizers", - "torch", - "torch-complex", - "torchaudio", - "torchlibrosa", - "s3prl @ git+https://github.com/ftshijt/s3prl.git@numpy2#egg=s3prl", - "transformers>=4.36.2", + # Configuration and Utilities + "hydra-core", + "omegaconf", + "pyyaml", + "protobuf", + "python-dateutil", + "lazy_loader", + # Build and Compatibility + "Cython", + "setuptools", + "importlib-metadata", + "idna", + # Optional/External Services + "kaggle", + "kaldiio", + "fastdtw", + "onnxruntime", + # Git Dependencies - Speech/Audio Frameworks + "espnet @ git+https://github.com/ftshijt/espnet.git@espnet_inference#egg=espnet", + "espnet-tts-frontend", "espnet_model_zoo", + "s3prl", + # Git Dependencies - Audio Models + # NOTE: Using latest commit for Python 3.13 compatibility + "openai-whisper @ git+https://github.com/openai/whisper.git", + # Git Dependencies - Evaluation Metrics "discrete-speech-metrics @ git+https://github.com/ftshijt/DiscreteSpeechMetrics.git@v1.0.2", + # Additional Dependencies + "torch-complex", "cdpam", ], extras_require={ @@ -65,6 +111,18 @@ "pytest-cov>=2.10.0", "black>=22.3.0", "flake8>=4.0.0", + "isort>=5.0.0", + "mypy>=0.900", + ], + "docs": [ + "sphinx>=4.0.0", + "sphinx-rtd-theme>=1.0.0", + "myst-parser>=0.17.0", + ], + "jupyter": [ + "jupyter>=1.0.0", + "ipykernel>=6.0.0", + "matplotlib>=3.3.0", ], }, entry_points={ @@ -72,9 +130,6 @@ "versa-score=versa.bin.scorer:main", ], }, - author="Jiatong Shi", - author_email="ftshijt@gmail.com", - description="A package for versatile evaluation of speech and audio", - url="https://github.com/shinjiwlab/versa.git", - keywords="speech metrics", + include_package_data=True, + zip_safe=False, ) diff --git a/test/test_general.py b/test/test_general.py index 3f327d7..e7e8005 100644 --- a/test/test_general.py +++ b/test/test_general.py @@ -41,13 +41,17 @@ "torch_squim_stoi": 0.6027805209159851, "torch_squim_pesq": 1.1683127880096436, "torch_squim_si_sdr": -11.109052658081055, - "dpam_distance": 0.15004253387451172, - "cdpam_distance": 0.05146043747663498, + "dpam_distance": 0.4179654121398926, + "cdpam_distance": 0.039433546364307404, "dnsmos_pro_bvcc": 1.1717286109924316, "dnsmos_pro_nisqa": 1.4733699560165405, "dnsmos_pro_vcc2018": 1.930935263633728, } +TEST_TOLERANCE = { + "se_si_snr": 1e-3, +} + @pytest.fixture def setup_paths(): @@ -125,9 +129,11 @@ def test_scoring_pipeline(setup_paths, load_config, caplog): continue # Check if values match within tolerance - assert ( - abs(TEST_INFO[key] - summary[key]) <= 1e-4 - ), f"Value issue in scorer {key}: expected {TEST_INFO[key]}, got {summary[key]}" + tolerance = TEST_TOLERANCE.get(key, 1e-4) + assert abs(TEST_INFO[key] - summary[key]) <= tolerance, ( + f"Value issue in scorer {key}: expected {TEST_INFO[key]}, " + f"got {summary[key]}" + ) print("Check successful", flush=True) diff --git a/test/test_metrics/test_asvspoof.py b/test/test_metrics/test_asvspoof.py new file mode 100644 index 0000000..5c9551c --- /dev/null +++ b/test/test_metrics/test_asvspoof.py @@ -0,0 +1,175 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch + +from versa.utterance_metrics.asvspoof_score import ASVSpoofMetric, is_aasist_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +@pytest.mark.parametrize( + "model_tag,use_gpu", + [ + ("default", False), + ], +) +def test_utterance_asvspoof(model_tag, use_gpu, fixed_audio): + """ + Test the ASVspoof metric using the fixed audio. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"model_tag": model_tag, "use_gpu": use_gpu} + + metric = ASVSpoofMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, metadata=metadata) + + asvspoof_score = result["asvspoof_score"] + + # Check that the score is a valid probability (between 0 and 1) + assert ( + 0.0 <= asvspoof_score <= 1.0 + ), f"ASVspoof score {asvspoof_score} is not between 0 and 1" + + # Check that the result contains the expected key + assert "asvspoof_score" in result, "Result should contain 'asvspoof_score' key" + + +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +def test_asvspoof_metric_metadata(): + """Test that the ASVspoof metric has correct metadata.""" + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "asvspoof" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +def test_asvspoof_metric_resampling(): + """Test that the ASVspoof metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, metadata=metadata_16k) + + # Both should return valid scores + assert 0.0 <= result_44k["asvspoof_score"] <= 1.0 + assert 0.0 <= result_16k["asvspoof_score"] <= 1.0 + + +@pytest.mark.skipif(not is_aasist_available(), reason="AASIST not available") +def test_asvspoof_metric_invalid_input(): + """Test that the ASVspoof metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() diff --git a/test/test_metrics/test_audiobox_aesthetics.py b/test/test_metrics/test_audiobox_aesthetics.py new file mode 100644 index 0000000..ed9e8e0 --- /dev/null +++ b/test/test_metrics/test_audiobox_aesthetics.py @@ -0,0 +1,237 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.audiobox_aesthetics_score import ( + AudioBoxAestheticsMetric, + is_audiobox_aesthetics_available, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +@pytest.mark.parametrize( + "batch_size,precision,use_gpu", + [ + (1, "bf16", False), + (2, "fp32", False), + ], +) +def test_utterance_audiobox_aesthetics(batch_size, precision, use_gpu, fixed_audio): + """ + Test the AudioBox Aesthetics metric using the fixed audio. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"batch_size": batch_size, "precision": precision, "use_gpu": use_gpu} + + metric = AudioBoxAestheticsMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, metadata=metadata) + + # Check that the result contains the expected keys + expected_keys = [ + "audiobox_aesthetics_CE", + "audiobox_aesthetics_CU", + "audiobox_aesthetics_PC", + "audiobox_aesthetics_PQ", + ] + + for key in expected_keys: + assert key in result, f"Result should contain '{key}' key" + assert isinstance(result[key], (int, float)), f"Score {key} should be numeric" + + # Check that all scores are reasonable (not negative for these metrics) + for key in expected_keys: + assert result[key] >= 0, f"Score {key} should be non-negative" + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_metadata(): + """Test that the AudioBox Aesthetics metric has correct metadata.""" + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "audiobox_aesthetics" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "audiobox_aesthetics" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_different_sample_rates(): + """Test that the AudioBox Aesthetics metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + + # Test with 44.1kHz audio + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, metadata=metadata_16k) + + # Both should return valid scores with expected keys + expected_keys = [ + "audiobox_aesthetics_CE", + "audiobox_aesthetics_CU", + "audiobox_aesthetics_PC", + "audiobox_aesthetics_PQ", + ] + + for key in expected_keys: + assert key in result_44k, f"44kHz result should contain '{key}' key" + assert key in result_16k, f"16kHz result should contain '{key}' key" + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_invalid_input(): + """Test that the AudioBox Aesthetics metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_audiobox_aesthetics_available(), reason="AudioBox Aesthetics not available" +) +def test_audiobox_aesthetics_metric_config_options(): + """Test that the AudioBox Aesthetics metric handles different configuration options.""" + # Test with different batch sizes + config_small_batch = {"batch_size": 1, "use_gpu": False} + metric_small = AudioBoxAestheticsMetric(config_small_batch) + + config_large_batch = {"batch_size": 4, "use_gpu": False} + metric_large = AudioBoxAestheticsMetric(config_large_batch) + + # Test with different precision + config_fp32 = {"precision": "fp32", "use_gpu": False} + metric_fp32 = AudioBoxAestheticsMetric(config_fp32) + + # All should work without errors + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_small = metric_small.compute(audio, metadata=metadata) + result_large = metric_large.compute(audio, metadata=metadata) + result_fp32 = metric_fp32.compute(audio, metadata=metadata) + + # All should return the same structure + expected_keys = [ + "audiobox_aesthetics_CE", + "audiobox_aesthetics_CU", + "audiobox_aesthetics_PC", + "audiobox_aesthetics_PQ", + ] + + for key in expected_keys: + assert key in result_small + assert key in result_large + assert key in result_fp32 + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() diff --git a/test/test_metrics/test_base_metrics.py b/test/test_metrics/test_base_metrics.py new file mode 100644 index 0000000..4b50c37 --- /dev/null +++ b/test/test_metrics/test_base_metrics.py @@ -0,0 +1,1011 @@ +import numpy as np +import pytest +import torch + +from versa.corpus_metrics.espnet_wer import ( + EspnetWerMetric, + register_espnet_wer_metric, +) +from versa.corpus_metrics.owsm_wer import OwsmWerMetric, register_owsm_wer_metric +from versa.corpus_metrics.whisper_wer import ( + WhisperWerMetric, + register_whisper_wer_metric, +) +from versa.definition import MetricRegistry +from versa.sequence_metrics.mcd_f0 import McdF0Metric, register_mcd_f0_metric +from versa.sequence_metrics.signal_metric import SignalMetric, register_signal_metric +from versa.sequence_metrics.warpq import WarpqMetric, register_warpq_metric +from versa.utterance_metrics.log_wmse import LogWmseMetric, register_log_wmse_metric +from versa.utterance_metrics.pseudo_mos import ( + PseudoMosMetric, + register_pseudo_mos_metric, +) +from versa.utterance_metrics.pysepm import PysepmMetric, register_pysepm_metric +from versa.utterance_metrics.qwen2_audio import ( + QWEN2_AUDIO_METRIC_CLASSES, + register_qwen2_audio_metric, +) +from versa.utterance_metrics.qwen_omni import ( + QWEN_OMNI_METRIC_CLASSES, + register_qwen_omni_metric, +) +from versa.utterance_metrics.scoreq import ( + ScoreqNrMetric, + ScoreqRefMetric, + register_scoreq_metric, +) +from versa.utterance_metrics.se_snr import SeSnrMetric, register_se_snr_metric +from versa.utterance_metrics.sheet_ssqa import ( + SheetSsqaMetric, + register_sheet_ssqa_metric, +) +from versa.utterance_metrics.singer import SingerMetric, register_singer_metric +from versa.utterance_metrics.speaker import SpeakerMetric, register_speaker_metric +from versa.utterance_metrics.squim import ( + SquimNoRefMetric, + SquimRefMetric, + register_squim_metric, +) +from versa.utterance_metrics.vad import VadMetric, register_vad_metric +from versa.utterance_metrics.universa import UniversaMetric, register_universa_metric +from versa.utterance_metrics.visqol_score import VisqolMetric, register_visqol_metric +from versa.utterance_metrics.vqscore import VqscoreMetric, register_vqscore_metric + + +def _audio_pair(length=16000): + t = np.linspace(0, 1, length, endpoint=False) + pred = 0.5 * np.sin(2 * np.pi * 220 * t).astype(np.float32) + ref = 0.5 * np.sin(2 * np.pi * 221 * t).astype(np.float32) + return pred, ref + + +def test_signal_metric_class_returns_existing_keys(): + pred, ref = _audio_pair() + metric = SignalMetric() + + scores = metric.compute(pred, ref) + + assert set(scores) == {"sdr", "sir", "sar", "si_snr", "ci_sdr"} + assert all(isinstance(value, (float, np.floating)) for value in scores.values()) + + +def test_mcd_f0_metric_class_returns_existing_keys(monkeypatch): + calls = {} + + def dummy_mcd_f0(pred_x, gt_x, fs, f0min, f0max, **kwargs): + calls["fs"] = fs + calls["f0min"] = f0min + calls["f0max"] = f0max + calls["kwargs"] = kwargs + return {"mcd": 1.2, "f0rmse": 3.4, "f0corr": 0.5} + + monkeypatch.setattr( + "versa.sequence_metrics.mcd_f0._ensure_mcd_f0_dependencies", lambda: None + ) + monkeypatch.setattr("versa.sequence_metrics.mcd_f0.mcd_f0", dummy_mcd_f0) + + pred, ref = _audio_pair() + metric = McdF0Metric({"f0min": 50, "f0max": 700, "dtw": True}) + scores = metric.compute(pred, ref, metadata={"sample_rate": 22050}) + + assert scores == {"mcd": 1.2, "f0rmse": 3.4, "f0corr": 0.5} + assert calls["fs"] == 22050 + assert calls["f0min"] == 50 + assert calls["f0max"] == 700 + assert calls["kwargs"]["dtw"] is True + + +def test_register_mcd_f0_metric(): + registry = MetricRegistry() + + register_mcd_f0_metric(registry) + + assert registry.get_metric("mcd_f0") is McdF0Metric + assert registry.get_metric("mcd") is McdF0Metric + assert registry.get_metadata("mcd_f0").requires_reference is True + + +def test_mcd_f0_missing_dependency(monkeypatch): + monkeypatch.setattr( + "versa.sequence_metrics.mcd_f0._ensure_mcd_f0_dependencies", + lambda: (_ for _ in ()).throw(ImportError("mcd_f0 requires pysptk")), + ) + + with pytest.raises(ImportError, match="mcd_f0 requires"): + McdF0Metric() + + +def test_warpq_metric_class_returns_existing_key(monkeypatch): + calls = {} + + class DummyWarpqModel: + args = {"sr": 16000} + + def dummy_warpq_setup(**kwargs): + calls["setup"] = kwargs + return DummyWarpqModel() + + def dummy_warpq(model, pred_x, gt_x, fs=8000): + calls["compute_fs"] = fs + return {"warpq": 3.8} + + monkeypatch.setattr("versa.sequence_metrics.warpq.warpq_setup", dummy_warpq_setup) + monkeypatch.setattr("versa.sequence_metrics.warpq.warpq", dummy_warpq) + + pred, ref = _audio_pair() + metric = WarpqMetric({"fs": 16000, "n_mfcc": 20, "apply_vad": True}) + scores = metric.compute(pred, ref, metadata={"sample_rate": 22050}) + + assert scores == {"warpq": 3.8} + assert calls["setup"]["fs"] == 16000 + assert calls["setup"]["n_mfcc"] == 20 + assert calls["setup"]["apply_vad"] is True + assert calls["compute_fs"] == 22050 + + +def test_register_warpq_metric(): + registry = MetricRegistry() + + register_warpq_metric(registry) + + assert registry.get_metric("warpq") is WarpqMetric + assert registry.get_metric("warp_q") is WarpqMetric + assert registry.get_metadata("warpq").requires_reference is True + + +def test_warpq_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.sequence_metrics.warpq.warpqMetric", None) + + with pytest.raises(ImportError, match="Please install WARP-Q"): + WarpqMetric() + + +def test_warpq_resamples_with_keyword_sample_rates(monkeypatch): + import versa.sequence_metrics.warpq as warpq_module + + calls = [] + + class DummyWarpqModel: + args = {"sr": 8000} + + def evaluate_versa(self, gt_x, pred_x): + calls.append(("evaluate", gt_x.shape[0], pred_x.shape[0])) + return 2.5 + + def dummy_resample(audio, orig_sr, target_sr): + calls.append(("resample", orig_sr, target_sr)) + return audio[:2] + + monkeypatch.setattr( + warpq_module, + "resample_audio", + dummy_resample, + ) + + scores = warpq_module.warpq(DummyWarpqModel(), np.arange(4), np.arange(4), fs=16000) + + assert scores == {"warpq": 2.5} + assert calls == [ + ("resample", 16000, 8000), + ("resample", 16000, 8000), + ("evaluate", 2, 2), + ] + + +def test_espnet_wer_metric_class_uses_reference_text(monkeypatch): + calls = {} + + def dummy_setup(**kwargs): + calls["setup"] = kwargs + return {"model": "dummy", "beam_size": kwargs["beam_size"]} + + monkeypatch.setattr( + "versa.corpus_metrics.espnet_wer.espnet_wer_setup", + dummy_setup, + ) + + def dummy_metric(wer_utils, pred_x, ref_text, fs=16000): + calls["wer_utils"] = wer_utils + calls["ref_text"] = ref_text + calls["fs"] = fs + return {"espnet_hyp_text": "hello", "espnet_wer_equal": 1} + + monkeypatch.setattr( + "versa.corpus_metrics.espnet_wer.espnet_levenshtein_metric", + dummy_metric, + ) + + pred, _ = _audio_pair() + metric = EspnetWerMetric({"beam_size": 7}) + scores = metric.compute(pred, metadata={"sample_rate": 22050, "text": "hello"}) + + assert scores == {"espnet_hyp_text": "hello", "espnet_wer_equal": 1} + assert calls["setup"]["cache_dir"] == "versa_cache/espnet_model_zoo" + assert calls["wer_utils"]["beam_size"] == 7 + assert calls["ref_text"] == "hello" + assert calls["fs"] == 22050 + + +def test_owsm_wer_metric_class_uses_reference_text(monkeypatch): + calls = {} + + def dummy_setup(**kwargs): + calls["setup"] = kwargs + return {"model": "dummy", "beam_size": kwargs["beam_size"]} + + monkeypatch.setattr( + "versa.corpus_metrics.owsm_wer.owsm_wer_setup", + dummy_setup, + ) + + def dummy_metric(wer_utils, pred_x, ref_text, fs=16000): + calls["ref_text"] = ref_text + calls["fs"] = fs + return {"owsm_hyp_text": "hello", "owsm_wer_equal": 1} + + monkeypatch.setattr( + "versa.corpus_metrics.owsm_wer.owsm_levenshtein_metric", + dummy_metric, + ) + + pred, _ = _audio_pair() + metric = OwsmWerMetric() + scores = metric.compute(pred, references="hello", metadata={"sample_rate": 16000}) + + assert scores == {"owsm_hyp_text": "hello", "owsm_wer_equal": 1} + assert calls["setup"]["cache_dir"] == "versa_cache/espnet_model_zoo" + assert calls["ref_text"] == "hello" + assert calls["fs"] == 16000 + + +def test_whisper_wer_metric_class_uses_cached_text(monkeypatch): + calls = {} + + def dummy_setup(**kwargs): + calls["setup"] = kwargs + return {"model": "dummy", "beam_size": kwargs["beam_size"]} + + monkeypatch.setattr( + "versa.corpus_metrics.whisper_wer.whisper_wer_setup", + dummy_setup, + ) + + def dummy_metric(wer_utils, pred_x, ref_text, fs=16000, cache_pred_text=None): + calls["ref_text"] = ref_text + calls["cache_pred_text"] = cache_pred_text + return {"whisper_hyp_text": cache_pred_text, "whisper_wer_equal": 1} + + monkeypatch.setattr( + "versa.corpus_metrics.whisper_wer.whisper_levenshtein_metric", + dummy_metric, + ) + + pred, _ = _audio_pair() + metric = WhisperWerMetric() + scores = metric.compute( + pred, + metadata={ + "sample_rate": 16000, + "text": "hello", + "general_cache": {"whisper_hyp_text": "cached hello"}, + }, + ) + + assert scores == {"whisper_hyp_text": "cached hello", "whisper_wer_equal": 1} + assert calls["setup"]["cache_dir"] == "versa_cache/whisper" + assert calls["ref_text"] == "hello" + assert calls["cache_pred_text"] == "cached hello" + + +def test_register_wer_metrics(): + registry = MetricRegistry() + + register_espnet_wer_metric(registry) + register_owsm_wer_metric(registry) + register_whisper_wer_metric(registry) + + assert registry.get_metric("espnet_wer") is EspnetWerMetric + assert registry.get_metric("owsm_wer") is OwsmWerMetric + assert registry.get_metric("whisper_wer") is WhisperWerMetric + assert registry.get_metadata("espnet_wer").requires_text is True + assert registry.get_metadata("owsm_wer").requires_text is True + assert registry.get_metadata("whisper_wer").requires_text is True + + +def test_whisper_wer_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.corpus_metrics.whisper_wer.whisper", None) + + with pytest.raises(RuntimeError, match="openai-whisper is not installed"): + WhisperWerMetric() + + +def test_register_signal_metric(): + registry = MetricRegistry() + + register_signal_metric(registry) + + assert registry.get_metric("signal_metric") is SignalMetric + assert registry.get_metric("signal") is SignalMetric + assert registry.get_metadata("signal_metric").requires_reference is True + + +def test_se_snr_metric_class_returns_existing_keys(monkeypatch): + calls = {} + + class DummySeModel: + pass + + def dummy_setup(**kwargs): + calls["setup"] = kwargs + return DummySeModel() + + def dummy_se_snr(model, pred_x, fs): + calls["model"] = model + calls["pred_x"] = pred_x + calls["fs"] = fs + return { + "se_sdr": 1.0, + "se_sar": 2.0, + "se_si_snr": 3.0, + "se_ci_sdr": 4.0, + } + + monkeypatch.setattr("versa.utterance_metrics.se_snr.se_snr_setup", dummy_setup) + monkeypatch.setattr("versa.utterance_metrics.se_snr.se_snr", dummy_se_snr) + + pred, _ = _audio_pair() + metric = SeSnrMetric({"model_tag": "test-tag", "use_gpu": True}) + scores = metric.compute(pred, metadata={"sample_rate": 22050}) + + assert scores == { + "se_sdr": 1.0, + "se_sar": 2.0, + "se_si_snr": 3.0, + "se_ci_sdr": 4.0, + } + assert calls["setup"]["cache_dir"] == "versa_cache/espnet_model_zoo" + assert calls["setup"]["model_tag"] == "test-tag" + assert calls["setup"]["use_gpu"] is True + assert calls["fs"] == 22050 + + +def test_register_se_snr_metric(): + registry = MetricRegistry() + + register_se_snr_metric(registry) + + assert registry.get_metric("se_snr") is SeSnrMetric + assert registry.get_metric("se_snr_metric") is SeSnrMetric + assert registry.get_metadata("se_snr").requires_reference is False + + +def test_se_snr_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.se_snr.SeparateSpeech", None) + + with pytest.raises(ImportError, match="se_snr requires espnet"): + SeSnrMetric() + + +def test_speaker_metric_class_returns_existing_key(monkeypatch): + calls = {} + + class DummySpeakerModel: + pass + + def dummy_setup(**kwargs): + calls["setup"] = kwargs + return DummySpeakerModel() + + def dummy_speaker_metric(model, pred_x, gt_x, fs): + calls["model"] = model + calls["fs"] = fs + return {"spk_similarity": 0.75} + + monkeypatch.setattr( + "versa.utterance_metrics.speaker.speaker_model_setup", dummy_setup + ) + monkeypatch.setattr( + "versa.utterance_metrics.speaker.speaker_metric", dummy_speaker_metric + ) + + pred, ref = _audio_pair() + metric = SpeakerMetric({"model_tag": "test-speaker", "use_gpu": True}) + scores = metric.compute(pred, ref, metadata={"sample_rate": 22050}) + + assert scores == {"spk_similarity": 0.75} + assert calls["setup"]["cache_dir"] == "versa_cache/espnet_model_zoo" + assert calls["setup"]["model_tag"] == "test-speaker" + assert calls["setup"]["use_gpu"] is True + assert calls["fs"] == 22050 + + +def test_register_speaker_metric(): + registry = MetricRegistry() + + register_speaker_metric(registry) + + assert registry.get_metric("speaker") is SpeakerMetric + assert registry.get_metric("spk_similarity") is SpeakerMetric + assert registry.get_metadata("speaker").requires_reference is True + + +def test_speaker_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.speaker.Speech2Embedding", None) + + with pytest.raises(ImportError, match="speaker requires espnet"): + SpeakerMetric() + + +def test_singer_metric_class_returns_existing_key(monkeypatch): + calls = {} + + class DummySingerModel: + pass + + def dummy_setup(**kwargs): + calls["setup"] = kwargs + return DummySingerModel() + + def dummy_singer_metric(model, pred_x, gt_x, fs, target_sr=44100): + calls["model"] = model + calls["fs"] = fs + calls["target_sr"] = target_sr + return {"singer_similarity": 0.5} + + monkeypatch.setattr( + "versa.utterance_metrics.singer.singer_model_setup", dummy_setup + ) + monkeypatch.setattr( + "versa.utterance_metrics.singer.singer_metric", dummy_singer_metric + ) + + pred, ref = _audio_pair() + metric = SingerMetric({"model_name": "contrastive", "target_sr": 48000}) + scores = metric.compute(pred, ref, metadata={"sample_rate": 22050}) + + assert scores == {"singer_similarity": 0.5} + assert calls["setup"]["model_name"] == "contrastive" + assert calls["fs"] == 22050 + assert calls["target_sr"] == 48000 + + +def test_register_singer_metric(): + registry = MetricRegistry() + + register_singer_metric(registry) + + assert registry.get_metric("singer") is SingerMetric + assert registry.get_metric("singer_similarity") is SingerMetric + assert registry.get_metadata("singer").requires_reference is True + + +def test_singer_missing_dependency(monkeypatch): + monkeypatch.setattr( + "versa.utterance_metrics.singer.singer_model_setup", + lambda **kwargs: (_ for _ in ()).throw( + ImportError("Please run `install_ssl-singer-identity.sh` in tools.") + ), + ) + + with pytest.raises(ImportError, match="install_ssl-singer-identity"): + SingerMetric() + + +def test_log_wmse_metric_class_returns_existing_key(monkeypatch): + calls = {} + + class DummyLogWMSE: + def __init__(self, **kwargs): + calls["setup"] = kwargs + + def __call__(self, unproc_x, proc_x, gt_x): + calls["unproc_shape"] = tuple(unproc_x.shape) + calls["proc_shape"] = tuple(proc_x.shape) + calls["gt_shape"] = tuple(gt_x.shape) + return torch.tensor([0.33]) + + monkeypatch.setattr("versa.utterance_metrics.log_wmse.LogWMSE", DummyLogWMSE) + + pred, ref = _audio_pair() + unprocessed = pred * 0.5 + metric = LogWmseMetric({"audio_length": 2.0, "sample_rate": 48000}) + scores = metric.compute( + pred, + ref, + metadata={"sample_rate": 16000, "unprocessed": unprocessed}, + ) + + assert scores == {"log_wmse": pytest.approx(0.33)} + assert calls["setup"]["audio_length"] == 2.0 + assert calls["setup"]["sample_rate"] == 48000 + assert calls["unproc_shape"] == (1, 1, pred.shape[0]) + assert calls["proc_shape"] == (1, 1, 1, pred.shape[0]) + assert calls["gt_shape"] == (1, 1, 1, ref.shape[0]) + + +def test_register_log_wmse_metric(): + registry = MetricRegistry() + + register_log_wmse_metric(registry) + + assert registry.get_metric("log_wmse") is LogWmseMetric + assert registry.get_metric("log-wmse") is LogWmseMetric + assert registry.get_metadata("log_wmse").requires_reference is True + + +def test_log_wmse_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.log_wmse.LogWMSE", None) + + with pytest.raises(ImportError, match="torch-log-wmse"): + LogWmseMetric() + + +def test_pseudo_mos_metric_class_returns_existing_keys(monkeypatch): + calls = {} + + def dummy_setup(predictor_types, predictor_args, cache_dir, use_gpu): + calls["setup"] = { + "predictor_types": predictor_types, + "predictor_args": predictor_args, + "cache_dir": cache_dir, + "use_gpu": use_gpu, + } + return {"utmos": object()}, {"utmos": 16000} + + def dummy_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): + calls["fs"] = fs + calls["use_gpu"] = use_gpu + return {"utmos": 4.2} + + monkeypatch.setattr( + "versa.utterance_metrics.pseudo_mos.pseudo_mos_setup", dummy_setup + ) + monkeypatch.setattr( + "versa.utterance_metrics.pseudo_mos.pseudo_mos_metric", dummy_metric + ) + + pred, _ = _audio_pair() + metric = PseudoMosMetric( + { + "predictor_types": ["utmos"], + "predictor_args": {"utmos": {"fs": 16000}}, + "cache_dir": "cache", + "use_gpu": True, + } + ) + scores = metric.compute(pred, metadata={"sample_rate": 22050}) + + assert scores == {"utmos": 4.2} + assert calls["setup"]["predictor_types"] == ["utmos"] + assert calls["setup"]["cache_dir"] == "cache" + assert calls["setup"]["use_gpu"] is True + assert calls["fs"] == 22050 + assert calls["use_gpu"] is True + + +def test_register_pseudo_mos_metric(): + registry = MetricRegistry() + + register_pseudo_mos_metric(registry) + + assert registry.get_metric("pseudo_mos") is PseudoMosMetric + assert registry.get_metric("utmos") is PseudoMosMetric + assert registry.get_metadata("pseudo_mos").requires_reference is False + + +def test_universa_metric_class_auto_selects_references(monkeypatch): + calls = {} + + def dummy_universa_metric( + audio_data, ref_audio=None, ref_text=None, original_sr=16000, ref_sr=None + ): + calls["ref_audio"] = ref_audio + calls["ref_text"] = ref_text + calls["original_sr"] = original_sr + calls["ref_sr"] = ref_sr + return {"universa_mos": 3.5} + + monkeypatch.setattr( + "versa.utterance_metrics.universa.universa_metric", dummy_universa_metric + ) + + pred, ref = _audio_pair() + metric = UniversaMetric() + scores = metric.compute( + pred, + ref, + metadata={"sample_rate": 22050, "text": "hello"}, + ) + + assert scores == {"universa_mos": 3.5} + assert calls["ref_audio"] is ref + assert calls["ref_text"] == "hello" + assert calls["original_sr"] == 22050 + + +def test_register_universa_metric(): + registry = MetricRegistry() + + register_universa_metric(registry) + + assert registry.get_metric("universa") is UniversaMetric + assert registry.get_metric("uni_versa") is UniversaMetric + assert registry.get_metadata("universa").requires_reference is False + + +def test_universa_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.universa.UniversaInference", None) + + with pytest.raises(ImportError, match="universa requires espnet"): + UniversaMetric({"model_type": "noref"}).compute(np.zeros(16000)) + + +def test_qwen2_audio_metric_class_returns_existing_key(monkeypatch): + calls = {} + + monkeypatch.setattr( + "versa.utterance_metrics.qwen2_audio.qwen2_model_setup", + lambda **kwargs: {"model": "dummy"}, + ) + + def dummy_base_metric( + qwen_utils, pred_x, fs=16000, custom_prompt=None, max_length=1000 + ): + calls["fs"] = fs + calls["custom_prompt"] = custom_prompt + calls["max_length"] = max_length + return "young adult" + + monkeypatch.setattr( + "versa.utterance_metrics.qwen2_audio.qwen2_base_metric", + dummy_base_metric, + ) + + pred, _ = _audio_pair() + metric_class = QWEN2_AUDIO_METRIC_CLASSES["speaker_age"] + metric = metric_class({"prompt": "Age?", "max_length": 77}) + scores = metric.compute(pred, metadata={"sample_rate": 22050}) + + assert scores == {"qwen_speaker_age": "young adult"} + assert calls["fs"] == 22050 + assert calls["custom_prompt"] == "Age?" + assert calls["max_length"] == 77 + + +def test_register_qwen2_audio_metric(): + registry = MetricRegistry() + + register_qwen2_audio_metric(registry) + + metric_class = QWEN2_AUDIO_METRIC_CLASSES["speaker_age"] + assert registry.get_metric("qwen2_audio_speaker_age") is metric_class + assert registry.get_metric("qwen2_speaker_age_metric") is metric_class + assert registry.get_metadata("qwen2_audio_speaker_age").requires_reference is False + + +def test_qwen_omni_metric_class_returns_existing_key(monkeypatch): + calls = {} + + monkeypatch.setattr( + "versa.utterance_metrics.qwen_omni.qwen_omni_model_setup", + lambda **kwargs: {"model": "dummy"}, + ) + + def dummy_base_metric( + qwen_utils, pred_x, fs=16000, custom_prompt=None, max_length=500 + ): + calls["fs"] = fs + calls["custom_prompt"] = custom_prompt + calls["max_length"] = max_length + return "happy" + + monkeypatch.setattr( + "versa.utterance_metrics.qwen_omni.qwen_omni_base_metric", + dummy_base_metric, + ) + + pred, _ = _audio_pair() + metric_class = QWEN_OMNI_METRIC_CLASSES["speech_emotion"] + metric = metric_class({"prompt": "Emotion?", "max_length": 88}) + scores = metric.compute(pred, metadata={"sample_rate": 22050}) + + assert scores == {"qwen_omni_speech_emotion": "happy"} + assert calls["fs"] == 22050 + assert calls["custom_prompt"] == "Emotion?" + assert calls["max_length"] == 88 + + +def test_register_qwen_omni_metric(): + registry = MetricRegistry() + + register_qwen_omni_metric(registry) + + metric_class = QWEN_OMNI_METRIC_CLASSES["speech_emotion"] + assert registry.get_metric("qwen_omni_speech_emotion") is metric_class + assert registry.get_metric("qwen_omni_speech_emotion_metric") is metric_class + assert registry.get_metadata("qwen_omni_speech_emotion").requires_reference is False + + +def test_squim_no_ref_metric_uses_cached_model(monkeypatch): + class DummyObjectiveBundle: + @staticmethod + def get_model(): + return lambda pred_x: ( + torch.tensor([0.6]), + torch.tensor([1.2]), + torch.tensor([-3.4]), + ) + + monkeypatch.setattr("versa.utterance_metrics.squim.SQUIM_AVAILABLE", True) + monkeypatch.setattr( + "versa.utterance_metrics.squim.SQUIM_OBJECTIVE", DummyObjectiveBundle + ) + + pred, _ = _audio_pair() + metric = SquimNoRefMetric() + scores = metric.compute(pred, metadata={"sample_rate": 16000}) + + assert scores == { + "torch_squim_stoi": pytest.approx(0.6), + "torch_squim_pesq": pytest.approx(1.2), + "torch_squim_si_sdr": pytest.approx(-3.4), + } + + +def test_squim_ref_metric_uses_cached_model(monkeypatch): + class DummySubjectiveBundle: + @staticmethod + def get_model(): + return lambda pred_x, ref_x: torch.tensor([4.2]) + + monkeypatch.setattr("versa.utterance_metrics.squim.SQUIM_AVAILABLE", True) + monkeypatch.setattr( + "versa.utterance_metrics.squim.SQUIM_SUBJECTIVE", DummySubjectiveBundle + ) + + pred, ref = _audio_pair() + metric = SquimRefMetric() + scores = metric.compute(pred, ref, metadata={"sample_rate": 16000}) + + assert scores == {"torch_squim_mos": pytest.approx(4.2)} + + +def test_register_squim_metric(): + registry = MetricRegistry() + + register_squim_metric(registry) + + assert registry.get_metric("squim_ref") is SquimRefMetric + assert registry.get_metric("squim_no_ref") is SquimNoRefMetric + assert registry.get_metric("squim") is SquimNoRefMetric + assert registry.get_metadata("squim_ref").requires_reference is True + assert registry.get_metadata("squim_no_ref").requires_reference is False + + +def test_pysepm_registration_and_missing_dependency(monkeypatch): + registry = MetricRegistry() + register_pysepm_metric(registry) + + assert registry.get_metric("pysepm") is PysepmMetric + assert registry.get_metric("pysepm_metric") is PysepmMetric + assert registry.get_metadata("pysepm").requires_reference is True + + monkeypatch.setattr("versa.utterance_metrics.pysepm.pysepm", None) + with pytest.raises(ImportError, match="pysepm is not installed"): + PysepmMetric() + + +def test_vad_metric_class_returns_existing_key(monkeypatch): + calls = {} + + def dummy_get_speech_ts(pred_x, model, **kwargs): + calls["pred_x"] = pred_x + calls["model"] = model + calls["kwargs"] = kwargs + return [{"start": 0.1, "end": 0.4}] + + monkeypatch.setattr( + "versa.utterance_metrics.vad.torch.hub.load", + lambda **kwargs: ("dummy-model", (dummy_get_speech_ts, None, None, None)), + ) + + pred, _ = _audio_pair() + metric = VadMetric( + { + "threshold": 0.3, + "min_speech_duration_ms": 100, + "max_speech_duration_s": 10, + "min_silence_duration_ms": 200, + "speech_pad_ms": 40, + } + ) + scores = metric.compute(pred, metadata={"sample_rate": 16000}) + + assert scores == {"vad_info": [{"start": 0.1, "end": 0.4}]} + assert calls["model"] == "dummy-model" + assert calls["kwargs"]["sampling_rate"] == 16000 + assert calls["kwargs"]["threshold"] == 0.3 + assert calls["kwargs"]["min_speech_duration_ms"] == 100 + assert calls["kwargs"]["max_speech_duration_s"] == 10 + assert calls["kwargs"]["min_silence_duration_ms"] == 200 + assert calls["kwargs"]["speech_pad_ms"] == 40 + + +def test_register_vad_metric(): + registry = MetricRegistry() + + register_vad_metric(registry) + + assert registry.get_metric("vad") is VadMetric + assert registry.get_metric("silero_vad") is VadMetric + assert registry.get_metadata("vad").requires_reference is False + + +def test_sheet_ssqa_metric_class_returns_existing_key(monkeypatch): + class DummyInnerModel: + def to(self, device): + self.device = device + return self + + class DummySheetModel: + def __init__(self): + self.model = DummyInnerModel() + + def predict(self, wav): + return 3.25 + + monkeypatch.setattr( + "versa.utterance_metrics.sheet_ssqa.torch.hub.load", + lambda *args, **kwargs: DummySheetModel(), + ) + + pred, _ = _audio_pair() + metric = SheetSsqaMetric({"cache_dir": "test-cache", "use_gpu": False}) + scores = metric.compute(pred, metadata={"sample_rate": 16000}) + + assert scores == {"sheet_ssqa": 3.25} + + +def test_register_sheet_ssqa_metric(): + registry = MetricRegistry() + + register_sheet_ssqa_metric(registry) + + assert registry.get_metric("sheet_ssqa") is SheetSsqaMetric + assert registry.get_metric("sheet") is SheetSsqaMetric + assert registry.get_metadata("sheet_ssqa").requires_reference is False + + +def test_scoreq_metric_classes_return_existing_keys(monkeypatch): + calls = [] + + class DummyScoreq: + def __init__(self, data_domain, mode, cache_dir, device): + self.mode = mode + calls.append( + { + "data_domain": data_domain, + "mode": mode, + "cache_dir": cache_dir, + "device": device, + } + ) + + def predict(self, test_path, ref_path): + assert test_path is not None + if self.mode == "ref": + assert ref_path is not None + return 1.2 + assert ref_path is None + return 2.4 + + monkeypatch.setattr("versa.utterance_metrics.scoreq.Scoreq", DummyScoreq) + + pred, ref = _audio_pair() + nr_metric = ScoreqNrMetric({"data_domain": "natural", "model_cache": "cache-a"}) + ref_metric = ScoreqRefMetric({"cache_dir": "cache-b"}) + + assert nr_metric.compute(pred, metadata={"sample_rate": 16000}) == { + "scoreq_nr": 2.4 + } + assert ref_metric.compute(pred, ref, metadata={"sample_rate": 16000}) == { + "scoreq_ref": 1.2 + } + assert calls[0] == { + "data_domain": "natural", + "mode": "nr", + "cache_dir": "cache-a", + "device": "cpu", + } + assert calls[1]["mode"] == "ref" + assert calls[1]["cache_dir"] == "cache-b" + + +def test_register_scoreq_metric(): + registry = MetricRegistry() + + register_scoreq_metric(registry) + + assert registry.get_metric("scoreq_nr") is ScoreqNrMetric + assert registry.get_metric("scoreq_ref") is ScoreqRefMetric + assert registry.get_metric("scoreq") is ScoreqNrMetric + assert registry.get_metadata("scoreq_nr").requires_reference is False + assert registry.get_metadata("scoreq_ref").requires_reference is True + + +def test_scoreq_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.scoreq.Scoreq", None) + + with pytest.raises(ModuleNotFoundError, match="scoreq is not installed"): + ScoreqNrMetric() + + +def test_vqscore_metric_class_returns_existing_key(monkeypatch): + class DummyVqscoreModel: + device = "cpu" + input_transform = "none" + + def CNN_1D_encoder(self, sp_input): + return torch.ones((1, 2, 3)) + + def quantizer(self, z, stochastic=False, update=False): + return z.transpose(2, 1), None, None, None + + monkeypatch.setattr( + "versa.utterance_metrics.vqscore.vqscore_setup", + lambda use_gpu=False: DummyVqscoreModel(), + ) + + pred, _ = _audio_pair() + metric = VqscoreMetric() + scores = metric.compute(pred, metadata={"sample_rate": 16000}) + + assert scores == {"vqscore": pytest.approx(1.0, abs=1e-4)} + + +def test_register_vqscore_metric(): + registry = MetricRegistry() + + register_vqscore_metric(registry) + + assert registry.get_metric("vqscore") is VqscoreMetric + assert registry.get_metric("vq_score") is VqscoreMetric + assert registry.get_metadata("vqscore").requires_reference is False + + +def test_visqol_metric_class_returns_existing_key(monkeypatch): + class DummySimilarityResult: + moslqo = 4.1 + + class DummyApi: + def Measure(self, gt_x, pred_x): + return DummySimilarityResult() + + monkeypatch.setattr( + "versa.utterance_metrics.visqol_score.visqol_setup", + lambda model: (DummyApi(), 16000), + ) + + pred, ref = _audio_pair() + metric = VisqolMetric({"model": "speech"}) + scores = metric.compute(pred, ref, metadata={"sample_rate": 16000}) + + assert scores == {"visqol": 4.1} + + +def test_register_visqol_metric(): + registry = MetricRegistry() + + register_visqol_metric(registry) + + assert registry.get_metric("visqol") is VisqolMetric + assert registry.get_metric("VISQOL") is VisqolMetric + assert registry.get_metadata("visqol").requires_reference is True + + +def test_visqol_missing_dependency(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.visqol_score.visqol_lib_py", None) + monkeypatch.setattr("versa.utterance_metrics.visqol_score.visqol_config_pb2", None) + + with pytest.raises(ImportError, match="visqol is not installed"): + VisqolMetric() diff --git a/test/test_metrics/test_cdpam.py b/test/test_metrics/test_cdpam.py deleted file mode 100755 index 8828dda..0000000 --- a/test/test_metrics/test_cdpam.py +++ /dev/null @@ -1,83 +0,0 @@ -import wave -from pathlib import Path - -import numpy as np -import pytest - -from versa.utterance_metrics.cdpam_distance import cdpam_metric, cdpam_model_setup - -# Assume the fixed WAV file fixtures and helper function are defined as in the ASR matching test. -# For example: - - -def generate_fixed_wav( - filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None -): - """ - Generate a deterministic WAV file with a modulated sine wave. - """ - t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) - if envelope_func is None: - envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) - else: - envelope = envelope_func(t) - audio = envelope * np.sin(2 * np.pi * base_freq * t) - amplitude = np.iinfo(np.int16).max - data = (audio * amplitude).astype(np.int16) - with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(data.tobytes()) - - -def load_wav_as_array(wav_path, sample_rate=16000): - """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. - """ - with wave.open(str(wav_path), "rb") as wf: - frames = wf.getnframes() - audio_data = wf.readframes(frames) - audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) - return audio_array / np.iinfo(np.int16).max - - -@pytest.fixture(scope="session") -def fixed_audio_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - audio_file = tmp_dir / "fixed_audio.wav" - generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) - return audio_file - - -@pytest.fixture(scope="session") -def fixed_ground_truth_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. - generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) - return gt_file - - -@pytest.fixture(scope="session") -def fixed_audio(fixed_audio_wav): - return load_wav_as_array(fixed_audio_wav) - - -@pytest.fixture(scope="session") -def fixed_ground_truth(fixed_ground_truth_wav): - return load_wav_as_array(fixed_ground_truth_wav) - - -# ------------------------------- -# CDPAM Metric Definition and Tests -# ------------------------------- -def test_cdpam_metric_identical(fixed_audio): - """ - When comparing an audio signal with itself, the cdpam distance should be 0.0. - """ - model = cdpam_model_setup() - scores = cdpam_metric(model, fixed_audio, fixed_audio, 16000) - assert ( - scores["cdpam_distance"] == 0.0 - ), f"Expected cdpam distance == 0.0 for identical signals, got {scores['cdpam_distance']}" diff --git a/test/test_metrics/test_cdpam_distance.py b/test/test_metrics/test_cdpam_distance.py new file mode 100644 index 0000000..20ae14f --- /dev/null +++ b/test/test_metrics/test_cdpam_distance.py @@ -0,0 +1,271 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.cdpam_distance import ( + CdpamDistanceMetric, + is_cdpam_available, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a ground truth WAV file to be used as reference audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the ground truth audio file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_cdpam_distance(use_gpu, fixed_audio, fixed_ground_truth): + """ + Test the CDPAM distance metric using the fixed audio files. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + + # Check that the result contains the expected key + assert "cdpam_distance" in result, "Result should contain 'cdpam_distance' key" + + # Check that the result is a float + cdpam_dist = result["cdpam_distance"] + assert isinstance(cdpam_dist, float), "cdpam_distance should be a float" + + # Check that the distance score is reasonable (should be non-negative) + assert cdpam_dist >= 0.0, f"CDPAM distance should be non-negative, got {cdpam_dist}" + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_metadata(): + """Test that the CDPAM distance metric has correct metadata.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "cdpam_distance" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "cdpam" in metadata.dependencies + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_different_sample_rates(): + """Test that the CDPAM distance metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + + # Test with 44.1kHz audio (should be resampled to 22.05kHz) + audio_44k_1 = np.random.random(44100) + audio_44k_2 = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k_1, audio_44k_2, metadata=metadata_44k) + + # Test with 22.05kHz audio (no resampling needed) + audio_22k_1 = np.random.random(22050) + audio_22k_2 = np.random.random(22050) + metadata_22k = {"sample_rate": 22050} + result_22k = metric.compute(audio_22k_1, audio_22k_2, metadata=metadata_22k) + + # Both should return valid scores with expected keys + assert ( + "cdpam_distance" in result_44k + ), "44kHz result should contain 'cdpam_distance' key" + assert ( + "cdpam_distance" in result_22k + ), "22kHz result should contain 'cdpam_distance' key" + + # Both should return float values + assert isinstance(result_44k["cdpam_distance"], float) + assert isinstance(result_22k["cdpam_distance"], float) + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_invalid_input(): + """Test that the CDPAM distance metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(22050), metadata={"sample_rate": 22050}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(22050), None, metadata={"sample_rate": 22050}) + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_config_options(): + """Test that the CDPAM distance metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = CdpamDistanceMetric(config_cpu) + + # All should work without errors + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + metadata = {"sample_rate": 22050} + + result_cpu = metric_cpu.compute(audio1, audio2, metadata=metadata) + + # Should return the same structure + assert "cdpam_distance" in result_cpu + assert isinstance(result_cpu["cdpam_distance"], float) + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_identical_signals(): + """Test that the CDPAM distance metric gives zero distance for identical signals.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with identical signals + audio = np.random.random(22050) + result = metric.compute(audio, audio, metadata=metadata) + + # Results should be 0.0 for identical signals + assert ( + result["cdpam_distance"] == 0.0 + ), "Identical signals should have zero distance" + + +@pytest.mark.skipif(not is_cdpam_available(), reason="CDPAM not available") +def test_cdpam_distance_metric_consistent_results(): + """Test that the CDPAM distance metric gives consistent results for the same inputs.""" + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with fixed signals + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + result1 = metric.compute(audio1, audio2, metadata=metadata) + result2 = metric.compute(audio1, audio2, metadata=metadata) + + # Results should be identical for the same inputs + np.testing.assert_almost_equal( + result1["cdpam_distance"], + result2["cdpam_distance"], + decimal=6, + err_msg="Results should be identical for the same inputs", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() diff --git a/test/test_metrics/test_chroma_alignment.py b/test/test_metrics/test_chroma_alignment.py new file mode 100644 index 0000000..9a2cff6 --- /dev/null +++ b/test/test_metrics/test_chroma_alignment.py @@ -0,0 +1,327 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.chroma_alignment import ChromaAlignmentMetric + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=22050, base_freq=440, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 440 Hz sine wave (A4 note). + generate_fixed_wav(audio_file, duration=1.0, sample_rate=22050, base_freq=440) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different duration but same frequency to test DTW alignment. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 440 Hz sine wave but different duration. + generate_fixed_wav(gt_file, duration=1.2, sample_rate=22050, base_freq=440) + return gt_file + + +@pytest.fixture(scope="session") +def different_pitch_wav(tmp_path_factory): + """ + Create a WAV file with a different pitch for testing distance metrics. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + diff_file = tmp_dir / "different_pitch.wav" + # Generate a file with a 554.37 Hz sine wave (C#5 note). + generate_fixed_wav(diff_file, duration=1.0, sample_rate=22050, base_freq=554.37) + return diff_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=22050): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +@pytest.fixture(scope="session") +def different_pitch_audio(different_pitch_wav): + """ + Load the different pitch audio file as a NumPy array. + """ + return load_wav_as_array(different_pitch_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.parametrize( + "scale_factor,feature_types,distance_metrics", + [ + (100.0, ["stft"], ["cosine"]), + (50.0, ["stft", "cqt"], ["cosine", "euclidean"]), + (200.0, ["stft", "cqt", "cens"], ["cosine"]), + ], +) +def test_utterance_chroma_alignment( + scale_factor, feature_types, distance_metrics, fixed_audio, fixed_ground_truth +): + """ + Test the Chroma Alignment metric using the fixed audio and ground truth. + The test uses deterministic data so that the result is always reproducible. + """ + config = { + "scale_factor": scale_factor, + "feature_types": feature_types, + "distance_metrics": distance_metrics, + "normalize": True, + "normalize_by_path": True, + } + + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + + # Check that the result contains the expected keys + for feat_type in feature_types: + for dist_metric in distance_metrics: + key = f"chroma_{feat_type}_{dist_metric}_dtw" + assert key in result, f"Result should contain '{key}' key" + assert isinstance( + result[key], (int, float) + ), f"Score {key} should be numeric" + assert result[key] >= 0, f"Score {key} should be non-negative" + + # Check for additional scaled variants + if "stft" in feature_types and "cosine" in distance_metrics: + assert "chroma_stft_cosine_dtw_raw" in result + assert "chroma_stft_cosine_dtw_log" in result + + +def test_chroma_alignment_metric_metadata(): + """Test that the Chroma Alignment metric has correct metadata.""" + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "chroma_alignment" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is False + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + assert "scipy" in metadata.dependencies + + +def test_chroma_alignment_metric_different_pitches(fixed_audio, different_pitch_audio): + """Test that the Chroma Alignment metric gives higher distances for different pitches.""" + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + + # Test with same pitch (should give lower distance) + result_same = metric.compute(fixed_audio, fixed_audio, metadata=metadata) + + # Test with different pitch (should give higher distance) + result_different = metric.compute( + fixed_audio, different_pitch_audio, metadata=metadata + ) + + # The distance should be higher for different pitches + for key in result_same: + if key in result_different and not key.endswith("_log"): + # Log-scaled metric works differently, so skip it + assert ( + result_different[key] >= result_same[key] + ), f"Distance should be higher for different pitches in {key}" + + +def test_chroma_alignment_metric_invalid_input(): + """Test that the Chroma Alignment metric handles invalid inputs correctly.""" + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + + # Test with None input + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(None, np.random.random(22050), metadata={"sample_rate": 22050}) + + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(np.random.random(22050), None, metadata={"sample_rate": 22050}) + + +def test_chroma_alignment_metric_config_options(): + """Test that the Chroma Alignment metric handles different configuration options.""" + # Test with different scale factors + config_small_scale = { + "scale_factor": 50.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric_small = ChromaAlignmentMetric(config_small_scale) + + config_large_scale = { + "scale_factor": 200.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric_large = ChromaAlignmentMetric(config_large_scale) + + # Test with normalization options + config_no_norm = { + "normalize": False, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric_no_norm = ChromaAlignmentMetric(config_no_norm) + + # All should work without errors + audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, 22050)) + audio2 = np.sin(2 * np.pi * 880 * np.linspace(0, 1, 22050)) + metadata = {"sample_rate": 22050} + result_small = metric_small.compute(audio, audio2, metadata=metadata) + result_large = metric_large.compute(audio, audio2, metadata=metadata) + result_no_norm = metric_no_norm.compute(audio, audio2, metadata=metadata) + + # All should return the same structure + assert "chroma_stft_cosine_dtw" in result_small + assert "chroma_stft_cosine_dtw" in result_large + assert "chroma_stft_cosine_dtw" in result_no_norm + + # Scale factor should affect the magnitude + assert ( + result_large["chroma_stft_cosine_dtw"] > result_small["chroma_stft_cosine_dtw"] + ) + + +def test_chroma_alignment_metric_alignment_paths(): + """Test that the Chroma Alignment metric can return alignment paths when requested.""" + config = { + "scale_factor": 100.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + "return_alignment": True, + } + + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + audio = np.random.random(22050) + + result = metric.compute(audio, audio, metadata=metadata) + + # Should contain alignments when requested + assert "alignments" in result + assert "chroma_stft_cosine_dtw" in result["alignments"] + + +def test_chroma_alignment_metric_multidimensional_input(): + """Test that the Chroma Alignment metric handles multidimensional input correctly.""" + config = { + "scale_factor": 100.0, + "feature_types": ["stft"], + "distance_metrics": ["cosine"], + } + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": 22050} + + # Test with 2D input (should be flattened) + audio_2d = np.random.random((22050, 1)) + result_2d = metric.compute(audio_2d, audio_2d, metadata=metadata) + + # Test with 1D input + audio_1d = np.random.random(22050) + result_1d = metric.compute(audio_1d, audio_1d, metadata=metadata) + + # Both should work and give similar results (not exactly the same due to randomness) + assert "chroma_stft_cosine_dtw" in result_2d + assert "chroma_stft_cosine_dtw" in result_1d + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist( + fixed_audio_wav, fixed_ground_truth_wav, different_pitch_wav +): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + assert Path(different_pitch_wav).exists() diff --git a/test/test_metrics/test_definition.py b/test/test_metrics/test_definition.py new file mode 100644 index 0000000..59b3f3d --- /dev/null +++ b/test/test_metrics/test_definition.py @@ -0,0 +1,41 @@ +from versa.definition import ( + BaseMetric, + MetricCategory, + MetricFactory, + MetricMetadata, + MetricRegistry, + MetricType, +) + + +class DummyMetric(BaseMetric): + def _setup(self): + pass + + def compute(self, predictions, references=None, metadata=None): + return {"dummy": 1.0} + + def get_metadata(self): + return DUMMY_METADATA + + +DUMMY_METADATA = MetricMetadata( + name="dummy", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["definitely_missing_dependency_for_versa_test"], + description="Dummy metric for registry tests.", +) + + +def test_metric_factory_create_suite_with_missing_dependency_and_default_config(): + registry = MetricRegistry() + registry.register(DummyMetric, DUMMY_METADATA) + + suite = MetricFactory(registry).create_metric_suite(["dummy"]) + + assert suite.compute_all(predictions=None) == {"dummy": {"dummy": 1.0}} diff --git a/test/test_metrics/test_discrete_speech.py b/test/test_metrics/test_discrete_speech.py index f439ed3..027f08e 100644 --- a/test/test_metrics/test_discrete_speech.py +++ b/test/test_metrics/test_discrete_speech.py @@ -5,137 +5,304 @@ import pytest from versa.utterance_metrics.discrete_speech import ( - discrete_speech_setup, - discrete_speech_metric, + DiscreteSpeechMetric, + is_discrete_speech_available, ) -# Reuse the same helper functions from your STOI test +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- def generate_fixed_wav( filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None ): """ Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. """ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. if envelope_func is None: envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) else: envelope = envelope_func(t) audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. amplitude = np.iinfo(np.int16).max data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. wf.setframerate(sample_rate) wf.writeframes(data.tobytes()) -def load_wav_as_array(wav_path, sample_rate=16000): - """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. - """ - with wave.open(str(wav_path), "rb") as wf: - frames = wf.getnframes() - audio_data = wf.readframes(frames) - audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) - return audio_array / np.iinfo(np.int16).max - - +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- @pytest.fixture(scope="session") def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ tmp_dir = tmp_path_factory.mktemp("audio_data") audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) return audio_file @pytest.fixture(scope="session") def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ tmp_dir = tmp_path_factory.mktemp("audio_data") gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + # Generate a ground truth file with a 300 Hz sine wave. generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) return gt_file +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + @pytest.fixture(scope="session") def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ return load_wav_as_array(fixed_audio_wav) @pytest.fixture(scope="session") def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ return load_wav_as_array(fixed_ground_truth_wav) -@pytest.fixture(scope="session") -def discrete_speech_predictors(): - """Set up discrete speech predictors once per test session.""" - return discrete_speech_setup(use_gpu=False) - - # ------------------------------- -# Discrete Speech Metric Tests +# Test Functions # ------------------------------- -def test_discrete_speech_metric_identical(fixed_audio, discrete_speech_predictors): +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_discrete_speech_identical(use_gpu, fixed_audio): """ + Test the Discrete Speech metric using identical audio signals. When comparing an audio signal with itself, the discrete speech scores should be high. """ - scores = discrete_speech_metric( - discrete_speech_predictors, fixed_audio, fixed_audio, 16000 - ) + config = {"use_gpu": use_gpu} + + metric = DiscreteSpeechMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_audio, metadata=metadata) # Check that all expected metrics are present - assert "speech_bert" in scores - assert "speech_bleu" in scores - assert "speech_token_distance" in scores + assert "speech_bert" in result, "Result should contain 'speech_bert' key" + assert "speech_bleu" in result, "Result should contain 'speech_bleu' key" + assert ( + "speech_token_distance" in result + ), "Result should contain 'speech_token_distance' key" # For identical signals, scores should be relatively high # Note: Perfect scores (1.0) are not always expected for discrete speech metrics assert ( - scores["speech_bert"] > 0.9 - ), f"Expected SpeechBERT score > 0.5 for identical signals, got {scores['speech_bert']}" + result["speech_bert"] > 0.9 + ), f"Expected SpeechBERT score > 0.9 for identical signals, got {result['speech_bert']}" assert ( - scores["speech_bleu"] > 0.9 - ), f"Expected SpeechBLEU score > 0.3 for identical signals, got {scores['speech_bleu']}" + result["speech_bleu"] > 0.9 + ), f"Expected SpeechBLEU score > 0.9 for identical signals, got {result['speech_bleu']}" assert ( - scores["speech_token_distance"] > 0.9 - ), f"Expected SpeechTokenDistance score > 0.3 for identical signals, got {scores['speech_token_distance']}" + result["speech_token_distance"] > 0.9 + ), f"Expected SpeechTokenDistance score > 0.9 for identical signals, got {result['speech_token_distance']}" -def test_discrete_speech_metric_different( - fixed_audio, fixed_ground_truth, discrete_speech_predictors -): +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_discrete_speech_different(use_gpu, fixed_audio, fixed_ground_truth): """ + Test the Discrete Speech metric using different audio signals. When comparing two different fixed signals, the discrete speech scores should be lower than identical signals. """ + config = {"use_gpu": use_gpu} + + metric = DiscreteSpeechMetric(config) + metadata = {"sample_rate": 16000} + # Get scores for identical signals first - identical_scores = discrete_speech_metric( - discrete_speech_predictors, fixed_audio, fixed_audio, 16000 - ) + identical_result = metric.compute(fixed_audio, fixed_audio, metadata=metadata) # Get scores for different signals - different_scores = discrete_speech_metric( - discrete_speech_predictors, fixed_audio, fixed_ground_truth, 16000 + different_result = metric.compute( + fixed_audio, fixed_ground_truth, metadata=metadata ) # Check that all expected metrics are present - assert "speech_bert" in different_scores - assert "speech_bleu" in different_scores - assert "speech_token_distance" in different_scores + assert "speech_bert" in different_result, "Result should contain 'speech_bert' key" + assert "speech_bleu" in different_result, "Result should contain 'speech_bleu' key" + assert ( + "speech_token_distance" in different_result + ), "Result should contain 'speech_token_distance' key" # Different signals should have lower scores than identical signals assert ( - different_scores["speech_bert"] <= identical_scores["speech_bert"] - ), f"Expected SpeechBERT score for different signals ({different_scores['speech_bert']}) to be <= identical signals ({identical_scores['speech_bert']})" - + different_result["speech_bert"] <= identical_result["speech_bert"] + ), f"Expected SpeechBERT score for different signals ({different_result['speech_bert']}) to be <= identical signals ({identical_result['speech_bert']})" assert ( - different_scores["speech_bleu"] <= identical_scores["speech_bleu"] - ), f"Expected SpeechBLEU score for different signals ({different_scores['speech_bleu']}) to be <= identical signals ({identical_scores['speech_bleu']})" - + different_result["speech_bleu"] <= identical_result["speech_bleu"] + ), f"Expected SpeechBLEU score for different signals ({different_result['speech_bleu']}) to be <= identical signals ({identical_result['speech_bleu']})" assert ( - different_scores["speech_token_distance"] - <= identical_scores["speech_token_distance"] - ), f"Expected SpeechTokenDistance score for different signals ({different_scores['speech_token_distance']}) to be <= identical signals ({identical_scores['speech_token_distance']})" + different_result["speech_token_distance"] + <= identical_result["speech_token_distance"] + ), f"Expected SpeechTokenDistance score for different signals ({different_result['speech_token_distance']}) to be <= identical signals ({identical_result['speech_token_distance']})" + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_metadata(): + """Test that the Discrete Speech metric has correct metadata.""" + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "discrete_speech" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "discrete_speech_metrics" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_different_sample_rates(): + """Test that the Discrete Speech metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, audio_16k, metadata=metadata_16k) + + # Both should return valid scores with expected keys + expected_keys = ["speech_bert", "speech_bleu", "speech_token_distance"] + + for key in expected_keys: + assert key in result_44k, f"44kHz result should contain '{key}' key" + assert key in result_16k, f"16kHz result should contain '{key}' key" + assert isinstance( + result_44k[key], (int, float) + ), f"Score {key} should be numeric" + assert isinstance( + result_16k[key], (int, float) + ), f"Score {key} should be numeric" + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_invalid_input(): + """Test that the Discrete Speech metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + + # Test with None input + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + with pytest.raises( + ValueError, match="Both predicted and ground truth signals must be provided" + ): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_discrete_speech_available(), reason="Discrete Speech Metrics not available" +) +def test_discrete_speech_metric_config_options(): + """Test that the Discrete Speech metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = DiscreteSpeechMetric(config_cpu) + + # Test with different sample rate + config_custom_sr = {"use_gpu": False, "sample_rate": 22050} + metric_custom_sr = DiscreteSpeechMetric(config_custom_sr) + + # All should work without errors + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_cpu = metric_cpu.compute(audio, audio, metadata=metadata) + result_custom_sr = metric_custom_sr.compute(audio, audio, metadata=metadata) + + # All should return the same structure + expected_keys = ["speech_bert", "speech_bleu", "speech_token_distance"] + + for key in expected_keys: + assert key in result_cpu + assert key in result_custom_sr + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() diff --git a/test/test_metrics/test_dpam.py b/test/test_metrics/test_dpam.py deleted file mode 100755 index e9557bc..0000000 --- a/test/test_metrics/test_dpam.py +++ /dev/null @@ -1,83 +0,0 @@ -import wave -from pathlib import Path - -import numpy as np -import pytest - -from versa.utterance_metrics.dpam_distance import dpam_metric, dpam_model_setup - -# Assume the fixed WAV file fixtures and helper function are defined as in the ASR matching test. -# For example: - - -def generate_fixed_wav( - filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None -): - """ - Generate a deterministic WAV file with a modulated sine wave. - """ - t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) - if envelope_func is None: - envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) - else: - envelope = envelope_func(t) - audio = envelope * np.sin(2 * np.pi * base_freq * t) - amplitude = np.iinfo(np.int16).max - data = (audio * amplitude).astype(np.int16) - with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(data.tobytes()) - - -def load_wav_as_array(wav_path, sample_rate=16000): - """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. - """ - with wave.open(str(wav_path), "rb") as wf: - frames = wf.getnframes() - audio_data = wf.readframes(frames) - audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) - return audio_array / np.iinfo(np.int16).max - - -@pytest.fixture(scope="session") -def fixed_audio_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - audio_file = tmp_dir / "fixed_audio.wav" - generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) - return audio_file - - -@pytest.fixture(scope="session") -def fixed_ground_truth_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. - generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) - return gt_file - - -@pytest.fixture(scope="session") -def fixed_audio(fixed_audio_wav): - return load_wav_as_array(fixed_audio_wav) - - -@pytest.fixture(scope="session") -def fixed_ground_truth(fixed_ground_truth_wav): - return load_wav_as_array(fixed_ground_truth_wav) - - -# ------------------------------- -# DPAM Metric Definition and Tests -# ------------------------------- -def test_dpam_metric_identical(fixed_audio): - """ - When comparing an audio signal with itself, the dpam distance should be 0.0. - """ - model = dpam_model_setup() - scores = dpam_metric(model, fixed_audio, fixed_audio, 16000) - assert ( - scores["dpam_distance"] == 0.0 - ), f"Expected dpam distance == 0.0 for identical signals, got {scores['dpam_distance']}" diff --git a/test/test_metrics/test_dpam_distance.py b/test/test_metrics/test_dpam_distance.py new file mode 100644 index 0000000..b07d936 --- /dev/null +++ b/test/test_metrics/test_dpam_distance.py @@ -0,0 +1,266 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.dpam_distance import DpamDistanceMetric + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a ground truth WAV file to be used as reference audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the ground truth audio file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_dpam_distance(use_gpu, fixed_audio, fixed_ground_truth): + """ + Test the DPAM distance metric using the fixed audio files. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + + # Check that the result contains the expected key + assert "dpam_distance" in result, "Result should contain 'dpam_distance' key" + + # Check that the result is a float + dpam_dist = result["dpam_distance"] + assert isinstance(dpam_dist, float), "dpam_distance should be a float" + + # Check that the distance score is reasonable (should be non-negative) + assert dpam_dist >= 0.0, f"DPAM distance should be non-negative, got {dpam_dist}" + + +def test_dpam_distance_metric_metadata(): + """Test that the DPAM distance metric has correct metadata.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "dpam_distance" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + assert "filelock" in metadata.dependencies + + +def test_dpam_distance_metric_different_sample_rates(): + """Test that the DPAM distance metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + + # Test with 44.1kHz audio (should be resampled to 22.05kHz) + audio_44k_1 = np.random.random(44100) + audio_44k_2 = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k_1, audio_44k_2, metadata=metadata_44k) + + # Test with 22.05kHz audio (no resampling needed) + audio_22k_1 = np.random.random(22050) + audio_22k_2 = np.random.random(22050) + metadata_22k = {"sample_rate": 22050} + result_22k = metric.compute(audio_22k_1, audio_22k_2, metadata=metadata_22k) + + # Both should return valid scores with expected keys + assert ( + "dpam_distance" in result_44k + ), "44kHz result should contain 'dpam_distance' key" + assert ( + "dpam_distance" in result_22k + ), "22kHz result should contain 'dpam_distance' key" + + # Both should return float values + assert isinstance(result_44k["dpam_distance"], float) + assert isinstance(result_22k["dpam_distance"], float) + + +def test_dpam_distance_metric_invalid_input(): + """Test that the DPAM distance metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(22050), metadata={"sample_rate": 22050}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(22050), None, metadata={"sample_rate": 22050}) + + +def test_dpam_distance_metric_config_options(): + """Test that the DPAM distance metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = DpamDistanceMetric(config_cpu) + + # Test with custom cache directory + config_custom_cache = {"use_gpu": False, "cache_dir": "custom_cache"} + metric_custom_cache = DpamDistanceMetric(config_custom_cache) + + # All should work without errors + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + metadata = {"sample_rate": 22050} + + result_cpu = metric_cpu.compute(audio1, audio2, metadata=metadata) + result_custom_cache = metric_custom_cache.compute(audio1, audio2, metadata=metadata) + + # All should return the same structure + assert "dpam_distance" in result_cpu + assert "dpam_distance" in result_custom_cache + assert isinstance(result_cpu["dpam_distance"], float) + assert isinstance(result_custom_cache["dpam_distance"], float) + + +def test_dpam_distance_metric_identical_signals(): + """Test that the DPAM distance metric gives zero distance for identical signals.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with identical signals + audio = np.random.random(22050) + result = metric.compute(audio, audio, metadata=metadata) + + # Results should be 0.0 for identical signals + assert result["dpam_distance"] == 0.0, "Identical signals should have zero distance" + + +def test_dpam_distance_metric_consistent_results(): + """Test that the DPAM distance metric gives consistent results for the same inputs.""" + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + + # Test with fixed signals + audio1 = np.random.random(22050) + audio2 = np.random.random(22050) + result1 = metric.compute(audio1, audio2, metadata=metadata) + result2 = metric.compute(audio1, audio2, metadata=metadata) + + # Results should be identical for the same inputs + np.testing.assert_almost_equal( + result1["dpam_distance"], + result2["dpam_distance"], + decimal=6, + err_msg="Results should be identical for the same inputs", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() diff --git a/test/test_metrics/test_emo_similarity.py b/test/test_metrics/test_emo_similarity.py new file mode 100644 index 0000000..c781e96 --- /dev/null +++ b/test/test_metrics/test_emo_similarity.py @@ -0,0 +1,278 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest + +from versa.utterance_metrics.emo_similarity import Emo2vecMetric, is_emo2vec_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_audio_wav_2(tmp_path_factory): + """ + Create a second fixed WAV file to be used as reference audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_2.wav" + # Generate an audio file with a 200 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=200) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_2(fixed_audio_wav_2): + """ + Load the second fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav_2) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_emotion(use_gpu, fixed_audio, fixed_audio_2): + """ + Test the Emotion metric using the fixed audio files. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_audio_2, metadata=metadata) + + # Check that the result contains the expected key + assert ( + "emotion_similarity" in result + ), "Result should contain 'emotion_similarity' key" + + # Check that the result is a float + emotion_sim = result["emotion_similarity"] + assert isinstance(emotion_sim, float), "emotion_similarity should be a float" + + # Check that the similarity score is reasonable (between -1 and 1 for cosine similarity) + assert ( + -1.0 <= emotion_sim <= 1.0 + ), f"Emotion similarity should be between -1 and 1, got {emotion_sim}" + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_metadata(): + """Test that the Emotion metric has correct metadata.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "emotion" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "emo2vec_versa" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_different_sample_rates(): + """Test that the Emotion metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k_1 = np.random.random(44100) + audio_44k_2 = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k_1, audio_44k_2, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k_1 = np.random.random(16000) + audio_16k_2 = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k_1, audio_16k_2, metadata=metadata_16k) + + # Both should return valid scores with expected keys + assert ( + "emotion_similarity" in result_44k + ), "44kHz result should contain 'emotion_similarity' key" + assert ( + "emotion_similarity" in result_16k + ), "16kHz result should contain 'emotion_similarity' key" + + # Both should return float values + assert isinstance(result_44k["emotion_similarity"], float) + assert isinstance(result_16k["emotion_similarity"], float) + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_invalid_input(): + """Test that the Emotion metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_config_options(): + """Test that the Emotion metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = Emo2vecMetric(config_cpu) + + # Test with different model tag + config_custom_model = {"use_gpu": False, "model_tag": "base"} + metric_custom_model = Emo2vecMetric(config_custom_model) + + # All should work without errors + audio1 = np.random.random(16000) + audio2 = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_cpu = metric_cpu.compute(audio1, audio2, metadata=metadata) + result_custom_model = metric_custom_model.compute(audio1, audio2, metadata=metadata) + + # All should return the same structure + assert "emotion_similarity" in result_cpu + assert "emotion_similarity" in result_custom_model + assert isinstance(result_cpu["emotion_similarity"], float) + assert isinstance(result_custom_model["emotion_similarity"], float) + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_identical_signals(): + """Test that the Emotion metric gives high similarity for identical signals.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + + # Test with identical signals + audio = np.random.random(16000) + result = metric.compute(audio, audio, metadata=metadata) + + # Results should be very close to 1.0 for identical signals + assert ( + result["emotion_similarity"] > 0.99 + ), "Identical signals should have very high similarity" + + +@pytest.mark.skipif(not is_emo2vec_available(), reason="Emo2vec not available") +def test_emotion_metric_consistent_results(): + """Test that the Emotion metric gives consistent results for the same inputs.""" + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + + # Test with fixed signals + audio1 = np.random.random(16000) + audio2 = np.random.random(16000) + result1 = metric.compute(audio1, audio2, metadata=metadata) + result2 = metric.compute(audio1, audio2, metadata=metadata) + + # Results should be identical for the same inputs + np.testing.assert_almost_equal( + result1["emotion_similarity"], + result2["emotion_similarity"], + decimal=6, + err_msg="Results should be identical for the same inputs", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_audio_wav_2): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_audio_wav_2).exists() diff --git a/test/test_metrics/test_emo_vad.py b/test/test_metrics/test_emo_vad.py index 6ffee75..80ac31b 100644 --- a/test/test_metrics/test_emo_vad.py +++ b/test/test_metrics/test_emo_vad.py @@ -4,68 +4,296 @@ import numpy as np import pytest -from versa.utterance_metrics.emo_vad import dim_emo_pred, w2v2_emo_dim_setup - -# Assume the fixed WAV file fixtures and helper function are defined as in the dimentionsal emotion prediction test. -# For example: +from versa.utterance_metrics.emo_vad import EmoVadMetric, is_transformers_available +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- def generate_fixed_wav( filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None ): """ Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. """ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. if envelope_func is None: envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) else: envelope = envelope_func(t) audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. amplitude = np.iinfo(np.int16).max data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. with wave.open(str(filename), "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. wf.setframerate(sample_rate) wf.writeframes(data.tobytes()) +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- def load_wav_as_array(wav_path, sample_rate=16000): """ - Load a WAV file and convert it to a NumPy array scaled to [-1, 1]. + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. """ with wave.open(str(wav_path), "rb") as wf: frames = wf.getnframes() audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) return audio_array / np.iinfo(np.int16).max -@pytest.fixture(scope="session") -def fixed_audio_wav(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("audio_data") - audio_file = tmp_dir / "fixed_audio.wav" - generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) - return audio_file - - @pytest.fixture(scope="session") def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ return load_wav_as_array(fixed_audio_wav) # ------------------------------- -# emo_vad Metric Definition and Tests +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +@pytest.mark.parametrize( + "use_gpu", + [ + False, + ], +) +def test_utterance_emo_vad(use_gpu, fixed_audio): + """ + Test the EmoVad metric using the fixed audio. + The test uses deterministic data so that the result is always reproducible. + """ + config = {"use_gpu": use_gpu} + + metric = EmoVadMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, metadata=metadata) + + # Check that the result contains the expected key + assert "arousal_emo_vad" in result, "Result should contain 'arousal_emo_vad' key" + assert "valence_emo_vad" in result, "Result should contain 'valence_emo_vad' key" + assert ( + "dominance_emo_vad" in result + ), "Result should contain 'dominance_emo_vad' key" + + # Check that the result is a numpy array with 3 values (arousal, valence, dominance) + arousal = result["arousal_emo_vad"] + valence = result["valence_emo_vad"] + dominance = result["dominance_emo_vad"] + assert isinstance(arousal, float), "arousal_emo_vad should be a float" + assert isinstance(valence, float), "valence_emo_vad should be a float" + assert isinstance(dominance, float), "dominance_emo_vad should be a float" + + # Check that all values are numeric and reasonable (emotion scores are typically between 0 and 1) + assert ( + 0.0 <= arousal <= 1.0 + ), f"Arousal score should be between 0 and 1, got {arousal}" + assert ( + 0.0 <= valence <= 1.0 + ), f"Valence score should be between 0 and 1, got {valence}" + assert ( + 0.0 <= dominance <= 1.0 + ), f"Dominance score should be between 0 and 1, got {dominance}" + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_metadata(): + """Test that the EmoVad metric has correct metadata.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "emo_vad" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert "transformers" in metadata.dependencies + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_different_sample_rates(): + """Test that the EmoVad metric handles different sample rates correctly.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + + # Test with 44.1kHz audio (should be resampled to 16kHz) + audio_44k = np.random.random(44100) + metadata_44k = {"sample_rate": 44100} + result_44k = metric.compute(audio_44k, metadata=metadata_44k) + + # Test with 16kHz audio (no resampling needed) + audio_16k = np.random.random(16000) + metadata_16k = {"sample_rate": 16000} + result_16k = metric.compute(audio_16k, metadata=metadata_16k) + + # Both should return valid scores with expected keys + assert ( + "arousal_emo_vad" in result_44k + ), "44kHz result should contain 'arousal_emo_vad' key" + assert ( + "valence_emo_vad" in result_44k + ), "44kHz result should contain 'valence_emo_vad' key" + assert ( + "dominance_emo_vad" in result_44k + ), "44kHz result should contain 'dominance_emo_vad' key" + assert ( + "arousal_emo_vad" in result_16k + ), "16kHz result should contain 'arousal_emo_vad' key" + assert ( + "valence_emo_vad" in result_16k + ), "16kHz result should contain 'valence_emo_vad' key" + assert ( + "dominance_emo_vad" in result_16k + ), "16kHz result should contain 'dominance_emo_vad' key" + + # Both should return numpy arrays with 3 values + assert ( + type(result_44k["arousal_emo_vad"]) == float + ), "arousal_emo_vad should be a float" + assert ( + type(result_44k["valence_emo_vad"]) == float + ), "valence_emo_vad should be a float" + assert ( + type(result_44k["dominance_emo_vad"]) == float + ), "dominance_emo_vad should be a float" + assert ( + type(result_16k["arousal_emo_vad"]) == float + ), "arousal_emo_vad should be a float" + assert ( + type(result_16k["valence_emo_vad"]) == float + ), "valence_emo_vad should be a float" + assert ( + type(result_16k["dominance_emo_vad"]) == float + ), "dominance_emo_vad should be a float" + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_invalid_input(): + """Test that the EmoVad metric handles invalid inputs correctly.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_config_options(): + """Test that the EmoVad metric handles different configuration options.""" + # Test with GPU disabled + config_cpu = {"use_gpu": False} + metric_cpu = EmoVadMetric(config_cpu) + + # Test with different model tag + config_custom_model = { + "use_gpu": False, + "model_tag": "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim", + } + metric_custom_model = EmoVadMetric(config_custom_model) + + # All should work without errors + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result_cpu = metric_cpu.compute(audio, metadata=metadata) + result_custom_model = metric_custom_model.compute(audio, metadata=metadata) + + # All should return the same structure + assert "arousal_emo_vad" in result_cpu + assert "valence_emo_vad" in result_cpu + assert "dominance_emo_vad" in result_cpu + assert "arousal_emo_vad" in result_custom_model + assert "valence_emo_vad" in result_custom_model + assert "dominance_emo_vad" in result_custom_model + assert ( + type(result_cpu["arousal_emo_vad"]) == float + ), "arousal_emo_vad should be a float" + assert ( + type(result_cpu["valence_emo_vad"]) == float + ), "valence_emo_vad should be a float" + assert ( + type(result_cpu["dominance_emo_vad"]) == float + ), "dominance_emo_vad should be a float" + + +@pytest.mark.skipif( + not is_transformers_available(), reason="Transformers not available" +) +def test_emo_vad_metric_identical_signals(): + """Test that the EmoVad metric gives consistent results for identical signals.""" + config = {"use_gpu": False} + metric = EmoVadMetric(config) + metadata = {"sample_rate": 16000} + + # Test with identical signals + audio = np.random.random(16000) + result1 = metric.compute(audio, metadata=metadata) + result2 = metric.compute(audio, metadata=metadata) + + # Results should be identical for the same input + np.testing.assert_array_almost_equal( + result1["arousal_emo_vad"], + result2["arousal_emo_vad"], + decimal=6, + err_msg="Results should be identical for the same input", + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) # ------------------------------- -def test_emo_vad_metric_identical(fixed_audio): +def test_fixed_wav_files_exist(fixed_audio_wav): """ - When comparing an audio signal with itself, the STOI score should be 1.0. + Verify that the fixed WAV files were created. """ - emo_utils = w2v2_emo_dim_setup() - scores = dim_emo_pred(emo_utils, fixed_audio, 16000) - assert scores["aro_val_dom_emo"] == pytest.approx( - np.array([0.3982302, 0.43092448, 0.41154572], dtype=np.float32), - rel=1e-3, - abs=1e-6, - ), f"Expected aro_val_dom_emo of [0.3982302, 0.43092448, 0.41154572] for identical signals, got {scores['aro_val_dom_emo']}" + assert Path(fixed_audio_wav).exists() diff --git a/test/test_metrics/test_nisqa.py b/test/test_metrics/test_nisqa.py new file mode 100644 index 0000000..54d742f --- /dev/null +++ b/test/test_metrics/test_nisqa.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Unit tests for NISQA metric.""" + +import wave +from pathlib import Path +from unittest.mock import Mock, patch + +import numpy as np +import pytest +import torch + +from versa.utterance_metrics.nisqa import ( + NisqaMetric, + nisqa_metric, + nisqa_model_setup, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +# ------------------------------- +# Mock NISQA Model Fixture +# ------------------------------- +@pytest.fixture +def mock_nisqa_model(): + """Create a mock NISQA model for testing.""" + model = Mock() + model.device = "cpu" + model.args = {"model": "NISQA"} + return model + + +# ------------------------------- +# Test NISQA Metric Class +# ------------------------------- +class TestNisqaMetric: + """Test the NisqaMetric class.""" + + def test_initialization_without_model_path(self): + """Test that initialization fails without model path.""" + config = {"use_gpu": False} + with pytest.raises(ValueError, match="NISQA model path must be provided"): + NisqaMetric(config) + + @patch("versa.utterance_metrics.nisqa.torch.load") + @patch("versa.utterance_metrics.nisqa.NL.NISQA") + def test_initialization_success(self, mock_nisqa_class, mock_torch_load): + """Test successful initialization of NisqaMetric.""" + # Mock the checkpoint + mock_checkpoint = { + "args": { + "model": "NISQA", + "ms_seg_length": 15, + "ms_n_mels": 48, + "cnn_model": "resnet", + "cnn_c_out_1": 32, + "cnn_c_out_2": 32, + "cnn_c_out_3": 32, + "cnn_kernel_size": 3, + "cnn_dropout": 0.1, + "cnn_pool_1": 2, + "cnn_pool_2": 2, + "cnn_pool_3": 2, + "cnn_fc_out_h": 128, + "td": "lstm", + "td_sa_d_model": 128, + "td_sa_nhead": 8, + "td_sa_pos_enc": "sin", + "td_sa_num_layers": 2, + "td_sa_h": 128, + "td_sa_dropout": 0.1, + "td_lstm_h": 128, + "td_lstm_num_layers": 2, + "td_lstm_dropout": 0.1, + "td_lstm_bidirectional": True, + "td_2": "lstm", + "td_2_sa_d_model": 128, + "td_2_sa_nhead": 8, + "td_2_sa_pos_enc": "sin", + "td_2_sa_num_layers": 2, + "td_2_sa_h": 128, + "td_2_sa_dropout": 0.1, + "td_2_lstm_h": 128, + "td_2_lstm_num_layers": 2, + "td_2_lstm_dropout": 0.1, + "td_2_lstm_bidirectional": True, + "pool": "att", + "pool_att_h": 128, + "pool_att_dropout": 0.1, + }, + "model_state_dict": {}, + } + mock_torch_load.return_value = mock_checkpoint + + # Mock the NISQA model + mock_model = Mock() + mock_model.load_state_dict.return_value = ([], []) # No missing/unexpected keys + mock_nisqa_class.return_value = mock_model + + config = { + "nisqa_model_path": "versa_cache/nisqa/nisqa.tar", + "use_gpu": False, + } + + metric = NisqaMetric(config) + assert metric.model is not None + assert metric.model.device == "cpu" + + def test_compute_with_none_predictions(self): + """Test that compute raises error with None predictions.""" + config = { + "nisqa_model_path": "versa_cache/nisqa/nisqa.tar", + "use_gpu": False, + } + metric = NisqaMetric(config) + + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None) + + @patch("versa.utterance_metrics.nisqa.NL.versa_eval_mos") + def test_compute_success(self, mock_eval_mos, mock_nisqa_model): + """Test successful computation of NISQA scores.""" + # Mock the evaluation function + mock_eval_mos.return_value = { + "mos_pred": [[0.5]], + "noi_pred": [[1.0]], + "dis_pred": [[2.0]], + "col_pred": [[1.5]], + "loud_pred": [[1.2]], + } + + config = { + "nisqa_model_path": "versa_cache/nisqa/nisqa.tar", + "use_gpu": False, + } + metric = NisqaMetric(config) + metric.model = mock_nisqa_model + + audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result = metric.compute(audio, metadata=metadata) + + assert "nisqa_mos_pred" in result + assert "nisqa_noi_pred" in result + assert "nisqa_dis_pred" in result + assert "nisqa_col_pred" in result + assert "nisqa_loud_pred" in result + assert result["nisqa_mos_pred"] == 0.5 + + def test_get_metadata(self): + """Test that get_metadata returns correct metadata.""" + config = { + "nisqa_model_path": "versa_cache/nisqa/nisqa.tar", + "use_gpu": False, + } + metric = NisqaMetric(config) + + metadata = metric.get_metadata() + assert metadata.name == "nisqa" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert not metadata.requires_reference + assert not metadata.requires_text + assert metadata.gpu_compatible + + +# ------------------------------- +# Integration Tests +# ------------------------------- +@pytest.mark.integration +class TestNisqaIntegration: + """Integration tests for NISQA metric.""" + + @pytest.mark.parametrize( + "sample_rate,use_gpu", + [ + (16000, False), + (22050, False), + (48000, False), + ], + ) + def test_nisqa_with_different_sample_rates(self, sample_rate, use_gpu, fixed_audio): + """Test NISQA with different sample rates.""" + # Skip if NISQA dependencies are not available + try: + import versa.utterance_metrics.nisqa_utils.nisqa_lib + except ImportError: + pytest.skip("NISQA dependencies not available") + + # This test would require a real NISQA model file + # For now, we'll just test the basic structure + config = { + "use_gpu": use_gpu, + } + + # Test that the metric can be instantiated (without actual model loading) + with pytest.raises(ValueError, match="NISQA model path must be provided"): + NisqaMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + + +# ------------------------------- +# Test Registration Function +# ------------------------------- +def test_register_nisqa_metric(): + """Test the registration function.""" + from versa.utterance_metrics.nisqa import register_nisqa_metric + + # Mock registry + mock_registry = Mock() + + # Register the metric + register_nisqa_metric(mock_registry) + + # Verify registration was called + mock_registry.register.assert_called_once() + + # Verify the call arguments + call_args = mock_registry.register.call_args + assert call_args[0][0] == NisqaMetric # First argument should be the class + assert call_args[0][1].name == "nisqa" # Second argument should be metadata diff --git a/test/test_metrics/test_nomad.py b/test/test_metrics/test_nomad.py new file mode 100644 index 0000000..3f011f2 --- /dev/null +++ b/test/test_metrics/test_nomad.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Unit tests for NOMAD metric.""" + +import wave +from pathlib import Path +from unittest.mock import Mock, patch + +import numpy as np +import pytest +import torch + +from versa.utterance_metrics.nomad import ( + NomadMetric, + is_nomad_available, +) + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 300 Hz sine wave. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +# ------------------------------- +# Mock NOMAD Model Fixture +# ------------------------------- +@pytest.fixture +def mock_nomad_model(): + """Create a mock NOMAD model for testing.""" + model = Mock() + model.predict.return_value = 0.5 # Mock prediction value + return model + + +# ------------------------------- +# Test NOMAD Metric Class +# ------------------------------- +class TestNomadMetric: + """Test the NomadMetric class.""" + + def test_initialization_without_nomad(self): + """Test that initialization fails without nomad dependency.""" + with patch("versa.utterance_metrics.nomad.NOMAD_AVAILABLE", False): + config = {"use_gpu": False, "model_cache": "test_cache"} + with pytest.raises(ImportError, match="nomad is not installed"): + NomadMetric(config) + + @patch("versa.utterance_metrics.nomad.Nomad") + def test_initialization_success(self, mock_nomad_class): + """Test successful initialization of NomadMetric.""" + # Mock the NOMAD class + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = { + "use_gpu": False, + "model_cache": "test_cache", + } + + metric = NomadMetric(config) + assert metric.model is not None + mock_nomad_class.assert_called_once_with(device="cpu", cache_dir="test_cache") + + def test_compute_with_none_predictions(self): + """Test that compute raises error with None predictions.""" + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000)) + + def test_compute_with_none_references(self): + """Test that compute raises error with None references.""" + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None) + + @patch("versa.utterance_metrics.nomad.resample_audio") + def test_compute_success(self, mock_resample, mock_nomad_model): + """Test successful computation of NOMAD score.""" + # Mock the resample function + mock_resample.side_effect = lambda x, orig_sr, target_sr: x + + config = {"use_gpu": False, "model_cache": "test_cache"} + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_nomad_class.return_value = mock_nomad_model + metric = NomadMetric(config) + + audio = np.random.random(16000) + gt_audio = np.random.random(16000) + metadata = {"sample_rate": 16000} + + result = metric.compute(audio, gt_audio, metadata=metadata) + + assert "nomad" in result + assert result["nomad"] == 0.5 + mock_nomad_model.predict.assert_called_once() + + @patch("versa.utterance_metrics.nomad.resample_audio") + def test_compute_with_resampling(self, mock_resample, mock_nomad_model): + """Test computation with resampling.""" + # Mock the resample function + mock_resample.side_effect = lambda x, orig_sr, target_sr: x + + config = {"use_gpu": False, "model_cache": "test_cache"} + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_nomad_class.return_value = mock_nomad_model + metric = NomadMetric(config) + + audio = np.random.random(8000) # Different sample rate + gt_audio = np.random.random(8000) + metadata = {"sample_rate": 8000} + + result = metric.compute(audio, gt_audio, metadata=metadata) + + assert "nomad" in result + # Verify resampling was called + assert mock_resample.call_count == 2 + + def test_get_metadata(self): + """Test that get_metadata returns correct metadata.""" + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + config = {"use_gpu": False, "model_cache": "test_cache"} + metric = NomadMetric(config) + + metadata = metric.get_metadata() + assert metadata.name == "nomad" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference + assert not metadata.requires_text + assert metadata.gpu_compatible + + +# ------------------------------- +# Test Utility Functions +# ------------------------------- +class TestUtilityFunctions: + """Test utility functions.""" + + @patch("versa.utterance_metrics.nomad.NOMAD_AVAILABLE", True) + def test_is_nomad_available_true(self): + """Test is_nomad_available when NOMAD is available.""" + assert is_nomad_available() is True + + @patch("versa.utterance_metrics.nomad.NOMAD_AVAILABLE", False) + def test_is_nomad_available_false(self): + """Test is_nomad_available when NOMAD is not available.""" + assert is_nomad_available() is False + + +# ------------------------------- +# Integration Tests +# ------------------------------- +@pytest.mark.integration +class TestNomadIntegration: + """Integration tests for NOMAD metric.""" + + @pytest.mark.parametrize( + "sample_rate,use_gpu", + [ + (16000, False), + (22050, False), + (48000, False), + ], + ) + def test_nomad_with_different_sample_rates( + self, sample_rate, use_gpu, fixed_audio, fixed_ground_truth + ): + """Test NOMAD with different sample rates.""" + # Skip if NOMAD dependencies are not available + if not is_nomad_available(): + pytest.skip("NOMAD dependencies not available") + + # This test would require a real NOMAD model file + # For now, we'll just test the basic structure + config = { + "use_gpu": use_gpu, + "model_cache": "test_cache", + } + + # Test that the metric can be instantiated (without actual model loading) + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_nomad_class.return_value = mock_model + + metric = NomadMetric(config) + assert metric.model is not None + + +# ------------------------------- +# Example Test Function Using the Reused WAV Files +# ------------------------------- +@pytest.mark.parametrize( + "use_gpu,cache_dir", + [ + (False, "test_cache"), + (True, "test_cache"), + ], +) +def test_utterance_nomad(use_gpu, cache_dir, fixed_audio, fixed_ground_truth): + """ + Test the NOMAD metric using the fixed audio and ground truth. + The test uses deterministic data so that the result is always reproducible. + """ + with patch("versa.utterance_metrics.nomad.Nomad") as mock_nomad_class: + mock_model = Mock() + mock_model.predict.return_value = 0.5 + mock_nomad_class.return_value = mock_model + + # Use the new class-based API + config = {"use_gpu": use_gpu, "model_cache": cache_dir} + metric = NomadMetric(config) + metadata = {"sample_rate": 16000} + result = metric.compute(fixed_audio, fixed_ground_truth, metadata=metadata) + nomad_score = result["nomad"] + + # We expect the score to be 0.5 based on our mock + assert nomad_score == pytest.approx( + 0.5, rel=1e-3, abs=1e-6 + ), "value from nomad_score {} is mismatch from the defined one {}".format( + nomad_score, 0.5 + ) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_ground_truth_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + + +# ------------------------------- +# Test Registration Function +# ------------------------------- +def test_register_nomad_metric(): + """Test the registration function.""" + from versa.utterance_metrics.nomad import register_nomad_metric + + # Mock registry + mock_registry = Mock() + + # Register the metric + register_nomad_metric(mock_registry) + + # Verify registration was called + mock_registry.register.assert_called_once() + + # Verify the call arguments + call_args = mock_registry.register.call_args + assert call_args[0][0] == NomadMetric # First argument should be the class + assert call_args[0][1].name == "nomad" # Second argument should be metadata diff --git a/test/test_metrics/test_noresqa.py b/test/test_metrics/test_noresqa.py new file mode 100644 index 0000000..8c6e78f --- /dev/null +++ b/test/test_metrics/test_noresqa.py @@ -0,0 +1,333 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.noresqa import NoresqaMetric, is_noresqa_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 300 Hz sine wave. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +@pytest.fixture(scope="session") +def fixed_audio_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_8k.wav" + # Generate an audio file with a 150 Hz sine wave at 8kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=8000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth_8k.wav" + # Generate a ground truth file with a 300 Hz sine wave at 8kHz. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=8000, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_8k(fixed_audio_8k_wav): + """ + Load the fixed 8kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_8k_wav, sample_rate=8000) + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k(fixed_ground_truth_8k_wav): + """ + Load the fixed 8kHz ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_8k_wav, sample_rate=8000) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +@pytest.mark.parametrize( + "metric_type,model_tag,use_gpu", + [ + (1, "default", False), # NORESQA-MOS + (0, "default", False), # NORESQA-score + ], +) +def test_noresqa_metric_basic( + metric_type, model_tag, use_gpu, fixed_audio, fixed_ground_truth +): + """ + Test the NORESQA metric with basic configuration. + """ + config = { + "metric_type": metric_type, + "model_tag": model_tag, + "use_gpu": use_gpu, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + result = metric.compute( + fixed_audio, fixed_ground_truth, metadata={"sample_rate": 16000} + ) + + # Check that result contains noresqa_score field + if metric_type == 0: + assert "noresqa_score" in result + assert isinstance(result["noresqa_score"], (int, float, np.number)) + assert not np.isnan(result["noresqa_score"]) + assert not np.isinf(result["noresqa_score"]) + elif metric_type == 1: + assert "noresqa_mos" in result + assert isinstance(result["noresqa_mos"], (int, float, np.number)) + assert not np.isnan(result["noresqa_mos"]) + assert not np.isinf(result["noresqa_mos"]) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +def test_noresqa_metric_resampling(fixed_audio_8k, fixed_ground_truth_8k): + """ + Test the NORESQA metric with audio that needs resampling. + """ + config = { + "metric_type": 1, # NORESQA-MOS + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + result = metric.compute( + fixed_audio_8k, fixed_ground_truth_8k, metadata={"sample_rate": 8000} + ) + + # Check that result contains noresqa_score field + assert "noresqa_mos" in result + assert isinstance(result["noresqa_mos"], (int, float, np.number)) + assert not np.isnan(result["noresqa_mos"]) + assert not np.isinf(result["noresqa_mos"]) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +def test_noresqa_metric_invalid_input(): + """ + Test the NORESQA metric with invalid input. + """ + config = { + "metric_type": 1, + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +@pytest.mark.parametrize("metric_type", [0, 1]) +def test_noresqa_metric_metadata(metric_type): + """ + Test the NORESQA metric metadata. + """ + config = { + "metric_type": metric_type, + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + metric = NoresqaMetric(config) + metadata = metric.get_metadata() + + expected_name = "noresqa_mos" if metric_type == 1 else "noresqa_score" + assert metadata.name == expected_name + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert metadata.auto_install is False + assert "fairseq" in metadata.dependencies + assert "torch" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_noresqa_metric_not_available(): + """ + Test the NORESQA metric when noresqa is not available. + """ + # This test should be skipped if noresqa is available + if is_noresqa_available(): + pytest.skip("noresqa is available, skipping this test") + + config = { + "metric_type": 1, + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + with pytest.raises(ImportError, match="NORESQA dependencies are not available"): + NoresqaMetric(config) + + +@pytest.mark.skipif( + not is_noresqa_available(), + reason="noresqa is not available", +) +def test_noresqa_metric_invalid_metric_type(): + """ + Test the NORESQA metric with invalid metric_type. + """ + config = { + "metric_type": 2, # Invalid metric type + "model_tag": "default", + "use_gpu": False, + "cache_dir": "test_cache/noresqa_model", + } + + with pytest.raises(RuntimeError, match="Invalid metric_type"): + NoresqaMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist( + fixed_audio_wav, + fixed_ground_truth_wav, + fixed_audio_8k_wav, + fixed_ground_truth_8k_wav, +): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + assert Path(fixed_audio_8k_wav).exists() + assert Path(fixed_ground_truth_8k_wav).exists() diff --git a/test/test_metrics/test_owsm_lid.py b/test/test_metrics/test_owsm_lid.py new file mode 100644 index 0000000..44c0c2c --- /dev/null +++ b/test/test_metrics/test_owsm_lid.py @@ -0,0 +1,240 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.owsm_lid import OwsmLidMetric, is_espnet2_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_audio_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_8k.wav" + # Generate an audio file with a 150 Hz sine wave at 8kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=8000, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_8k(fixed_audio_8k_wav): + """ + Load the fixed 8kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_8k_wav, sample_rate=8000) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +@pytest.mark.parametrize( + "model_tag,nbest,use_gpu", + [ + ("default", 3, False), + ("default", 5, False), + ("espnet/owsm_v3.1_ebf", 3, False), + ], +) +def test_owsm_lid_metric_basic(model_tag, nbest, use_gpu, fixed_audio): + """ + Test the OWSM LID metric with basic configuration. + """ + config = { + "model_tag": model_tag, + "nbest": nbest, + "use_gpu": use_gpu, + } + + metric = OwsmLidMetric(config) + result = metric.compute(fixed_audio, metadata={"sample_rate": 16000}) + + # Check that result contains language field + assert "language" in result + assert isinstance(result["language"][0][0], str) + assert len(result["language"]) > 0 + + +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +def test_owsm_lid_metric_resampling(fixed_audio_8k): + """ + Test the OWSM LID metric with audio that needs resampling. + """ + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + metric = OwsmLidMetric(config) + result = metric.compute(fixed_audio_8k, metadata={"sample_rate": 8000}) + + # Check that result contains language field + assert "language" in result + assert isinstance(result["language"][0][0], str) + assert len(result["language"]) > 0 + + +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +def test_owsm_lid_metric_invalid_input(): + """ + Test the OWSM LID metric with invalid input. + """ + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + metric = OwsmLidMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_espnet2_available(), + reason="espnet2 is not available", +) +def test_owsm_lid_metric_metadata(): + """ + Test the OWSM LID metric metadata. + """ + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + metric = OwsmLidMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "lid" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "list" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert metadata.auto_install is False + assert "espnet2" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_owsm_lid_metric_espnet2_not_available(): + """ + Test the OWSM LID metric when espnet2 is not available. + """ + # This test should be skipped if espnet2 is available + if is_espnet2_available(): + pytest.skip("espnet2 is available, skipping this test") + + config = { + "model_tag": "default", + "nbest": 3, + "use_gpu": False, + } + + with pytest.raises(ImportError, match="espnet2 is not properly installed"): + OwsmLidMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_audio_8k_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_audio_8k_wav).exists() diff --git a/test/test_metrics/test_pam.py b/test/test_metrics/test_pam.py new file mode 100644 index 0000000..b4fcd5d --- /dev/null +++ b/test/test_metrics/test_pam.py @@ -0,0 +1,337 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.pam import PamMetric, PAM, is_pam_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_audio_44k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 44.1kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_44k.wav" + # Generate an audio file with a 150 Hz sine wave at 44.1kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=44100, base_freq=150) + return audio_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_44k(fixed_audio_44k_wav): + """ + Load the fixed 44.1kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_44k_wav, sample_rate=44100) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +@pytest.mark.parametrize( + "repro,use_gpu", + [ + (True, False), + (False, False), + ], +) +def test_pam_metric_basic(repro, use_gpu, fixed_audio): + """ + Test the PAM metric with basic configuration. + """ + config = { + "repro": repro, + "use_gpu": use_gpu, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + result = metric.compute(fixed_audio, metadata={"sample_rate": 16000}) + + # Check that result contains pam_score field + assert "pam_score" in result + assert isinstance(result["pam_score"], (int, float, np.number)) + assert not np.isnan(result["pam_score"]) + assert not np.isinf(result["pam_score"]) + # PAM score should be between 0 and 1 + assert 0.0 <= result["pam_score"] <= 1.0 + + +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +def test_pam_metric_resampling(fixed_audio_44k): + """ + Test the PAM metric with audio that needs resampling. + """ + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + result = metric.compute(fixed_audio_44k, metadata={"sample_rate": 44100}) + + # Check that result contains pam_score field + assert "pam_score" in result + assert isinstance(result["pam_score"], (int, float, np.number)) + assert not np.isnan(result["pam_score"]) + assert not np.isinf(result["pam_score"]) + # PAM score should be between 0 and 1 + assert 0.0 <= result["pam_score"] <= 1.0 + + +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +def test_pam_metric_invalid_input(): + """ + Test the PAM metric with invalid input. + """ + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + + # Test with None input + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_pam_available(), + reason="PAM dependencies are not available", +) +def test_pam_metric_metadata(): + """ + Test the PAM metric metadata. + """ + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + metric = PamMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "pam" + assert metadata.category.value == "independent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is False + assert metadata.requires_text is False + assert metadata.gpu_compatible is True + assert metadata.auto_install is False + assert "torch" in metadata.dependencies + assert "torchaudio" in metadata.dependencies + assert "transformers" in metadata.dependencies + assert "huggingface_hub" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_pam_metric_not_available(): + """ + Test the PAM metric when PAM dependencies are not available. + """ + # This test should be skipped if PAM is available + if is_pam_available(): + pytest.skip("PAM dependencies are available, skipping this test") + + config = { + "repro": True, + "use_gpu": False, + "cache_dir": "test_cache/pam", + "text_model": "gpt2", + "text_len": 77, + "transformer_embed_dim": 768, + "audioenc_name": "HTSAT", + "out_emb": 768, + "sampling_rate": 44100, + "duration": 7, + "fmin": 50, + "fmax": 8000, + "n_fft": 1024, + "hop_size": 320, + "mel_bins": 64, + "window_size": 1024, + "d_proj": 1024, + "temperature": 0.003, + "num_classes": 527, + "batch_size": 1024, + "demo": False, + } + + with pytest.raises(RuntimeError, match="Failed to initialize PAM model"): + PamMetric(config) + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist(fixed_audio_wav, fixed_audio_44k_wav): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_audio_44k_wav).exists() diff --git a/test/test_metrics/test_pesq_score.py b/test/test_metrics/test_pesq_score.py new file mode 100644 index 0000000..7c28186 --- /dev/null +++ b/test/test_metrics/test_pesq_score.py @@ -0,0 +1,367 @@ +import wave +from pathlib import Path + +import numpy as np +import pytest +import torch +from packaging.version import parse as V + +from versa.utterance_metrics.pesq_score import PesqMetric, is_pesq_available + + +# ------------------------------- +# Helper: Generate a fixed WAV file +# ------------------------------- +def generate_fixed_wav( + filename, duration=1.0, sample_rate=16000, base_freq=150, envelope_func=None +): + """ + Generate a deterministic WAV file with a modulated sine wave. + + Parameters: + - filename: Path (str or Path) to write the WAV file. + - duration: Duration of the audio in seconds. + - sample_rate: Number of samples per second. + - base_freq: Frequency (in Hz) of the sine wave. + - envelope_func: Optional function to generate a custom amplitude envelope. + If None, a default sine-based envelope is used. + """ + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + # Use default envelope if none is provided. + if envelope_func is None: + envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 0.5 * t) + else: + envelope = envelope_func(t) + audio = envelope * np.sin(2 * np.pi * base_freq * t) + + # Scale to 16-bit PCM. + amplitude = np.iinfo(np.int16).max + data = (audio * amplitude).astype(np.int16) + + # Write the WAV file. + with wave.open(str(filename), "w") as wf: + wf.setnchannels(1) # Mono audio. + wf.setsampwidth(2) # 16 bits per sample. + wf.setframerate(sample_rate) + wf.writeframes(data.tobytes()) + + +# ------------------------------- +# Session-Scoped Fixtures to Create WAV Files +# ------------------------------- +@pytest.fixture(scope="session") +def fixed_audio_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as test audio. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio.wav" + # Generate an audio file with a 150 Hz sine wave. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=16000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_wav(tmp_path_factory): + """ + Create a fixed WAV file to be used as ground truth. + This one uses a different base frequency (e.g., 300 Hz) so that the test + intentionally simulates a mismatch. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth.wav" + # Generate a ground truth file with a 300 Hz sine wave. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) + return gt_file + + +@pytest.fixture(scope="session") +def fixed_audio_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_8k.wav" + # Generate an audio file with a 150 Hz sine wave at 8kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=8000, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 8kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth_8k.wav" + # Generate a ground truth file with a 300 Hz sine wave at 8kHz. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=8000, base_freq=300) + return gt_file + + +@pytest.fixture(scope="session") +def fixed_audio_22k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 22.05kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + audio_file = tmp_dir / "fixed_audio_22k.wav" + # Generate an audio file with a 150 Hz sine wave at 22.05kHz. + generate_fixed_wav(audio_file, duration=1.0, sample_rate=22050, base_freq=150) + return audio_file + + +@pytest.fixture(scope="session") +def fixed_ground_truth_22k_wav(tmp_path_factory): + """ + Create a fixed WAV file with 22.05kHz sample rate to test resampling. + """ + tmp_dir = tmp_path_factory.mktemp("audio_data") + gt_file = tmp_dir / "fixed_ground_truth_22k.wav" + # Generate a ground truth file with a 300 Hz sine wave at 22.05kHz. + generate_fixed_wav(gt_file, duration=1.0, sample_rate=22050, base_freq=300) + return gt_file + + +# ------------------------------- +# Fixtures to Load WAV Files into NumPy Arrays +# ------------------------------- +def load_wav_as_array(wav_path, sample_rate=16000): + """ + Load a WAV file and convert it into a NumPy array of floats scaled to [-1, 1]. + """ + with wave.open(str(wav_path), "rb") as wf: + frames = wf.getnframes() + audio_data = wf.readframes(frames) + # Convert from 16-bit PCM. + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + return audio_array / np.iinfo(np.int16).max + + +@pytest.fixture(scope="session") +def fixed_audio(fixed_audio_wav): + """ + Load the fixed audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_wav) + + +@pytest.fixture(scope="session") +def fixed_ground_truth(fixed_ground_truth_wav): + """ + Load the fixed ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_wav) + + +@pytest.fixture(scope="session") +def fixed_audio_8k(fixed_audio_8k_wav): + """ + Load the fixed 8kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_8k_wav, sample_rate=8000) + + +@pytest.fixture(scope="session") +def fixed_ground_truth_8k(fixed_ground_truth_8k_wav): + """ + Load the fixed 8kHz ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_8k_wav, sample_rate=8000) + + +@pytest.fixture(scope="session") +def fixed_audio_22k(fixed_audio_22k_wav): + """ + Load the fixed 22.05kHz audio file as a NumPy array. + """ + return load_wav_as_array(fixed_audio_22k_wav, sample_rate=22050) + + +@pytest.fixture(scope="session") +def fixed_ground_truth_22k(fixed_ground_truth_22k_wav): + """ + Load the fixed 22.05kHz ground truth file as a NumPy array. + """ + return load_wav_as_array(fixed_ground_truth_22k_wav, sample_rate=22050) + + +# ------------------------------- +# Test Functions +# ------------------------------- +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +@pytest.mark.parametrize( + "sample_rate", + [8000, 16000], +) +def test_pesq_metric_basic(sample_rate, fixed_audio, fixed_ground_truth): + """ + Test the PESQ metric with basic configuration. + """ + config = {} + + metric = PesqMetric(config) + result = metric.compute( + fixed_audio, fixed_ground_truth, metadata={"sample_rate": sample_rate} + ) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 4.5 + assert -0.5 <= result["pesq"] <= 4.5 + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_8k_resampling(fixed_audio_8k, fixed_ground_truth_8k): + """ + Test the PESQ metric with 8kHz audio that needs resampling. + """ + config = {} + + metric = PesqMetric(config) + result = metric.compute( + fixed_audio_8k, fixed_ground_truth_8k, metadata={"sample_rate": 8000} + ) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 4.5 + assert -0.5 <= result["pesq"] <= 4.5 + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_22k_resampling(fixed_audio_22k, fixed_ground_truth_22k): + """ + Test the PESQ metric with 22.05kHz audio that needs resampling. + """ + config = {} + + metric = PesqMetric(config) + result = metric.compute( + fixed_audio_22k, fixed_ground_truth_22k, metadata={"sample_rate": 22050} + ) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 4.5 + assert -0.5 <= result["pesq"] <= 4.5 + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_invalid_input(): + """ + Test the PESQ metric with invalid input. + """ + config = {} + + metric = PesqMetric(config) + + # Test with None predictions + with pytest.raises(ValueError, match="Predicted signal must be provided"): + metric.compute(None, np.random.random(16000), metadata={"sample_rate": 16000}) + + # Test with None references + with pytest.raises(ValueError, match="Reference signal must be provided"): + metric.compute(np.random.random(16000), None, metadata={"sample_rate": 16000}) + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_metadata(): + """ + Test the PESQ metric metadata. + """ + config = {} + + metric = PesqMetric(config) + metadata = metric.get_metadata() + + assert metadata.name == "pesq" + assert metadata.category.value == "dependent" + assert metadata.metric_type.value == "float" + assert metadata.requires_reference is True + assert metadata.requires_text is False + assert metadata.gpu_compatible is False + assert metadata.auto_install is False + assert "pesq" in metadata.dependencies + assert "librosa" in metadata.dependencies + assert "numpy" in metadata.dependencies + + +def test_pesq_metric_not_available(): + """ + Test the PESQ metric when pesq is not available. + """ + # This test should be skipped if pesq is available + if is_pesq_available(): + pytest.skip("pesq is available, skipping this test") + + config = {} + + with pytest.raises(ImportError, match="pesq is not properly installed"): + PesqMetric(config) + + +@pytest.mark.skipif( + not is_pesq_available(), + reason="pesq is not available", +) +def test_pesq_metric_same_audio(): + """ + Test the PESQ metric with identical audio (should give high score). + """ + config = {} + + metric = PesqMetric(config) + # Use the same audio for both prediction and reference + audio = np.random.random(16000) + result = metric.compute(audio, audio, metadata={"sample_rate": 16000}) + + # Check that result contains pesq field + assert "pesq" in result + assert isinstance(result["pesq"], (int, float, np.number)) + assert not np.isnan(result["pesq"]) + # PESQ score should be between -0.5 and 5 + assert -0.5 <= result["pesq"] <= 5 + + +# ------------------------------- +# Additional Example Test to Verify the File Creation (Optional) +# ------------------------------- +def test_fixed_wav_files_exist( + fixed_audio_wav, + fixed_ground_truth_wav, + fixed_audio_8k_wav, + fixed_ground_truth_8k_wav, + fixed_audio_22k_wav, + fixed_ground_truth_22k_wav, +): + """ + Verify that the fixed WAV files were created. + """ + assert Path(fixed_audio_wav).exists() + assert Path(fixed_ground_truth_wav).exists() + assert Path(fixed_audio_8k_wav).exists() + assert Path(fixed_ground_truth_8k_wav).exists() + assert Path(fixed_audio_22k_wav).exists() + assert Path(fixed_ground_truth_22k_wav).exists() diff --git a/test/test_metrics/test_sigmos.py b/test/test_metrics/test_sigmos.py new file mode 100644 index 0000000..29ce1a1 --- /dev/null +++ b/test/test_metrics/test_sigmos.py @@ -0,0 +1,47 @@ +import numpy as np + +from versa.definition import MetricRegistry +from versa.utterance_metrics import sigmos +from versa.utterance_metrics.sigmos import SigmosMetric, register_sigmos_metric + + +class DummySigmosModel: + def run(self, audio, sr=None): + return { + "SIGMOS_COL": 1.0, + "SIGMOS_DISC": 2.0, + "SIGMOS_LOUD": 3.0, + "SIGMOS_REVERB": 4.0, + "SIGMOS_SIG": 5.0, + "SIGMOS_OVRL": 6.0, + } + + +def test_sigmos_metric_class_returns_expected_keys(monkeypatch): + monkeypatch.setattr( + sigmos, "sigmos_setup", lambda model_dir=None: DummySigmosModel() + ) + + metric = SigmosMetric({"model_dir": "unused"}) + result = metric.compute( + np.zeros(48000, dtype=np.float32), metadata={"sample_rate": 48000} + ) + + assert result == { + "SIGMOS_COL": 1.0, + "SIGMOS_DISC": 2.0, + "SIGMOS_LOUD": 3.0, + "SIGMOS_REVERB": 4.0, + "SIGMOS_SIG": 5.0, + "SIGMOS_OVRL": 6.0, + } + + +def test_register_sigmos_metric(): + registry = MetricRegistry() + + register_sigmos_metric(registry) + + assert registry.get_metric("sigmos") is SigmosMetric + assert registry.get_metric("sig_mos") is SigmosMetric + assert registry.get_metadata("sigmos").requires_reference is False diff --git a/test/test_metrics/test_stoi.py b/test/test_metrics/test_stoi.py index f8d80a4..ecaa97a 100755 --- a/test/test_metrics/test_stoi.py +++ b/test/test_metrics/test_stoi.py @@ -1,13 +1,15 @@ import wave -from pathlib import Path import numpy as np import pytest -from versa.utterance_metrics.stoi import stoi_metric - -# Assume the fixed WAV file fixtures and helper function are defined as in the ASR matching test. -# For example: +from versa.definition import MetricRegistry +from versa.utterance_metrics.stoi import ( + EstoiMetric, + StoiMetric, + register_stoi_metric, + stoi_metric, +) def generate_fixed_wav( @@ -54,7 +56,7 @@ def fixed_audio_wav(tmp_path_factory): def fixed_ground_truth_wav(tmp_path_factory): tmp_dir = tmp_path_factory.mktemp("audio_data") gt_file = tmp_dir / "fixed_ground_truth.wav" - # Use a different base frequency for ground truth (e.g. 300 Hz) to simulate a mismatch. + # Use a different base frequency to simulate a mismatch. generate_fixed_wav(gt_file, duration=1.0, sample_rate=16000, base_freq=300) return gt_file @@ -90,3 +92,36 @@ def test_stoi_metric_different(fixed_audio, fixed_ground_truth): assert ( scores["stoi"] < 1.0 ), f"Expected STOI below 1.0 for different signals, got {scores['stoi']}" + + +def test_stoi_metric_class_matches_legacy_function(fixed_audio, fixed_ground_truth): + legacy_scores = stoi_metric(fixed_audio, fixed_ground_truth, 16000) + metric = StoiMetric() + + scores = metric.compute( + fixed_audio, fixed_ground_truth, metadata={"sample_rate": 16000} + ) + + assert scores["stoi"] == pytest.approx(legacy_scores["stoi"]) + + +def test_estoi_metric_class_uses_extended_mode(fixed_audio, fixed_ground_truth): + metric = EstoiMetric() + + scores = metric.compute( + fixed_audio, fixed_ground_truth, metadata={"sample_rate": 16000} + ) + + assert "estoi" in scores + assert isinstance(scores["estoi"], float) + + +def test_register_stoi_metric(): + registry = MetricRegistry() + + register_stoi_metric(registry) + + assert registry.get_metric("stoi") is StoiMetric + assert registry.get_metric("estoi") is EstoiMetric + assert registry.get_metric("stoi_metric") is StoiMetric + assert registry.get_metadata("estoi").requires_reference is True diff --git a/test/test_metrics/test_wvmos.py b/test/test_metrics/test_wvmos.py new file mode 100644 index 0000000..67cd7a4 --- /dev/null +++ b/test/test_metrics/test_wvmos.py @@ -0,0 +1,31 @@ +import numpy as np + +from versa.definition import MetricRegistry +from versa.utterance_metrics import wvmos +from versa.utterance_metrics.wvmos import WvmosMetric, register_wvmos_metric + + +def test_wvmos_metric_class_returns_existing_key(monkeypatch): + monkeypatch.setattr(wvmos, "wvmos_setup", lambda use_gpu=False: object()) + monkeypatch.setattr( + wvmos, + "wvmos_calculate", + lambda model, pred_x, gen_sr: {"wvmos": 0.75}, + ) + + metric = WvmosMetric({"use_gpu": False}) + result = metric.compute( + np.zeros(16000, dtype=np.float32), metadata={"sample_rate": 16000} + ) + + assert result == {"wvmos": 0.75} + + +def test_register_wvmos_metric(): + registry = MetricRegistry() + + register_wvmos_metric(registry) + + assert registry.get_metric("wvmos") is WvmosMetric + assert registry.get_metric("wv_mos") is WvmosMetric + assert registry.get_metadata("wvmos").requires_reference is False diff --git a/test/test_pipeline/test_asr_match.py b/test/test_pipeline/test_asr_match.py index 2388eae..19ac420 100755 --- a/test/test_pipeline/test_asr_match.py +++ b/test/test_pipeline/test_asr_match.py @@ -4,25 +4,21 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.asr_matching import register_asr_match_metric -TEST_INFO = { - "asr_match_error_rate": 0.0, -} +TEST_INFO = {"asr_match_error_rate": 0.0} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +27,15 @@ def info_update(): with open("egs/separate_metrics/asr_match.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register ASR-Match metric + registry = MetricRegistry() + register_asr_match_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,17 +43,17 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_asvspoof.py b/test/test_pipeline/test_asvspoof.py index 2f76b39..e5ed31f 100755 --- a/test/test_pipeline/test_asvspoof.py +++ b/test/test_pipeline/test_asvspoof.py @@ -4,12 +4,10 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.asvspoof_score import register_asvspoof_metric TEST_INFO = { "asvspoof_score": 8.472739e-08, @@ -23,6 +21,7 @@ def info_update(): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +30,15 @@ def info_update(): with open("egs/separate_metrics/asvspoof.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register ASVspoof metric + registry = MetricRegistry() + register_asvspoof_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,12 +46,13 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): diff --git a/test/test_pipeline/test_audiobox_aesthetics.py b/test/test_pipeline/test_audiobox_aesthetics.py index a1c8f45..6a06d63 100755 --- a/test/test_pipeline/test_audiobox_aesthetics.py +++ b/test/test_pipeline/test_audiobox_aesthetics.py @@ -4,11 +4,11 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.audiobox_aesthetics_score import ( + register_audiobox_aesthetics_metric, ) TEST_INFO = { @@ -32,7 +32,15 @@ def info_update(): ) as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register AudioBox Aesthetics metric + registry = MetricRegistry() + register_audiobox_aesthetics_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=False, use_gpu=False, @@ -40,11 +48,13 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files=None, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=None, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: # the plc mos is undeterministic diff --git a/test/test_pipeline/test_base_metrics_pipeline.py b/test/test_pipeline/test_base_metrics_pipeline.py new file mode 100644 index 0000000..8c2ca9c --- /dev/null +++ b/test/test_pipeline/test_base_metrics_pipeline.py @@ -0,0 +1,675 @@ +import pytest +import torch + +from versa.corpus_metrics.espnet_wer import register_espnet_wer_metric +from versa.corpus_metrics.owsm_wer import register_owsm_wer_metric +from versa.corpus_metrics.whisper_wer import register_whisper_wer_metric +from versa.definition import MetricRegistry +from versa.scorer_shared import VersaScorer, find_files +from versa.sequence_metrics.mcd_f0 import register_mcd_f0_metric +from versa.sequence_metrics.signal_metric import register_signal_metric +from versa.sequence_metrics.warpq import register_warpq_metric +from versa.utterance_metrics.log_wmse import register_log_wmse_metric +from versa.utterance_metrics.pseudo_mos import register_pseudo_mos_metric +from versa.utterance_metrics.qwen2_audio import register_qwen2_audio_metric +from versa.utterance_metrics.qwen_omni import register_qwen_omni_metric +from versa.utterance_metrics.scoreq import register_scoreq_metric +from versa.utterance_metrics.se_snr import register_se_snr_metric +from versa.utterance_metrics.sheet_ssqa import register_sheet_ssqa_metric +from versa.utterance_metrics.singer import register_singer_metric +from versa.utterance_metrics.speaker import register_speaker_metric +from versa.utterance_metrics.squim import register_squim_metric +from versa.utterance_metrics.stoi import register_stoi_metric +from versa.utterance_metrics.vad import register_vad_metric +from versa.utterance_metrics.universa import register_universa_metric +from versa.utterance_metrics.visqol_score import register_visqol_metric +from versa.utterance_metrics.vqscore import register_vqscore_metric + + +def _sample_files(): + gen_files = find_files("test/test_samples/test2") + gt_files = find_files("test/test_samples/test1") + return gen_files, gt_files + + +def test_stoi_and_signal_pipeline_with_registry(): + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_stoi_metric(registry) + register_signal_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "stoi"}, {"name": "signal_metric"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert "stoi" in score_info[0] + assert "sdr" in score_info[0] + assert "ci_sdr" in score_info[0] + + +def test_mcd_f0_pipeline_with_registry_and_mocked_metric(monkeypatch): + monkeypatch.setattr( + "versa.sequence_metrics.mcd_f0._ensure_mcd_f0_dependencies", lambda: None + ) + monkeypatch.setattr( + "versa.sequence_metrics.mcd_f0.mcd_f0", + lambda pred_x, gt_x, fs, f0min, f0max, **kwargs: { + "mcd": 1.2, + "f0rmse": 3.4, + "f0corr": 0.5, + }, + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_mcd_f0_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "mcd_f0"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["mcd"] == 1.2 + assert score_info[0]["f0rmse"] == 3.4 + assert score_info[0]["f0corr"] == 0.5 + + +def test_warpq_pipeline_with_registry_and_mocked_metric(monkeypatch): + class DummyWarpqModel: + args = {"sr": 16000} + + monkeypatch.setattr( + "versa.sequence_metrics.warpq.warpq_setup", + lambda **kwargs: DummyWarpqModel(), + ) + monkeypatch.setattr( + "versa.sequence_metrics.warpq.warpq", + lambda model, pred_x, gt_x, fs=8000: {"warpq": 3.8}, + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_warpq_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "warpq"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["warpq"] == 3.8 + + +def test_se_snr_pipeline_with_registry_and_mocked_metric(monkeypatch): + class DummySeModel: + pass + + monkeypatch.setattr( + "versa.utterance_metrics.se_snr.se_snr_setup", + lambda **kwargs: DummySeModel(), + ) + monkeypatch.setattr( + "versa.utterance_metrics.se_snr.se_snr", + lambda model, pred_x, fs: { + "se_sdr": 1.0, + "se_sar": 2.0, + "se_si_snr": 3.0, + "se_ci_sdr": 4.0, + }, + ) + + gen_files, _ = _sample_files() + registry = MetricRegistry() + register_se_snr_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "se_snr"}], + use_gt=False, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["se_sdr"] == 1.0 + assert score_info[0]["se_sar"] == 2.0 + assert score_info[0]["se_si_snr"] == 3.0 + assert score_info[0]["se_ci_sdr"] == 4.0 + + +def test_speaker_pipeline_with_registry_and_mocked_metric(monkeypatch): + class DummySpeakerModel: + pass + + monkeypatch.setattr( + "versa.utterance_metrics.speaker.speaker_model_setup", + lambda **kwargs: DummySpeakerModel(), + ) + monkeypatch.setattr( + "versa.utterance_metrics.speaker.speaker_metric", + lambda model, pred_x, gt_x, fs: {"spk_similarity": 0.75}, + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_speaker_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "speaker"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["spk_similarity"] == 0.75 + + +def test_singer_pipeline_with_registry_and_mocked_metric(monkeypatch): + class DummySingerModel: + pass + + monkeypatch.setattr( + "versa.utterance_metrics.singer.singer_model_setup", + lambda **kwargs: DummySingerModel(), + ) + monkeypatch.setattr( + "versa.utterance_metrics.singer.singer_metric", + lambda model, pred_x, gt_x, fs, target_sr=44100: {"singer_similarity": 0.5}, + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_singer_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "singer"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["singer_similarity"] == 0.5 + + +def test_log_wmse_pipeline_with_registry_and_mocked_model(monkeypatch): + class DummyLogWMSE: + def __init__(self, **kwargs): + pass + + def __call__(self, unproc_x, proc_x, gt_x): + return torch.tensor([0.33]) + + monkeypatch.setattr("versa.utterance_metrics.log_wmse.LogWMSE", DummyLogWMSE) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_log_wmse_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "log_wmse"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["log_wmse"] == pytest.approx(0.33) + + +def test_pseudo_mos_pipeline_with_registry_and_mocked_metric(monkeypatch): + monkeypatch.setattr( + "versa.utterance_metrics.pseudo_mos.pseudo_mos_setup", + lambda predictor_types, predictor_args, cache_dir, use_gpu: ( + {"utmos": object()}, + {"utmos": 16000}, + ), + ) + monkeypatch.setattr( + "versa.utterance_metrics.pseudo_mos.pseudo_mos_metric", + lambda pred, fs, predictor_dict, predictor_fs, use_gpu=False: {"utmos": 4.2}, + ) + + gen_files, _ = _sample_files() + registry = MetricRegistry() + register_pseudo_mos_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "pseudo_mos", "predictor_types": ["utmos"]}], + use_gt=False, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["utmos"] == 4.2 + + +def test_universa_pipeline_with_registry_and_mocked_metric(monkeypatch): + def dummy_universa_metric( + audio_data, ref_audio=None, ref_text=None, original_sr=16000, ref_sr=None + ): + return {"universa_mos": 3.5} + + monkeypatch.setattr( + "versa.utterance_metrics.universa.universa_metric", + dummy_universa_metric, + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_universa_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "universa"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["universa_mos"] == 3.5 + + +def test_qwen_metric_pipelines_with_registry_and_mocked_models(monkeypatch): + monkeypatch.setattr( + "versa.utterance_metrics.qwen2_audio.qwen2_model_setup", + lambda **kwargs: {"model": "qwen2"}, + ) + monkeypatch.setattr( + "versa.utterance_metrics.qwen2_audio.qwen2_base_metric", + lambda qwen_utils, pred_x, fs=16000, custom_prompt=None, max_length=1000: ( + "young adult" + ), + ) + monkeypatch.setattr( + "versa.utterance_metrics.qwen_omni.qwen_omni_model_setup", + lambda **kwargs: {"model": "omni"}, + ) + monkeypatch.setattr( + "versa.utterance_metrics.qwen_omni.qwen_omni_base_metric", + lambda qwen_utils, pred_x, fs=16000, custom_prompt=None, max_length=500: ( + "happy" + ), + ) + + gen_files, _ = _sample_files() + registry = MetricRegistry() + register_qwen2_audio_metric(registry) + register_qwen_omni_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [ + {"name": "qwen2_audio_speaker_age"}, + {"name": "qwen_omni_speech_emotion"}, + ], + use_gt=False, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["qwen_speaker_age"] == "young adult" + assert score_info[0]["qwen_omni_speech_emotion"] == "happy" + + +def test_squim_pipeline_with_registry_and_mocked_models(monkeypatch): + class DummyObjectiveBundle: + @staticmethod + def get_model(): + return lambda pred_x: ( + torch.tensor([0.6]), + torch.tensor([1.2]), + torch.tensor([-3.4]), + ) + + class DummySubjectiveBundle: + @staticmethod + def get_model(): + return lambda pred_x, ref_x: torch.tensor([4.2]) + + monkeypatch.setattr("versa.utterance_metrics.squim.SQUIM_AVAILABLE", True) + monkeypatch.setattr( + "versa.utterance_metrics.squim.SQUIM_OBJECTIVE", DummyObjectiveBundle + ) + monkeypatch.setattr( + "versa.utterance_metrics.squim.SQUIM_SUBJECTIVE", DummySubjectiveBundle + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_squim_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "squim_no_ref"}, {"name": "squim_ref"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["torch_squim_stoi"] == 0.6 + assert score_info[0]["torch_squim_pesq"] == 1.2 + assert score_info[0]["torch_squim_si_sdr"] == -3.4 + assert score_info[0]["torch_squim_mos"] == 4.2 + + +def test_vad_pipeline_with_registry_and_mocked_model(monkeypatch): + def dummy_get_speech_ts(pred_x, model, **kwargs): + return [{"start": 0.1, "end": 0.2}] + + monkeypatch.setattr( + "versa.utterance_metrics.vad.torch.hub.load", + lambda **kwargs: ("dummy-model", (dummy_get_speech_ts, None, None, None)), + ) + + gen_files, _ = _sample_files() + registry = MetricRegistry() + register_vad_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "vad"}], + use_gt=False, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["vad_info"] == [{"start": 0.1, "end": 0.2}] + + +def test_sheet_ssqa_pipeline_with_registry_and_mocked_model(monkeypatch): + class DummyInnerModel: + def to(self, device): + return self + + class DummySheetModel: + def __init__(self): + self.model = DummyInnerModel() + + def predict(self, wav): + return 3.25 + + monkeypatch.setattr( + "versa.utterance_metrics.sheet_ssqa.torch.hub.load", + lambda *args, **kwargs: DummySheetModel(), + ) + + gen_files, _ = _sample_files() + registry = MetricRegistry() + register_sheet_ssqa_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "sheet_ssqa"}], + use_gt=False, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["sheet_ssqa"] == 3.25 + + +def test_scoreq_pipeline_with_registry_and_mocked_model(monkeypatch): + class DummyScoreq: + def __init__(self, data_domain, mode, cache_dir, device): + self.mode = mode + + def predict(self, test_path, ref_path): + if self.mode == "ref": + return 1.2 + return 2.4 + + monkeypatch.setattr("versa.utterance_metrics.scoreq.Scoreq", DummyScoreq) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_scoreq_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "scoreq_nr"}, {"name": "scoreq_ref"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["scoreq_nr"] == 2.4 + assert score_info[0]["scoreq_ref"] == 1.2 + + +def test_vqscore_pipeline_with_registry_and_mocked_model(monkeypatch): + class DummyVqscoreModel: + device = "cpu" + input_transform = "none" + + def CNN_1D_encoder(self, sp_input): + return torch.ones((1, 2, 3)) + + def quantizer(self, z, stochastic=False, update=False): + return z.transpose(2, 1), None, None, None + + monkeypatch.setattr( + "versa.utterance_metrics.vqscore.vqscore_setup", + lambda use_gpu=False: DummyVqscoreModel(), + ) + + gen_files, _ = _sample_files() + registry = MetricRegistry() + register_vqscore_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "vqscore"}], + use_gt=False, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["vqscore"] == pytest.approx(1.0, abs=1e-4) + + +def test_visqol_pipeline_with_registry_and_mocked_model(monkeypatch): + class DummySimilarityResult: + moslqo = 4.1 + + class DummyApi: + def Measure(self, gt_x, pred_x): + return DummySimilarityResult() + + monkeypatch.setattr( + "versa.utterance_metrics.visqol_score.visqol_setup", + lambda model: (DummyApi(), 16000), + ) + + gen_files, gt_files = _sample_files() + registry = MetricRegistry() + register_visqol_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "visqol", "model": "speech"}], + use_gt=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["visqol"] == 4.1 + + +def test_wer_pipeline_with_registry_and_mocked_models(monkeypatch): + monkeypatch.setattr( + "versa.corpus_metrics.espnet_wer.espnet_wer_setup", + lambda **kwargs: {"model": "espnet"}, + ) + monkeypatch.setattr( + "versa.corpus_metrics.owsm_wer.owsm_wer_setup", + lambda **kwargs: {"model": "owsm"}, + ) + monkeypatch.setattr( + "versa.corpus_metrics.whisper_wer.whisper_wer_setup", + lambda **kwargs: {"model": "whisper"}, + ) + monkeypatch.setattr( + "versa.corpus_metrics.espnet_wer.espnet_levenshtein_metric", + lambda wer_utils, pred_x, ref_text, fs=16000: { + "espnet_hyp_text": ref_text, + "espnet_wer_equal": 1, + }, + ) + monkeypatch.setattr( + "versa.corpus_metrics.owsm_wer.owsm_levenshtein_metric", + lambda wer_utils, pred_x, ref_text, fs=16000: { + "owsm_hyp_text": ref_text, + "owsm_wer_equal": 1, + }, + ) + monkeypatch.setattr( + "versa.corpus_metrics.whisper_wer.whisper_levenshtein_metric", + lambda wer_utils, pred_x, ref_text, fs=16000, cache_pred_text=None: { + "whisper_hyp_text": cache_pred_text or ref_text, + "whisper_wer_equal": 1, + }, + ) + + gen_files, _ = _sample_files() + text_info = {key: "hello world" for key in gen_files} + registry = MetricRegistry() + register_espnet_wer_metric(registry) + register_owsm_wer_metric(registry) + register_whisper_wer_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( + [{"name": "espnet_wer"}, {"name": "owsm_wer"}, {"name": "whisper_wer"}], + use_gt=False, + use_gt_text=True, + use_gpu=False, + ) + + score_info = scorer.score_utterances( + gen_files, + metric_suite, + text_info=text_info, + output_file=None, + io="soundfile", + ) + + assert score_info + assert score_info[0]["espnet_hyp_text"] == "hello world" + assert score_info[0]["owsm_hyp_text"] == "hello world" + assert score_info[0]["whisper_hyp_text"] == "hello world" diff --git a/test/test_pipeline/test_cdpam_distance.py b/test/test_pipeline/test_cdpam_distance.py new file mode 100644 index 0000000..b18b3a0 --- /dev/null +++ b/test/test_pipeline/test_cdpam_distance.py @@ -0,0 +1,71 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.cdpam_distance import register_cdpam_distance_metric + +TEST_INFO = { + "cdpam_distance": 0.039433546364307404, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/cdpam_distance.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register CDPAM distance metric + registry = MetricRegistry() + register_cdpam_distance_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_chroma_alignment.py b/test/test_pipeline/test_chroma_alignment.py new file mode 100644 index 0000000..2ee3743 --- /dev/null +++ b/test/test_pipeline/test_chroma_alignment.py @@ -0,0 +1,79 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.chroma_alignment import register_chroma_alignment_metric + +TEST_INFO = { + "chroma_stft_cosine_dtw": 0.8895886718828439, + "chroma_stft_euclidean_dtw": 45.091055545199, + "chroma_cqt_cosine_dtw": 1.1888872845493323, + "chroma_cqt_euclidean_dtw": 56.16051355647546, + "chroma_cens_cosine_dtw": 0.6962623125421354, + "chroma_cens_euclidean_dtw": 38.38994047138499, + "chroma_stft_cosine_dtw_raw": 8.895886718828438, + "chroma_stft_cosine_dtw_log": 20.508107511971517, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/chroma_alignment.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register Chroma Alignment metric + registry = MetricRegistry() + register_chroma_alignment_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_discrete_speech.py b/test/test_pipeline/test_discrete_speech.py new file mode 100644 index 0000000..9d4357e --- /dev/null +++ b/test/test_pipeline/test_discrete_speech.py @@ -0,0 +1,74 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.discrete_speech import register_discrete_speech_metric + +TEST_INFO = { + "speech_bert": 0.9727544784545898, + "speech_bleu": 0.6699938983346256, + "speech_token_distance": 0.850506056080969, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/discrete_speech.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register Discrete Speech metric + registry = MetricRegistry() + register_discrete_speech_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_dpam_distance.py b/test/test_pipeline/test_dpam_distance.py new file mode 100644 index 0000000..1418d19 --- /dev/null +++ b/test/test_pipeline/test_dpam_distance.py @@ -0,0 +1,71 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.dpam_distance import register_dpam_distance_metric + +TEST_INFO = { + "dpam_distance": 0.4179654121398926, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/dpam_distance.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register DPAM distance metric + registry = MetricRegistry() + register_dpam_distance_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_emo_similarity.py b/test/test_pipeline/test_emo_similarity.py index afa5cf5..fac8ba8 100755 --- a/test/test_pipeline/test_emo_similarity.py +++ b/test/test_pipeline/test_emo_similarity.py @@ -4,12 +4,10 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.emo_similarity import register_emo2vec_metric TEST_INFO = { "emotion_similarity": 0.9984976053237915, @@ -31,7 +29,15 @@ def info_update(): with open("egs/separate_metrics/emo_similarity.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register Emotion metric + registry = MetricRegistry() + register_emo2vec_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,11 +45,13 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): diff --git a/test/test_pipeline/test_emo_vad.py b/test/test_pipeline/test_emo_vad.py new file mode 100644 index 0000000..45a4641 --- /dev/null +++ b/test/test_pipeline/test_emo_vad.py @@ -0,0 +1,69 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.emo_vad import register_emo_vad_metric + +TEST_INFO = { + "arousal_emo_vad": 0.663333535194397, + "valence_emo_vad": 0.5060539245605469, + "dominance_emo_vad": 0.6355133056640625, +} + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/emo_vad.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register EmoVad metric + registry = MetricRegistry() + register_emo_vad_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=False, + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=None, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + # for sir" + continue + # the plc mos is undeterministic + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_lid.py b/test/test_pipeline/test_lid.py deleted file mode 100755 index 83ec124..0000000 --- a/test/test_pipeline/test_lid.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -import math -import os - -import yaml - -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) - - -def info_update(): - - # find files - if os.path.isdir("test/test_samples/test2"): - gen_files = find_files("test/test_samples/test2") - - # find reference file - if os.path.isdir("test/test_samples/test1"): - gt_files = find_files("test/test_samples/test1") - - logging.info("The number of utterances = %d" % len(gen_files)) - - with open("egs/separate_metrics/lid.yaml", "r", encoding="utf-8") as f: - score_config = yaml.full_load(f) - - score_modules = load_score_modules( - score_config, - use_gt=(True if gt_files is not None else False), - use_gpu=False, - ) - - assert len(score_config) > 0, "no scoring function is provided" - - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" - ) - print("Summary: {}".format((score_info), flush=True)) - - if abs(score_info[0]["language"][0][1] - 0.8865218162536621) > 1e-4: - raise ValueError( - "Value issue in the test case, might be some issue in scorer lanugage" - ) - - print("check successful", flush=True) - - -if __name__ == "__main__": - info_update() diff --git a/test/test_pipeline/test_nisqa.py b/test/test_pipeline/test_nisqa.py index 8e212d4..ba6b92c 100755 --- a/test/test_pipeline/test_nisqa.py +++ b/test/test_pipeline/test_nisqa.py @@ -1,32 +1,37 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test pipeline for NISQA metric using the VersaScorer API.""" + import logging import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.nisqa import register_nisqa_metric -TEST_INFO = { - "nisqa_mos_pred": 0.4359583258628845, - "nisqa_noi_pred": 1.5543216466903687, - "nisqa_dis_pred": 2.293217182159424, - "nisqa_col_pred": 1.059649109840393, - "nisqa_loud_pred": 1.2060534954071045, -} +TEST_INFO = { + "nisqa_mos_pred": 0.6706185, + "nisqa_noi_pred": 1.1474215, + "nisqa_dis_pred": 1.8560395, + "nisqa_col_pred": 0.90170276, + "nisqa_loud_pred": 1.4682994, +} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -35,7 +40,15 @@ def info_update(): with open("egs/separate_metrics/nisqa.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register NISQA metric + registry = MetricRegistry() + register_nisqa_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -43,14 +56,18 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if abs(TEST_INFO[key] - summary[key]) > 1e-4: + if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): + continue + if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( key diff --git a/test/test_pipeline/test_nomad.py b/test/test_pipeline/test_nomad.py index 20e741c..838d859 100755 --- a/test/test_pipeline/test_nomad.py +++ b/test/test_pipeline/test_nomad.py @@ -1,26 +1,31 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test pipeline for NOMAD metric using the VersaScorer API.""" + import logging import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.nomad import register_nomad_metric TEST_INFO = {"nomad": 0.0336} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -29,7 +34,15 @@ def info_update(): with open("egs/separate_metrics/nomad.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register NOMAD metric + registry = MetricRegistry() + register_nomad_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -37,17 +50,17 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_noresqa.py b/test/test_pipeline/test_noresqa.py index a3ebb37..8b78719 100755 --- a/test/test_pipeline/test_noresqa.py +++ b/test/test_pipeline/test_noresqa.py @@ -1,28 +1,31 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test pipeline for NORESQA metric using the VersaScorer API.""" + import logging import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.noresqa import register_noresqa_metric -TEST_INFO = { - "noresqa": 12.010879979211092, # need to be updated -} +TEST_INFO = {"noresqa_mos": 1.051746129989624} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +34,15 @@ def info_update(): with open("egs/separate_metrics/noresqa.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register NORESQA metric + registry = MetricRegistry() + register_noresqa_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,17 +50,17 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_owsm_lid.py b/test/test_pipeline/test_owsm_lid.py new file mode 100755 index 0000000..0588099 --- /dev/null +++ b/test/test_pipeline/test_owsm_lid.py @@ -0,0 +1,64 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.owsm_lid import register_owsm_lid_metric + +TEST_INFO = {"language": 0.8865218162536621} + + +def info_update(): + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/lid.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register OWSM LID metric + registry = MetricRegistry() + register_owsm_lid_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + print("Scorer score_info: {}".format(score_info)) + + best_hyper = score_info[0]["language"][0][1] + if abs(best_hyper - TEST_INFO["language"]) > 1e-4: + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + "language" + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_pam.py b/test/test_pipeline/test_pam.py index 315a939..a3dd9f7 100755 --- a/test/test_pipeline/test_pam.py +++ b/test/test_pipeline/test_pam.py @@ -4,46 +4,56 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.pam import register_pam_metric -TEST_INFO = {"pam_score": 0.01386283989995718} +TEST_INFO = {"pam_score": 0.3535262942314148} def info_update(): - # find files if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + logging.info("The number of utterances = %d" % len(gen_files)) with open("egs/separate_metrics/pam.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register PAM metric + registry = MetricRegistry() + register_pam_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, - use_gt=False, + use_gt=(True if gt_files is not None else False), use_gpu=False, ) assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_pesq_score.py b/test/test_pipeline/test_pesq_score.py new file mode 100644 index 0000000..f658f7d --- /dev/null +++ b/test/test_pipeline/test_pesq_score.py @@ -0,0 +1,65 @@ +import logging +import math +import os + +import yaml + +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.pesq_score import register_pesq_metric + +TEST_INFO = {"pesq": 1.5722705125808716} # Expected PESQ score for test audio + + +def info_update(): + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + # find reference file + gt_files = None + if os.path.isdir("test/test_samples/test1"): + gt_files = find_files("test/test_samples/test1") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/pesq.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + # Create registry and register PESQ metric + registry = MetricRegistry() + register_pesq_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( + score_config, + use_gt=(True if gt_files is not None else False), + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" + ) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) + + for key in summary: + if abs(TEST_INFO[key] - summary[key]) > 1e-4: + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + key + ) + ) + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/test/test_pipeline/test_scoreq.py b/test/test_pipeline/test_scoreq.py index 11ff633..feec3d8 100755 --- a/test/test_pipeline/test_scoreq.py +++ b/test/test_pipeline/test_scoreq.py @@ -1,61 +1,76 @@ -import logging +import importlib.util import math import os +import pytest import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.definition import MetricRegistry +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.utterance_metrics.scoreq import register_scoreq_metric TEST_INFO = {"scoreq_ref": 1.0068472623825073, "scoreq_nr": 1.7731} +RUN_REAL_MODEL_TESTS = os.environ.get("VERSA_RUN_REAL_MODEL_TESTS") == "1" -def info_update(): +def _load_scoreq_config(): + with open("egs/separate_metrics/scoreq.yaml", "r", encoding="utf-8") as f: + return yaml.full_load(f) - # find files - if os.path.isdir("test/test_samples/test2"): - gen_files = find_files("test/test_samples/test2") - # find reference file - if os.path.isdir("test/test_samples/test1"): - gt_files = find_files("test/test_samples/test1") +def _sample_files(): + gen_path = "test/test_samples/test2" + gt_path = "test/test_samples/test1" + if not os.path.isdir(gen_path) or not os.path.isdir(gt_path): + pytest.skip("Required test sample directories are not available") + return find_files(gen_path), find_files(gt_path) - logging.info("The number of utterances = %d" % len(gen_files)) - with open("egs/separate_metrics/scoreq.yaml", "r", encoding="utf-8") as f: - score_config = yaml.full_load(f) +def _scoreq_is_available(): + return importlib.util.find_spec("scoreq_versa") is not None - score_modules = load_score_modules( - score_config, - use_gt=(True if gt_files is not None else False), - use_gpu=False, - ) + +@pytest.mark.real_model +@pytest.mark.skipif( + not RUN_REAL_MODEL_TESTS, + reason="Set VERSA_RUN_REAL_MODEL_TESTS=1 to run real model-backed checks", +) +@pytest.mark.skipif( + not _scoreq_is_available(), + reason="scoreq_versa is not installed; run tools/install_scoreq.sh first", +) +def test_scoreq_pipeline_with_real_model(): + """Run ScoreQ through the registry/scorer path with real model inference.""" + gen_files, gt_files = _sample_files() + score_config = _load_scoreq_config() + + registry = MetricRegistry() + register_scoreq_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics(score_config, use_gt=True, use_gpu=False) assert len(score_config) > 0, "no scoring function is provided" + assert set(metric_suite.metrics) == {"scoreq_ref", "scoreq_nr"} - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + score_info = scorer.score_utterances( + gen_files, + metric_suite, + gt_files=gt_files, + output_file=None, + io="soundfile", ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) - - for key in summary: - if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" - continue - # the plc mos is undeterministic - if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": - raise ValueError( - "Value issue in the test case, might be some issue in scorer {}".format( - key - ) - ) - print("check successful", flush=True) + summary = compute_summary(score_info) + + for key, expected_value in TEST_INFO.items(): + assert key in summary + if math.isinf(expected_value): + assert math.isinf(summary[key]) + else: + assert summary[key] == pytest.approx(expected_value, abs=1e-4) if __name__ == "__main__": - info_update() + if not RUN_REAL_MODEL_TESTS: + raise SystemExit("Set VERSA_RUN_REAL_MODEL_TESTS=1 to run this check") + pytest.main([__file__, "-q", "-s"]) diff --git a/test/test_pipeline/test_sigmos.py b/test/test_pipeline/test_sigmos.py index 4f569e7..2735b61 100644 --- a/test/test_pipeline/test_sigmos.py +++ b/test/test_pipeline/test_sigmos.py @@ -1,15 +1,12 @@ import logging -import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.definition import MetricRegistry +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.utterance_metrics.sigmos import register_sigmos_metric TEST_INFO = { "SIGMOS_COL": 1.3242647647857666, @@ -26,34 +23,25 @@ def info_update(): if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") - # find reference file - if os.path.isdir("test/test_samples/test1"): - gt_files = find_files("test/test_samples/test1") - logging.info("The number of utterances = %d" % len(gen_files)) with open("egs/separate_metrics/sigmos.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( - score_config, - use_gt=(True if gt_files is not None else False), - use_gpu=False, - ) + registry = MetricRegistry() + register_sigmos_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics(score_config, use_gt=False, use_gpu=False) assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=None, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" - continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/test/test_pipeline/test_speaking_rate.py b/test/test_pipeline/test_speaking_rate.py index 52d2143..2a67f19 100755 --- a/test/test_pipeline/test_speaking_rate.py +++ b/test/test_pipeline/test_speaking_rate.py @@ -1,14 +1,14 @@ -import logging -import math import os +import pytest import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, +from versa.definition import MetricRegistry +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.utterance_metrics.speaking_rate import ( + SpeakingRateMetric, + register_speaking_rate_metric, ) TEST_INFO = { @@ -16,7 +16,54 @@ } -def info_update(): +class DummyWhisperModel: + def transcribe(self, audio, beam_size=5): + return {"text": "one two three four"} + + +class DummyWhisper: + @staticmethod + def load_model(model_tag, device="cpu"): + return DummyWhisperModel() + + +class DummyTextCleaner: + def __init__(self, cleaner): + self.cleaner = cleaner + + +@pytest.fixture() +def mocked_speaking_rate_dependencies(monkeypatch): + monkeypatch.setattr("versa.utterance_metrics.speaking_rate.whisper", DummyWhisper) + monkeypatch.setattr( + "versa.utterance_metrics.speaking_rate.TextCleaner", DummyTextCleaner + ) + + +def test_speaking_rate_metric_class_uses_existing_keys( + mocked_speaking_rate_dependencies, +): + metric = SpeakingRateMetric({"use_gpu": False}) + scores = metric.compute([0.0] * 16000, metadata={"sample_rate": 16000}) + + assert scores == { + "speaking_rate": pytest.approx(4.0), + "whisper_hyp_text": "one two three four", + } + + +def test_speaking_rate_registration(): + registry = MetricRegistry() + + register_speaking_rate_metric(registry) + + assert registry.get_metric("speaking_rate") is SpeakingRateMetric + assert registry.get_metric("speaking_rate_metric") is SpeakingRateMetric + assert registry.get_metric("swr") is SpeakingRateMetric + assert registry.get_metadata("speaking_rate").requires_reference is False + + +def info_update(mocked_speaking_rate_dependencies=None): # find files if os.path.isdir("test/test_samples/test2"): @@ -26,12 +73,13 @@ def info_update(): if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") - logging.info("The number of utterances = %d" % len(gen_files)) - with open("egs/separate_metrics/speaking_rate.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + registry = MetricRegistry() + register_speaking_rate_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,25 +87,20 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" - continue - # the plc mos is undeterministic - if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": - raise ValueError( - "Value issue in the test case, might be some issue in scorer {}".format( - key - ) - ) + assert summary[key] == pytest.approx(TEST_INFO[key], abs=1e-4) print("check successful", flush=True) +def test_speaking_rate_pipeline_with_registry(mocked_speaking_rate_dependencies): + info_update(mocked_speaking_rate_dependencies) + + if __name__ == "__main__": info_update() diff --git a/test/test_pipeline/test_srmr.py b/test/test_pipeline/test_srmr.py index b97866e..184e39c 100755 --- a/test/test_pipeline/test_srmr.py +++ b/test/test_pipeline/test_srmr.py @@ -4,16 +4,12 @@ import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.definition import MetricRegistry +from versa.utterance_metrics.srmr import register_srmr_metric -TEST_INFO = { - "srmr": 0.6123816687905584, -} +TEST_INFO = {"srmr": 0.6123816687905584} def info_update(): @@ -23,6 +19,7 @@ def info_update(): gen_files = find_files("test/test_samples/test2") # find reference file + gt_files = None if os.path.isdir("test/test_samples/test1"): gt_files = find_files("test/test_samples/test1") @@ -31,7 +28,15 @@ def info_update(): with open("egs/separate_metrics/srmr.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Create registry and register SRMR metric + registry = MetricRegistry() + register_srmr_metric(registry) + + # Initialize VersaScorer with the populated registry + scorer = VersaScorer(registry) + + # Load metrics using the new API + metric_suite = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gpu=False, @@ -39,18 +44,16 @@ def info_update(): assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + # Score utterances using the new API + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" - continue - # the plc mos is undeterministic - if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": + if abs(TEST_INFO[key] - summary[key]) > 1e-4: raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( key diff --git a/test/test_pipeline/test_wvmos.py b/test/test_pipeline/test_wvmos.py index 6eed09f..44c87fe 100644 --- a/test/test_pipeline/test_wvmos.py +++ b/test/test_pipeline/test_wvmos.py @@ -1,15 +1,12 @@ import logging -import math import os import yaml -from versa.scorer_shared import ( - find_files, - list_scoring, - load_score_modules, - load_summary, -) +from versa.definition import MetricRegistry +from versa.scorer_shared import VersaScorer, compute_summary +from versa.utils_shared import find_files +from versa.utterance_metrics.wvmos import register_wvmos_metric TEST_INFO = {"wvmos": 0.621284008026123} @@ -19,34 +16,25 @@ def info_update(): if os.path.isdir("test/test_samples/test2"): gen_files = find_files("test/test_samples/test2") - # find reference file - if os.path.isdir("test/test_samples/test1"): - gt_files = find_files("test/test_samples/test1") - logging.info("The number of utterances = %d" % len(gen_files)) with open("egs/separate_metrics/wvmos.yaml", "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( - score_config, - use_gt=(True if gt_files is not None else False), - use_gpu=False, - ) + registry = MetricRegistry() + register_wvmos_metric(registry) + scorer = VersaScorer(registry) + metric_suite = scorer.load_metrics(score_config, use_gt=False, use_gpu=False) assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( - gen_files, score_modules, gt_files, output_file=None, io="soundfile" + score_info = scorer.score_utterances( + gen_files, metric_suite, gt_files=None, output_file=None, io="soundfile" ) - summary = load_summary(score_info) - print("Summary: {}".format(load_summary(score_info)), flush=True) + summary = compute_summary(score_info) + print("Summary: {}".format(summary), flush=True) for key in summary: - if math.isinf(TEST_INFO[key]) and math.isinf(summary[key]): - # for sir" - continue - # the plc mos is undeterministic if abs(TEST_INFO[key] - summary[key]) > 1e-4 and key != "plcmos": raise ValueError( "Value issue in the test case, might be some issue in scorer {}".format( diff --git a/tools/easy_install.sh b/tools/easy_install.sh index 02e05f0..c72398e 100755 --- a/tools/easy_install.sh +++ b/tools/easy_install.sh @@ -8,6 +8,8 @@ set -eou pipefail . ./install_utmosv2.sh || echo "error in utmosv2" # . ./install_warpq.sh . ./install_scoreq.sh || echo "error in scoreq" +. ./install_log_wmse.sh || echo "error in log_wmse" +. ./install_vqscore.sh || echo "error in vqscore" . ./install_nomad.sh || echo "error in nomad" . ./install_asvspoof.sh || echo "error in asvspoof" . ./install_pysepm.sh || echo "error in pysepm" diff --git a/tools/install_asvspoof.sh b/tools/install_asvspoof.sh index bc92c90..f7640a1 100755 --- a/tools/install_asvspoof.sh +++ b/tools/install_asvspoof.sh @@ -1,4 +1,9 @@ #!/bin/bash +set -e + +cd "$(dirname "$0")" + ## cloning the AASIST repo into the checkpoint folder +mkdir -p checkpoints git clone https://github.com/clovaai/aasist.git checkpoints/aasist diff --git a/tools/install_fairseq.sh b/tools/install_fairseq.sh index 005f3aa..92c4df4 100755 --- a/tools/install_fairseq.sh +++ b/tools/install_fairseq.sh @@ -1,11 +1,18 @@ #!/bin/bash +set -euo pipefail + +PYTHON_BIN="${PYTHON:-python}" +if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then + echo "ERROR: Python executable '$PYTHON_BIN' not found. Activate your environment or set PYTHON=/path/to/python." + exit 1 +fi # Repository information REPO_OWNER="ftshijt" REPO_NAME="fairseq" REPO_PATH="$REPO_OWNER/$REPO_NAME" BRANCH="versa" -EXPECTED_COMMIT_ID="612be207e0afe60859ec393608ef89bba0e5246c" +EXPECTED_COMMIT_ID="7c814e9580e24f69bd6198b400ec12bc3f90fd51" # Old version: EXPECTED_COMMIT_ID="0e35caead74528f04e741986b78ff0b4b543dbe6" # Function to check if repository exists @@ -122,6 +129,26 @@ check_local_commit() { return 0 } +ensure_legacy_config_stubs() { + mkdir -p fairseq/config/criterion fairseq/config/task fairseq/config/optimizer fairseq/config/lr_scheduler + + if [ ! -f fairseq/config/criterion/cross_entropy.yaml ]; then + printf '# @package _group_\n\n_name: cross_entropy\n' > fairseq/config/criterion/cross_entropy.yaml + fi + if [ ! -f fairseq/config/task/audio_pretraining.yaml ]; then + printf '# @package _group_\n\n_name: audio_pretraining\n' > fairseq/config/task/audio_pretraining.yaml + fi + if [ ! -f fairseq/config/optimizer/adam.yaml ]; then + printf '# @package _group_\n\n_name: adam\n' > fairseq/config/optimizer/adam.yaml + fi + if [ ! -f fairseq/config/lr_scheduler/polynomial_decay.yaml ]; then + printf '# @package _group_\n\n_name: polynomial_decay\n' > fairseq/config/lr_scheduler/polynomial_decay.yaml + fi + if [ ! -f fairseq/config/lr_scheduler/fixed.yaml ]; then + printf '# @package _group_\n\n_name: fixed\n' > fairseq/config/lr_scheduler/fixed.yaml + fi +} + # Function to validate the specific commit after cloning (if a specific commit is required) validate_cloned_commit() { local expected_commit=$1 @@ -159,29 +186,30 @@ fi # Check if we already have the correct version installed if [ -n "$EXPECTED_COMMIT_ID" ] && check_local_commit "$EXPECTED_COMMIT_ID"; then echo "Local repository is already at the expected commit: $EXPECTED_COMMIT_ID" - echo "Skipping reinstallation." - exit 0 -fi - -# Clean up existing directory if it exists -if [ -d "fairseq" ]; then - echo "Removing existing fairseq directory..." - rm -rf fairseq -fi + cd fairseq +else + # Clean up existing directory if it exists + if [ -d "fairseq" ]; then + echo "Removing existing fairseq directory..." + rm -rf fairseq + fi -# Clone the repository -echo "Cloning repository: $REPO_PATH (branch: $BRANCH)..." -git clone -b "$BRANCH" "https://github.com/$REPO_PATH.git" -cd fairseq + # Clone the repository + echo "Cloning repository: $REPO_PATH (branch: $BRANCH)..." + git clone -b "$BRANCH" "https://github.com/$REPO_PATH.git" + cd fairseq -# Validate the commit if specified -if [ -n "$EXPECTED_COMMIT_ID" ]; then - validate_cloned_commit "$EXPECTED_COMMIT_ID" + # Validate the commit if specified + if [ -n "$EXPECTED_COMMIT_ID" ]; then + validate_cloned_commit "$EXPECTED_COMMIT_ID" + fi fi +ensure_legacy_config_stubs + # Install the package echo "Installing fairseq package..." -pip install -e . +"$PYTHON_BIN" -m pip install -e . cd .. echo "=== fairseq installation completed successfully ===" diff --git a/tools/install_log_wmse.sh b/tools/install_log_wmse.sh new file mode 100755 index 0000000..cbc438d --- /dev/null +++ b/tools/install_log_wmse.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -euo pipefail + +PYTHON_BIN="${PYTHON:-python}" +if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then + echo "ERROR: Python executable '$PYTHON_BIN' not found. Activate your environment or set PYTHON=/path/to/python." + exit 1 +fi + +"$PYTHON_BIN" -m pip install torch-log-wmse diff --git a/tools/install_noresqa.sh b/tools/install_noresqa.sh index 924d38f..1298ec6 100755 --- a/tools/install_noresqa.sh +++ b/tools/install_noresqa.sh @@ -1,12 +1,16 @@ #/bin/bash +set -e -rm -rf Noresqa +cd "$(dirname "$0")" -# # NOTE(hyejin): a versa-specialized implementation for Noresqa -git clone https://github.com/ftshijt/Noresqa.git - -wget https://github.com/facebookresearch/Noresqa/raw/refs/heads/main/models/model_noresqa_mos.pth -wget wget https://github.com/facebookresearch/Noresqa/raw/refs/heads/main/models/model_noresqa.pth -mv model_noresqa_mos.pth Noresqa/models/model_noresqa_mos.pth -mv model_noresqa.pth Noresqa/models/model_noresqa.pth +mkdir -p ../versa_cache/noresqa_model +if [ ! -f ../versa_cache/noresqa_model/model_noresqa_mos.pth ]; then + wget -O ../versa_cache/noresqa_model/model_noresqa_mos.pth https://github.com/facebookresearch/Noresqa/raw/refs/heads/main/models/model_noresqa_mos.pth +fi +if [ ! -f ../versa_cache/noresqa_model/model_noresqa.pth ]; then + wget -O ../versa_cache/noresqa_model/model_noresqa.pth https://github.com/facebookresearch/Noresqa/raw/refs/heads/main/models/model_noresqa.pth +fi +if [ ! -f ../versa_cache/noresqa_model/wav2vec_small.pt ]; then + wget -O ../versa_cache/noresqa_model/wav2vec_small.pt https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt +fi diff --git a/tools/install_scoreq.sh b/tools/install_scoreq.sh index e329953..09087b7 100755 --- a/tools/install_scoreq.sh +++ b/tools/install_scoreq.sh @@ -1,13 +1,33 @@ -#/bin/bash +#!/bin/bash +set -euo pipefail -if [ -d "scoreq" ]; then - rm -rf scoreq +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +PYTHON_BIN="${PYTHON:-python}" + +if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then + echo "ERROR: Python executable '$PYTHON_BIN' not found. Activate your environment or set PYTHON=/path/to/python." + exit 1 fi -./install_fairseq.sh || { echo "fairseq installation exit"; } +cd "$REPO_ROOT" + +"$SCRIPT_DIR/install_fairseq.sh" || { + echo "fairseq installation exit" + exit 1 +} # # NOTE(jiatong): a versa-specialized implementation for scoreq -git clone https://github.com/ftshijt/scoreq.git +if [ -d "scoreq/.git" ]; then + git -C scoreq fetch origin + git -C scoreq checkout main + git -C scoreq pull --ff-only origin main +elif [ -d "scoreq" ]; then + echo "ERROR: scoreq exists but is not a git checkout. Move it aside and retry." + exit 1 +else + git clone https://github.com/ftshijt/scoreq.git +fi cd scoreq -pip install -e . +"$PYTHON_BIN" -m pip install -e . --no-deps cd .. diff --git a/tools/install_srmr.sh b/tools/install_srmr.sh index 41fe6db..ca00b9a 100755 --- a/tools/install_srmr.sh +++ b/tools/install_srmr.sh @@ -1,11 +1,13 @@ #/bin/bash +set -e -rm -rf srmr +cd "$(dirname "$0")" + +rm -rf SRMRpy # # NOTE(hyejin): a versa-specialized implementation for pysepm git clone https://github.com/shimhz/SRMRpy.git cd SRMRpy pip install -e . -cd .. diff --git a/tools/install_ssl-singer-identity.sh b/tools/install_ssl-singer-identity.sh index b707eca..f824cba 100755 --- a/tools/install_ssl-singer-identity.sh +++ b/tools/install_ssl-singer-identity.sh @@ -1,12 +1,22 @@ -#/bin/bash +#!/bin/bash +set -euo pipefail -if [ -d "ssl-singer-identity" ]; then - rm -rf ssl-singer-identity +PYTHON_BIN="${PYTHON:-python}" +if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then + echo "ERROR: Python executable '$PYTHON_BIN' not found. Activate your environment or set PYTHON=/path/to/python." + exit 1 fi -# # NOTE(jiatong): a versa-specialized implementation for singer identity -git clone https://github.com/ftshijt/ssl-singer-identity.git -cd ssl-singer-identity -pip install -e . -cd .. +cd "$(dirname "$0")" + +tmpdir="$(mktemp -d)" +trap 'rm -rf "$tmpdir"' EXIT + +# NOTE(jiatong): a versa-specialized implementation for singer identity. +git clone --depth 1 https://github.com/ftshijt/ssl-singer-identity.git "$tmpdir/ssl-singer-identity" +cd "$tmpdir/ssl-singer-identity" +"$PYTHON_BIN" -c "from pathlib import Path; p=Path('singer_identity/utils/fetch_pretrained.py'); s=p.read_text(); old='''repo_id=source,\n filename=filename,\n use_auth_token=use_auth_token,'''; new='''repo_id=source,\n filename=filename,\n token=use_auth_token or None,'''; p.write_text(s.replace(old, new))" +"$PYTHON_BIN" -c "from pathlib import Path; p=Path('singer_identity/utils/fetch_pretrained.py'); s=p.read_text(); old='''except ValueError:\n if pymodule_file == \"custom.py\":'''; new='''except Exception:\n if pymodule_file == \"custom.py\":'''; p.write_text(s.replace(old, new))" +"$PYTHON_BIN" -m pip install . +"$PYTHON_BIN" -m pip install nnAudio torchvision diff --git a/tools/install_vqscore.sh b/tools/install_vqscore.sh new file mode 100755 index 0000000..0b713a1 --- /dev/null +++ b/tools/install_vqscore.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +VQSCORE_DIR="$REPO_ROOT/versa/utterance_metrics/VQscore" + +cd "$REPO_ROOT" + +if [ ! -f ".gitmodules" ]; then + echo "ERROR: VQScore is configured as a git submodule, but .gitmodules was not found." + exit 1 +fi + +git submodule update --init --recursive versa/utterance_metrics/VQscore + +if [ ! -f "$VQSCORE_DIR/models/VQVAE_models.py" ]; then + echo "ERROR: VQScore submodule initialized, but models/VQVAE_models.py is missing." + echo "Check the submodule checkout at $VQSCORE_DIR." + exit 1 +fi + +if [ ! -f "$VQSCORE_DIR/config/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml" ]; then + echo "ERROR: VQScore config is missing from $VQSCORE_DIR/config." + exit 1 +fi + +if [ ! -f "$VQSCORE_DIR/exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/checkpoint-dnsmos_ovr_CC=0.835.pkl" ]; then + echo "ERROR: VQScore checkpoint is missing from $VQSCORE_DIR/exp." + echo "The upstream submodule must include or download checkpoint-dnsmos_ovr_CC=0.835.pkl before running vqscore." + exit 1 +fi + +echo "VQScore submodule and checkpoint are ready." diff --git a/tools/install_wvmos.sh b/tools/install_wvmos.sh index a4253d0..a382b2d 100644 --- a/tools/install_wvmos.sh +++ b/tools/install_wvmos.sh @@ -1,12 +1,22 @@ -#/bin/bash +#!/bin/bash -if [ -d "wvmos" ]; then - rm -rf wvmos -fi +set -euo pipefail -# # Clone and install wvmos -git clone https://github.com/AndreevP/wvmos.git -cd wvmos -pip install -e . -cd .. +cd "$(dirname "$0")" +PYTHON_BIN="${PYTHON:-python}" +tmpdir="$(mktemp -d)" +trap 'rm -rf "$tmpdir"' EXIT + +git clone --depth 1 https://github.com/AndreevP/wvmos.git "$tmpdir/wvmos" +cd "$tmpdir/wvmos" +"$PYTHON_BIN" -m pip install . + +"$PYTHON_BIN" - <<'PY' +import wvmos +from transformers import Wav2Vec2Model, Wav2Vec2Processor + +wvmos.get_wvmos(cuda=False) +Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") +Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") +PY diff --git a/tools/pysepm/pysepm/intelligibilityMeasures.py b/tools/pysepm/pysepm/intelligibilityMeasures.py index 5aee96f..77ccf19 100755 --- a/tools/pysepm/pysepm/intelligibilityMeasures.py +++ b/tools/pysepm/pysepm/intelligibilityMeasures.py @@ -1,50 +1,86 @@ -from scipy.signal import stft,resample,butter,lfilter,hilbert +from scipy.signal import stft, resample, butter, lfilter, hilbert from scipy.interpolate import interp1d -from pystoi import stoi as pystoi # https://github.com/mpariente/pystoi +from pystoi import stoi as pystoi # https://github.com/mpariente/pystoi import numpy as np -from .util import extract_overlapped_windows,resample_matlab_like +from .util import extract_overlapped_windows, resample_matlab_like stoi = pystoi -def fwseg_noise(clean_speech, processed_speech,fs,frameLen=0.03, overlap=0.75): - + +def fwseg_noise(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75): + clean_length = len(clean_speech) processed_length = len(processed_speech) - rms_all=np.linalg.norm(clean_speech)/np.sqrt(processed_length) - - winlength = round(frameLen*fs) #window length in samples - skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples - max_freq = fs/2 #maximum bandwidth - num_crit = 16 # number of critical bands - n_fft = int(2**np.ceil(np.log2(2*winlength))) - n_fftby2 = int(n_fft/2) - - cent_freq=np.zeros((num_crit,)) - bandwidth=np.zeros((num_crit,)) + rms_all = np.linalg.norm(clean_speech) / np.sqrt(processed_length) + + winlength = round(frameLen * fs) # window length in samples + skiprate = int(np.floor((1 - overlap) * frameLen * fs)) # window skip in samples + max_freq = fs / 2 # maximum bandwidth + num_crit = 16 # number of critical bands + n_fft = int(2 ** np.ceil(np.log2(2 * winlength))) + n_fftby2 = int(n_fft / 2) + + cent_freq = np.zeros((num_crit,)) + bandwidth = np.zeros((num_crit,)) # ---------------------------------------------------------------------- # Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz) # ---------------------------------------------------------------------- - cent_freq[0] = 150.0000; bandwidth[0] = 100.0000; - cent_freq[1] = 250.000; bandwidth[1] = 100.0000; - cent_freq[2] = 350.000; bandwidth[2] = 100.0000; - cent_freq[3] = 450.000; bandwidth[3] = 110.0000; - cent_freq[4] = 570.000; bandwidth[4] = 120.0000; - cent_freq[5] = 700.000; bandwidth[5] = 140.0000; - cent_freq[6] = 840.000; bandwidth[6] = 150.0000; - cent_freq[7] = 1000.000; bandwidth[7] = 160.000; - cent_freq[8] = 1170.000; bandwidth[8] = 190.000; - cent_freq[9] = 1370.000; bandwidth[9] = 210.000; - cent_freq[10] = 1600.000; bandwidth[10]= 240.000; - cent_freq[11] = 1850.000; bandwidth[11]= 280.000; - cent_freq[12] = 2150.000; bandwidth[12]= 320.000; - cent_freq[13] = 2500.000; bandwidth[13]= 380.000; - cent_freq[14] = 2900.000; bandwidth[14]= 450.000; - cent_freq[15] = 3400.000; bandwidth[15]= 550.000; - - Weight=np.array([0.0192,0.0312,0.0926,0.1031,0.0735,0.0611,0.0495,0.044,0.044,0.049,0.0486,0.0493, 0.049,0.0547,0.0555,0.0493]) - + cent_freq[0] = 150.0000 + bandwidth[0] = 100.0000 + cent_freq[1] = 250.000 + bandwidth[1] = 100.0000 + cent_freq[2] = 350.000 + bandwidth[2] = 100.0000 + cent_freq[3] = 450.000 + bandwidth[3] = 110.0000 + cent_freq[4] = 570.000 + bandwidth[4] = 120.0000 + cent_freq[5] = 700.000 + bandwidth[5] = 140.0000 + cent_freq[6] = 840.000 + bandwidth[6] = 150.0000 + cent_freq[7] = 1000.000 + bandwidth[7] = 160.000 + cent_freq[8] = 1170.000 + bandwidth[8] = 190.000 + cent_freq[9] = 1370.000 + bandwidth[9] = 210.000 + cent_freq[10] = 1600.000 + bandwidth[10] = 240.000 + cent_freq[11] = 1850.000 + bandwidth[11] = 280.000 + cent_freq[12] = 2150.000 + bandwidth[12] = 320.000 + cent_freq[13] = 2500.000 + bandwidth[13] = 380.000 + cent_freq[14] = 2900.000 + bandwidth[14] = 450.000 + cent_freq[15] = 3400.000 + bandwidth[15] = 550.000 + + Weight = np.array( + [ + 0.0192, + 0.0312, + 0.0926, + 0.1031, + 0.0735, + 0.0611, + 0.0495, + 0.044, + 0.044, + 0.049, + 0.0486, + 0.0493, + 0.049, + 0.0547, + 0.0555, + 0.0493, + ] + ) + # ---------------------------------------------------------------------- # Set up the critical band filters. Note here that Gaussianly shaped # filters are used. Also, the sum of the filter weights are equivalent @@ -52,228 +88,334 @@ def fwseg_noise(clean_speech, processed_speech,fs,frameLen=0.03, overlap=0.75): # zero. # ---------------------------------------------------------------------- - all_f0=np.zeros((num_crit,)) - crit_filter=np.zeros((num_crit,int(n_fftby2))) - g = np.zeros((num_crit,n_fftby2)) - - b = bandwidth; - q = cent_freq/1000; - p = 4*1000*q/b; # Eq. (7) - - #15.625=4000/256 - j = np.arange(0,n_fftby2) - + all_f0 = np.zeros((num_crit,)) + crit_filter = np.zeros((num_crit, int(n_fftby2))) + g = np.zeros((num_crit, n_fftby2)) + + b = bandwidth + q = cent_freq / 1000 + p = 4 * 1000 * q / b + # Eq. (7) + + # 15.625=4000/256 + j = np.arange(0, n_fftby2) + for i in range(num_crit): - g[i,:]=np.abs(1-j*(fs/n_fft)/(q[i]*1000));# Eq. (9) - crit_filter[i,:] = (1+p[i]*g[i,:])*np.exp(-p[i]*g[i,:]);# Eq. (8) - - num_frames = int(clean_length/skiprate-(winlength/skiprate)); # number of frames - start = 0 # starting sample - hannWin = 0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) - - f,t,clean_spec=stft(clean_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=False, boundary=None, padded=False) - f,t,processed_spec=stft(processed_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=False, boundary=None, padded=False) - - clean_frames = extract_overlapped_windows(clean_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)],winlength,winlength-skiprate,None) - rms_seg = np.linalg.norm(clean_frames,axis=-1)/np.sqrt(winlength); - rms_db = 20*np.log10(rms_seg/rms_all); - #-------------------------------------------------------------- + g[i, :] = np.abs(1 - j * (fs / n_fft) / (q[i] * 1000)) + # Eq. (9) + crit_filter[i, :] = (1 + p[i] * g[i, :]) * np.exp(-p[i] * g[i, :]) + # Eq. (8) + + num_frames = int(clean_length / skiprate - (winlength / skiprate)) + # number of frames + start = 0 # starting sample + hannWin = 0.5 * ( + 1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)) + ) + + f, t, clean_spec = stft( + clean_speech[0 : int(num_frames) * skiprate + int(winlength - skiprate)], + fs=fs, + window=hannWin, + nperseg=winlength, + noverlap=winlength - skiprate, + nfft=n_fft, + detrend=False, + return_onesided=False, + boundary=None, + padded=False, + ) + f, t, processed_spec = stft( + processed_speech[0 : int(num_frames) * skiprate + int(winlength - skiprate)], + fs=fs, + window=hannWin, + nperseg=winlength, + noverlap=winlength - skiprate, + nfft=n_fft, + detrend=False, + return_onesided=False, + boundary=None, + padded=False, + ) + + clean_frames = extract_overlapped_windows( + clean_speech[0 : int(num_frames) * skiprate + int(winlength - skiprate)], + winlength, + winlength - skiprate, + None, + ) + rms_seg = np.linalg.norm(clean_frames, axis=-1) / np.sqrt(winlength) + rms_db = 20 * np.log10(rms_seg / rms_all) + # -------------------------------------------------------------- # cal r2_high,r2_middle,r2_low - highInd = np.where(rms_db>=0) + highInd = np.where(rms_db >= 0) highInd = highInd[0] - middleInd = np.where((rms_db>=-10) & (rms_db<0)) + middleInd = np.where((rms_db >= -10) & (rms_db < 0)) middleInd = middleInd[0] - lowInd = np.where(rms_db<-10) + lowInd = np.where(rms_db < -10) lowInd = lowInd[0] - - num_high = np.sum(clean_spec[0:n_fftby2,highInd]*np.conj(processed_spec[0:n_fftby2,highInd]),axis=-1) - denx_high = np.sum(np.abs(clean_spec[0:n_fftby2,highInd])**2,axis=-1) - deny_high = np.sum(np.abs(processed_spec[0:n_fftby2,highInd])**2,axis=-1); - - num_middle = np.sum(clean_spec[0:n_fftby2,middleInd]*np.conj(processed_spec[0:n_fftby2,middleInd]),axis=-1) - denx_middle = np.sum(np.abs(clean_spec[0:n_fftby2,middleInd])**2,axis=-1) - deny_middle = np.sum(np.abs(processed_spec[0:n_fftby2,middleInd])**2,axis=-1); - - num_low = np.sum(clean_spec[0:n_fftby2,lowInd]*np.conj(processed_spec[0:n_fftby2,lowInd]),axis=-1) - denx_low = np.sum(np.abs(clean_spec[0:n_fftby2,lowInd])**2,axis=-1) - deny_low = np.sum(np.abs(processed_spec[0:n_fftby2,lowInd])**2,axis=-1); - - num2_high = np.abs(num_high)**2; - r2_high = num2_high/(denx_high*deny_high); - - num2_middle = np.abs(num_middle)**2; - r2_middle = num2_middle/(denx_middle*deny_middle); - - num2_low = np.abs(num_low)**2; - r2_low = num2_low/(denx_low*deny_low); - #-------------------------------------------------------------- + + num_high = np.sum( + clean_spec[0:n_fftby2, highInd] * np.conj(processed_spec[0:n_fftby2, highInd]), + axis=-1, + ) + denx_high = np.sum(np.abs(clean_spec[0:n_fftby2, highInd]) ** 2, axis=-1) + deny_high = np.sum(np.abs(processed_spec[0:n_fftby2, highInd]) ** 2, axis=-1) + + num_middle = np.sum( + clean_spec[0:n_fftby2, middleInd] + * np.conj(processed_spec[0:n_fftby2, middleInd]), + axis=-1, + ) + denx_middle = np.sum(np.abs(clean_spec[0:n_fftby2, middleInd]) ** 2, axis=-1) + deny_middle = np.sum(np.abs(processed_spec[0:n_fftby2, middleInd]) ** 2, axis=-1) + + num_low = np.sum( + clean_spec[0:n_fftby2, lowInd] * np.conj(processed_spec[0:n_fftby2, lowInd]), + axis=-1, + ) + denx_low = np.sum(np.abs(clean_spec[0:n_fftby2, lowInd]) ** 2, axis=-1) + deny_low = np.sum(np.abs(processed_spec[0:n_fftby2, lowInd]) ** 2, axis=-1) + + num2_high = np.abs(num_high) ** 2 + r2_high = num2_high / (denx_high * deny_high) + + num2_middle = np.abs(num_middle) ** 2 + r2_middle = num2_middle / (denx_middle * deny_middle) + + num2_low = np.abs(num_low) ** 2 + r2_low = num2_low / (denx_low * deny_low) + # -------------------------------------------------------------- # cal distortion frame by frame - - clean_spec = np.abs(clean_spec); - processed_spec = np.abs(processed_spec)**2; - - W_freq=Weight - - processed_energy = crit_filter.dot((processed_spec[0:n_fftby2,highInd].T*r2_high).T) - de_processed_energy= crit_filter.dot((processed_spec[0:n_fftby2,highInd].T*(1-r2_high)).T) - SDR = processed_energy/de_processed_energy;# Eq 13 in Kates (2005) - SDRlog=10*np.log10(SDR); - SDRlog_lim = SDRlog - SDRlog_lim[SDRlog_lim<-15]=-15 - SDRlog_lim[SDRlog_lim>15]=15 # limit between [-15, 15] - Tjm = (SDRlog_lim+15)/30; - distortionh = W_freq.dot(Tjm)/np.sum(W_freq,axis=0) - distortionh[distortionh<0]=0 - - - processed_energy = crit_filter.dot((processed_spec[0:n_fftby2,middleInd].T*r2_middle).T) - de_processed_energy= crit_filter.dot((processed_spec[0:n_fftby2,middleInd].T*(1-r2_middle)).T) - SDR = processed_energy/de_processed_energy;# Eq 13 in Kates (2005) - SDRlog=10*np.log10(SDR); - SDRlog_lim = SDRlog - SDRlog_lim[SDRlog_lim<-15]=-15 - SDRlog_lim[SDRlog_lim>15]=15 # limit between [-15, 15] - Tjm = (SDRlog_lim+15)/30; - distortionm = W_freq.dot(Tjm)/np.sum(W_freq,axis=0) - distortionm[distortionm<0]=0 - - processed_energy = crit_filter.dot((processed_spec[0:n_fftby2,lowInd].T*r2_low).T) - de_processed_energy= crit_filter.dot((processed_spec[0:n_fftby2,lowInd].T*(1-r2_low)).T) - SDR = processed_energy/de_processed_energy;# Eq 13 in Kates (2005) - SDRlog=10*np.log10(SDR); - SDRlog_lim = SDRlog - SDRlog_lim[SDRlog_lim<-15]=-15 - SDRlog_lim[SDRlog_lim>15]=15 # limit between [-15, 15] - Tjm = (SDRlog_lim+15)/30; - distortionl = W_freq.dot(Tjm)/np.sum(W_freq,axis=0) - distortionl[distortionl<0]=0 - - return distortionh,distortionm,distortionl - - -def csii(clean_speech, processed_speech,sample_rate): - sampleLen= min(len( clean_speech), len( processed_speech)) - clean_speech= clean_speech[0: sampleLen] - processed_speech= processed_speech[0: sampleLen] - vec_CSIIh,vec_CSIIm,vec_CSIIl = fwseg_noise(clean_speech, processed_speech, sample_rate) - - CSIIh=np.mean(vec_CSIIh) - CSIIm=np.mean(vec_CSIIm) - CSIIl=np.mean(vec_CSIIl) - return CSIIh,CSIIm,CSIIl - - - -def get_band(M,Fs): + + clean_spec = np.abs(clean_spec) + processed_spec = np.abs(processed_spec) ** 2 + + W_freq = Weight + + processed_energy = crit_filter.dot( + (processed_spec[0:n_fftby2, highInd].T * r2_high).T + ) + de_processed_energy = crit_filter.dot( + (processed_spec[0:n_fftby2, highInd].T * (1 - r2_high)).T + ) + SDR = processed_energy / de_processed_energy + # Eq 13 in Kates (2005) + SDRlog = 10 * np.log10(SDR) + SDRlog_lim = SDRlog + SDRlog_lim[SDRlog_lim < -15] = -15 + SDRlog_lim[SDRlog_lim > 15] = 15 # limit between [-15, 15] + Tjm = (SDRlog_lim + 15) / 30 + distortionh = W_freq.dot(Tjm) / np.sum(W_freq, axis=0) + distortionh[distortionh < 0] = 0 + + processed_energy = crit_filter.dot( + (processed_spec[0:n_fftby2, middleInd].T * r2_middle).T + ) + de_processed_energy = crit_filter.dot( + (processed_spec[0:n_fftby2, middleInd].T * (1 - r2_middle)).T + ) + SDR = processed_energy / de_processed_energy + # Eq 13 in Kates (2005) + SDRlog = 10 * np.log10(SDR) + SDRlog_lim = SDRlog + SDRlog_lim[SDRlog_lim < -15] = -15 + SDRlog_lim[SDRlog_lim > 15] = 15 # limit between [-15, 15] + Tjm = (SDRlog_lim + 15) / 30 + distortionm = W_freq.dot(Tjm) / np.sum(W_freq, axis=0) + distortionm[distortionm < 0] = 0 + + processed_energy = crit_filter.dot( + (processed_spec[0:n_fftby2, lowInd].T * r2_low).T + ) + de_processed_energy = crit_filter.dot( + (processed_spec[0:n_fftby2, lowInd].T * (1 - r2_low)).T + ) + SDR = processed_energy / de_processed_energy + # Eq 13 in Kates (2005) + SDRlog = 10 * np.log10(SDR) + SDRlog_lim = SDRlog + SDRlog_lim[SDRlog_lim < -15] = -15 + SDRlog_lim[SDRlog_lim > 15] = 15 # limit between [-15, 15] + Tjm = (SDRlog_lim + 15) / 30 + distortionl = W_freq.dot(Tjm) / np.sum(W_freq, axis=0) + distortionl[distortionl < 0] = 0 + + return distortionh, distortionm, distortionl + + +def csii(clean_speech, processed_speech, sample_rate): + sampleLen = min(len(clean_speech), len(processed_speech)) + clean_speech = clean_speech[0:sampleLen] + processed_speech = processed_speech[0:sampleLen] + vec_CSIIh, vec_CSIIm, vec_CSIIl = fwseg_noise( + clean_speech, processed_speech, sample_rate + ) + + CSIIh = np.mean(vec_CSIIh) + CSIIm = np.mean(vec_CSIIm) + CSIIl = np.mean(vec_CSIIl) + return CSIIh, CSIIm, CSIIl + + +def get_band(M, Fs): # This function sets the bandpass filter band edges. # It assumes that the sampling frequency is 8000 Hz. - A = 165 - a = 2.1 - K = 1 - L = 35 - CF = 300; - x_100 = (L/a)*np.log10(CF/A + K) - CF = Fs/2-600 - x_8000 = (L/a)*np.log10(CF/A + K); - LX = x_8000 - x_100 - x_step = LX / M - x = np.arange(x_100,x_8000+x_step+1e-20,x_step) + A = 165 + a = 2.1 + K = 1 + L = 35 + CF = 300 + x_100 = (L / a) * np.log10(CF / A + K) + CF = Fs / 2 - 600 + x_8000 = (L / a) * np.log10(CF / A + K) + LX = x_8000 - x_100 + x_step = LX / M + x = np.arange(x_100, x_8000 + x_step + 1e-20, x_step) if len(x) == M: - np.append(x,x_8000) + np.append(x, x_8000) - BAND = A*(10**(a*x/L) - K) + BAND = A * (10 ** (a * x / L) - K) return BAND + def get_ansis(BAND): - fcenter=(BAND[0:-1]+BAND[1:])/2; + fcenter = (BAND[0:-1] + BAND[1:]) / 2 # Data from Table B.1 in "ANSI (1997). S3.5–1997 Methods for Calculation of the Speech Intelligibility # Index. New York: American National Standards Institute." - f=np.array([150,250,350,450,570,700,840,1000,1170,1370,1600,1850,2150,2500,2900,3400,4000,4800,5800,7000,8500]) - BIF=np.array([0.0192,0.0312,0.0926,0.1031,0.0735,0.0611,0.0495,0.0440,0.0440,0.0490,0.0486,0.0493,0.0490,0.0547,0.0555,0.0493,0.0359,0.0387,0.0256,0.0219,0.0043]) - f_ANSI = interp1d(f,BIF) - ANSIs= f_ANSI(fcenter); - return fcenter,ANSIs - - -def ncm(clean_speech,processed_speech,fs): - - if fs != 8000 and fs != 16000: - raise ValueError('fs must be either 8 kHz or 16 kHz') - - - - x= clean_speech # clean signal - y= processed_speech # noisy signal + f = np.array( + [ + 150, + 250, + 350, + 450, + 570, + 700, + 840, + 1000, + 1170, + 1370, + 1600, + 1850, + 2150, + 2500, + 2900, + 3400, + 4000, + 4800, + 5800, + 7000, + 8500, + ] + ) + BIF = np.array( + [ + 0.0192, + 0.0312, + 0.0926, + 0.1031, + 0.0735, + 0.0611, + 0.0495, + 0.0440, + 0.0440, + 0.0490, + 0.0486, + 0.0493, + 0.0490, + 0.0547, + 0.0555, + 0.0493, + 0.0359, + 0.0387, + 0.0256, + 0.0219, + 0.0043, + ] + ) + f_ANSI = interp1d(f, BIF) + ANSIs = f_ANSI(fcenter) + return fcenter, ANSIs + + +def ncm(clean_speech, processed_speech, fs): + + if fs != 8000 and fs != 16000: + raise ValueError("fs must be either 8 kHz or 16 kHz") + + x = clean_speech # clean signal + y = processed_speech # noisy signal F_SIGNAL = fs - F_ENVELOPE = 32 # limits modulations to 0 Ly: - x = x[0:Ly] + x = x[0:Ly] if Ly > Lx: - y = y[0:Lx] + y = y[0:Lx] - Lx = len(x); - Ly = len(y); + Lx = len(x) + Ly = len(y) - X_BANDS = np.zeros((Lx,M_CHANNELS)) - Y_BANDS = np.zeros((Lx,M_CHANNELS)) + X_BANDS = np.zeros((Lx, M_CHANNELS)) + Y_BANDS = np.zeros((Lx, M_CHANNELS)) # DESIGN BANDPASS FILTERS for a in range(M_CHANNELS): - B_bp,A_bp = butter( 4 , np.array([BAND[a],BAND[a+1]])*(2/F_SIGNAL),btype='bandpass') - X_BANDS[:,a] = lfilter( B_bp , A_bp , x ) - Y_BANDS[:,a] = lfilter( B_bp , A_bp , y ) + B_bp, A_bp = butter( + 4, np.array([BAND[a], BAND[a + 1]]) * (2 / F_SIGNAL), btype="bandpass" + ) + X_BANDS[:, a] = lfilter(B_bp, A_bp, x) + Y_BANDS[:, a] = lfilter(B_bp, A_bp, y) gcd = np.gcd(F_SIGNAL, F_ENVELOPE) # CALCULATE HILBERT ENVELOPES, and resample at F_ENVELOPE Hz - analytic_x = hilbert( X_BANDS,axis=0); - X = np.abs( analytic_x ); - #X = resample( X , round(len(x)/F_SIGNAL*F_ENVELOPE)); - X = resample_matlab_like(X,F_ENVELOPE,F_SIGNAL) - analytic_y = hilbert( Y_BANDS,axis=0); - Y = np.abs( analytic_y ); - #Y = resample( Y , round(len(x)/F_SIGNAL*F_ENVELOPE)); - Y = resample_matlab_like(Y,F_ENVELOPE,F_SIGNAL) + analytic_x = hilbert(X_BANDS, axis=0) + X = np.abs(analytic_x) + # X = resample( X , round(len(x)/F_SIGNAL*F_ENVELOPE)); + X = resample_matlab_like(X, F_ENVELOPE, F_SIGNAL) + analytic_y = hilbert(Y_BANDS, axis=0) + Y = np.abs(analytic_y) + # Y = resample( Y , round(len(x)/F_SIGNAL*F_ENVELOPE)); + Y = resample_matlab_like(Y, F_ENVELOPE, F_SIGNAL) ## ---compute weights based on clean signal's rms envelopes----- # - Ldx, pp=X.shape - p=3 # power exponent - see Eq. 12 + Ldx, pp = X.shape + p = 3 # power exponent - see Eq. 12 ro2 = np.zeros((M_CHANNELS,)) asnr = np.zeros((M_CHANNELS,)) TI = np.zeros((M_CHANNELS,)) for k in range(M_CHANNELS): - x_tmp= X[ :, k] - y_tmp= Y[ :, k] - lambda_x= np.linalg.norm( x_tmp- np.mean( x_tmp))**2 - lambda_y= np.linalg.norm( y_tmp- np.mean( y_tmp))**2 - lambda_xy= np.sum( (x_tmp- np.mean( x_tmp))*(y_tmp- np.mean( y_tmp))) - ro2[k]= (lambda_xy**2)/ (lambda_x*lambda_y) - asnr[k]= 10*np.log10( (ro2[k]+ 1e-20)/ (1- ro2[k]+ 1e-20)); # Eq.9 in [1] - - if asnr[k]< -15: - asnr[k]= -15 - elif asnr[k]> 15: - asnr[k]= 15 - - TI[k]= (asnr[k]+ 15)/ 30 # Eq.10 in [1] - - WEIGHT=WEIGHT[:-1] # NOTE(jiatong): fix - ncm_val= WEIGHT.dot(TI)/np.sum(WEIGHT) # Eq.11 + x_tmp = X[:, k] + y_tmp = Y[:, k] + lambda_x = np.linalg.norm(x_tmp - np.mean(x_tmp)) ** 2 + lambda_y = np.linalg.norm(y_tmp - np.mean(y_tmp)) ** 2 + lambda_xy = np.sum((x_tmp - np.mean(x_tmp)) * (y_tmp - np.mean(y_tmp))) + ro2[k] = (lambda_xy**2) / (lambda_x * lambda_y) + asnr[k] = 10 * np.log10((ro2[k] + 1e-20) / (1 - ro2[k] + 1e-20)) + # Eq.9 in [1] + + if asnr[k] < -15: + asnr[k] = -15 + elif asnr[k] > 15: + asnr[k] = 15 + + TI[k] = (asnr[k] + 15) / 30 # Eq.10 in [1] + + WEIGHT = WEIGHT[:-1] # NOTE(jiatong): fix + ncm_val = WEIGHT.dot(TI) / np.sum(WEIGHT) # Eq.11 return ncm_val diff --git a/tools/setup_nisqa.sh b/tools/setup_nisqa.sh index 7cfdc0d..1e97a80 100755 --- a/tools/setup_nisqa.sh +++ b/tools/setup_nisqa.sh @@ -1,8 +1,14 @@ #/bin/bash -if [ -d "NISQA" ]; then - rm -rf NISQA -fi +set -e -# # NOTE(jiatong): only for pre-trained model -git clone https://github.com/gabrielmittag/NISQA.git +cd "$(dirname "$0")" + +tmpdir="$(mktemp -d)" +trap 'rm -rf "$tmpdir"' EXIT + +# NOTE(jiatong): only for pre-trained model weights. +git clone --depth 1 https://github.com/gabrielmittag/NISQA.git "$tmpdir/NISQA" +mkdir -p ../versa_cache/nisqa +cp "$tmpdir/NISQA"/weights/*.tar ../versa_cache/nisqa/ +cp "$tmpdir/NISQA"/weights/LICENSE_model_weights ../versa_cache/nisqa/ diff --git a/versa/__init__.py b/versa/__init__.py index c5fdcf1..188368b 100644 --- a/versa/__init__.py +++ b/versa/__init__.py @@ -1,116 +1,225 @@ +import importlib import logging +import os +from pathlib import Path __version__ = "0.0.1" # noqa: F401 -from versa.sequence_metrics.mcd_f0 import mcd_f0 -from versa.sequence_metrics.signal_metric import signal_metric - -try: - from versa.utterance_metrics.discrete_speech import ( - discrete_speech_metric, - discrete_speech_setup, - ) -except ImportError: - logging.info( - "Please pip install git+https://github.com/ftshijt/DiscreteSpeechMetrics.git and retry" - ) -except RuntimeError: - logging.info( - "Issues detected in discrete speech metrics, please double check the environment." - ) - -from versa.utterance_metrics.pseudo_mos import pseudo_mos_metric, pseudo_mos_setup - -try: - from versa.utterance_metrics.pesq_score import pesq_metric -except ImportError: - logging.info("Please install pesq with `pip install pesq` and retry") - -try: - from versa.utterance_metrics.stoi import stoi_metric, estoi_metric -except ImportError: - logging.info("Please install pystoi with `pip install pystoi` and retry") - -try: - from versa.utterance_metrics.speaker import speaker_metric, speaker_model_setup -except ImportError: - logging.info("Please install espnet with `pip install espnet` and retry") - -try: - from versa.utterance_metrics.singer import singer_metric, singer_model_setup -except ImportError: - logging.info("Please install ...") - -try: - from versa.utterance_metrics.visqol_score import visqol_metric, visqol_setup -except ImportError: - logging.info( - "Please install visqol follow https://github.com/google/visqol and retry" - ) - -from versa.corpus_metrics.espnet_wer import espnet_levenshtein_metric, espnet_wer_setup -from versa.corpus_metrics.fad import fad_scoring, fad_setup -from versa.corpus_metrics.owsm_wer import owsm_levenshtein_metric, owsm_wer_setup -from versa.corpus_metrics.whisper_wer import ( - whisper_levenshtein_metric, - whisper_wer_setup, -) -from versa.utterance_metrics.asr_matching import asr_match_metric, asr_match_setup -from versa.utterance_metrics.audiobox_aesthetics_score import ( - audiobox_aesthetics_score, - audiobox_aesthetics_setup, -) -from versa.utterance_metrics.emotion import emo2vec_setup, emo_sim -from versa.utterance_metrics.nomad import nomad, nomad_setup -from versa.utterance_metrics.noresqa import noresqa_metric, noresqa_model_setup -from versa.utterance_metrics.owsm_lid import language_id, owsm_lid_model_setup -from versa.utterance_metrics.pysepm import pysepm_metric -from versa.utterance_metrics.qwen2_audio import ( - qwen2_channel_type_metric, - qwen2_language_metric, - qwen2_laughter_crying_metric, - qwen2_model_setup, - qwen2_overlapping_speech_metric, - qwen2_pitch_range_metric, - qwen2_recording_quality_metric, - qwen2_speaker_age_metric, - qwen2_speaker_count_metric, - qwen2_speaker_gender_metric, - qwen2_speaking_style_metric, - qwen2_speech_background_environment_metric, - qwen2_speech_clarity_metric, - qwen2_speech_emotion_metric, - qwen2_speech_impairment_metric, - qwen2_speech_purpose_metric, - qwen2_speech_rate_metric, - qwen2_speech_register_metric, - qwen2_speech_volume_level_metric, - qwen2_vocabulary_complexity_metric, - qwen2_voice_pitch_metric, - qwen2_voice_type_metric, - qwen2_singing_technique_metric, -) -from versa.utterance_metrics.qwen_omni import ( - qwen_omni_model_setup, - qwen_omni_singing_technique_metric, -) -from versa.utterance_metrics.scoreq import ( - scoreq_nr, - scoreq_nr_setup, - scoreq_ref, - scoreq_ref_setup, -) -from versa.utterance_metrics.se_snr import se_snr, se_snr_setup -from versa.utterance_metrics.sheet_ssqa import sheet_ssqa, sheet_ssqa_setup -from versa.utterance_metrics.speaking_rate import ( - speaking_rate_metric, - speaking_rate_model_setup, -) -from versa.utterance_metrics.squim import squim_metric, squim_metric_no_ref -from versa.utterance_metrics.srmr import srmr_metric -from versa.utterance_metrics.chroma_alignment import chroma_metric -from versa.utterance_metrics.wvmos import wvmos_setup, wvmos_calculate -from versa.utterance_metrics.sigmos import sigmos_setup, sigmos_calculate -from versa.utterance_metrics.dpam_distance import dpam_metric, dpam_model_setup -from versa.utterance_metrics.cdpam_distance import cdpam_metric, cdpam_model_setup -from versa.utterance_metrics.vqscore import vqscore_metric, vqscore_setup +logger = logging.getLogger(__name__) + +os.environ.setdefault( + "NUMBA_CACHE_DIR", str(Path.cwd() / "versa_cache" / "numba_cache") +) + + +def _optional_metric_import(module_name, names, install_hint=None): + """Import optional metric symbols without making package import fail.""" + try: + module = importlib.import_module(module_name) + except ImportError: + if install_hint: + logger.info(install_hint) + else: + logger.info("Optional metric module %s is not available", module_name) + return + except RuntimeError: + logger.info("Issues detected in %s; please check the environment.", module_name) + return + + for name in names: + globals()[name] = getattr(module, name) + + +_optional_metric_import( + "versa.sequence_metrics.mcd_f0", + ("McdF0Metric", "register_mcd_f0_metric"), +) +# from versa.sequence_metrics.signal_metric import SignalMetric, register_signal_metric +_optional_metric_import( + "versa.sequence_metrics.warpq", + ("WarpqMetric", "register_warpq_metric"), +) + +_optional_metric_import( + "versa.utterance_metrics.discrete_speech", + ("DiscreteSpeechMetric", "register_discrete_speech_metric"), + ( + "Please pip install " + "git+https://github.com/ftshijt/DiscreteSpeechMetrics.git and retry" + ), +) + +_optional_metric_import( + "versa.utterance_metrics.pseudo_mos", + ("PseudoMosMetric", "register_pseudo_mos_metric"), +) + +_optional_metric_import( + "versa.utterance_metrics.pesq_score", + ("PesqMetric", "register_pesq_metric"), + "Please install pesq with `pip install pesq` and retry", +) + +# try: +# from versa.utterance_metrics.stoi import StoiMetric, register_stoi_metric +# except ImportError: +# logging.info("Please install pystoi with `pip install pystoi` and retry") +_optional_metric_import( + "versa.utterance_metrics.stoi", + ("StoiMetric", "EstoiMetric", "register_stoi_metric"), + "Please install pystoi with `pip install pystoi` and retry", +) + +_optional_metric_import( + "versa.utterance_metrics.speaker", + ("SpeakerMetric", "register_speaker_metric"), +) + +_optional_metric_import( + "versa.utterance_metrics.singer", + ("SingerMetric", "register_singer_metric"), + "Please install singer_identity following tools/install_ssl-singer-identity.sh", +) + +_optional_metric_import( + "versa.utterance_metrics.visqol_score", + ("VisqolMetric", "register_visqol_metric"), + "Please install visqol following https://github.com/google/visqol and retry", +) + +_optional_metric_import( + "versa.corpus_metrics.espnet_wer", + ("EspnetWerMetric", "register_espnet_wer_metric"), +) +# from versa.corpus_metrics.fad import FadMetric, register_fad_metric +_optional_metric_import( + "versa.corpus_metrics.owsm_wer", + ("OwsmWerMetric", "register_owsm_wer_metric"), +) +_optional_metric_import( + "versa.corpus_metrics.whisper_wer", + ("WhisperWerMetric", "register_whisper_wer_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.asr_matching", + ("ASRMatchMetric", "register_asr_match_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.audiobox_aesthetics_score", + ("AudioBoxAestheticsMetric", "register_audiobox_aesthetics_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.emo_similarity", + ("Emo2vecMetric", "register_emo2vec_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.nomad", + ("NomadMetric", "register_nomad_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.noresqa", + ("NoresqaMetric", "register_noresqa_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.owsm_lid", + ("OwsmLidMetric", "register_owsm_lid_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.log_wmse", + ("LogWmseMetric", "register_log_wmse_metric"), + "Please install torch-log-wmse and retry", +) +_optional_metric_import( + "versa.utterance_metrics.universa", + ("UniversaMetric", "register_universa_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.arecho", + ("ArechoMetric", "register_arecho_metric"), +) + +# from versa.utterance_metrics.pysepm import PysepmMetric, register_pysepm_metric +_optional_metric_import( + "versa.utterance_metrics.pysepm", + ("PysepmMetric", "register_pysepm_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.qwen2_audio", + ("Qwen2AudioMetric", "register_qwen2_audio_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.qwen_omni", + ("QwenOmniMetric", "register_qwen_omni_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.scoreq", + ( + "ScoreqMetric", + "ScoreqNrMetric", + "ScoreqRefMetric", + "register_scoreq_metric", + ), +) +_optional_metric_import( + "versa.utterance_metrics.se_snr", + ("SeSnrMetric", "register_se_snr_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.sheet_ssqa", + ("SheetSsqaMetric", "register_sheet_ssqa_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.speaking_rate", + ("SpeakingRateMetric", "register_speaking_rate_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.squim", + ("SquimMetric", "SquimRefMetric", "SquimNoRefMetric", "register_squim_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.srmr", + ("SRMRMetric", "register_srmr_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.chroma_alignment", + ("ChromaAlignmentMetric", "register_chroma_alignment_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.dpam_distance", + ("DpamDistanceMetric", "register_dpam_distance_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.cdpam_distance", + ("CdpamDistanceMetric", "register_cdpam_distance_metric"), +) + +_optional_metric_import( + "versa.utterance_metrics.vqscore", + ("VqscoreMetric", "register_vqscore_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.vad", + ("VadMetric", "register_vad_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.nisqa", + ("NisqaMetric", "register_nisqa_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.pam", + ("PamMetric", "register_pam_metric"), +) +_optional_metric_import( + "versa.sequence_metrics.signal_metric", + ("SignalMetric", "register_signal_metric"), +) +_optional_metric_import( + "versa.utterance_metrics.sigmos", + ("SigmosMetric", "register_sigmos_metric"), + "Please install SigMOS dependencies and retry", +) +_optional_metric_import( + "versa.utterance_metrics.wvmos", + ("WvmosMetric", "register_wvmos_metric"), + "Please install WVMOS following tools/install_wvmos.sh", +) diff --git a/versa/audio_utils.py b/versa/audio_utils.py new file mode 100644 index 0000000..6e1abfb --- /dev/null +++ b/versa/audio_utils.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright 2026 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Small audio helpers shared by metric wrappers.""" + +from math import gcd + +from scipy.signal import resample_poly + + +def resample_audio(audio, orig_sr, target_sr): + """Resample 1-D audio without importing librosa's numba-heavy audio module.""" + if orig_sr == target_sr: + return audio + divisor = gcd(int(orig_sr), int(target_sr)) + return resample_poly(audio, int(target_sr) // divisor, int(orig_sr) // divisor) diff --git a/versa/bin/scorer.py b/versa/bin/scorer.py index 2431006..99392df 100644 --- a/versa/bin/scorer.py +++ b/versa/bin/scorer.py @@ -13,11 +13,8 @@ from versa.scorer_shared import ( audio_loader_setup, - corpus_scoring, - list_scoring, - load_corpus_modules, - load_score_modules, - load_summary, + VersaScorer, + compute_summary, ) @@ -141,47 +138,58 @@ def main(): with open(args.score_config, "r", encoding="utf-8") as f: score_config = yaml.full_load(f) - score_modules = load_score_modules( + # Initialize VersaScorer + scorer = VersaScorer() + + # Load utterance-level metrics + utterance_metrics = scorer.load_metrics( score_config, use_gt=(True if gt_files is not None else False), use_gt_text=(True if text_info is not None else False), use_gpu=args.use_gpu, ) - if len(score_modules) > 0: - score_info = list_scoring( + # Perform utterance-level scoring + if len(utterance_metrics.metrics) > 0: + score_info = scorer.score_utterances( gen_files, - score_modules, + utterance_metrics, gt_files, text_info, output_file=args.output_file, io=args.io, ) - logging.info("Summary: {}".format(load_summary(score_info))) + logging.info("Summary: {}".format(compute_summary(score_info))) else: logging.info("No utterance-level scoring function is provided.") - corpus_score_modules = load_corpus_modules( + # Load corpus-level metrics (distributional metrics) + corpus_metrics = scorer.load_metrics( score_config, + use_gt=(True if gt_files is not None else False), + use_gt_text=(True if text_info is not None else False), use_gpu=args.use_gpu, - cache_folder=args.cache_folder, - io=args.io, ) - assert ( - len(corpus_score_modules) > 0 or len(score_modules) > 0 - ), "no scoring function is provided" - if len(corpus_score_modules) > 0: - corpus_score_info = corpus_scoring( - args.pred, - corpus_score_modules, - args.gt, + + # Filter for corpus-level metrics and perform corpus scoring + from versa.definition import MetricCategory + + corpus_suite = corpus_metrics.filter_by_category(MetricCategory.DISTRIBUTIONAL) + if len(corpus_suite.metrics) > 0: + corpus_score_info = scorer.score_corpus( + gen_files, + corpus_suite, + gt_files, text_info, - output_file=args.output_file + ".corpus", + output_file=args.output_file + ".corpus" if args.output_file else None, ) logging.info("Corpus Summary: {}".format(corpus_score_info)) else: logging.info("No corpus-level scoring function is provided.") - return + + # Ensure at least one scoring function is provided + if len(utterance_metrics.metrics) == 0 and len(corpus_suite.metrics) == 0: + raise ValueError("No scoring function is provided") if __name__ == "__main__": diff --git a/versa/corpus_metrics/espnet_wer.py b/versa/corpus_metrics/espnet_wer.py index c85bd30..d7898cc 100644 --- a/versa/corpus_metrics/espnet_wer.py +++ b/versa/corpus_metrics/espnet_wer.py @@ -1,156 +1,259 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging - -import librosa -import numpy as np -import torch -from espnet2.bin.asr_inference import Speech2Text -from espnet2.text.cleaner import TextCleaner -from Levenshtein import opcodes - -TARGET_FS = 16000 -CHUNK_SIZE = 30 # seconds - - -def espnet_wer_setup( - model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True -): - if model_tag == "default": - model_tag = "espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp" - device = "cuda" if use_gpu else "cpu" - model = Speech2Text.from_pretrained( - model_tag=model_tag, - device=device, - beam_size=beam_size, - ) - textcleaner = TextCleaner(text_cleaner) - if "whisper" in text_cleaner: - try: - import whisper - except ImportError: - logging.warning( - "Whipser-based cleaner is used but openai-whisper is not installed" - ) - wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} - return wer_utils - - -def espnet_predict( - model, - speech, - fs: int, - beam_size: int = 5, -): - """Generate predictions using the espnet model. (from URGENT Challenge) - - Args: - model (torch.nn.Module): espnet model. - speech (np.ndarray): speech signal < 120s (time,) - fs (int): sampling rate in Hz. - beam_size (int): beam size used in beam search. - Returns: - text (str): predicted text - """ - model.beam_search.beam_size = int(beam_size) - - assert fs == 16000, (fs, 16000) - - # assuming 10 tokens per second - model.maxlenratio = -min(300, int((len(speech) / TARGET_FS) * 10)) - - speech = librosa.util.fix_length(speech, size=(TARGET_FS * CHUNK_SIZE)) - text = model(speech)[0][0] - - return text - - -def espnet_levenshtein_metric(wer_utils, pred_x, ref_text, fs=16000): - """Calculate the Levenshtein distance between ref and inf ASR results. - - Args: - wer_utils (dict): a utility dict for WER calculation. - including: espnet model ("model"), text cleaner ("textcleaner"), and - beam size ("beam size") - pred_x (np.ndarray): test signal (time,) - ref_text (string): reference transcript - fs (int): sampling rate in Hz - Returns: - ret (dict): ditionary containing occurrences of edit operations - """ - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - fs = TARGET_FS - with torch.no_grad(): - inf_txt = espnet_predict( - wer_utils["model"], - pred_x, - fs, - beam_size=wer_utils["beam_size"], - ) - - ref_text = wer_utils["cleaner"](ref_text) - pred_text = wer_utils["cleaner"](inf_txt) - - # process wer - ref_words = ref_text.strip().split() - pred_words = pred_text.strip().split() - ret = { - "espnet_hyp_text": pred_text, - "ref_text": ref_text, - "espnet_wer_delete": 0, - "espnet_wer_insert": 0, - "espnet_wer_replace": 0, - "espnet_wer_equal": 0, - } - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): - if op == "insert": - ret["espnet_wer_" + op] = ret["espnet_wer_" + op] + inf_et - inf_st - else: - ret["espnet_wer_" + op] = ret["espnet_wer_" + op] + ref_et - ref_st - total = ( - ret["espnet_wer_delete"] + ret["espnet_wer_replace"] + ret["espnet_wer_equal"] - ) - assert total == len(ref_words), (total, len(ref_words)) - total = ( - ret["espnet_wer_insert"] + ret["espnet_wer_replace"] + ret["espnet_wer_equal"] - ) - assert total == len(pred_words), (total, len(pred_words)) - - # process cer - ref_words = [c for c in ref_text] - pred_words = [c for c in pred_text] - ret.update( - espnet_cer_delete=0, - espnet_cer_insert=0, - espnet_cer_replace=0, - espnet_cer_equal=0, - ) - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): - if op == "insert": - ret["espnet_cer_" + op] = ret["espnet_cer_" + op] + inf_et - inf_st - else: - ret["espnet_cer_" + op] = ret["espnet_cer_" + op] + ref_et - ref_st - total = ( - ret["espnet_cer_delete"] + ret["espnet_cer_replace"] + ret["espnet_cer_equal"] - ) - assert total == len(ref_words), (total, len(ref_words)) - total = ( - ret["espnet_cer_insert"] + ret["espnet_cer_replace"] + ret["espnet_cer_equal"] - ) - assert total == len(pred_words), (total, len(pred_words)) - - return ret - - -if __name__ == "__main__": - a = np.random.random(16000) - wer_utils = espnet_wer_setup() - print( - "metrics: {}".format( - espnet_levenshtein_metric(wer_utils, a, "test a sentence.", 16000) - ) - ) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import importlib.util +import logging + +import librosa +import numpy as np +import torch +from Levenshtein import opcodes + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + + +def _ensure_torchaudio_legacy_backend_api(): + try: + import torchaudio + except ImportError: + return + + if not hasattr(torchaudio, "set_audio_backend"): + torchaudio.set_audio_backend = lambda *args, **kwargs: None + + +_ensure_torchaudio_legacy_backend_api() + +try: + from espnet2.bin.asr_inference import Speech2Text + from espnet2.text.cleaner import TextCleaner +except ImportError: + Speech2Text = None + TextCleaner = None + +TARGET_FS = 16000 +CHUNK_SIZE = 30 # seconds + + +def espnet_wer_setup( + model_tag="default", + beam_size=5, + text_cleaner="whisper_basic", + use_gpu=True, + cache_dir=None, +): + if model_tag == "default": + model_tag = ( + "espnet/" + "simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_" + "bpe5000_sp" + ) + device = "cuda" if use_gpu else "cpu" + if Speech2Text is None or TextCleaner is None: + raise ImportError("espnet_wer requires espnet. Please install espnet and retry") + if cache_dir is None: + model = Speech2Text.from_pretrained( + model_tag=model_tag, + device=device, + beam_size=beam_size, + ) + else: + try: + from espnet_model_zoo.downloader import ModelDownloader + except ImportError: + raise ImportError( + "espnet_wer requires espnet_model_zoo. Please install it and retry" + ) + model_kwargs = ModelDownloader(cachedir=cache_dir).download_and_unpack( + model_tag + ) + model = Speech2Text(device=device, beam_size=beam_size, **model_kwargs) + textcleaner = TextCleaner(text_cleaner) + if "whisper" in text_cleaner: + if importlib.util.find_spec("whisper") is None: + logging.warning( + "Whipser-based cleaner is used but openai-whisper is not installed" + ) + wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} + return wer_utils + + +def espnet_predict( + model, + speech, + fs: int, + beam_size: int = 5, +): + """Generate predictions using the espnet model. (from URGENT Challenge) + + Args: + model (torch.nn.Module): espnet model. + speech (np.ndarray): speech signal < 120s (time,) + fs (int): sampling rate in Hz. + beam_size (int): beam size used in beam search. + Returns: + text (str): predicted text + """ + model.beam_search.beam_size = int(beam_size) + + assert fs == 16000, (fs, 16000) + + # assuming 10 tokens per second + model.maxlenratio = -min(300, int((len(speech) / TARGET_FS) * 10)) + + speech = librosa.util.fix_length(speech, size=(TARGET_FS * CHUNK_SIZE)) + text = model(speech)[0][0] + + return text + + +def espnet_levenshtein_metric(wer_utils, pred_x, ref_text, fs=16000): + """Calculate the Levenshtein distance between ref and inf ASR results. + + Args: + wer_utils (dict): a utility dict for WER calculation. + including: espnet model ("model"), text cleaner ("textcleaner"), and + beam size ("beam size") + pred_x (np.ndarray): test signal (time,) + ref_text (string): reference transcript + fs (int): sampling rate in Hz + Returns: + ret (dict): ditionary containing occurrences of edit operations + """ + if fs != TARGET_FS: + pred_x = resample_audio(pred_x, fs, TARGET_FS) + fs = TARGET_FS + with torch.no_grad(): + inf_txt = espnet_predict( + wer_utils["model"], + pred_x, + fs, + beam_size=wer_utils["beam_size"], + ) + + ref_text = wer_utils["cleaner"](ref_text) + pred_text = wer_utils["cleaner"](inf_txt) + + # process wer + ref_words = ref_text.strip().split() + pred_words = pred_text.strip().split() + ret = { + "espnet_hyp_text": pred_text, + "ref_text": ref_text, + "espnet_wer_delete": 0, + "espnet_wer_insert": 0, + "espnet_wer_replace": 0, + "espnet_wer_equal": 0, + } + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["espnet_wer_" + op] = ret["espnet_wer_" + op] + inf_et - inf_st + else: + ret["espnet_wer_" + op] = ret["espnet_wer_" + op] + ref_et - ref_st + total = ( + ret["espnet_wer_delete"] + ret["espnet_wer_replace"] + ret["espnet_wer_equal"] + ) + assert total == len(ref_words), (total, len(ref_words)) + total = ( + ret["espnet_wer_insert"] + ret["espnet_wer_replace"] + ret["espnet_wer_equal"] + ) + assert total == len(pred_words), (total, len(pred_words)) + + # process cer + ref_words = [c for c in ref_text] + pred_words = [c for c in pred_text] + ret.update( + espnet_cer_delete=0, + espnet_cer_insert=0, + espnet_cer_replace=0, + espnet_cer_equal=0, + ) + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["espnet_cer_" + op] = ret["espnet_cer_" + op] + inf_et - inf_st + else: + ret["espnet_cer_" + op] = ret["espnet_cer_" + op] + ref_et - ref_st + total = ( + ret["espnet_cer_delete"] + ret["espnet_cer_replace"] + ret["espnet_cer_equal"] + ) + assert total == len(ref_words), (total, len(ref_words)) + total = ( + ret["espnet_cer_insert"] + ret["espnet_cer_replace"] + ret["espnet_cer_equal"] + ) + assert total == len(pred_words), (total, len(pred_words)) + + return ret + + +class EspnetWerMetric(BaseMetric): + """ESPnet ASR-based WER/CER edit counts.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.beam_size = self.config.get("beam_size", 5) + self.text_cleaner = self.config.get("text_cleaner", "whisper_basic") + self.use_gpu = self.config.get("use_gpu", True) + self.cache_dir = self.config.get("cache_dir", "versa_cache/espnet_model_zoo") + self.wer_utils = espnet_wer_setup( + model_tag=self.model_tag, + beam_size=self.beam_size, + text_cleaner=self.text_cleaner, + use_gpu=self.use_gpu, + cache_dir=self.cache_dir, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + metadata = metadata or {} + ref_text = metadata.get("text") + if ref_text is None and isinstance(references, str): + ref_text = references + if ref_text is None: + raise ValueError("Reference text must be provided") + + fs = metadata.get("sample_rate", 16000) + return espnet_levenshtein_metric( + self.wer_utils, + np.asarray(predictions), + ref_text, + fs=fs, + ) + + def get_metadata(self): + return _espnet_wer_metadata() + + +def _espnet_wer_metadata(): + return MetricMetadata( + name="espnet_wer", + category=MetricCategory.NON_MATCH, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=True, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "Levenshtein", "librosa", "numpy", "torch"], + description="ESPnet ASR-based WER and CER edit counts", + paper_reference="https://arxiv.org/pdf/1804.00015", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_espnet_wer_metric(registry): + """Register ESPnet WER with the registry.""" + registry.register( + EspnetWerMetric, + _espnet_wer_metadata(), + aliases=["espnet_asr_wer", "espnet_wer_metric"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = EspnetWerMetric() + print(metric.compute(a, metadata={"sample_rate": 16000, "text": "test sentence"})) diff --git a/versa/corpus_metrics/owsm_wer.py b/versa/corpus_metrics/owsm_wer.py index c0bbc7c..cabd67f 100644 --- a/versa/corpus_metrics/owsm_wer.py +++ b/versa/corpus_metrics/owsm_wer.py @@ -1,221 +1,311 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging - -import librosa -import numpy as np -import torch -from espnet2.bin.s2t_inference import Speech2Text -from espnet2.text.cleaner import TextCleaner -from Levenshtein import opcodes - -TARGET_FS = 16000 -CHUNK_SIZE = 30 # seconds - - -def owsm_wer_setup( - model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True -): - if model_tag == "default": - model_tag = "espnet/owsm_v3.1_ebf" - device = "cuda" if use_gpu else "cpu" - model = Speech2Text.from_pretrained( - model_tag=model_tag, - device=device, - task_sym="", - beam_size=beam_size, - predict_time=False, - ) - textcleaner = TextCleaner(text_cleaner) - if "whisper" in text_cleaner: - try: - import whisper - except ImportError: - logging.warning( - "Whipser-based cleaner is used but openai-whisper is not installed" - ) - wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} - return wer_utils - - -# Copied from Whisper utils -def format_timestamp( - seconds: float, always_include_hours: bool = False, decimal_marker: str = "." -): - assert seconds >= 0, "non-negative timestamp expected" - milliseconds = round(seconds * 1000.0) - - hours = milliseconds // 3_600_000 - milliseconds -= hours * 3_600_000 - - minutes = milliseconds // 60_000 - milliseconds -= minutes * 60_000 - - seconds = milliseconds // 1_000 - milliseconds -= seconds * 1_000 - - hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" - return ( - f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" - ) - - -def owsm_predict( - model, - speech, - fs: int, - src_lang: str = "none", - beam_size: int = 5, - long_form: bool = False, - text_prev: str = "", -): - """Generate predictions using the OWSM model. (from URGENT Challenge) - - Args: - model (torch.nn.Module): OWSM model. - speech (np.ndarray): speech signal < 120s (time,) - fs (int): sampling rate in Hz. - src_lang (str): source language in ISO 639-2 Code. - beam_size (int): beam size used in beam search. - long_form (bool): perform long-form decoding for audios longer than 30s. - If an exception happens, it will fall back to standard decoding on the - initial 30s. - text_prev (str): generation will be conditioned on this prompt if provided. - Returns: - text (str): predicted text - """ - task_sym = "" - model.beam_search.beam_size = int(beam_size) - - assert fs == 16000, (fs, 16000) - - # Detect language using the first 30s of speech - if src_lang == "none": - from espnet2.bin.s2t_inference_language import Speech2Language as Speech2Lang - - # default 30 seconds chunk for owsm training - src_lang = model( - librosa.util.fix_length(speech, size=(TARGET_FS * CHUNK_SIZE)) - )[0][0].strip()[1:-1] - lang_sym = f"<{src_lang}>" - - # ASR or ST - if long_form: # speech will be padded in decode_long() - try: - model.maxlenratio = -300 - utts = model.decode_long( - speech, - condition_on_prev_text=False, - init_text=text_prev, - end_time_threshold="<29.00>", - lang_sym=lang_sym, - task_sym=task_sym, - ) - - text = [] - for t1, t2, res in utts: - text.append( - f"[{format_timestamp(seconds=t1)} --> " - f"{format_timestamp(seconds=t2)}] {res}" - ) - text = "\n".join(text) - - return text - except: - print( - "An exception occurred in long-form decoding. " - "Fall back to standard decoding (only first 30s)" - ) - - # assuming 10 tokens per second - model.maxlenratio = -min(300, int((len(speech) / TARGET_FS) * 10)) - - speech = librosa.util.fix_length(speech, size=(TARGET_FS * CHUNK_SIZE)) - text = model(speech, text_prev, lang_sym=lang_sym, task_sym=task_sym)[0][-2] - - return text - - -def owsm_levenshtein_metric(wer_utils, pred_x, ref_text, fs=16000): - """Calculate the Levenshtein distance between ref and inf ASR results. - - Args: - wer_utils (dict): a utility dict for WER calculation. - including: owsm model ("model"), text cleaner ("textcleaner"), and - beam size ("beam size") - pred_x (np.ndarray): test signal (time,) - ref_text (string): reference transcript - fs (int): sampling rate in Hz - Returns: - ret (dict): ditionary containing occurrences of edit operations - """ - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - fs = TARGET_FS - with torch.no_grad(): - inf_txt = owsm_predict( - wer_utils["model"], - pred_x, - fs, - src_lang="eng", - beam_size=wer_utils["beam_size"], - long_form=len(pred_x) > CHUNK_SIZE * fs, - ) - - ref_text = wer_utils["cleaner"](ref_text).strip() - pred_text = wer_utils["cleaner"](inf_txt).strip() - - # process wer - ref_words = ref_text.strip().split() - pred_words = pred_text.strip().split() - ret = { - "owsm_hyp_text": pred_text, - "ref_text": ref_text, - "owsm_wer_delete": 0, - "owsm_wer_insert": 0, - "owsm_wer_replace": 0, - "owsm_wer_equal": 0, - } - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): - if op == "insert": - ret["owsm_wer_" + op] = ret["owsm_wer_" + op] + inf_et - inf_st - else: - ret["owsm_wer_" + op] = ret["owsm_wer_" + op] + ref_et - ref_st - total = ret["owsm_wer_delete"] + ret["owsm_wer_replace"] + ret["owsm_wer_equal"] - assert total == len(ref_words), (total, len(ref_words)) - total = ret["owsm_wer_insert"] + ret["owsm_wer_replace"] + ret["owsm_wer_equal"] - assert total == len(pred_words), (total, len(pred_words)) - - # process cer - ref_words = [c for c in ref_text] - pred_words = [c for c in pred_text] - ret.update( - owsm_cer_delete=0, - owsm_cer_insert=0, - owsm_cer_replace=0, - owsm_cer_equal=0, - ) - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): - if op == "insert": - ret["owsm_cer_" + op] = ret["owsm_cer_" + op] + inf_et - inf_st - else: - ret["owsm_cer_" + op] = ret["owsm_cer_" + op] + ref_et - ref_st - total = ret["owsm_cer_delete"] + ret["owsm_cer_replace"] + ret["owsm_cer_equal"] - assert total == len(ref_words), (total, len(ref_words)) - total = ret["owsm_cer_insert"] + ret["owsm_cer_replace"] + ret["owsm_cer_equal"] - assert total == len(pred_words), (total, len(pred_words)) - - return ret - - -if __name__ == "__main__": - a = np.random.random(16000) - wer_utils = owsm_wer_setup() - print( - "metrics: {}".format( - owsm_levenshtein_metric(wer_utils, a, "test a sentence.", 16000) - ) - ) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import importlib.util +import logging + +import librosa +import numpy as np +import torch +from Levenshtein import opcodes + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +try: + from espnet2.bin.s2t_inference import Speech2Text + from espnet2.text.cleaner import TextCleaner +except ImportError: + Speech2Text = None + TextCleaner = None + +TARGET_FS = 16000 +CHUNK_SIZE = 30 # seconds + + +def owsm_wer_setup( + model_tag="default", + beam_size=5, + text_cleaner="whisper_basic", + use_gpu=True, + cache_dir=None, +): + if model_tag == "default": + model_tag = "espnet/owsm_v3.1_ebf" + device = "cuda" if use_gpu else "cpu" + if Speech2Text is None or TextCleaner is None: + raise ImportError("owsm_wer requires espnet. Please install espnet and retry") + if cache_dir is None: + model = Speech2Text.from_pretrained( + model_tag=model_tag, + device=device, + task_sym="", + beam_size=beam_size, + predict_time=False, + ) + else: + try: + from espnet_model_zoo.downloader import ModelDownloader + except ImportError: + raise ImportError( + "owsm_wer requires espnet_model_zoo. Please install it and retry" + ) + model_kwargs = ModelDownloader(cachedir=cache_dir).download_and_unpack( + model_tag + ) + model = Speech2Text( + device=device, + task_sym="", + beam_size=beam_size, + predict_time=False, + **model_kwargs, + ) + textcleaner = TextCleaner(text_cleaner) + if "whisper" in text_cleaner: + if importlib.util.find_spec("whisper") is None: + logging.warning( + "Whipser-based cleaner is used but openai-whisper is not installed" + ) + wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} + return wer_utils + + +# Copied from Whisper utils +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + +def owsm_predict( + model, + speech, + fs: int, + src_lang: str = "none", + beam_size: int = 5, + long_form: bool = False, + text_prev: str = "", +): + """Generate predictions using the OWSM model. (from URGENT Challenge) + + Args: + model (torch.nn.Module): OWSM model. + speech (np.ndarray): speech signal < 120s (time,) + fs (int): sampling rate in Hz. + src_lang (str): source language in ISO 639-2 Code. + beam_size (int): beam size used in beam search. + long_form (bool): perform long-form decoding for audios longer than 30s. + If an exception happens, it will fall back to standard decoding on the + initial 30s. + text_prev (str): generation will be conditioned on this prompt if provided. + Returns: + text (str): predicted text + """ + task_sym = "" + model.beam_search.beam_size = int(beam_size) + + assert fs == 16000, (fs, 16000) + + # Detect language using the first 30s of speech + if src_lang == "none": + # default 30 seconds chunk for owsm training + src_lang = model( + librosa.util.fix_length(speech, size=(TARGET_FS * CHUNK_SIZE)) + )[0][0].strip()[1:-1] + lang_sym = f"<{src_lang}>" + + # ASR or ST + if long_form: # speech will be padded in decode_long() + try: + model.maxlenratio = -300 + utts = model.decode_long( + speech, + condition_on_prev_text=False, + init_text=text_prev, + end_time_threshold="<29.00>", + lang_sym=lang_sym, + task_sym=task_sym, + ) + + text = [] + for t1, t2, res in utts: + text.append( + f"[{format_timestamp(seconds=t1)} --> " + f"{format_timestamp(seconds=t2)}] {res}" + ) + text = "\n".join(text) + + return text + except Exception: + print( + "An exception occurred in long-form decoding. " + "Fall back to standard decoding (only first 30s)" + ) + + # assuming 10 tokens per second + model.maxlenratio = -min(300, int((len(speech) / TARGET_FS) * 10)) + + speech = librosa.util.fix_length(speech, size=(TARGET_FS * CHUNK_SIZE)) + text = model(speech, text_prev, lang_sym=lang_sym, task_sym=task_sym)[0][-2] + + return text + + +def owsm_levenshtein_metric(wer_utils, pred_x, ref_text, fs=16000): + """Calculate the Levenshtein distance between ref and inf ASR results. + + Args: + wer_utils (dict): a utility dict for WER calculation. + including: owsm model ("model"), text cleaner ("textcleaner"), and + beam size ("beam size") + pred_x (np.ndarray): test signal (time,) + ref_text (string): reference transcript + fs (int): sampling rate in Hz + Returns: + ret (dict): ditionary containing occurrences of edit operations + """ + if fs != TARGET_FS: + pred_x = resample_audio(pred_x, fs, TARGET_FS) + fs = TARGET_FS + with torch.no_grad(): + inf_txt = owsm_predict( + wer_utils["model"], + pred_x, + fs, + src_lang="eng", + beam_size=wer_utils["beam_size"], + long_form=len(pred_x) > CHUNK_SIZE * fs, + ) + + ref_text = wer_utils["cleaner"](ref_text).strip() + pred_text = wer_utils["cleaner"](inf_txt).strip() + + # process wer + ref_words = ref_text.strip().split() + pred_words = pred_text.strip().split() + ret = { + "owsm_hyp_text": pred_text, + "ref_text": ref_text, + "owsm_wer_delete": 0, + "owsm_wer_insert": 0, + "owsm_wer_replace": 0, + "owsm_wer_equal": 0, + } + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["owsm_wer_" + op] = ret["owsm_wer_" + op] + inf_et - inf_st + else: + ret["owsm_wer_" + op] = ret["owsm_wer_" + op] + ref_et - ref_st + total = ret["owsm_wer_delete"] + ret["owsm_wer_replace"] + ret["owsm_wer_equal"] + assert total == len(ref_words), (total, len(ref_words)) + total = ret["owsm_wer_insert"] + ret["owsm_wer_replace"] + ret["owsm_wer_equal"] + assert total == len(pred_words), (total, len(pred_words)) + + # process cer + ref_words = [c for c in ref_text] + pred_words = [c for c in pred_text] + ret.update( + owsm_cer_delete=0, + owsm_cer_insert=0, + owsm_cer_replace=0, + owsm_cer_equal=0, + ) + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["owsm_cer_" + op] = ret["owsm_cer_" + op] + inf_et - inf_st + else: + ret["owsm_cer_" + op] = ret["owsm_cer_" + op] + ref_et - ref_st + total = ret["owsm_cer_delete"] + ret["owsm_cer_replace"] + ret["owsm_cer_equal"] + assert total == len(ref_words), (total, len(ref_words)) + total = ret["owsm_cer_insert"] + ret["owsm_cer_replace"] + ret["owsm_cer_equal"] + assert total == len(pred_words), (total, len(pred_words)) + + return ret + + +class OwsmWerMetric(BaseMetric): + """OWSM ASR-based WER/CER edit counts.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.beam_size = self.config.get("beam_size", 5) + self.text_cleaner = self.config.get("text_cleaner", "whisper_basic") + self.use_gpu = self.config.get("use_gpu", True) + self.cache_dir = self.config.get("cache_dir", "versa_cache/espnet_model_zoo") + self.wer_utils = owsm_wer_setup( + model_tag=self.model_tag, + beam_size=self.beam_size, + text_cleaner=self.text_cleaner, + use_gpu=self.use_gpu, + cache_dir=self.cache_dir, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + metadata = metadata or {} + ref_text = metadata.get("text") + if ref_text is None and isinstance(references, str): + ref_text = references + if ref_text is None: + raise ValueError("Reference text must be provided") + + fs = metadata.get("sample_rate", 16000) + return owsm_levenshtein_metric( + self.wer_utils, + np.asarray(predictions), + ref_text, + fs=fs, + ) + + def get_metadata(self): + return _owsm_wer_metadata() + + +def _owsm_wer_metadata(): + return MetricMetadata( + name="owsm_wer", + category=MetricCategory.NON_MATCH, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=True, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "Levenshtein", "librosa", "numpy", "torch"], + description="OWSM ASR-based WER and CER edit counts", + paper_reference="https://arxiv.org/abs/2309.13876", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_owsm_wer_metric(registry): + """Register OWSM WER with the registry.""" + registry.register( + OwsmWerMetric, + _owsm_wer_metadata(), + aliases=["owsm_asr_wer", "owsm_wer_metric"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = OwsmWerMetric() + print(metric.compute(a, metadata={"sample_rate": 16000, "text": "test sentence"})) diff --git a/versa/corpus_metrics/whisper_wer.py b/versa/corpus_metrics/whisper_wer.py index 2a4b218..a1e0e90 100644 --- a/versa/corpus_metrics/whisper_wer.py +++ b/versa/corpus_metrics/whisper_wer.py @@ -1,139 +1,218 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging - -import librosa -import numpy as np -import torch -from Levenshtein import opcodes - -try: - import whisper -except ImportError: - logging.warning( - "Whisper is not properly installed. Please install following https://github.com/openai/whisper" - ) - whisper = None - -from espnet2.text.cleaner import TextCleaner - -TARGET_FS = 16000 -CHUNK_SIZE = 30 # seconds - - -def whisper_wer_setup( - model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True -): - if model_tag == "default": - model_tag = "large" - device = "cuda" if use_gpu else "cpu" - if whisper is None: - raise RuntimeError( - "Whisper WER is used for evaluation while openai-whisper is not installed" - ) - model = whisper.load_model(model_tag, device=device) - textcleaner = TextCleaner(text_cleaner) - wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} - return wer_utils - - -def whisper_levenshtein_metric( - wer_utils, pred_x, ref_text, fs=16000, cache_pred_text=None -): - """Calculate the Levenshtein distance between ref and inf ASR results. - - Args: - wer_utils (dict): a utility dict for WER calculation. - including: whisper model ("model"), text cleaner ("textcleaner"), and - beam size ("beam size") - pred_x (np.ndarray): test signal (time,) - ref_text (string): reference transcript - cache_pred_text (string): transcription from cache (previous modules) - fs (int): sampling rate in Hz - Returns: - ret (dict): ditionary containing occurrences of edit operations - """ - if cache_pred_text is not None: - inf_text = cache_pred_text - else: - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - fs = TARGET_FS - with torch.no_grad(): - inf_text = wer_utils["model"].transcribe( - torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] - )["text"] - - ref_text = wer_utils["cleaner"](ref_text).strip() - pred_text = wer_utils["cleaner"](inf_text).strip() - - # process wer - ref_words = ref_text.strip().split() - pred_words = pred_text.strip().split() - ret = { - "whisper_hyp_text": pred_text, - "ref_text": ref_text, - "whisper_wer_delete": 0, - "whisper_wer_insert": 0, - "whisper_wer_replace": 0, - "whisper_wer_equal": 0, - } - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): - if op == "insert": - ret["whisper_wer_" + op] = ret["whisper_wer_" + op] + inf_et - inf_st - else: - ret["whisper_wer_" + op] = ret["whisper_wer_" + op] + ref_et - ref_st - total = ( - ret["whisper_wer_delete"] - + ret["whisper_wer_replace"] - + ret["whisper_wer_equal"] - ) - assert total == len(ref_words), (total, len(ref_words)) - total = ( - ret["whisper_wer_insert"] - + ret["whisper_wer_replace"] - + ret["whisper_wer_equal"] - ) - assert total == len(pred_words), (total, len(pred_words)) - - # process cer - ref_words = [c for c in ref_text] - pred_words = [c for c in pred_text] - ret.update( - whisper_cer_delete=0, - whisper_cer_insert=0, - whisper_cer_replace=0, - whisper_cer_equal=0, - ) - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): - if op == "insert": - ret["whisper_cer_" + op] = ret["whisper_cer_" + op] + inf_et - inf_st - else: - ret["whisper_cer_" + op] = ret["whisper_cer_" + op] + ref_et - ref_st - total = ( - ret["whisper_cer_delete"] - + ret["whisper_cer_replace"] - + ret["whisper_cer_equal"] - ) - assert total == len(ref_words), (total, len(ref_words)) - total = ( - ret["whisper_cer_insert"] - + ret["whisper_cer_replace"] - + ret["whisper_cer_equal"] - ) - assert total == len(pred_words), (total, len(pred_words)) - - return ret - - -if __name__ == "__main__": - a = np.random.random(16000) - wer_utils = whisper_wer_setup() - print( - "metrics: {}".format( - whisper_levenshtein_metric(wer_utils, a, "test a sentence.", 16000) - ) - ) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging + +import numpy as np +import torch +from Levenshtein import opcodes + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +try: + import whisper +except ImportError: + logging.warning( + "Whisper is not properly installed. Please install following " + "https://github.com/openai/whisper" + ) + whisper = None + +try: + from espnet2.text.cleaner import TextCleaner +except ImportError: + TextCleaner = None + +TARGET_FS = 16000 +CHUNK_SIZE = 30 # seconds + + +def whisper_wer_setup( + model_tag="default", + beam_size=5, + text_cleaner="whisper_basic", + use_gpu=True, + cache_dir="versa_cache/whisper", +): + if model_tag == "default": + model_tag = "large" + device = "cuda" if use_gpu else "cpu" + if whisper is None: + raise RuntimeError( + "Whisper WER is used for evaluation while openai-whisper is not installed" + ) + if TextCleaner is None: + raise ImportError("whisper_wer requires espnet TextCleaner. Install espnet") + model = whisper.load_model(model_tag, device=device, download_root=cache_dir) + textcleaner = TextCleaner(text_cleaner) + wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} + return wer_utils + + +def whisper_levenshtein_metric( + wer_utils, pred_x, ref_text, fs=16000, cache_pred_text=None +): + """Calculate the Levenshtein distance between ref and inf ASR results. + + Args: + wer_utils (dict): a utility dict for WER calculation. + including: whisper model ("model"), text cleaner ("textcleaner"), and + beam size ("beam size") + pred_x (np.ndarray): test signal (time,) + ref_text (string): reference transcript + cache_pred_text (string): transcription from cache (previous modules) + fs (int): sampling rate in Hz + Returns: + ret (dict): ditionary containing occurrences of edit operations + """ + if cache_pred_text is not None: + inf_text = cache_pred_text + else: + if fs != TARGET_FS: + pred_x = resample_audio(pred_x, fs, TARGET_FS) + fs = TARGET_FS + with torch.no_grad(): + inf_text = wer_utils["model"].transcribe( + torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] + )["text"] + + ref_text = wer_utils["cleaner"](ref_text).strip() + pred_text = wer_utils["cleaner"](inf_text).strip() + + # process wer + ref_words = ref_text.strip().split() + pred_words = pred_text.strip().split() + ret = { + "whisper_hyp_text": pred_text, + "ref_text": ref_text, + "whisper_wer_delete": 0, + "whisper_wer_insert": 0, + "whisper_wer_replace": 0, + "whisper_wer_equal": 0, + } + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["whisper_wer_" + op] = ret["whisper_wer_" + op] + inf_et - inf_st + else: + ret["whisper_wer_" + op] = ret["whisper_wer_" + op] + ref_et - ref_st + total = ( + ret["whisper_wer_delete"] + + ret["whisper_wer_replace"] + + ret["whisper_wer_equal"] + ) + assert total == len(ref_words), (total, len(ref_words)) + total = ( + ret["whisper_wer_insert"] + + ret["whisper_wer_replace"] + + ret["whisper_wer_equal"] + ) + assert total == len(pred_words), (total, len(pred_words)) + + # process cer + ref_words = [c for c in ref_text] + pred_words = [c for c in pred_text] + ret.update( + whisper_cer_delete=0, + whisper_cer_insert=0, + whisper_cer_replace=0, + whisper_cer_equal=0, + ) + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_words, pred_words): + if op == "insert": + ret["whisper_cer_" + op] = ret["whisper_cer_" + op] + inf_et - inf_st + else: + ret["whisper_cer_" + op] = ret["whisper_cer_" + op] + ref_et - ref_st + total = ( + ret["whisper_cer_delete"] + + ret["whisper_cer_replace"] + + ret["whisper_cer_equal"] + ) + assert total == len(ref_words), (total, len(ref_words)) + total = ( + ret["whisper_cer_insert"] + + ret["whisper_cer_replace"] + + ret["whisper_cer_equal"] + ) + assert total == len(pred_words), (total, len(pred_words)) + + return ret + + +class WhisperWerMetric(BaseMetric): + """Whisper ASR-based WER/CER edit counts.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.beam_size = self.config.get("beam_size", 5) + self.text_cleaner = self.config.get("text_cleaner", "whisper_basic") + self.use_gpu = self.config.get("use_gpu", True) + self.cache_dir = self.config.get("cache_dir", "versa_cache/whisper") + self.wer_utils = whisper_wer_setup( + model_tag=self.model_tag, + beam_size=self.beam_size, + text_cleaner=self.text_cleaner, + use_gpu=self.use_gpu, + cache_dir=self.cache_dir, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + metadata = metadata or {} + ref_text = metadata.get("text") + if ref_text is None and isinstance(references, str): + ref_text = references + if ref_text is None: + raise ValueError("Reference text must be provided") + + cache_pred_text = metadata.get("whisper_hyp_text") + general_cache = metadata.get("general_cache") + if cache_pred_text is None and general_cache: + cache_pred_text = general_cache.get("whisper_hyp_text") + + fs = metadata.get("sample_rate", 16000) + return whisper_levenshtein_metric( + self.wer_utils, + np.asarray(predictions), + ref_text, + fs=fs, + cache_pred_text=cache_pred_text, + ) + + def get_metadata(self): + return _whisper_wer_metadata() + + +def _whisper_wer_metadata(): + return MetricMetadata( + name="whisper_wer", + category=MetricCategory.NON_MATCH, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=True, + gpu_compatible=True, + auto_install=False, + dependencies=["whisper", "espnet2", "Levenshtein", "librosa", "numpy", "torch"], + description="Whisper ASR-based WER and CER edit counts", + paper_reference="https://arxiv.org/abs/2212.04356", + implementation_source="https://github.com/openai/whisper", + ) + + +def register_whisper_wer_metric(registry): + """Register Whisper WER with the registry.""" + registry.register( + WhisperWerMetric, + _whisper_wer_metadata(), + aliases=["whisper_asr_wer", "whisper_wer_metric"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = WhisperWerMetric() + print(metric.compute(a, metadata={"sample_rate": 16000, "text": "test sentence"})) diff --git a/versa/definition.py b/versa/definition.py new file mode 100644 index 0000000..e98045c --- /dev/null +++ b/versa/definition.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from enum import Enum +import logging + + +class MetricCategory(Enum): + INDEPENDENT = "independent" + DEPENDENT = "dependent" + NON_MATCH = "non_match" + DISTRIBUTIONAL = "distributional" + + +class MetricType(Enum): + STRING = "string" + FLOAT = "float" + INT = "int" + BOOL = "bool" + LIST = "list" + DICT = "dict" + TUPLE = "tuple" + ARRAY = "array" + TIME = "time" + + +@dataclass +class MetricMetadata: + name: str + category: MetricCategory + metric_type: MetricType + requires_reference: bool + requires_text: bool + gpu_compatible: bool + auto_install: bool + dependencies: List[str] + description: str + paper_reference: Optional[str] = None + implementation_source: Optional[str] = None + + +class MetricRegistry: + """Centralized registry for all metrics with automatic discovery.""" + + def __init__(self): + self._metrics: Dict[str, type] = {} + self._metadata: Dict[str, MetricMetadata] = {} + self._aliases: Dict[str, str] = {} + + def register( + self, metric_class: type, metadata: MetricMetadata, aliases: List[str] = None + ): + """Register a metric with its metadata.""" + self._metrics[metadata.name] = metric_class + self._metadata[metadata.name] = metadata + + # Register aliases + if aliases: + for alias in aliases: + self._aliases[alias] = metadata.name + + def get_metric(self, name: str) -> type: + """Get metric class by name or alias.""" + real_name = self._aliases.get(name, name) + return self._metrics.get(real_name) + + def get_metadata(self, name: str) -> MetricMetadata: + """Get metric metadata by name or alias.""" + real_name = self._aliases.get(name, name) + return self._metadata.get(real_name) + + def list_metrics( + self, category: MetricCategory = None, metric_type: MetricType = None + ) -> List[str]: + """List available metrics with optional filtering.""" + metrics = [] + for name, metadata in self._metadata.items(): + if category and metadata.category != category: + continue + if metric_type and metadata.metric_type != metric_type: + continue + metrics.append(name) + return sorted(metrics) + + +class BaseMetric(ABC): + """Abstract base class for all metrics.""" + + def __init__(self, config: Dict[str, Any] = None): + self.config = config or {} + self.logger = logging.getLogger(self.__class__.__name__) + self._setup() + + @abstractmethod + def _setup(self): + """Initialize metric-specific components.""" + pass + + @abstractmethod + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Any: + """Compute the metric score.""" + pass + + @abstractmethod + def get_metadata(self) -> MetricMetadata: + """Return metric metadata.""" + pass + + def validate_inputs(self, predictions: Any, references: Any = None) -> bool: + """Validate input data before computation.""" + return True + + def preprocess(self, data: Any) -> Any: + """Preprocess data before metric computation.""" + return data + + def postprocess(self, scores: Any) -> Any: + """Postprocess scores after computation.""" + return scores + + +class GPUMetric(BaseMetric): + """Base class for GPU-compatible metrics.""" + + def __init__(self, config: Dict[str, Any] = None, device: str = "cuda"): + self.device = device + super().__init__(config) + + def to_device(self, data: Any) -> Any: + """Move data to specified device.""" + if hasattr(data, "to"): + return data.to(self.device) + return data + + +class MetricFactory: + """Factory for creating metric instances with dependency management.""" + + def __init__(self, registry: MetricRegistry): + self.registry = registry + self._dependency_cache = {} + self.logger = logging.getLogger(self.__class__.__name__) + + def create_metric(self, name: str, config: Dict[str, Any] = None) -> BaseMetric: + """Create a metric instance with proper dependency resolution.""" + metadata = self.registry.get_metadata(name) + metric_class = self.registry.get_metric(name) + + if not metric_class: + raise ValueError(f"Metric '{name}' not found in registry") + + # Check and install dependencies + self._ensure_dependencies(metadata.dependencies) + + return metric_class(config) + + def create_metric_suite( + self, metric_names: List[str], config: Dict[str, Any] = None + ) -> "MetricSuite": + """Create a suite of metrics.""" + metrics = {} + config = config or {} + for name in metric_names: + metrics[name] = self.create_metric(name, config.get(name, {})) + return MetricSuite(metrics) + + def _ensure_dependencies(self, dependencies: List[str]): + """Ensure all dependencies are available.""" + for dep in dependencies: + if dep not in self._dependency_cache: + try: + __import__(dep) + self._dependency_cache[dep] = True + except ImportError: + self.logger.warning(f"Dependency '{dep}' not available") + self._dependency_cache[dep] = False + + +class MetricSuite: + """Container for multiple metrics with batch processing capabilities.""" + + def __init__(self, metrics: Dict[str, BaseMetric]): + self.metrics = metrics + self.logger = logging.getLogger(self.__class__.__name__) + + def compute_all( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Any]: + """Compute all metrics in the suite.""" + results = {} + for name, metric in self.metrics.items(): + try: + results[name] = metric.compute(predictions, references, metadata) + except Exception as e: + self.logger.error(f"Error computing metric '{name}': {e}") + results[name] = None + return results + + def compute_parallel( + self, + predictions: Any, + references: Any = None, + metadata: Dict[str, Any] = None, + n_workers: int = 4, + ) -> Dict[str, Any]: + """Compute metrics in parallel.""" + # Implementation for parallel metric computation + pass + + def filter_by_category(self, category: MetricCategory) -> "MetricSuite": + """Filter metrics by category.""" + filtered_metrics = { + name: metric + for name, metric in self.metrics.items() + if metric.get_metadata().category == category + } + return MetricSuite(filtered_metrics) diff --git a/versa/metrics.py b/versa/metrics.py index 1e86b19..3257a3e 100644 --- a/versa/metrics.py +++ b/versa/metrics.py @@ -3,6 +3,12 @@ # Copyright 2025 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +DICT_METRIC = [ + "match_details", + "language", +] + STR_METRIC = [ "vad_info", "language", @@ -31,34 +37,13 @@ "espnet_hyp_text", "owsm_hyp_text", "whisper_hyp_text", - "arecho_qwen_vocabulary_complexity", - "arecho_qwen_speaker_age", - "arecho_qwen_voice_pitch", - "arecho_qwen_speech_purpose", - "arecho_qwen_speech_emotion", - "arecho_qwen_language", - "arecho_qwen_speech_clarity", - "arecho_qwen_recording_quality", - "arecho_qwen_speech_background_environment", - "arecho_qwen_channel_type", - "arecho_qwen_speech_rate", - "arecho_qwen_speaker_count", - "arecho_qwen_speech_volume_level", - "arecho_qwen_speaker_gender", - "arecho_qwen_pitch_range", - "arecho_qwen_speech_register", - "arecho_qwen_laughter_crying", - "arecho_qwen_speaking_style", - "arecho_qwen_voice_type", - "arecho_qwen_speech_impairment", - "arecho_rir_room_size", - "arecho_real_language", - "arecho_language", ] NUM_METRIC = [ "dnsmos_overall", "dnsmos_p808", + "dns_overall", + "dns_p808", "nisqa", "utmos", "plcmos", @@ -81,9 +66,11 @@ "audiobox_aesthetics_CU", "audiobox_aesthetics_PC", "audiobox_aesthetics_PQ", - "cdpam", - "dpam", + "cdpam_distance", + "dpam_distance", "mcd", + "f0corr", + "f0rmse", "f0_corr", "f0_rmse", "sir", @@ -91,8 +78,11 @@ "sdr", "ci-sdr", "si-snr", + "ci_sdr", + "si_snr", "pesq", "stoi", + "estoi", "speech_bert", "speech_belu", "speech_token_distance", @@ -135,23 +125,47 @@ "owsm_cer_equal", "whisper_wer", "whisper_wer_delete", - "espnet_wer_insert", - "espnet_wer_replace", - "espnet_wer_equal", + "whisper_wer_insert", + "whisper_wer_replace", + "whisper_wer_equal", "whisper_cer", "whisper_cer_delete", - "espnet_cer_insert", - "espnet_cer_replace", - "espnet_cer_equal", + "whisper_cer_insert", + "whisper_cer_replace", + "whisper_cer_equal", "emotion_similarity", "spk_similarity", + "singer_similarity", "nomad", "clap_score", "apa", "pysepm_llr", + "chroma_stft_cosine_dtw", + "chroma_stft_euclidean_dtw", + "chroma_cqt_cosine_dtw", + "chroma_cqt_euclidean_dtw", + "chroma_cens_cosine_dtw", + "chroma_cens_euclidean_dtw", + "chroma_stft_cosine_dtw_raw", + "chroma_stft_cosine_dtw_log", + "speech_bert", + "speech_bleu", + "speech_token_distance", + "arousal_emo_vad", + "valence_emo_vad", + "dominance_emo_vad", "dnsmos_pro_bvcc", "dnsmos_pro_nisqa", "dnsmos_pro_vcc2018", + "singmos_pro", + "nisqa_mos_pred", + "nisqa_noi_pred", + "nisqa_dis_pred", + "nisqa_col_pred", + "nisqa_loud_pred", + "noresqa_mos", + "noresqa_score", + "pam_score", "arecho_srmr", "arecho_voicemos_real_mos", "arecho_rt60", diff --git a/versa/scorer_shared.py b/versa/scorer_shared.py index cecd525..13fcd9f 100644 --- a/versa/scorer_shared.py +++ b/versa/scorer_shared.py @@ -1,1485 +1,477 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging - -import json -import kaldiio -import librosa -import soundfile as sf -import yaml -from tqdm import tqdm - -from versa.metrics import STR_METRIC, NUM_METRIC -from versa.utils_shared import ( - check_all_same, - check_minimum_length, - default_numpy_serializer, - find_files, - load_audio, - wav_normalize, -) - - -def audio_loader_setup(audio, io): - # get ready compute embeddings - if io == "kaldi": - audio_files = kaldiio.load_scp(audio) - elif io == "dir": - audio_files = find_files(audio) - elif io == "soundfile": - audio_files = {} - with open(audio) as f: - for line in f.readlines(): - key, value = line.strip().split(maxsplit=1) - if value.endswith("|"): - raise ValueError( - "Not supported wav.scp format. Set IO interface to kaldi" - ) - audio_files[key] = value - return audio_files - - -def load_score_modules(score_config, use_gt=True, use_gt_text=False, use_gpu=False): - assert score_config, "no scoring function is provided" - score_modules = {} - for config in score_config: - print(config, flush=True) - if config["name"] == "mcd_f0": - if not use_gt: - logging.warning( - "Cannot use mcd/f0 metrics because no gt audio is provided" - ) - continue - - logging.info("Loading MCD & F0 evaluation...") - from versa import mcd_f0 - - score_modules["mcd_f0"] = { - "module": mcd_f0, - "args": { - "f0min": config.get("f0min", 0), - "f0max": config.get("f0max", 24000), - "mcep_shift": config.get("mcep_shift", 5), - "mcep_fftl": config.get("mcep_fftl", 1024), - "mcep_dim": config.get("mcep_dim", 39), - "mcep_alpha": config.get("mcep_alpha", 0.466), - "seq_mismatch_tolerance": config.get("seq_mismatch_tolerance", 0.1), - "power_threshold": config.get("power_threshold", -20), - "dtw": config.get("dtw", False), - }, - } - logging.info("Initiate MCD & F0 evaluation successfully.") - - elif config["name"] == "signal_metric": - if not use_gt: - logging.warning( - "Cannot use signal metric because no gt audio is provided" - ) - continue - - logging.info("Loading signal metric evaluation...") - from versa import signal_metric - - score_modules["signal_metric"] = {"module": signal_metric} - logging.info("Initiate signal metric evaluation successfully.") - - elif config["name"] == "warpq": - if not use_gt: - logging.warning("Cannot use warpq because no gt audio is provided") - continue - - logging.info("Loading WARPQ metric evaluation...") - from versa.sequence_metrics.warpq import warpq, warpq_setup - - score_modules["warpq"] = {"model": warpq_setup(), "module": warpq} - logging.info("Initiate WARP-Q metric...") - - elif config["name"] == "nisqa": - logging.info("Loading NISQA evaluation...") - from versa.utterance_metrics.nisqa import nisqa_metric, nisqa_model_setup - - # Load the NISQA model - nisqa_model = nisqa_model_setup( - nisqa_model_path=config.get( - "model_path", "./tools/NISQA/weights/nisqa.tar" - ), - use_gpu=use_gpu, - ) - score_modules["nisqa"] = { - "module": nisqa_metric, - "model": nisqa_model, - } - logging.info("Initiate NISQA evaluation successfully.") - - elif config["name"] == "discrete_speech": - if not use_gt: - logging.warning( - "Cannot use discrete speech metric because no gt audio is provided" - ) - continue - - logging.info("Loading discrete speech evaluation...") - from versa import discrete_speech_metric, discrete_speech_setup - - score_modules["discrete_speech"] = { - "module": discrete_speech_metric, - "model": discrete_speech_setup(use_gpu=use_gpu), - } - logging.info("Initiate discrete speech evaluation successfully.") - - elif config["name"] == "pseudo_mos": - logging.info("Loading pseudo MOS evaluation...") - from versa import pseudo_mos_metric, pseudo_mos_setup - - predictor_dict, predictor_fs = pseudo_mos_setup( - use_gpu=use_gpu, - predictor_types=config.get("predictor_types", ["utmos"]), - predictor_args=config.get("predictor_args", {}), - ) - score_modules["pseudo_mos"] = { - "module": pseudo_mos_metric, - "args": { - "predictor_dict": predictor_dict, - "predictor_fs": predictor_fs, - "use_gpu": use_gpu, - }, - } - logging.info("Initiate pseudo MOS evaluation successfully.") - - elif config["name"] == "pesq": - if not use_gt: - logging.warning( - "Cannot use pesq metric because no gt audio is provided" - ) - continue - - logging.info("Loading pesq evaluation...") - from versa import pesq_metric - - score_modules["pesq"] = {"module": pesq_metric} - logging.info("Initiate pesq evaluation successfully.") - - elif config["name"] == "stoi": - if not use_gt: - logging.warning( - "Cannot use stoi metric because no gt audio is provided" - ) - continue - - logging.info("Loading stoi evaluation...") - from versa import stoi_metric - - score_modules["stoi"] = {"module": stoi_metric} - logging.info("Initiate stoi evaluation successfully.") - - elif config["name"] == "estoi": - if not use_gt: - logging.warning( - "Cannot use estoi metric because no gt audio is provided" - ) - continue - - logging.info("Loading stoi evaluation...") - from versa import estoi_metric - - score_modules["estoi"] = {"module": estoi_metric} - logging.info("Initiate stoi evaluation successfully.") - - elif config["name"] == "visqol": - if not use_gt: - logging.warning( - "Cannot use visqol metric because no gt audio is provided" - ) - continue - - logging.info("Loading visqol evaluation...") - try: - from versa import visqol_metric, visqol_setup - except ImportError: - logging.warning( - "VISQOL not installed, please check `tools` for installation guideline" - ) - continue - - api, fs = visqol_setup(model=config.get("model", "default")) - score_modules["visqol"] = { - "module": visqol_metric, - "args": {"api": api, "api_fs": fs}, - } - logging.info("Initiate visqol evaluation successfully.") - - elif config["name"] == "speaker": - if not use_gt: - logging.warning( - "Cannot use speaker metric because no gt audio is provided" - ) - continue - - logging.info("Loading speaker evaluation...") - from versa import speaker_metric, speaker_model_setup - - spk_model = speaker_model_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - model_config=config.get("model_config", None), - use_gpu=use_gpu, - ) - score_modules["speaker"] = { - "module": speaker_metric, - "args": {"model": spk_model}, - } - logging.info("Initiate speaker evaluation successfully.") - - elif config["name"] == "singer": - if not use_gt: - logging.warning( - "Cannot use singer metric because no gt audio is provided" - ) - continue - - logging.info("Loading singer evaluation...") - from versa import singer_metric, singer_model_setup - - singer_model = singer_model_setup( - model_name=config.get("model_name", "byol"), - model_path=config.get("model_path", None), - use_gpu=use_gpu, - torchscript=config.get("torchscript", False), - ) - - score_modules["singer"] = { - "module": singer_metric, - "args": {"model": singer_model}, - } - logging.info("Initiate singer evaluation successfully.") - - elif config["name"] == "sheet_ssqa": - logging.info("Loading Sheet SSQA models for evaluation...") - from versa import sheet_ssqa, sheet_ssqa_setup - - sheet_model = sheet_ssqa_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - model_config=config.get("model_config", None), - use_gpu=use_gpu, - ) - score_modules["sheet_ssqa"] = { - "module": sheet_ssqa, - "args": {"model": sheet_model, "use_gpu": use_gpu}, - } - logging.info("Initiate Sheet SSQA evaluation successfully.") - - elif config["name"] == "squim_ref": - if not use_gt: - logging.warning("Cannot use squim_ref because no gt audio is provided") - continue - - logging.info("Loading squim metrics with reference") - from versa import squim_metric - - score_modules["squim_ref"] = { - "module": squim_metric, - } - logging.info("Initiate torch squim (with reference) successfully") - - elif config["name"] == "squim_no_ref": - logging.info("Loading squim metrics with reference") - from versa import squim_metric_no_ref - - score_modules["squim_no_ref"] = { - "module": squim_metric_no_ref, - } - logging.info("Initiate torch squim (without reference) successfully") - - elif config["name"] == "espnet_wer": - if not use_gt_text: - logging.warning("Cannot use espnet_wer because no gt text is provided") - continue - - logging.info("Loading espnet_wer metric with reference text") - from versa import espnet_levenshtein_metric, espnet_wer_setup - - score_modules["espnet_wer"] = { - "module": espnet_levenshtein_metric, - "args": espnet_wer_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ), - } - logging.info("Initiate ESPnet WER calculation successfully") - - elif config["name"] == "owsm_wer": - if not use_gt_text: - logging.warning("Cannot use owsm_wer because no gt text is provided") - continue - - logging.info("Loading owsm_wer metric with reference text") - from versa import owsm_levenshtein_metric, owsm_wer_setup - - score_modules["owsm_wer"] = { - "module": owsm_levenshtein_metric, - "args": owsm_wer_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ), - } - logging.info("Initiate ESPnet-OWSM WER calculation successfully") - - elif config["name"] == "whisper_wer": - if not use_gt_text: - logging.warning("Cannot use whisper_wer because no gt text is provided") - continue - - logging.info("Loading whisper_wer metric with reference text") - from versa import whisper_levenshtein_metric, whisper_wer_setup - - # Load whisper model if it is already loaded - if ( - "speaking_rate" in score_modules.keys() - or "asr_matching" in score_modules.keys() - ): - args_cache = score_modules["speaking_rate"]["args"] - else: - args_cache = whisper_wer_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ) - - score_modules["whisper_wer"] = { - "module": whisper_levenshtein_metric, - "args": args_cache, - } - logging.info("Initiate Whisper WER calculation successfully") - - elif config["name"] == "scoreq_ref": - if not use_gt: - logging.warning("Cannot use scoreq_ref because no gt audio is provided") - continue - - logging.info("Loading scoreq metrics with reference") - from versa import scoreq_ref, scoreq_ref_setup - - model = scoreq_ref_setup( - data_domain=config.get("data_domain", "synthetic"), - cache_dir=config.get("model_cache", "versa_cache/scoreq_pt-models"), - use_gpu=use_gpu, - ) - - score_modules["scoreq_ref"] = { - "module": scoreq_ref, - "model": model, - } - logging.info("Initiate scoreq (with reference) successfully") - - elif config["name"] == "scoreq_nr": - logging.info("Loading scoreq metrics without reference") - from versa import scoreq_nr, scoreq_nr_setup - - model = scoreq_nr_setup( - data_domain=config.get("data_domain", "synthetic"), - cache_dir=config.get("model_cache", "versa_cache/scoreq_pt-models"), - use_gpu=use_gpu, - ) - - score_modules["scoreq_nr"] = { - "module": scoreq_nr, - "model": model, - } - logging.info("Initiate scoreq (with reference) successfully") - - elif config["name"] == "nomad": - if not use_gt: - logging.warning("Cannot use nomad because no gt audio is provided") - continue - - logging.info("Loading nomad metrics with reference") - from versa import nomad, nomad_setup - - model = nomad_setup( - cache_dir=config.get("model_cache", "versa_cache/nomad_pt-models"), - use_gpu=use_gpu, - ) - - score_modules["nomad"] = { - "module": nomad, - "model": model, - } - logging.info("Initiate nomad successfully") - - elif config["name"] == "emo2vec_similarity": - if not use_gt: - logging.warning( - "Cannot use emo2vec_similarity metric because no gt audio is provided" - ) - continue - - logging.info("Loading emo2vec metrics with reference") - from versa import emo2vec_setup, emo_sim - - model = emo2vec_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - use_gpu=use_gpu, - ) - - score_modules["emotion"] = { - "module": emo_sim, - "model": model, - } - logging.info("Initiate emo2vec successfully") - - elif config["name"] == "w2v2_dimensional_emotion": - from versa import w2v2_emo_dim_setup, w2v2_emo_dim_metric - - args_cache = w2v2_emo_dim_setup() - score_modules["w2v2_dimensional_emotion"] = { - "module": w2v2_emo_dim_metric, - "args": args_cache, - } - logging.info("Initiate w2v2_dimensional_emotion successfully") - - elif config["name"] == "se_snr": - logging.info("Loading se_snr metrics with reference") - from versa import se_snr, se_snr_setup - - model = se_snr_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - use_gpu=use_gpu, - ) - - score_modules["se_snr"] = { - "module": se_snr, - "model": model, - } - logging.info("Initiate se_snr successfully") - - elif config["name"] == "pam": - logging.info("Loading pam metric without reference...") - from versa.utterance_metrics.pam import pam_metric, pam_model_setup - - pam_model = pam_model_setup(model_config=config, use_gpu=use_gpu) - score_modules["pam"] = { - "module": pam_metric, - "model": pam_model, - } - logging.info("Initiate pam metric successfully.") - elif config["name"] == "vad": - logging.info("Loading vad metric without reference...") - from versa.utterance_metrics.vad import vad_metric, vad_model_setup - - vad_model = vad_model_setup( - threshold=config.get("threshold", 0.5), - min_speech_duration_ms=config.get("min_speech_duration_ms", 250), - max_speech_duration_s=config.get("max_speech_duration_s", float("inf")), - min_silence_duration_ms=config.get("min_silence_duration_ms", 100), - speech_pad_ms=config.get("speech_pad_ms", 30), - ) - score_modules["vad"] = { - "module": vad_metric, - "args": vad_model, - } - logging.info("Initiate vad metric successfully.") - - elif config["name"] == "asvspoof_score": - logging.info("Loading asvspoof score metric without reference...") - from versa.utterance_metrics.asvspoof_score import ( - asvspoof_metric, - deepfake_detection_model_setup, - ) - - deepfake_detection_model = deepfake_detection_model_setup(use_gpu=use_gpu) - score_modules["asvspoof_score"] = { - "module": asvspoof_metric, - "model": deepfake_detection_model, - } - logging.info("Initiate asvspoof score metric successfully.") - - elif config["name"] == "pysepm": - if not use_gt: - logging.warning("Cannot use pysepm because no gt audio is provided") - continue - - logging.info("Loading pysepm metrics with reference") - from versa import pysepm_metric - - score_modules["pysepm"] = { - "module": pysepm_metric, - "args": { - "frame_len": config.get("frame_len", 0.03), - "overlap": config.get("overlap", 0.75), - }, - } - logging.info("Initiate pysepm successfully") - - elif config["name"] == "srmr": - logging.info("Loading srmr metrics with reference") - from versa import srmr_metric - - score_modules["srmr"] = { - "module": srmr_metric, - "args": { - "n_cochlear_filters": config.get("n_cochlear_filters", 23), - "low_freq": config.get("low_freq", 125), - "min_cf": config.get("min_cf", 128), - "max_cf": config.get("max_cf", 128), - "fast": config.get("fast", True), - "norm": config.get("norm", False), - }, - } - logging.info("Initiate srmr successfully") - - elif config["name"] == "noresqa": - if not use_gt: - logging.warning("Cannot use noresqa because no gt audio is provided") - continue - - logging.info("Loading noresqa metrics with reference") - - from versa.utterance_metrics.noresqa import ( - noresqa_metric, - noresqa_model_setup, - ) - - noresqa_model = noresqa_model_setup( - metric_type=config.get("metric_type", 0), - cache_dir=config.get("cache_dir", "versa_cache/noresqa_model"), - use_gpu=use_gpu, - ) - score_modules["noresqa"] = { - "module": noresqa_metric, - "args": { - "metric_type": config.get("metric_type", 0), - "model": noresqa_model, - }, - } - logging.info("Initiate noresqa score metric successfully.") - - elif config["name"] == "speaking_rate": - logging.info("Loading speaking rate metrics without reference") - from versa import speaking_rate_metric, speaking_rate_model_setup - - # Load whisper model if it is already loaded - if "whisper_wer" in score_modules.keys(): - speaking_rate_model = score_modules["whisper_wer"]["args"] - else: - speaking_rate_model = speaking_rate_model_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ) - - score_modules["speaking_rate"] = { - "module": speaking_rate_metric, - "args": speaking_rate_model, - } - logging.info("Initiate speaking rate metric successfully.") - - elif config["name"] == "asr_match": - if not use_gt: - logging.warning("Cannot use asr_match because no gt audio is provided") - continue - - logging.info("Loading asr_match metric with reference text") - from versa import asr_match_metric, asr_match_setup - - # Load whisper model if it is already loaded - if "whisper_wer" in score_modules.keys(): - asr_model = score_modules["whisper_wer"]["args"] - elif "speaking_rate" in score_modules.keys(): - asr_model = score_modules["speaking_rate"]["args"] - else: - asr_model = asr_match_setup( - model_tag=config.get("model_tag", "default"), - beam_size=config.get("beam_size", 1), - text_cleaner=config.get("text_cleaner", "whisper_basic"), - use_gpu=use_gpu, - ) - - score_modules["asr_match"] = { - "module": asr_match_metric, - "args": asr_model, - } - logging.info("Initiate asr_match metric successfully") - - elif config["name"] == "lid": - logging.info("Loading language identification metric") - from versa import language_id, owsm_lid_model_setup - - owsm_model = owsm_lid_model_setup( - model_tag=config.get("model_tag", "default"), - nbest=config.get("nbest", 3), - use_gpu=use_gpu, - ) - - score_modules["lid"] = { - "module": language_id, - "args": owsm_model, - } - - elif config["name"] == "audiobox_aesthetics": - logging.info("Loading audiobox aesthetics metric") - from versa import audiobox_aesthetics_score, audiobox_aesthetics_setup - - audiobox_model = audiobox_aesthetics_setup( - model_path=config.get("model_path", None), - batch_size=config.get("batch_size", 1), - precision=config.get("precision", "bf16"), - cache_dir=config.get("cache_dir", "versa_cache/audiobox"), - use_huggingface=config.get("use_huggingface", True), - use_gpu=use_gpu, - ) - - score_modules["audiobox_aesthetics"] = { - "module": audiobox_aesthetics_score, - "args": {"model": audiobox_model}, - } - logging.info("Initiate audiobox aesthetics metric successfully") - - elif config["name"] == "cdpam": - if not use_gt: - logging.warning( - "Cannot use cdpam metrics because no gt audio is provided" - ) - continue - logging.info("Loading cdpam evaluation...") - from versa import cdpam_metric, cdpam_model_setup - - cdpam_model = cdpam_model_setup(use_gpu=use_gpu) - score_modules["cdpam"] = { - "module": cdpam_metric, - "args": {"model": cdpam_model}, - } - logging.info("Initiate cdpam evaluation successfully.") - - elif config["name"] == "dpam": - if not use_gt: - logging.warning( - "Cannot use dpam metrics because no gt audio is provided" - ) - continue - logging.info("Loading dpam evaluation...") - from versa import dpam_metric, dpam_model_setup - - dpam_model = dpam_model_setup(use_gpu=use_gpu) - score_modules["dpam"] = { - "module": dpam_metric, - "args": {"model": dpam_model}, - } - logging.info("Initiate dpam evaluation successfully.") - - elif "qwen_omni" in config["name"]: - logging.info("Loading qwen omni model") - from versa import qwen_omni_model_setup - - if "qwen_omni" not in score_modules.keys(): - qwen_omni_model = qwen_omni_model_setup( - model_tag=config.get("model_tag", "default"), - ) - score_modules["qwen_omni"] = { - "module": qwen_omni_model, - "start_prompt": config.get("start_prompt", None), - } - - if config["name"] == "qwen_omni_singing_technique": - from versa import qwen_omni_singing_technique_metric - - score_modules["qwen_omni_singing_technique"] = { - "module": qwen_omni_singing_technique_metric, - "prompt": config.get("prompt", None), - } - # To add qwen-omni modules for others - - elif "qwen2_audio" in config["name"]: - logging.info("Loading qwen2-audio model") - from versa import qwen2_model_setup - - if "qwen2_audio" not in score_modules.keys(): - qwen_model = qwen2_model_setup( - model_tag=config.get("model_tag", "default"), - ) - score_modules["qwen2_audio"] = { - "module": qwen_model, - "start_prompt": config.get("start_prompt", None), - } - - # 1. Speaker Characteristics - if config["name"] == "qwen2_audio_speaker_count": - from versa import qwen2_speaker_count_metric - - score_modules["qwen2_audio_speaker_count"] = { - "module": qwen2_speaker_count_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speaker_gender": - from versa import qwen2_speaker_gender_metric - - score_modules["qwen2_audio_speaker_gender"] = { - "module": qwen2_speaker_gender_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speaker_age": - from versa import qwen2_speaker_age_metric - - score_modules["qwen2_audio_speaker_age"] = { - "module": qwen2_speaker_age_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_impairment": - from versa import qwen2_speech_impairment_metric - - score_modules["qwen2_audio_speech_impairment"] = { - "module": qwen2_speech_impairment_metric, - "prompt": config.get("prompt", None), - } - - # 2. Voice Properties - elif config["name"] == "qwen2_audio_voice_pitch": - from versa import qwen2_voice_pitch_metric - - score_modules["qwen2_audio_voice_pitch"] = { - "module": qwen2_voice_pitch_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_pitch_range": - from versa import qwen2_pitch_range_metric - - score_modules["qwen2_audio_pitch_range"] = { - "module": qwen2_pitch_range_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_voice_type": - from versa import qwen2_voice_type_metric - - score_modules["qwen2_audio_voice_type"] = { - "module": qwen2_voice_type_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_volume_level": - from versa import qwen2_speech_volume_level_metric - - score_modules["qwen2_audio_speech_volume_level"] = { - "module": qwen2_speech_volume_level_metric, - "prompt": config.get("prompt", None), - } - - # 3. Speech Content - elif config["name"] == "qwen2_audio_language": - from versa import qwen2_language_metric - - score_modules["qwen2_audio_language"] = { - "module": qwen2_language_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_register": - from versa import qwen2_speech_register_metric - - score_modules["qwen2_audio_speech_register"] = { - "module": qwen2_speech_register_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_vocabulary_complexity": - from versa import qwen2_vocabulary_complexity_metric - - score_modules["qwen2_audio_vocabulary_complexity"] = { - "module": qwen2_vocabulary_complexity_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_purpose": - from versa import qwen2_speech_purpose_metric - - score_modules["qwen2_audio_speech_purpose"] = { - "module": qwen2_speech_purpose_metric, - "prompt": config.get("prompt", None), - } - - # 4. Speech Delivery - elif config["name"] == "qwen2_audio_speech_emotion": - from versa import qwen2_speech_emotion_metric - - score_modules["qwen2_audio_speech_emotion"] = { - "module": qwen2_speech_emotion_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_clarity": - from versa import qwen2_speech_clarity_metric - - score_modules["qwen2_audio_speech_clarity"] = { - "module": qwen2_speech_clarity_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speech_rate": - from versa import qwen2_speech_rate_metric - - score_modules["qwen2_audio_speech_rate"] = { - "module": qwen2_speech_rate_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_speaking_style": - from versa import qwen2_speaking_style_metric - - score_modules["qwen2_audio_speaking_style"] = { - "module": qwen2_speaking_style_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_laughter_crying": - from versa import qwen2_laughter_crying_metric - - score_modules["qwen2_audio_laughter_crying"] = { - "module": qwen2_laughter_crying_metric, - "prompt": config.get("prompt", None), - } - - # 5. Interaction Patterns - elif config["name"] == "qwen2_audio_overlapping_speech": - from versa import qwen2_overlapping_speech_metric - - score_modules["qwen2_audio_overlapping_speech"] = { - "module": qwen2_overlapping_speech_metric, - "prompt": config.get("prompt", None), - } - - # 6. Recording Environment - elif config["name"] == "qwen2_audio_speech_background_environment": - from versa import qwen2_speech_background_environment_metric - - score_modules["qwen2_audio_speech_background_environment"] = { - "module": qwen2_speech_background_environment_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_recording_quality": - from versa import qwen2_recording_quality_metric - - score_modules["qwen2_audio_recording_quality"] = { - "module": qwen2_recording_quality_metric, - "prompt": config.get("prompt", None), - } - elif config["name"] == "qwen2_audio_channel_type": - from versa import qwen2_channel_type_metric - - score_modules["qwen2_audio_channel_type"] = { - "module": qwen2_channel_type_metric, - "prompt": config.get("prompt", None), - } - - # 7. Vocal Evaluation - elif config["name"] == "qwen2_audio_singing_technique": - from versa import qwen2_singing_technique_metric - - score_modules["qwen2_audio_singing_technique"] = { - "module": qwen2_singing_technique_metric, - "prompt": config.get("prompt", None), - } - - logging.info( - "Initiate qwen2 audio metric: {} successfully".format(config["name"]) - ) - elif "chroma_alignment" in config["name"]: - from versa import chroma_metric - - score_modules["chroma_alignment"] = { - "module": chroma_metric, - "args": { - "scale_factor": config.get("scale_factor", 100), - }, - } - elif "wvmos" in config["name"]: - logging.info("Loading WVMOS metric") - from versa import wvmos_setup, wvmos_calculate - - model = wvmos_setup( - use_gpu=use_gpu, - ) - score_modules["wvmos"] = { - "module": wvmos_calculate, - "args": {"model": model}, - } - logging.info("Initiate WVMOS metric successfully") - elif "sigmos" in config["name"]: - logging.info("Loading SIGMOS metric") - from versa import sigmos_setup, sigmos_calculate - - model = sigmos_setup() - - score_modules["sigmos"] = { - "module": sigmos_calculate, - "args": {"model": model}, - } - logging.info("Initiate SIGMOS metric successfully") - elif "vqscore" in config["name"]: - logging.info("Loading VQScore model") - from versa import vqscore_metric, vqscore_setup - - vqscore_model = vqscore_setup(use_gpu=use_gpu) - score_modules["vqscore"] = { - "module": vqscore_metric, - "args": {"model": vqscore_model}, - } - logging.info("Initiate VQScore evaluation successfully.") - - elif config["name"] == "universa_noref": - logging.info("Loading Universa no-reference model...") - from versa.utterance_metrics.universa import ( - universa_noref_metric, - universa_model_setup, - ) - - universa_model = universa_model_setup( - model_tag=config.get("model_tag", "noref"), - use_gpu=use_gpu, - ) - score_modules["universa_noref"] = { - "module": universa_noref_metric, - "model": universa_model, - } - logging.info("Initiate Universa no-reference evaluation successfully.") - - elif config["name"] == "universa_audioref": - if not use_gt: - logging.warning( - "Cannot use universa_audioref because no gt audio is provided" - ) - continue - - logging.info("Loading Universa audio reference model...") - from versa.utterance_metrics.universa import ( - universa_audioref_metric, - universa_model_setup, - ) - - universa_model = universa_model_setup( - model_tag=config.get("model_tag", "audioref"), - use_gpu=use_gpu, - ) - score_modules["universa_audioref"] = { - "module": universa_audioref_metric, - "model": universa_model, - } - logging.info("Initiate Universa audio reference evaluation successfully.") - - elif config["name"] == "universa_textref": - if not use_gt_text: - logging.warning( - "Cannot use universa_textref because no gt text is provided" - ) - continue - - logging.info("Loading Universa text reference model...") - from versa.utterance_metrics.universa import ( - universa_textref_metric, - universa_model_setup, - ) - - universa_model = universa_model_setup( - model_tag=config.get("model_tag", "textref"), - use_gpu=use_gpu, - ) - score_modules["universa_textref"] = { - "module": universa_textref_metric, - "model": universa_model, - } - logging.info("Initiate Universa text reference evaluation successfully.") - - elif config["name"] == "universa_fullref": - if not use_gt or not use_gt_text: - logging.warning( - "Cannot use universa_fullref because no gt audio or text is provided" - ) - continue - - logging.info("Loading Universa full reference model...") - from versa.utterance_metrics.universa import ( - universa_fullref_metric, - universa_model_setup, - ) - - universa_model = universa_model_setup( - model_tag=config.get("model_tag", "fullref"), - use_gpu=use_gpu, - ) - score_modules["universa_fullref"] = { - "module": universa_fullref_metric, - "model": universa_model, - } - logging.info("Initiate Universa full reference evaluation successfully.") - - elif config["name"] == "arecho": - logging.info("Loading ARECHO no-reference model...") - from versa.utterance_metrics.arecho import ( - arecho_noref_metric, - arecho_model_setup, - ) - - arecho_model = arecho_model_setup( - model_tag=config.get("model_tag", "base_v0"), - use_gpu=use_gpu, - ) - score_modules["arecho"] = { - "module": arecho_noref_metric, - "model": arecho_model, - } - logging.info("Initiate ARECHO no-reference evaluation successfully.") - - return score_modules - - -def process_cache_info(cache_info, score_modules, output_file): - batch_score_info = [] - for utt_info in cache_info: - key, gen_wav, gt_wav, gen_sr, text = utt_info - utt_score = {"key": key} - try: - utt_score.update( - use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text) - ) - except Exception as e: - print("error processing file: {} with error {}".format(key, e)) - batch_score_info.append(utt_score) - if output_file is not None: - printable_result = json.dumps(utt_score, default=default_numpy_serializer) - output_file.write(f"{printable_result}\n") - return batch_score_info - - -def use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text=None): - utt_score = {} - - # general cache information to reduce recaculation - general_cache = { - "whisper_hyp_text": None, - } - for key in score_modules.keys(): - if key == "mcd_f0" or key == "chroma_alignment": - score = score_modules[key]["module"]( - gen_wav, gt_wav, gen_sr, **score_modules[key]["args"] - ) - elif key == "signal_metric": - try: - score = score_modules[key]["module"](gen_wav, gt_wav) - except ValueError as e: - logging.warning( - "Value error in signal metric. Usually due to silence audio: {}".format( - e - ) - ) - continue - elif key == "warpq": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gt_wav, gen_sr - ) - elif key == "nisqa": - try: - score = score_modules[key]["module"]( - score_modules[key]["model"], - gen_wav, - gen_sr, - ) - except ValueError as e: - logging.warning( - "Value error in NISQA metric. Usually due to silence audio: {}".format( - e - ) - ) - continue - elif key == "discrete_speech": - score = score_modules[key]["module"]( - score_modules[key]["model"], - gen_wav, - gt_wav, - gen_sr, - ) - elif key == "pseudo_mos": - score = score_modules[key]["module"]( - gen_wav, gen_sr, **score_modules[key]["args"] - ) - elif key in ["pesq", "stoi", "estoi"]: - score = score_modules[key]["module"](gen_wav, gt_wav, gen_sr) - elif key == "visqol": - score = score_modules[key]["module"]( - score_modules[key]["args"]["api"], - score_modules[key]["args"]["api_fs"], - gen_wav, - gt_wav, - gen_sr, - ) - elif key == "speaker" or key == "singer": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], gen_wav, gt_wav, gen_sr - ) - elif key == "sheet_ssqa": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gen_sr, - use_gpu=score_modules[key]["args"]["use_gpu"], - ) - elif key == "squim_ref": - score = score_modules[key]["module"](gen_wav, gt_wav, gen_sr) - elif key == "squim_no_ref": - score = score_modules[key]["module"](gen_wav, gen_sr) - elif key == "nomad": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gt_wav, gen_sr - ) - elif key == "espnet_wer" or key == "owsm_wer" or key == "whisper_wer": - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - text, - gen_sr, - ) - if key == "whisper_wer": - general_cache["whisper_hyp_text"] = score["whisper_hyp_text"] - elif key in ["scoreq_ref", "emotion"]: - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gt_wav, gen_sr - ) - elif key in ["scoreq_nr", "se_snr"]: - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr - ) - elif key in ["pam", "asvspoof_score"]: - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, fs=gen_sr - ) - elif key in ["vad", "lid", "w2v2_dimensional_emotion"]: - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - gen_sr, - ) - elif key == "pysepm": - score = score_modules[key]["module"](gen_wav, gt_wav, fs=gen_sr) - elif key == "srmr": - score = score_modules[key]["module"](gen_wav, fs=gen_sr) - elif key == "noresqa": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gt_wav, - fs=gen_sr, - metric_type=score_modules[key]["args"]["metric_type"], - ) - elif key == "speaking_rate": - cache_text = None - if general_cache.get("whisper_hyp_text", None) is not None: - cache_text = utt_score["whisper_hyp_text"] - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - cache_text, - gen_sr, - ) - if cache_text is None: - general_cache["whisper_hyp_text"] = score["whisper_hyp_text"] - elif key == "asr_match": - cache_text = None - if general_cache.get("whisper_hyp_text", None) is not None: - cache_text = utt_score["whisper_hyp_text"] - score = score_modules[key]["module"]( - score_modules[key]["args"], - gen_wav, - gt_wav, - cache_text, - gen_sr, - ) - if cache_text is None: - general_cache["whisper_hyp_text"] = score["whisper_hyp_text"] - elif key == "audiobox_aesthetics": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gen_sr, - ) - elif key == "dpam" or key == "cdpam": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gt_wav, - gen_sr, - ) - - elif "qwen2_audio" in key: - if key == "qwen2_audio": - continue # skip the base model, only use the specific metrics - # Support qwen2_audio metrics - score = score_modules[key]["module"]( - score_modules["qwen2_audio"]["module"], - gen_wav, - gen_sr, - custom_prompt=score_modules[key]["prompt"], - ) - elif "qwen_omni" in key: - if key == "qwen_omni": - continue - score = score_modules[key]["module"]( - score_modules["qwen_omni"]["module"], - gen_wav, - gen_sr, - custom_prompt=score_modules[key]["prompt"], - ) - elif key == "wvmos": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gen_sr, - ) - elif key == "sigmos": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], - gen_wav, - gen_sr, - ) - elif key == "vqscore": - score = score_modules[key]["module"]( - score_modules[key]["args"]["model"], gen_wav, gen_sr - ) - elif key == "universa_noref": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr - ) - elif key == "universa_audioref": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr, gt_wav - ) - elif key == "universa_textref": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr, ref_text=text - ) - elif key == "universa_fullref": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr, gt_wav, text - ) - elif key == "arecho": - score = score_modules[key]["module"]( - score_modules[key]["model"], gen_wav, gen_sr - ) - else: - raise NotImplementedError( - f"Not supported metrics: {key}, check egs/separate_metrics/README.md for supported metrics" - ) - - logging.info(f"Score for {key} is {score}") - utt_score.update(score) - return utt_score - - -def list_scoring( - gen_files, - score_modules, - gt_files=None, - text_info=None, - output_file=None, - io="kaldi", - batch_size=1, -): - if output_file is not None: - f = open(output_file, "w", encoding="utf-8") - else: - f = None - - score_info = [] - cache_info = [] # for batch processing - for key in tqdm(gen_files.keys()): - try: - # Step1: load source speech and conduct basic checks - gen_sr, gen_wav = load_audio(gen_files[key], io) - gen_wav = wav_normalize(gen_wav) - except Exception as e: - print(f"Error loading audio file for key '{key}': {gen_files[key]}") - print(f"Error details: {e}") - continue # Skip this file and move to the next one - - # length check - if not check_minimum_length(gen_wav.shape[0] / gen_sr, score_modules.keys()): - logging.warning( - "audio {} (generated, length {}) is too short to be evaluated with some metric metrics, skipping".format( - key, gen_wav.shape[0] / gen_sr - ) - ) - continue - - # Step2: load reference (gt) speech and conduct basic checks - if gt_files is not None: - if key not in gen_files.keys(): - logging.warning( - "key {} not found in ground truth files though provided, skipping".format( - key - ) - ) - continue - - gt_sr, gt_wav = load_audio(gt_files[key], io) - gt_wav = wav_normalize(gt_wav) - - # check ground truth audio files - if check_all_same(gt_wav): - logging.warning( - "gt audio of key {} has only the same value, skipping".format(key) - ) - continue - - # length check - if not check_minimum_length(gt_wav.shape[0] / gt_sr, score_modules.keys()): - logging.warning( - "audio {} (ground truth, length {}) is too short to be evaluated with many metrics, skipping".format( - key, gt_wav.shape[0] / gt_sr - ) - ) - continue - else: - gt_wav = None - gt_sr = None - - # Step3: load text information if provided - text = None - if text_info is not None: - if key not in text_info.keys(): - logging.warning( - "key {} not found in ground truth transcription though provided, skipping".format( - key - ) - ) - continue - else: - text = text_info[key] - - # Step4: check if the sampling rate of generated and gt audio are the same - if gt_sr is not None and gen_sr > gt_sr: - logging.warning( - "Resampling the generated audio to match the ground truth audio" - ) - gen_wav = librosa.resample(gen_wav, orig_sr=gen_sr, target_sr=gt_sr) - gen_sr = gt_sr - elif gt_sr is not None and gen_sr < gt_sr: - logging.warning( - "Resampling the ground truth audio to match the generated audio" - ) - gt_wav = librosa.resample(gt_wav, orig_sr=gt_sr, target_sr=gen_sr) - - # Step5: cache for batch processing - utterance_info = (key, gen_wav, gt_wav, gen_sr, text) - - cache_info.append(utterance_info) - if len(cache_info) == batch_size: - # Process after a batch is collected - score_info.extend(process_cache_info(cache_info, score_modules, f)) - cache_info = [] - else: - # continue collect the batch - continue - - # Process left-over batch - score_info.extend(process_cache_info(cache_info, score_modules, f)) - - logging.info("Scoring completed and save score at {}".format(output_file)) - return score_info - - -def load_summary(score_info): - summary = {} - for key in score_info[0].keys(): - if key in STR_METRIC or key == "key": - # NOTE(jiatong): skip text cases - continue - summary[key] = sum([score[key] for score in score_info]) - if "_wer" not in key and "_cer" not in key: - # Average for non-WER/CER metrics - summary[key] /= len(score_info) - return summary - - -def load_corpus_modules( - score_config, cache_folder="versa_cache", use_gpu=False, io="kaldi" -): - score_modules = {} - for config in score_config: - if config["name"] == "fad": - logging.info("Loading FAD evaluation with specific models...") - # TODO(jiatong): fad will automatically use cuda if detected - # need to sync to the same space - from versa import fad_scoring, fad_setup - - fad_info = fad_setup( - fad_embedding=config.get("fad_embedding", "default"), - baseline=config.get("baseline_audio", "missing"), - cache_dir=config.get("cache_dir", cache_folder), - use_inf=config.get("use_inf", False), - io=io, - ) - - fad_key = "fad_{}".format(config.get("model", "default")) - - score_modules[fad_key] = { - "module": fad_scoring, - "args": fad_info, - } - logging.info( - "Initiate {} calculation evaluation successfully.".format(fad_key) - ) - elif config["name"] == "kid": - logging.info("Loading KID evaluation with specific models...") - from versa import kid_scoring, kid_setup - - kid_info = kid_setup( - model_tag=config.get("model_tag", "default"), - model_path=config.get("model_path", None), - model_config=config.get("model_config", None), - use_gpu=use_gpu, - ) - kid_key = "kid_{}".format(config.get("model", "default")) - score_modules[kid_key] = { - "module": kid_scoring, - "args": kid_info, - } - logging.info( - "Initiate {} calculation evaluation successfully.".format(kid_key) - ) - - return score_modules - - -def corpus_scoring( - gen_files, - score_modules, - base_files=None, - text_info=None, - output_file=None, -): - score_info = {} - for key in score_modules.keys(): - if key.startswith("fad"): - fad_info = score_modules[key]["args"] - if base_files is not None: - fad_info["baseline"] = base_files - elif fad_info["baseline"] == "missing": - raise ValueError("Baseline audio not provided for FAD") - score_result = score_modules[key]["module"]( - gen_files, fad_info, key_info=key - ) - elif key.startswith("kld"): - kid_info = score_modules[key]["args"] - if base_files is not None: - kid_info["baseline"] = base_files - elif kid_info["baseline"] == "missing": - raise ValueError("Baseline audio not provided for FAD") - score_result = score_modules[key]["module"]( - gen_files, kid_info, key_info=key - ) - else: - raise NotImplementedError("Not supported {}".format(key)) - score_info.update(score_result) - - if output_file is not None: - with open(output_file, "w") as f: - yaml.dump(score_info, f) - return score_info +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +import json +import kaldiio +import soundfile as sf +import yaml +from typing import Dict, List, Optional, Any, Union +from tqdm import tqdm + +from versa.audio_utils import resample_audio +from versa.definition import ( + BaseMetric, + GPUMetric, + MetricRegistry, + MetricFactory, + MetricSuite, + MetricCategory, + MetricType, + MetricMetadata, +) +from versa.metrics import STR_METRIC, NUM_METRIC +from versa.utils_shared import ( + check_all_same, + check_minimum_length, + default_numpy_serializer, + find_files, + load_audio, + wav_normalize, +) + + +def audio_loader_setup(audio, io): + # get ready compute embeddings + if io == "kaldi": + audio_files = kaldiio.load_scp(audio) + elif io == "dir": + audio_files = find_files(audio) + elif io == "soundfile": + audio_files = {} + with open(audio) as f: + for line in f.readlines(): + key, value = line.strip().split(maxsplit=1) + if value.endswith("|"): + raise ValueError( + "Not supported wav.scp format. Set IO interface to kaldi" + ) + audio_files[key] = value + return audio_files + + +def _create_populated_registry() -> MetricRegistry: + """Create a registry populated with all importable package metrics.""" + import versa as versa_package + + registry = MetricRegistry() + for name in dir(versa_package): + if not name.startswith("register_") or not name.endswith("_metric"): + continue + register_fn = getattr(versa_package, name) + if not callable(register_fn): + continue + try: + register_fn(registry) + except Exception as e: + logging.getLogger(__name__).warning( + "Failed to register metric via %s: %s", name, e + ) + return registry + + +def load_score_modules( + score_config: List[Dict[str, Any]], + use_gt: bool = True, + use_gt_text: bool = False, + use_gpu: bool = False, +) -> MetricSuite: + """Legacy wrapper for loading utterance-level scoring modules.""" + assert score_config, "no scoring function is provided" + scorer = VersaScorer(_create_populated_registry()) + return scorer.load_metrics( + score_config, + use_gt=use_gt, + use_gt_text=use_gt_text, + use_gpu=use_gpu, + ) + + +def list_scoring( + gen_files: Dict[str, str], + score_modules: MetricSuite, + gt_files: Optional[Dict[str, str]] = None, + text_info: Optional[Dict[str, str]] = None, + output_file: Optional[str] = None, + io: str = "kaldi", + batch_size: int = 1, +) -> List[Dict[str, Any]]: + """Legacy wrapper for scoring a list of utterances.""" + scorer = VersaScorer(_create_populated_registry()) + return scorer.score_utterances( + gen_files, + score_modules, + gt_files=gt_files, + text_info=text_info, + output_file=output_file, + io=io, + batch_size=batch_size, + ) + + +def load_summary(score_info: List[Dict[str, Any]]) -> Dict[str, Any]: + """Legacy alias for summary computation.""" + return compute_summary(score_info) + + +def load_corpus_modules( + score_config: List[Dict[str, Any]], + use_gpu: bool = False, + cache_folder: Optional[str] = None, + io: str = "kaldi", +) -> Dict[str, Any]: + """Legacy wrapper for loading corpus-level scoring modules.""" + assert score_config, "no scoring function is provided" + score_modules = {} + logger = logging.getLogger(__name__) + + for config in score_config: + metric_name = config["name"] + try: + if metric_name == "fad": + from versa.corpus_metrics.fad import fad_setup + + cache_dir = config.get( + "cache_dir", + f"{cache_folder}/fad" if cache_folder else "versa_cache/fad", + ) + score_modules["fad"] = fad_setup( + baseline=None, + fad_embedding=config.get("fad_embedding", "default"), + cache_dir=cache_dir, + use_inf=config.get("use_inf", True), + io=config.get("io", io), + ) + elif metric_name == "kid": + from versa.corpus_metrics.kid import kid_setup + + cache_dir = config.get( + "cache_dir", + f"{cache_folder}/kid" if cache_folder else "versa_cache/kid", + ) + score_modules["kid"] = kid_setup( + baseline=None, + kid_embedding=config.get( + "kid_embedding", config.get("fad_embedding", "default") + ), + cache_dir=cache_dir, + use_inf=config.get("use_inf", True), + io=config.get("io", io), + ) + except Exception as e: + logger.error("Failed to load corpus metric %s: %s", metric_name, e) + + return score_modules + + +def corpus_scoring( + pred_x: str, + score_modules: Dict[str, Any], + baseline: Optional[str] = None, + output_file: Optional[str] = None, +) -> Dict[str, Any]: + """Legacy wrapper for scoring corpus-level metrics.""" + score_info = {} + + for metric_name, module_info in score_modules.items(): + if baseline is not None: + module_info["baseline"] = baseline + + if metric_name == "fad": + from versa.corpus_metrics.fad import fad_scoring + + score_info.update(fad_scoring(pred_x, module_info, key_info="fad")) + elif metric_name == "kid": + from versa.corpus_metrics.kid import kid_scoring + + score_info.update(kid_scoring(pred_x, module_info, key_info="kid")) + + if output_file: + with open(output_file, "w") as f: + yaml.dump(score_info, f) + + return score_info + + +class ScoreProcessor: + """Handles batch processing and caching of scores.""" + + def __init__(self, metric_suite: MetricSuite, output_file: Optional[str] = None): + self.metric_suite = metric_suite + self.output_file = output_file + self.logger = logging.getLogger(self.__class__.__name__) + + if output_file: + self.file_handle = open(output_file, "w", encoding="utf-8") + else: + self.file_handle = None + + def process_batch(self, cache_info: List[tuple]) -> List[Dict[str, Any]]: + """Process a batch of cached utterance information.""" + batch_score_info = [] + for utt_info in cache_info: + key, gen_wav, gt_wav, gen_sr, text = utt_info + utt_score = {"key": key} + + try: + # Prepare metadata for metric computation + metadata = { + "key": key, + "sample_rate": gen_sr, + "text": text, + "general_cache": {"whisper_hyp_text": None}, + } + + # Compute all metrics + scores = self.metric_suite.compute_all( + predictions=gen_wav, references=gt_wav, metadata=metadata + ) + + # Flatten the metric results + for metric_name, metric_results in scores.items(): + if isinstance(metric_results, dict): + utt_score.update(metric_results) + else: + utt_score[metric_name] = metric_results + + except Exception as e: + self.logger.error(f"Error processing file: {key} with error {e}") + + batch_score_info.append(utt_score) + + if self.file_handle: + printable_result = json.dumps( + utt_score, default=default_numpy_serializer + ) + self.file_handle.write(f"{printable_result}\n") + + return batch_score_info + + def close(self): + """Close file handle if open.""" + if self.file_handle: + self.file_handle.close() + + +class VersaScorer: + """Main scorer class that orchestrates the scoring process.""" + + def __init__(self, registry: MetricRegistry = None): + self.registry = registry or self._create_default_registry() + self.factory = MetricFactory(self.registry) + self.logger = logging.getLogger(self.__class__.__name__) + + def _create_default_registry(self) -> MetricRegistry: + """Create and populate the default metric registry.""" + return _create_populated_registry() + + def load_metrics( + self, + score_config: List[Dict[str, Any]], + use_gt: bool = True, + use_gt_text: bool = False, + use_gpu: bool = False, + ) -> MetricSuite: + """Load and configure metrics based on configuration.""" + metrics = {} + + for config in score_config: + metric_name = config["name"] + + try: + # Check if metric requires ground truth + metadata = self.registry.get_metadata(metric_name) + if metadata and metadata.requires_reference and not use_gt: + self.logger.warning( + f"Cannot use {metric_name} because no ground truth is provided" + ) + continue + + if metadata and metadata.requires_text and not use_gt_text: + self.logger.warning( + f"Cannot use {metric_name} because no ground truth text is provided" + ) + continue + + # Create metric instance + metric_config = {**config, "use_gpu": use_gpu} + metric = self.factory.create_metric(metric_name, metric_config) + metrics[metric_name] = metric + + self.logger.info(f"Loaded {metric_name} successfully") + + except Exception as e: + self.logger.error(f"Failed to load metric {metric_name}: {e}") + continue + + return MetricSuite(metrics) + + def score_utterances( + self, + gen_files: Dict[str, str], + metric_suite: MetricSuite, + gt_files: Optional[Dict[str, str]] = None, + text_info: Optional[Dict[str, str]] = None, + output_file: Optional[str] = None, + io: str = "kaldi", + batch_size: int = 1, + ) -> List[Dict[str, Any]]: + """Score individual utterances.""" + + processor = ScoreProcessor(metric_suite, output_file) + score_info = [] + cache_info = [] + + try: + for key in tqdm(gen_files.keys()): + # Step1: Load and validate generated audio + gen_sr, gen_wav = load_audio(gen_files[key], io) + gen_wav = wav_normalize(gen_wav) + + if not self._validate_audio(gen_wav, gen_sr, key, "generated"): + continue + + # Step2: Load and validate ground truth audio + gt_wav, gt_sr = None, None + if gt_files is not None: + if key not in gt_files: + self.logger.warning( + f"Ground truth not found for key {key}, skipping" + ) + continue + + gt_sr, gt_wav = load_audio(gt_files[key], io) + gt_wav = wav_normalize(gt_wav) + + if not self._validate_audio(gt_wav, gt_sr, key, "ground truth"): + continue + + # Step3: Load text information + text = text_info.get(key) if text_info else None + if text_info and key not in text_info: + self.logger.warning(f"Text not found for key {key}, skipping") + continue + + # Step4: Resample if needed + gen_wav, gt_wav, gen_sr = self._align_sample_rates( + gen_wav, gt_wav, gen_sr, gt_sr + ) + + # Step5: Cache for batch processing + utterance_info = (key, gen_wav, gt_wav, gen_sr, text) + cache_info.append(utterance_info) + + if len(cache_info) >= batch_size: + score_info.extend(processor.process_batch(cache_info)) + cache_info = [] + + # Process remaining items + if cache_info: + score_info.extend(processor.process_batch(cache_info)) + + finally: + processor.close() + + self.logger.info(f"Scoring completed. Results saved to {output_file}") + return score_info + + def score_corpus( + self, + gen_files: Dict[str, str], + metric_suite: MetricSuite, + base_files: Optional[Dict[str, str]] = None, + text_info: Optional[Dict[str, str]] = None, + output_file: Optional[str] = None, + ) -> Dict[str, Any]: + """Score at corpus level (e.g., FAD, KID).""" + + score_info = {} + + # Filter for distributional metrics + distributional_metrics = metric_suite.filter_by_category( + MetricCategory.DISTRIBUTIONAL + ) + + for name, metric in distributional_metrics.metrics.items(): + try: + metadata = {"baseline_files": base_files, "text_info": text_info} + + score_result = metric.compute( + predictions=gen_files, references=base_files, metadata=metadata + ) + score_info.update({name: score_result}) + + except Exception as e: + self.logger.error(f"Error computing corpus metric {name}: {e}") + + if output_file: + with open(output_file, "w") as f: + yaml.dump(score_info, f) + + return score_info + + def _validate_audio(self, wav: Any, sr: int, key: str, audio_type: str) -> bool: + """Validate audio data.""" + # Length check + if not check_minimum_length( + wav.shape[0] / sr, [] + ): # Metric names would be passed here + self.logger.warning( + f"Audio {key} ({audio_type}, length {wav.shape[0] / sr}) is too short, skipping" + ) + return False + + # Check for silent audio + if check_all_same(wav): + self.logger.warning( + f"Audio {key} ({audio_type}) has only the same value, skipping" + ) + return False + + return True + + def _align_sample_rates( + self, gen_wav: Any, gt_wav: Any, gen_sr: int, gt_sr: Optional[int] + ) -> tuple: + """Align sample rates between generated and ground truth audio.""" + if gt_sr is None: + return gen_wav, gt_wav, gen_sr + + if gen_sr > gt_sr: + self.logger.warning("Resampling generated audio to match ground truth") + gen_wav = resample_audio(gen_wav, gen_sr, gt_sr) + gen_sr = gt_sr + elif gen_sr < gt_sr: + self.logger.warning( + "Resampling ground truth audio to match generated audio" + ) + gt_wav = resample_audio(gt_wav, gt_sr, gen_sr) + + return gen_wav, gt_wav, gen_sr + + +def compute_summary(score_info: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute summary statistics from individual scores.""" + if not score_info: + return {} + + summary = {} + for key in score_info[0].keys(): + if key not in NUM_METRIC: + continue + + values = [ + score[key] + for score in score_info + if key in score and score[key] is not None + ] + if not values: + continue + + summary[key] = sum(values) + if "_wer" not in key and "_cer" not in key: + summary[key] /= len(values) + + return summary diff --git a/versa/sequence_metrics/mcd_f0.py b/versa/sequence_metrics/mcd_f0.py index 4e45d21..c2f2fca 100644 --- a/versa/sequence_metrics/mcd_f0.py +++ b/versa/sequence_metrics/mcd_f0.py @@ -7,11 +7,33 @@ import logging import numpy as np -import pysptk -import pyworld as pw -import scipy -from fastdtw import fastdtw -from scipy.signal import firwin, lfilter + +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +try: + import pysptk + import pyworld as pw + import scipy + from fastdtw import fastdtw + from scipy.signal import firwin, lfilter +except ImportError: + pysptk = None + pw = None + scipy = None + fastdtw = None + firwin = None + lfilter = None + + +def _ensure_mcd_f0_dependencies(): + if any( + dependency is None + for dependency in (pysptk, pw, scipy, fastdtw, firwin, lfilter) + ): + raise ImportError( + "mcd_f0 requires pysptk, pyworld, scipy, and fastdtw. " + "Please install these dependencies and retry" + ) def low_cut_filter(x, fs, cutoff=70): @@ -26,6 +48,7 @@ def low_cut_filter(x, fs, cutoff=70): (ndarray): Low cut filtered waveform sequence """ + _ensure_mcd_f0_dependencies() nyquist = fs // 2 norm_cutoff = cutoff / nyquist @@ -131,6 +154,7 @@ def world_extract( mcep_alpha=0.466, filter_cutoff=70, ): + _ensure_mcd_f0_dependencies() # scale from [-1, 1] to [-32768, 32767] x = x * np.iinfo(np.int16).max @@ -175,6 +199,7 @@ def mcd_f0( power_threshold=-20, dtw=False, ): + _ensure_mcd_f0_dependencies() pred_feats = world_extract( pred_x, fs, f0min, f0max, mcep_shift, mcep_fftl, mcep_dim, mcep_alpha @@ -224,8 +249,9 @@ def mcd_f0( f0corr = scipy.stats.pearsonr(pred_f0_dtw, gt_f0_dtw)[0] except ValueError: logging.warning( - "No nonzero f0 is found. Skip f0rmse f0corr computation and set them to NaN. " - "This might due to unconverge training. Please tune the training time and hypers." + "No nonzero f0 is found. Skip f0rmse f0corr computation and " + "set them to NaN. This might due to unconverge training. " + "Please tune the training time and hypers." ) f0rmse = np.nan f0corr = np.nan @@ -235,10 +261,12 @@ def mcd_f0( pred_seq_len = len(pred_feats["f0"]) gt_seq_len = len(gt_feats["f0"]) min_len = min(pred_seq_len, gt_seq_len) - assert (pred_seq_len + gt_seq_len - 2 * min_len) / ( + mismatch_ratio = (pred_seq_len + gt_seq_len - 2 * min_len) / ( pred_seq_len + gt_seq_len - ) < seq_mismatch_tolerance, "two input sequence mismatch ratio over threshold {}".format( - seq_mismatch_tolerance + ) + assert mismatch_ratio < seq_mismatch_tolerance, ( + "two input sequence mismatch ratio over threshold " + f"{seq_mismatch_tolerance}" ) diff2sum = np.sum( (pred_feats["mcep"][:min_len] - gt_feats["mcep"][:min_len]) ** 2, 1 @@ -258,9 +286,78 @@ def mcd_f0( } +class McdF0Metric(BaseMetric): + """Mel cepstral distortion and F0 metrics.""" + + def _setup(self): + _ensure_mcd_f0_dependencies() + self.f0min = self.config.get("f0min", 40) + self.f0max = self.config.get("f0max", 800) + self.mcep_shift = self.config.get("mcep_shift", 5) + self.mcep_fftl = self.config.get("mcep_fftl", 1024) + self.mcep_dim = self.config.get("mcep_dim", 39) + self.mcep_alpha = self.config.get("mcep_alpha", 0.466) + self.seq_mismatch_tolerance = self.config.get("seq_mismatch_tolerance", 0.1) + self.power_threshold = self.config.get("power_threshold", -20) + self.dtw = self.config.get("dtw", False) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return mcd_f0( + np.asarray(predictions), + np.asarray(references), + fs, + self.f0min, + self.f0max, + mcep_shift=self.mcep_shift, + mcep_fftl=self.mcep_fftl, + mcep_dim=self.mcep_dim, + mcep_alpha=self.mcep_alpha, + seq_mismatch_tolerance=self.seq_mismatch_tolerance, + power_threshold=self.power_threshold, + dtw=self.dtw, + ) + + def get_metadata(self): + return _mcd_f0_metadata() + + +def _mcd_f0_metadata(): + return MetricMetadata( + name="mcd_f0", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.DICT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pysptk", "pyworld", "scipy", "fastdtw", "numpy"], + description="Mel cepstral distortion, F0 RMSE, and F0 correlation", + paper_reference="https://ieeexplore.ieee.org/document/407206", + implementation_source=( + "https://github.com/espnet/espnet and " + "https://github.com/unilight/s3prl-vc" + ), + ) + + +def register_mcd_f0_metric(registry): + """Register MCD/F0 metrics with the registry.""" + registry.register( + McdF0Metric, + _mcd_f0_metadata(), + aliases=["mcd", "mcd_f0_metric"], + ) + + # debug code if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - print(a, b) - print("metrics: {}".format(mcd_f0(a, b, 16000, 1, 8000, dtw=True))) + metric = McdF0Metric({"dtw": True, "f0min": 1, "f0max": 8000}) + print("metrics: {}".format(metric.compute(a, b, metadata={"sample_rate": 16000}))) diff --git a/versa/sequence_metrics/signal_metric.py b/versa/sequence_metrics/signal_metric.py index 712efaf..7b586e3 100644 --- a/versa/sequence_metrics/signal_metric.py +++ b/versa/sequence_metrics/signal_metric.py @@ -10,6 +10,8 @@ import torch from mir_eval.separation import bss_eval_sources +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + def calculate_si_snr(pred_x, gt_x, zero_mean=None, clamp_db=None, pairwise=False): # TODO(jiatong): pass zero_mean and clamp_db setup to the function @@ -59,9 +61,50 @@ def signal_metric(pred_x, gt_x): } +class SignalMetric(BaseMetric): + """Reference-based signal distortion metrics.""" + + def _setup(self): + pass + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + return signal_metric(np.asarray(predictions), np.asarray(references)) + + def get_metadata(self): + return _signal_metadata() + + +def _signal_metadata(): + return MetricMetadata( + name="signal_metric", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.DICT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["ci_sdr", "fast_bss_eval", "mir_eval", "numpy", "torch"], + description="Reference-based SDR, SIR, SAR, SI-SNR, and CI-SDR metrics", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_signal_metric(registry): + """Register signal distortion metrics with the registry.""" + registry.register( + SignalMetric, + _signal_metadata(), + aliases=["signal", "snr_related"], + ) + + # debug code if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - print(a, b) - print("metrics: {}".format(signal_metric(a, b))) + metric = SignalMetric() + print("metrics: {}".format(metric.compute(a, b))) diff --git a/versa/sequence_metrics/warpq.py b/versa/sequence_metrics/warpq.py index 53a5f75..e6dda79 100644 --- a/versa/sequence_metrics/warpq.py +++ b/versa/sequence_metrics/warpq.py @@ -5,11 +5,12 @@ import logging -logger = logging.getLogger(__name__) - -import librosa import numpy as np -import torch + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +logger = logging.getLogger(__name__) try: from WARPQ.WARPQmetric import warpqMetric @@ -39,7 +40,8 @@ def warpq_setup( } if warpqMetric is None: raise ImportError( - "Please install WARP-Q from /tools/install_warpq.sh, and retry after installation" + "Please install WARP-Q from /tools/install_warpq.sh, " + "and retry after installation" ) model = warpqMetric(args) logger.info("Mapping model is not loaded for current implementation.") @@ -56,15 +58,77 @@ def warpq(model, pred_x, gt_x, fs=8000): """ target_fs = model.args["sr"] if target_fs != fs: - gt_x = librosa.resample(gt_x, fs, target_fs) - pred_x = librosa.resample(pred_x, fs, target_fs) + gt_x = resample_audio(gt_x, fs, target_fs) + pred_x = resample_audio(pred_x, fs, target_fs) score = model.evaluate_versa(gt_x, pred_x) return {"warpq": score} +class WarpqMetric(BaseMetric): + """WARP-Q dynamic time warping cost metric.""" + + def _setup(self): + self.fs = self.config.get("fs", 8000) + self.n_mfcc = self.config.get("n_mfcc", 13) + self.fmax = self.config.get("fmax", 4000) + self.patch_size = self.config.get("patch_size", 0.5) + self.sigma = self.config.get("sigma", [[1, 0], [0, 3], [1, 3]]) + self.apply_vad = self.config.get("apply_vad", False) + self.model = warpq_setup( + fs=self.fs, + n_mfcc=self.n_mfcc, + fmax=self.fmax, + patch_size=self.patch_size, + sigma=self.sigma, + apply_vad=self.apply_vad, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return warpq( + self.model, + np.asarray(predictions), + np.asarray(references), + fs=fs, + ) + + def get_metadata(self): + return _warpq_metadata() + + +def _warpq_metadata(): + return MetricMetadata( + name="warpq", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["WARPQ", "librosa", "numpy"], + description="WARP-Q dynamic time warping cost metric", + paper_reference="https://arxiv.org/abs/2102.10449", + implementation_source="https://github.com/wjassim/WARP-Q", + ) + + +def register_warpq_metric(registry): + """Register WARP-Q with the registry.""" + registry.register( + WarpqMetric, + _warpq_metadata(), + aliases=["warpq_metric", "warp_q"], + ) + + if __name__ == "__main__": - model = warpq_setup() test_audio = np.zeros(16000) ref_audio = np.zeros(16000) - print(warpq(model, ref_audio, test_audio, 8000)) + metric = WarpqMetric() + print(metric.compute(test_audio, ref_audio, metadata={"sample_rate": 8000})) diff --git a/versa/utterance_metrics/arecho.py b/versa/utterance_metrics/arecho.py index 7943da9..b1c4b73 100644 --- a/versa/utterance_metrics/arecho.py +++ b/versa/utterance_metrics/arecho.py @@ -8,6 +8,8 @@ import librosa import soundfile as sf +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + def arecho_model_setup(model_tag="default", use_gpu=False): """ @@ -144,6 +146,57 @@ def arecho_noref_metric(model, pred_x, fs): return arecho_metric(model, pred_x, fs) +class ArechoMetric(BaseMetric): + """ARECHO no-reference speech quality metric.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.use_gpu = self.config.get("use_gpu", False) + self.model = arecho_model_setup( + model_tag=self.model_tag, + use_gpu=self.use_gpu, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + metadata = metadata or {} + fs = metadata.get("sample_rate", 16000) + pred_x = np.asarray(predictions) + return arecho_noref_metric(self.model, pred_x, fs) + + def get_metadata(self): + return _arecho_metadata() + + +def _arecho_metadata(): + return MetricMetadata( + name="arecho", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "torch", "librosa", "numpy"], + description=( + "ARECHO no-reference audio reference echo cancellation and codec " + "quality assessment" + ), + paper_reference="https://arxiv.org/abs/2505.20741", + implementation_source="https://huggingface.co/espnet/arecho_base_v0", + ) + + +def register_arecho_metric(registry): + """Register ARECHO with the metric registry.""" + registry.register( + ArechoMetric, + _arecho_metadata(), + aliases=["arecho_noref"], + ) + + if __name__ == "__main__": # Test the implementation print("Testing ARECHO metric implementation...") diff --git a/versa/utterance_metrics/asr_matching.py b/versa/utterance_metrics/asr_matching.py index d77797f..8635d9d 100644 --- a/versa/utterance_metrics/asr_matching.py +++ b/versa/utterance_metrics/asr_matching.py @@ -1,263 +1,249 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging -from typing import Dict, Optional, Union, Any - -import librosa -import numpy as np -import torch -from Levenshtein import opcodes - -logger = logging.getLogger(__name__) - -# Handle optional whisper dependency -try: - import whisper - - WHISPER_AVAILABLE = True -except ImportError: - logger.warning( - "Whisper is not properly installed. " - "Please install following https://github.com/openai/whisper" - ) - whisper = None - WHISPER_AVAILABLE = False - -from espnet2.text.cleaner import TextCleaner - -# Constants -TARGET_FS = 16000 -CHUNK_SIZE = 30 # seconds - - -class WhisperNotAvailableError(RuntimeError): - """Exception raised when Whisper is required but not available.""" - - pass - - -def asr_match_setup( - model_tag: str = "default", - beam_size: int = 5, - text_cleaner: str = "whisper_basic", - use_gpu: bool = True, -) -> Dict[str, Any]: - """ - Set up ASR matching utilities. - - Args: - model_tag: Whisper model tag. Options include "tiny", "base", "small", - "medium", "large", or "large-v2". Defaults to "large". - beam_size: Beam size for decoding. - text_cleaner: Text cleaner type for post-processing. - use_gpu: Whether to use GPU for computation. - - Returns: - Dictionary containing the model, text cleaner, and beam size. - - Raises: - WhisperNotAvailableError: If Whisper is not installed but is required. - RuntimeError: If model loading fails. - """ - if not WHISPER_AVAILABLE: - raise WhisperNotAvailableError( - "Whisper WER is used for evaluation while openai-whisper is not installed" - ) - - # Use the large model by default - if model_tag == "default": - model_tag = "large" - - # Set device based on availability and user preference - device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" - - try: - # Load the Whisper model - logger.info(f"Loading Whisper model '{model_tag}' on {device}") - model = whisper.load_model(model_tag, device=device) - - # Initialize text cleaner - textcleaner = TextCleaner(text_cleaner) - - # Return utilities dictionary - return {"model": model, "cleaner": textcleaner, "beam_size": beam_size} - except Exception as e: - raise RuntimeError(f"Failed to initialize Whisper model: {str(e)}") from e - - -def asr_match_metric( - wer_utils: Dict[str, Any], - pred_x: np.ndarray, - gt_x: np.ndarray, - cache_pred_text: Optional[str] = None, - fs: int = 16000, -) -> Dict[str, Union[float, str]]: - """ - Calculate the ASR match error rate and related metrics. - - This function compares the ASR transcription of the predicted audio - with the transcription of the ground truth audio to compute character-level - edit distance metrics. - - Args: - wer_utils: A utility dict for WER calculation including: - - model: whisper model - - cleaner: text cleaner - - beam_size: beam size for decoding - pred_x: Predicted/test signal as a numpy array (time,) - gt_x: Ground truth signal as a numpy array (time,) - cache_pred_text: Optional pre-computed transcription for pred_x - fs: Sampling rate of the input audio in Hz - - Returns: - Dictionary containing: - - asr_match_error_rate: The character error rate - - whisper_hyp_text: The transcription of the predicted audio - - Raises: - ValueError: If input data is invalid - RuntimeError: If transcription fails - """ - # Validate inputs - if pred_x is None or gt_x is None: - raise ValueError("Both predicted and ground truth signals must be provided") - - # Make sure inputs are numpy arrays - pred_x = np.asarray(pred_x) - gt_x = np.asarray(gt_x) - - # Process the speech to be evaluated - if cache_pred_text is not None: - inf_text = cache_pred_text - else: - try: - # Resample if necessary - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - - # Convert to tensor and transcribe - with torch.no_grad(): - transcription = wer_utils["model"].transcribe( - torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] - ) - inf_text = transcription["text"] - except Exception as e: - raise RuntimeError( - f"Failed to transcribe predicted signal: {str(e)}" - ) from e - - # Process the ground truth speech - try: - # Resample if necessary - if fs != TARGET_FS: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) - - # Convert to tensor and transcribe - with torch.no_grad(): - transcription = wer_utils["model"].transcribe( - torch.tensor(gt_x).float(), beam_size=wer_utils["beam_size"] - ) - gt_text = transcription["text"] - except Exception as e: - raise RuntimeError(f"Failed to transcribe ground truth signal: {str(e)}") from e - - # Clean the text using the provided cleaner - ref_text = wer_utils["cleaner"](gt_text) - pred_text = wer_utils["cleaner"](inf_text) - - # Convert texts to character lists for edit distance calculation - ref_chars = list(ref_text) - pred_chars = list(pred_text) - - # Initialize result dictionary with operation counts - result = { - "asr_match_delete": 0, # Deletions: chars in reference but not in prediction - "asr_match_insert": 0, # Insertions: chars in prediction but not in reference - "asr_match_replace": 0, # Substitutions: chars that differ between ref and pred - "asr_match_equal": 0, # Matches: chars that are the same in ref and pred - } - - # Calculate edit operations using Levenshtein - for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_chars, pred_chars): - if op == "insert": - result["asr_match_" + op] += inf_et - inf_st - else: - result["asr_match_" + op] += ref_et - ref_st - - # Validate operation counts - total_ref = ( - result["asr_match_delete"] - + result["asr_match_replace"] - + result["asr_match_equal"] - ) - if total_ref != len(ref_chars): - logger.warning( - f"Reference operation count mismatch: {total_ref} vs {len(ref_chars)}" - ) - - total_pred = ( - result["asr_match_insert"] - + result["asr_match_replace"] - + result["asr_match_equal"] - ) - if total_pred != len(pred_chars): - logger.warning( - f"Prediction operation count mismatch: {total_pred} vs {len(pred_chars)}" - ) - - # Calculate error rate - if len(ref_chars) == 0: - # Handle empty reference case - asr_match_error_rate = 1.0 - logger.warning("Reference text is empty, setting error rate to 1.0") - else: - # Calculate character error rate - asr_match_error_rate = ( - result["asr_match_delete"] - + result["asr_match_insert"] - + result["asr_match_replace"] - ) / len(ref_chars) - - # Return results - return { - "asr_match_error_rate": asr_match_error_rate, - "whisper_hyp_text": inf_text, - # Additional metrics that might be useful - "ref_text_length": len(ref_chars), - "pred_text_length": len(pred_chars), - "match_details": result, - } - - -def is_whisper_available(): - """ - Check if the Whisper package is available. - - Returns: - bool: True if Whisper is available, False otherwise. - """ - return WHISPER_AVAILABLE - - -if __name__ == "__main__": - # Example usage - try: - # Generate random test audio (1 second at 16kHz) - test_audio = np.random.random(TARGET_FS) - - # Set up ASR matching utilities - wer_utils = asr_match_setup(model_tag="tiny", use_gpu=torch.cuda.is_available()) - - # Calculate metrics - metrics = asr_match_metric(wer_utils, test_audio, test_audio, None, TARGET_FS) - - # Print results - print(f"ASR Match Error Rate: {metrics['asr_match_error_rate']:.4f}") - print(f"Transcription: '{metrics['whisper_hyp_text']}'") - except WhisperNotAvailableError: - print("This script requires the Whisper package. Please install it first.") - except Exception as e: - print(f"Error running ASR match: {str(e)}") +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +from typing import Dict, Optional, Union, Any + +import numpy as np +import torch +from Levenshtein import opcodes + +logger = logging.getLogger(__name__) + +# Handle optional whisper dependency +try: + import whisper + + WHISPER_AVAILABLE = True +except ImportError: + logger.warning( + "Whisper is not properly installed. " + "Please install following https://github.com/openai/whisper" + ) + whisper = None + WHISPER_AVAILABLE = False + +try: + from espnet2.text.cleaner import TextCleaner +except ImportError: + TextCleaner = None +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + +# Constants +TARGET_FS = 16000 +CHUNK_SIZE = 30 # seconds + + +class WhisperNotAvailableError(RuntimeError): + """Exception raised when Whisper is required but not available.""" + + pass + + +def is_whisper_available(): + """ + Check if the Whisper package is available. + + Returns: + bool: True if Whisper is available, False otherwise. + """ + return WHISPER_AVAILABLE + + +class ASRMatchMetric(BaseMetric): + """ASR-oriented Mismatch Error Rate (ASR-Match) metric using Whisper.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.beam_size = self.config.get("beam_size", 5) + self.text_cleaner = self.config.get("text_cleaner", "whisper_basic") + self.use_gpu = self.config.get("use_gpu", True) + self.cache_dir = self.config.get("cache_dir", "versa_cache/whisper") + self.wer_utils = asr_match_setup( + self.model_tag, + self.beam_size, + self.text_cleaner, + use_gpu=self.use_gpu, + cache_dir=self.cache_dir, + ) + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + pred_x = predictions + gt_x = references + fs = 16000 + cache_pred_text = None + if metadata is not None: + fs = metadata.get("sample_rate", 16000) + cache_pred_text = metadata.get("cache_pred_text", None) + # Validate inputs + if pred_x is None or gt_x is None: + raise ValueError("Both predicted and ground truth signals must be provided") + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + return asr_match_metric(self.wer_utils, pred_x, gt_x, cache_pred_text, fs) + + def get_metadata(self) -> MetricMetadata: + return MetricMetadata( + name="asr_match", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["whisper", "espnet2", "Levenshtein", "librosa", "torch"], + description="ASR-oriented Mismatch Error Rate (ASR-Match) using Whisper for reference-based speech evaluation.", + paper_reference=None, + implementation_source="https://github.com/ftshijt/versa", + ) + + +def register_asr_match_metric(registry): + """Register ASR-Match metric with the registry.""" + metric_metadata = MetricMetadata( + name="asr_match", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["whisper", "espnet2", "Levenshtein", "librosa", "torch"], + description="ASR-oriented Mismatch Error Rate (ASR-Match) using Whisper for reference-based speech evaluation.", + paper_reference=None, + implementation_source="https://github.com/ftshijt/versa", + ) + registry.register( + ASRMatchMetric, metric_metadata, aliases=["ASRMatch", "asr_match_error_rate"] + ) + + +def asr_match_setup( + model_tag="default", + beam_size=5, + text_cleaner="whisper_basic", + use_gpu=True, + cache_dir="versa_cache/whisper", +): + """Legacy function API for setting up ASR-Match.""" + if not WHISPER_AVAILABLE: + raise ImportError( + "Whisper is not properly installed. Please install following https://github.com/openai/whisper" + ) + if TextCleaner is None: + raise ImportError("asr_match requires espnet TextCleaner. Install espnet") + if model_tag == "default": + model_tag = "large" + device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" + try: + model = whisper.load_model(model_tag, device=device, download_root=cache_dir) + cleaner = TextCleaner(text_cleaner) + except Exception as e: + raise RuntimeError(f"Failed to initialize Whisper model: {str(e)}") from e + return { + "model": model, + "cleaner": cleaner, + "beam_size": beam_size, + "device": device, + } + + +def asr_match_metric(wer_utils, pred_x, gt_x, cache_pred_text=None, fs=16000): + """Legacy function API for computing ASR-Match.""" + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + model = wer_utils["model"] + cleaner = wer_utils["cleaner"] + beam_size = wer_utils.get("beam_size", 5) + + if cache_pred_text is not None: + inf_text = cache_pred_text + else: + try: + if fs != TARGET_FS: + pred_x = resample_audio(pred_x, fs, TARGET_FS) + with torch.no_grad(): + transcription = model.transcribe( + torch.tensor(pred_x).float(), beam_size=beam_size + ) + inf_text = transcription["text"] + except Exception as e: + raise RuntimeError( + f"Failed to transcribe predicted signal: {str(e)}" + ) from e + + try: + if fs != TARGET_FS: + gt_x = resample_audio(gt_x, fs, TARGET_FS) + with torch.no_grad(): + transcription = model.transcribe( + torch.tensor(gt_x).float(), beam_size=beam_size + ) + gt_text = transcription["text"] + except Exception as e: + raise RuntimeError(f"Failed to transcribe ground truth signal: {str(e)}") from e + + ref_text = cleaner(gt_text) + pred_text = cleaner(inf_text) + ref_chars = list(ref_text) + pred_chars = list(pred_text) + result = { + "asr_match_delete": 0, + "asr_match_insert": 0, + "asr_match_replace": 0, + "asr_match_equal": 0, + } + for op, ref_st, ref_et, inf_st, inf_et in opcodes(ref_chars, pred_chars): + if op == "insert": + result["asr_match_" + op] += inf_et - inf_st + else: + result["asr_match_" + op] += ref_et - ref_st + + if len(ref_chars) == 0: + asr_match_error_rate = 1.0 + logger.warning("Reference text is empty, setting error rate to 1.0") + else: + asr_match_error_rate = ( + result["asr_match_delete"] + + result["asr_match_insert"] + + result["asr_match_replace"] + ) / len(ref_chars) + + return { + "asr_match_error_rate": asr_match_error_rate, + "whisper_hyp_text": inf_text, + "ref_text_length": len(ref_chars), + "pred_text_length": len(pred_chars), + "match_details": result, + } + + +if __name__ == "__main__": + # Example usage for the class-based metric + try: + # Generate random test audio (1 second at 16kHz) + test_audio = np.random.random(TARGET_FS) + # Set up ASR matching metric + config = { + "model_tag": "tiny", + "beam_size": 1, + "text_cleaner": "whisper_basic", + "use_gpu": torch.cuda.is_available(), + } + metric = ASRMatchMetric(config) + # Calculate metrics + metrics = metric.compute( + test_audio, test_audio, metadata={"sample_rate": TARGET_FS} + ) + # Print results + print(f"ASR Match Error Rate: {metrics['asr_match_error_rate']:.4f}") + print(f"Transcription: '{metrics['whisper_hyp_text']}'") + except WhisperNotAvailableError: + print("This script requires the Whisper package. Please install it first.") + except Exception as e: + print(f"Error running ASR match: {str(e)}") diff --git a/versa/utterance_metrics/asvspoof_score.py b/versa/utterance_metrics/asvspoof_score.py index 473c184..5ab4dae 100644 --- a/versa/utterance_metrics/asvspoof_score.py +++ b/versa/utterance_metrics/asvspoof_score.py @@ -13,79 +13,183 @@ """ import json +import logging import os import sys +from pathlib import Path +from typing import Dict, Any, Optional, Union -import librosa import numpy as np import torch -sys.path.append("./tools/checkpoints/aasist") -from models.AASIST import Model as AASIST # noqa: E402 +logger = logging.getLogger(__name__) +# Handle optional AASIST dependency +try: + aasist_path = ( + Path(__file__).resolve().parents[2] / "tools" / "checkpoints" / "aasist" + ) + sys.path.append(str(aasist_path)) + from models.AASIST import Model as AASIST # noqa: E402 -def deepfake_detection_model_setup( - model_tag="default", model_path=None, model_config=None, use_gpu=False -): - """Setup deepfake detection model. + AASIST_AVAILABLE = True +except ImportError: + logger.warning( + "AASIST is not properly installed. " + "Please install following https://github.com/clovaai/aasist" + ) + AASIST = None + AASIST_AVAILABLE = False - Args: - model_tag (str): Model tag. Defaults to "default". - model_path (str, optional): Path to model weights. Defaults to None. - model_config (str, optional): Path to model config. Defaults to None. - use_gpu (bool, optional): Whether to use GPU. Defaults to False. +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType - Returns: - AASIST: The loaded model. - """ - device = "cuda" if use_gpu else "cpu" - - if model_path is not None and model_config is not None: - with open(model_config, "r") as f_json: - config = json.loads(f_json.read()) - model = AASIST(config["model_config"]).to(device) - model.load_state_dict(torch.load(model_path, map_location=device)) - else: - if model_tag == "default": - model_root = "./tools/checkpoints/aasist" - model_config = os.path.join(model_root, "config/AASIST.conf") - model_path = os.path.join(model_root, "models/weights/AASIST.pth") - - with open(model_config, "r") as f_json: - config = json.loads(f_json.read()) - model = AASIST(config["model_config"]).to(device) - model.load_state_dict(torch.load(model_path, map_location=device)) - else: - raise NotImplementedError - model.device = device - return model +class AASISTNotAvailableError(RuntimeError): + """Exception raised when AASIST is required but not available.""" -def asvspoof_metric(model, pred_x, fs): - """Calculate ASVspoof score for audio. + pass - Args: - model (AASIST): The loaded deepfake detection model. - pred_x (np.ndarray): Audio signal. - fs (int): Sampling rate. + +def is_aasist_available(): + """ + Check if the AASIST package is available. Returns: - dict: Dictionary containing the ASVspoof score. + bool: True if AASIST is available, False otherwise. """ - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + return AASIST_AVAILABLE + + +class ASVSpoofMetric(BaseMetric): + """ASVspoof deepfake detection metric using AASIST model.""" - pred_x = torch.from_numpy(pred_x).unsqueeze(0).float().to(model.device) - model.eval() - with torch.no_grad(): - _, output = model(pred_x) - output = torch.softmax(output, dim=1) - output = output.squeeze(0).cpu().numpy() - return {"asvspoof_score": output[1]} + def _setup(self): + """Initialize ASVspoof-specific components.""" + if not AASIST_AVAILABLE: + raise ImportError( + "AASIST is not properly installed. Please install following https://github.com/clovaai/aasist" + ) + + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path", None) + self.model_config = self.config.get("model_config", None) + self.use_gpu = self.config.get("use_gpu", False) + + self.device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize AASIST model: {str(e)}") from e + + def _setup_model(self): + """Setup the AASIST model.""" + if self.model_path is not None and self.model_config is not None: + with open(self.model_config, "r") as f_json: + config = json.loads(f_json.read()) + model = AASIST(config["model_config"]).to(self.device) + model.load_state_dict( + torch.load(self.model_path, map_location=self.device) + ) + else: + if self.model_tag == "default": + model_root = "./tools/checkpoints/aasist" + model_config = os.path.join(model_root, "config/AASIST.conf") + model_path = os.path.join(model_root, "models/weights/AASIST.pth") + + with open(model_config, "r") as f_json: + config = json.loads(f_json.read()) + model = AASIST(config["model_config"]).to(self.device) + model.load_state_dict( + torch.load(model_path, map_location=self.device) + ) + else: + raise NotImplementedError( + f"Model tag '{self.model_tag}' not implemented" + ) + + model.device = self.device + return model + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate ASVspoof score for audio. + + Args: + predictions: Audio signal to evaluate. + references: Not used for this metric. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the ASVspoof score. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate input + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # NOTE(jiatong): only work for 16000 Hz + if fs != 16000: + pred_x = resample_audio(pred_x, fs, 16000) + + pred_x = torch.from_numpy(pred_x).unsqueeze(0).float().to(self.device) + self.model.eval() + with torch.no_grad(): + _, output = self.model(pred_x) + output = torch.softmax(output, dim=1) + output = output.squeeze(0).cpu().numpy() + + return {"asvspoof_score": output[1]} + + def get_metadata(self) -> MetricMetadata: + """Return ASVspoof metric metadata.""" + return MetricMetadata( + name="asvspoof", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="ASVspoof deepfake detection score using AASIST model for speech authenticity assessment", + paper_reference="https://github.com/clovaai/aasist", + implementation_source="https://github.com/clovaai/aasist", + ) + + +def register_asvspoof_metric(registry): + """Register ASVspoof metric with the registry.""" + metric_metadata = MetricMetadata( + name="asvspoof", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="ASVspoof deepfake detection score using AASIST model for speech authenticity assessment", + paper_reference="https://github.com/clovaai/aasist", + implementation_source="https://github.com/clovaai/aasist", + ) + registry.register( + ASVSpoofMetric, metric_metadata, aliases=["ASVSpoof", "asvspoof_score"] + ) if __name__ == "__main__": a = np.random.random(16000) - model = deepfake_detection_model_setup(use_gpu=False) - print(f"metrics: {asvspoof_metric(model, a, 16000)}") + + # Test the new class-based metric + config = {"use_gpu": False} + metric = ASVSpoofMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/audiobox_aesthetics_score.py b/versa/utterance_metrics/audiobox_aesthetics_score.py index 71a05cb..af2d833 100644 --- a/versa/utterance_metrics/audiobox_aesthetics_score.py +++ b/versa/utterance_metrics/audiobox_aesthetics_score.py @@ -6,87 +6,184 @@ """Module for evaluating audio using AudioBox Aesthetics models.""" import json +import logging import os +from typing import Dict, Any, Optional, Union import numpy as np +logger = logging.getLogger(__name__) + +# Handle optional audiobox_aesthetics dependency try: import audiobox_aesthetics.infer import audiobox_aesthetics.utils + + AUDIOBOX_AESTHETICS_AVAILABLE = True except ImportError: + logger.warning( + "audiobox_aesthetics is not properly installed. " + "Please install with tools/install_audiobox-aesthetics.sh first." + ) audiobox_aesthetics = None + AUDIOBOX_AESTHETICS_AVAILABLE = False +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType -def audiobox_aesthetics_setup( - model_path=None, - batch_size=1, - precision="bf16", - cache_dir="versa_cache/audiobox", - use_huggingface=True, - use_gpu=False, -): - """Set up the AudioBox Aesthetics model for inference. - - Args: - model_path (str, optional): Path to model weights. Defaults to None. - batch_size (int, optional): Batch size for inference. Defaults to 1. - precision (str, optional): Precision for inference. Defaults to "bf16". - cache_dir (str, optional): Directory to cache model. Defaults to "versa_cache/audiobox". - use_huggingface (bool, optional): Whether to use Hugging Face. Defaults to True. - use_gpu (bool, optional): Whether to use GPU. Defaults to False. - Returns: - AesWavlmPredictorMultiOutput: The loaded model. +class AudioBoxAestheticsNotAvailableError(RuntimeError): + """Exception raised when AudioBox Aesthetics is required but not available.""" + + pass + - Raises: - ImportError: If audiobox_aesthetics is not installed. +def is_audiobox_aesthetics_available(): """ - if audiobox_aesthetics is None: - raise ImportError( - "Please install with tools/install_audiobox-aesthetics.sh first." - ) + Check if the AudioBox Aesthetics package is available. - device = "cuda" if use_gpu else "cpu" + Returns: + bool: True if AudioBox Aesthetics is available, False otherwise. + """ + return AUDIOBOX_AESTHETICS_AVAILABLE - if model_path is None: - if use_huggingface: - model_path = audiobox_aesthetics.utils.load_model(model_path) - else: - os.makedirs(cache_dir, exist_ok=True) - model_path = os.path.join( - cache_dir, audiobox_aesthetics.utils.DEFAULT_CKPT_FNAME + +class AudioBoxAestheticsMetric(BaseMetric): + """AudioBox Aesthetics metric for audio quality assessment.""" + + def _setup(self): + """Initialize AudioBox Aesthetics-specific components.""" + if not AUDIOBOX_AESTHETICS_AVAILABLE: + raise ImportError( + "audiobox_aesthetics is not properly installed. " + "Please install with tools/install_audiobox-aesthetics.sh first." ) - model_url = audiobox_aesthetics.utils.DEFAULT_S3_URL - if not os.path.exists(model_path): - print(f"Downloading model from {model_url} to {model_path}") - audiobox_aesthetics.utils.download_file(model_url, model_path) - - predictor = audiobox_aesthetics.infer.AesWavlmPredictorMultiOutput( - checkpoint_pth=model_path, - device=device, - batch_size=batch_size, - precision=precision, - ) - return predictor + self.model_path = self.config.get("model_path", None) + self.batch_size = self.config.get("batch_size", 1) + self.precision = self.config.get("precision", "bf16") + self.cache_dir = self.config.get("cache_dir", "versa_cache/audiobox") + self.use_huggingface = self.config.get("use_huggingface", True) + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError( + f"Failed to initialize AudioBox Aesthetics model: {str(e)}" + ) from e + + def _setup_model(self): + """Setup the AudioBox Aesthetics model.""" + device = "cuda" if self.use_gpu else "cpu" + + if self.model_path is None: + if self.use_huggingface: + try: + import huggingface_hub + except ImportError as e: + raise ImportError( + "Please install huggingface_hub or set use_huggingface=False " + "to download the AudioBox Aesthetics checkpoint directly." + ) from e + + os.makedirs(self.cache_dir, exist_ok=True) + model_path = huggingface_hub.hf_hub_download( + audiobox_aesthetics.utils.DEFAULT_HF_REPO, + audiobox_aesthetics.utils.DEFAULT_CKPT_FNAME, + cache_dir=self.cache_dir, + ) + logger.info("Load AudioBox Aesthetics checkpoint from %s", model_path) + else: + os.makedirs(self.cache_dir, exist_ok=True) + model_path = os.path.join( + self.cache_dir, audiobox_aesthetics.utils.DEFAULT_CKPT_FNAME + ) + model_url = audiobox_aesthetics.utils.DEFAULT_S3_URL + if not os.path.exists(model_path): + print(f"Downloading model from {model_url} to {model_path}") + audiobox_aesthetics.utils.download_file(model_url, model_path) + else: + model_path = self.model_path -def audiobox_aesthetics_score(model, pred_x, fs): - """Calculate AudioBox Aesthetics scores for audio. + predictor = audiobox_aesthetics.infer.AesWavlmPredictorMultiOutput( + checkpoint_pth=model_path, + device=device, + batch_size=self.batch_size, + precision=self.precision, + ) + return predictor + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate AudioBox Aesthetics scores for audio. + + Args: + predictions: Audio signal to evaluate. + references: Not used for this metric. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the AudioBox Aesthetics scores. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate input + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + output = json.loads(self.model.forward_versa([(pred_x, fs)])[0]) + output = {"audiobox_aesthetics_" + k: v for k, v in output.items()} + return output + + def get_metadata(self) -> MetricMetadata: + """Return AudioBox Aesthetics metric metadata.""" + return MetricMetadata( + name="audiobox_aesthetics", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["audiobox_aesthetics", "numpy"], + description="AudioBox Aesthetics scores for audio quality assessment using WavLM-based models", + paper_reference="https://github.com/facebookresearch/audiobox-aesthetics", + implementation_source="https://github.com/facebookresearch/audiobox-aesthetics", + ) - Args: - model (AesWavlmPredictorMultiOutput): The loaded model. - pred_x (np.ndarray): Audio signal. - fs (int): Sampling rate. - Returns: - dict: Dictionary containing the AudioBox Aesthetics scores. - """ - output = json.loads(model.forward_versa([(pred_x, fs)])[0]) - output = {"audiobox_aesthetics_" + k: v for k, v in output.items()} - return output +def register_audiobox_aesthetics_metric(registry): + """Register AudioBox Aesthetics metric with the registry.""" + metric_metadata = MetricMetadata( + name="audiobox_aesthetics", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["audiobox_aesthetics", "numpy"], + description="AudioBox Aesthetics scores for audio quality assessment using WavLM-based models", + paper_reference="https://github.com/facebookresearch/audiobox-aesthetics", + implementation_source="https://github.com/facebookresearch/audiobox-aesthetics", + ) + registry.register( + AudioBoxAestheticsMetric, + metric_metadata, + aliases=["AudioBoxAesthetics", "audiobox_aesthetics"], + ) if __name__ == "__main__": a = np.random.random(16000) - model = audiobox_aesthetics_setup() - print(f"metrics: {audiobox_aesthetics_score(model, a, 16000)}") + + # Test the new class-based metric + config = {"use_gpu": False} + metric = AudioBoxAestheticsMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/cdpam_distance.py b/versa/utterance_metrics/cdpam_distance.py index 772e23e..2f11638 100644 --- a/versa/utterance_metrics/cdpam_distance.py +++ b/versa/utterance_metrics/cdpam_distance.py @@ -1,33 +1,166 @@ -import torch -import librosa -import numpy as np +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for CDPAM distance metrics.""" + +import logging from functools import partial -import cdpam +from typing import Dict, Any, Optional, Union + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# Handle optional cdpam dependency +try: + import cdpam + + CDPAM_AVAILABLE = True +except ImportError: + logger.warning("cdpam is not properly installed. " "Please install cdpam and retry") + cdpam = None + CDPAM_AVAILABLE = False + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class CdpamNotAvailableError(RuntimeError): + """Exception raised when cdpam is required but not available.""" + + pass + + +def is_cdpam_available(): + """ + Check if the cdpam package is available. -TARGET_FS = 22050 + Returns: + bool: True if cdpam is available, False otherwise. + """ + return CDPAM_AVAILABLE -def cdpam_model_setup(use_gpu=False): - device = "cpu" if not use_gpu else "cuda" - _original_torch_load = torch.load - torch.load = partial(torch.load, weights_only=False) - model = cdpam.CDPAM(dev=device) - torch.load = _original_torch_load - return model +class CdpamDistanceMetric(BaseMetric): + """CDPAM distance metric.""" + TARGET_FS = 22050 -def cdpam_metric(model, pred_x, gt_x, fs): - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) - pred_x = (torch.from_numpy(pred_x).unsqueeze(0) * 32768).round() - gt_x = (torch.from_numpy(gt_x).unsqueeze(0) * 32768).round() - dist = model.forward(gt_x, pred_x) - return {"cdpam_distance": dist.detach().cpu().numpy().item()} + def _setup(self): + """Initialize CDPAM-specific components.""" + if not CDPAM_AVAILABLE: + raise ImportError( + "cdpam is not properly installed. " "Please install cdpam and retry" + ) + + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize CDPAM model: {str(e)}") from e + + def _setup_model(self): + """Setup the CDPAM model.""" + device = "cpu" if not self.use_gpu else "cuda" + # Suppress PyTorch config registration warnings during model loading + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + _original_torch_load = torch.load + torch.load = partial(torch.load, weights_only=False) + model = cdpam.CDPAM(dev=device) + torch.load = _original_torch_load + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate CDPAM distance between two audio samples. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the CDPAM distance score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 22050) if metadata else 22050 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + if fs != self.TARGET_FS: + pred_x = resample_audio(pred_x, fs, self.TARGET_FS) + gt_x = resample_audio(gt_x, fs, self.TARGET_FS) + + pred_x = (torch.from_numpy(pred_x).unsqueeze(0) * 32768).round() + gt_x = (torch.from_numpy(gt_x).unsqueeze(0) * 32768).round() + dist = self.model.forward(gt_x, pred_x) + + return {"cdpam_distance": dist.detach().cpu().numpy().item()} + + def get_metadata(self) -> MetricMetadata: + """Return CDPAM distance metric metadata.""" + return MetricMetadata( + name="cdpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["cdpam", "torch", "librosa", "numpy"], + description="CDPAM distance between audio samples", + paper_reference="https://github.com/facebookresearch/audiocraft", + implementation_source="https://github.com/facebookresearch/audiocraft", + ) + + +def register_cdpam_distance_metric(registry): + """Register CDPAM distance metric with the registry.""" + metric_metadata = MetricMetadata( + name="cdpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["cdpam", "torch", "librosa", "numpy"], + description="CDPAM distance between audio samples", + paper_reference="https://github.com/facebookresearch/audiocraft", + implementation_source="https://github.com/facebookresearch/audiocraft", + ) + registry.register( + CdpamDistanceMetric, + metric_metadata, + aliases=["CdpamDistance", "cdpam_distance", "cdpam"], + ) if __name__ == "__main__": a = np.random.random(22050) b = np.random.random(22050) - model = cdpam_model_setup() - print("metrics: {}".format(cdpam_metric(model, a, b, 22050))) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = CdpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/chroma_alignment.py b/versa/utterance_metrics/chroma_alignment.py index 54f030d..4fb8b3f 100644 --- a/versa/utterance_metrics/chroma_alignment.py +++ b/versa/utterance_metrics/chroma_alignment.py @@ -4,10 +4,16 @@ # Chroma-based distance estimation with dynamic programming alignment # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import logging +from typing import Dict, Any, Optional, Union, Tuple, List + import librosa import numpy as np from scipy.spatial.distance import cosine, euclidean -from typing import Tuple, Dict, Optional + +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + +logger = logging.getLogger(__name__) def calculate_chroma_features(audio, sr=22050, feature_type="stft", **kwargs): @@ -161,127 +167,166 @@ def calculate_chroma_distance( return dtw_dist, alignment_path -def chroma_metric(pred_x, gt_x, sr=22050, return_alignment=False, scale_factor=100.0): - """ - Calculate multiple chroma-based distance metrics. +class ChromaAlignmentMetric(BaseMetric): + """Chroma-based distance estimation with dynamic programming alignment.""" - Args: - pred_x: Predicted audio signal (1D numpy array) - gt_x: Ground truth audio signal (1D numpy array) - sr: Sample rate - return_alignment: Whether to return alignment paths - scale_factor: Multiplicative scaling factor for distances - - Returns: - Dictionary of chroma distance metrics - """ - # Ensure 1D arrays - if pred_x.ndim > 1: - pred_x = pred_x.flatten() - if gt_x.ndim > 1: - gt_x = gt_x.flatten() - - results = {} - alignments = {} if return_alignment else None - - # Different chroma feature types - feature_types = [ - "stft", - "cqt", - "cens", - ] # 'vqt' might not be available in all librosa versions - distance_metrics = ["cosine", "euclidean"] - - for feat_type in feature_types: - for dist_metric in distance_metrics: - try: - dtw_dist, alignment = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type=feat_type, - distance_metric=dist_metric, - scale_factor=scale_factor, - ) - - metric_name = f"chroma_{feat_type}_{dist_metric}_dtw" - results[metric_name] = dtw_dist - - if return_alignment: - alignments[metric_name] = alignment - - except Exception as e: - print( - f"Warning: Could not calculate {feat_type} with {dist_metric}: {e}" - ) - continue - - # Add additional scaled variants - try: - # Raw DTW distance (no path normalization, higher scale) - dtw_dist_raw, _ = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type="stft", - distance_metric="cosine", - scale_factor=1000.0, - normalize_by_path=True, + def _setup(self): + """Initialize Chroma Alignment-specific components.""" + self.sample_rate = self.config.get("sample_rate", 22050) + self.feature_types = self.config.get("feature_types", ["stft", "cqt", "cens"]) + self.distance_metrics = self.config.get( + "distance_metrics", ["cosine", "euclidean"] ) - results["chroma_stft_cosine_dtw_raw"] = dtw_dist_raw - - # Log-scaled distance - dtw_dist_base, _ = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type="stft", - distance_metric="cosine", - scale_factor=1.0, - normalize_by_path=True, + self.scale_factor = self.config.get("scale_factor", 100.0) + self.normalize = self.config.get("normalize", True) + self.normalize_by_path = self.config.get("normalize_by_path", True) + self.return_alignment = self.config.get("return_alignment", False) + self.chroma_kwargs = self.config.get("chroma_kwargs", {}) + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate chroma-based distance metrics. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing chroma distance metrics. + """ + pred_x = predictions + gt_x = references + sr = ( + metadata.get("sample_rate", self.sample_rate) + if metadata + else self.sample_rate ) - results["chroma_stft_cosine_dtw_log"] = -np.log10(dtw_dist_base + 1e-10) * 10 - except Exception as e: - print(f"Warning: Could not calculate additional scaled metrics: {e}") - - if return_alignment: - return results, alignments - return results + # Validate inputs + if pred_x is None or gt_x is None: + raise ValueError("Both predicted and ground truth signals must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # Ensure 1D arrays + if pred_x.ndim > 1: + pred_x = pred_x.flatten() + if gt_x.ndim > 1: + gt_x = gt_x.flatten() + + results = {} + alignments = {} if self.return_alignment else None + + # Calculate metrics for different feature types and distance metrics + for feat_type in self.feature_types: + for dist_metric in self.distance_metrics: + try: + dtw_dist, alignment = calculate_chroma_distance( + pred_x, + gt_x, + sr=sr, + feature_type=feat_type, + distance_metric=dist_metric, + scale_factor=self.scale_factor, + normalize=self.normalize, + normalize_by_path=self.normalize_by_path, + **self.chroma_kwargs, + ) + + metric_name = f"chroma_{feat_type}_{dist_metric}_dtw" + results[metric_name] = dtw_dist + + if self.return_alignment and alignments is not None: + alignments[metric_name] = alignment + + except Exception as e: + logger.warning( + f"Could not calculate {feat_type} with {dist_metric}: {e}" + ) + continue + + # Add additional scaled variants + try: + # Raw DTW distance (no path normalization, higher scale) + dtw_dist_raw, _ = calculate_chroma_distance( + pred_x, + gt_x, + sr=sr, + feature_type="stft", + distance_metric="cosine", + scale_factor=1000.0, + normalize_by_path=True, + normalize=self.normalize, + **self.chroma_kwargs, + ) + results["chroma_stft_cosine_dtw_raw"] = dtw_dist_raw + + # Log-scaled distance + dtw_dist_base, _ = calculate_chroma_distance( + pred_x, + gt_x, + sr=sr, + feature_type="stft", + distance_metric="cosine", + scale_factor=1.0, + normalize_by_path=True, + normalize=self.normalize, + **self.chroma_kwargs, + ) + results["chroma_stft_cosine_dtw_log"] = ( + -np.log10(dtw_dist_base + 1e-10) * 10 + ) + except Exception as e: + logger.warning(f"Could not calculate additional scaled metrics: {e}") + + if self.return_alignment and alignments is not None: + results["alignments"] = alignments + + return results + + def get_metadata(self) -> MetricMetadata: + """Return Chroma Alignment metric metadata.""" + return MetricMetadata( + name="chroma_alignment", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["librosa", "numpy", "scipy"], + description="Chroma-based distance estimation with dynamic programming alignment for audio similarity assessment", + paper_reference="https://librosa.org/doc/latest/generated/librosa.feature.chroma_stft.html", + implementation_source="https://github.com/librosa/librosa", + ) -def simple_chroma_distance( - pred_x, - gt_x, - sr=22050, - feature_type="stft", - distance_metric="cosine", - scale_factor=100.0, -): - """ - Args: - pred_x: Predicted audio signal - gt_x: Ground truth audio signal - sr: Sample rate - feature_type: Chroma feature type - distance_metric: Distance metric - scale_factor: Multiplicative scaling factor - Returns: - DTW distance value - """ - dtw_dist, _ = calculate_chroma_distance( - pred_x, - gt_x, - sr=sr, - feature_type=feature_type, - distance_metric=distance_metric, - scale_factor=scale_factor, +def register_chroma_alignment_metric(registry): + """Register Chroma Alignment metric with the registry.""" + metric_metadata = MetricMetadata( + name="chroma_alignment", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["librosa", "numpy", "scipy"], + description="Chroma-based distance estimation with dynamic programming alignment for audio similarity assessment", + paper_reference="https://librosa.org/doc/latest/generated/librosa.feature.chroma_stft.html", + implementation_source="https://github.com/librosa/librosa", + ) + registry.register( + ChromaAlignmentMetric, + metric_metadata, + aliases=["ChromaAlignment", "chroma_alignment"], ) - return dtw_dist -# Debug code if __name__ == "__main__": # Create test signals with different lengths sr = 22050 @@ -295,48 +340,9 @@ def simple_chroma_distance( pred_signal = np.sin(2 * np.pi * 440 * t1) # A4 note gt_signal = np.sin(2 * np.pi * 440 * t2) # Same note, different length - # Create a more different signal for testing - diff_signal = np.sin(2 * np.pi * 554.37 * t1) # C#5 note (different pitch) - - print(f"Predicted signal length: {len(pred_signal)} samples ({duration1}s)") - print(f"Ground truth signal length: {len(gt_signal)} samples ({duration2}s)") - - # Calculate chroma metrics for similar signals - print("\n=== SIMILAR SIGNALS (same pitch, different length) ===") - metrics_similar = chroma_metric(pred_signal, gt_signal, sr=sr, scale_factor=100.0) - for metric_name, value in metrics_similar.items(): - print(f"{metric_name}: {value:.4f}") - - # Calculate chroma metrics for different signals - print("\n=== DIFFERENT SIGNALS (different pitch) ===") - metrics_different = chroma_metric( - pred_signal, diff_signal, sr=sr, scale_factor=100.0 - ) - for metric_name, value in metrics_different.items(): - print(f"{metric_name}: {value:.4f}") - - # Simple interface examples with different scale factors - print("\n=== SIMPLE INTERFACE WITH DIFFERENT SCALES ===") - print( - f"Scale 1.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=1.0):.4f}" - ) - print( - f"Scale 10.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=10.0):.4f}" - ) - print( - f"Scale 100.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=100.0):.4f}" - ) - print( - f"Scale 1000.0: {simple_chroma_distance(pred_signal, gt_signal, sr=sr, scale_factor=1000.0):.4f}" - ) - - # Test with random signals (should give larger distances) - print("\n=== RANDOM SIGNALS (should give larger distances) ===") - random_signal1 = np.random.randn(int(sr * 2.0)) - random_signal2 = np.random.randn(int(sr * 2.0)) - - metrics_random = chroma_metric( - random_signal1, random_signal2, sr=sr, scale_factor=100.0 - ) - for metric_name, value in list(metrics_random.items())[:3]: # Show first 3 metrics - print(f"{metric_name}: {value:.4f}") + # Test the new class-based metric + config = {"scale_factor": 100.0} + metric = ChromaAlignmentMetric(config) + metadata = {"sample_rate": sr} + score = metric.compute(pred_signal, gt_signal, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/discrete_speech.py b/versa/utterance_metrics/discrete_speech.py index f628906..a28773f 100644 --- a/versa/utterance_metrics/discrete_speech.py +++ b/versa/utterance_metrics/discrete_speech.py @@ -5,94 +5,238 @@ """Module for discrete speech metrics evaluation.""" +import importlib.util import logging +import os +from pathlib import Path +from typing import Dict, Any, Optional, Union -import librosa import numpy as np -try: - from discrete_speech_metrics import SpeechBERTScore, SpeechBLEU, SpeechTokenDistance -except ImportError: - raise ImportError("Please install discrete_speech_metrics and retry") - logger = logging.getLogger(__name__) +DISCRETE_SPEECH_AVAILABLE = ( + importlib.util.find_spec("discrete_speech_metrics") is not None +) +if not DISCRETE_SPEECH_AVAILABLE: + logger.warning( + "discrete_speech_metrics is not properly installed. " + "Please install discrete_speech_metrics and retry" + ) -def discrete_speech_setup(use_gpu=False): - """Set up discrete speech metrics. +SpeechBERTScore = None +SpeechBLEU = None +SpeechTokenDistance = None - Args: - use_gpu (bool, optional): Whether to use GPU. Defaults to False. - Returns: - dict: Dictionary containing the initialized metrics. - """ - # NOTE(jiatong) existing discrete speech metrics only works for 16khz - # We keep the paper best setting. To use other settings, please conduct the - # test on your own. +def _configure_huggingface_cache(cache_dir): + cache_path = Path(cache_dir).resolve() + cache_path.mkdir(parents=True, exist_ok=True) + os.environ.setdefault("HF_HOME", str(cache_path)) + os.environ.setdefault("HF_HUB_CACHE", str(cache_path / "hub")) + os.environ.setdefault("TRANSFORMERS_CACHE", str(cache_path / "transformers")) - speech_bert = SpeechBERTScore( - sr=16000, model_type="wavlm-large", layer=14, use_gpu=use_gpu - ) - speech_bleu = SpeechBLEU( - sr=16000, - model_type="hubert-base", - vocab=200, - layer=11, - n_ngram=2, - remove_repetition=True, - use_gpu=use_gpu, - ) - speech_token_distance = SpeechTokenDistance( - sr=16000, - model_type="hubert-base", - vocab=200, - layer=6, - distance_type="jaro-winkler", - remove_repetition=False, - use_gpu=use_gpu, - ) - return { - "speech_bert": speech_bert, - "speech_bleu": speech_bleu, - "speech_token_distance": speech_token_distance, - } +def _load_discrete_speech_classes(cache_dir): + global SpeechBERTScore, SpeechBLEU, SpeechTokenDistance -def discrete_speech_metric(discrete_speech_predictors, pred_x, gt_x, fs): - """Calculate discrete speech metrics. + if not DISCRETE_SPEECH_AVAILABLE: + raise ImportError( + "discrete_speech_metrics is not properly installed. " + "Please install discrete_speech_metrics and retry" + ) + _configure_huggingface_cache(cache_dir) + if SpeechBERTScore is None or SpeechBLEU is None or SpeechTokenDistance is None: + from discrete_speech_metrics import ( + SpeechBERTScore as _SpeechBERTScore, + SpeechBLEU as _SpeechBLEU, + SpeechTokenDistance as _SpeechTokenDistance, + ) - Args: - discrete_speech_predictors (dict): Dictionary of speech metrics. - pred_x (np.ndarray): Predicted audio signal. - gt_x (np.ndarray): Ground truth audio signal. - fs (int): Sampling rate. + SpeechBERTScore = _SpeechBERTScore + SpeechBLEU = _SpeechBLEU + SpeechTokenDistance = _SpeechTokenDistance + return SpeechBERTScore, SpeechBLEU, SpeechTokenDistance - Returns: - dict: Dictionary containing the metric scores. - Raises: - NotImplementedError: If an unsupported metric is provided. - """ - scores = {} +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) +class DiscreteSpeechNotAvailableError(RuntimeError): + """Exception raised when discrete_speech_metrics is required but not available.""" - for key in discrete_speech_predictors.keys(): - if key == "speech_bert": - score, _, _ = discrete_speech_predictors[key].score(gt_x, pred_x) - elif key == "speech_bleu" or key == "speech_token_distance": - score = discrete_speech_predictors[key].score(gt_x, pred_x) - else: - raise NotImplementedError(f"Not supported {key}") - scores[key] = score - return scores + pass + + +def is_discrete_speech_available(): + """ + Check if the discrete_speech_metrics package is available. + + Returns: + bool: True if discrete_speech_metrics is available, False otherwise. + """ + return DISCRETE_SPEECH_AVAILABLE + + +class DiscreteSpeechMetric(BaseMetric): + """Discrete speech metrics for audio evaluation.""" + + def _setup(self): + """Initialize Discrete Speech-specific components.""" + if not DISCRETE_SPEECH_AVAILABLE: + raise ImportError( + "discrete_speech_metrics is not properly installed. " + "Please install discrete_speech_metrics and retry" + ) + + self.use_gpu = self.config.get("use_gpu", False) + self.sample_rate = self.config.get("sample_rate", 16000) + self.cache_dir = self.config.get("cache_dir", "versa_cache/huggingface") + speech_bert_cls, speech_bleu_cls, speech_token_distance_cls = ( + _load_discrete_speech_classes(self.cache_dir) + ) + + # NOTE(jiatong) existing discrete speech metrics only works for 16khz + # We keep the paper best setting. To use other settings, please conduct the + # test on your own. + + try: + self.speech_bert = speech_bert_cls( + sr=self.sample_rate, + model_type="wavlm-large", + layer=14, + use_gpu=self.use_gpu, + ) + self.speech_bleu = speech_bleu_cls( + sr=self.sample_rate, + model_type="hubert-base", + vocab=200, + layer=11, + n_ngram=2, + remove_repetition=True, + use_gpu=self.use_gpu, + ) + self.speech_token_distance = speech_token_distance_cls( + sr=self.sample_rate, + model_type="hubert-base", + vocab=200, + layer=6, + distance_type="jaro-winkler", + remove_repetition=False, + use_gpu=self.use_gpu, + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize discrete speech metrics: {str(e)}" + ) from e + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate discrete speech metrics. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the metric scores. + """ + pred_x = predictions + gt_x = references + fs = ( + metadata.get("sample_rate", self.sample_rate) + if metadata + else self.sample_rate + ) + + # Validate inputs + if pred_x is None or gt_x is None: + raise ValueError("Both predicted and ground truth signals must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + scores = {} + + if fs != self.sample_rate: + gt_x = resample_audio(gt_x, fs, self.sample_rate) + pred_x = resample_audio(pred_x, fs, self.sample_rate) + + # Calculate SpeechBERT score + try: + score, _, _ = self.speech_bert.score(gt_x, pred_x) + scores["speech_bert"] = score + except Exception as e: + logger.warning(f"Could not calculate SpeechBERT score: {e}") + scores["speech_bert"] = 0.0 + + # Calculate SpeechBLEU score + try: + score = self.speech_bleu.score(gt_x, pred_x) + scores["speech_bleu"] = score + except Exception as e: + logger.warning(f"Could not calculate SpeechBLEU score: {e}") + scores["speech_bleu"] = 0.0 + + # Calculate SpeechTokenDistance score + try: + score = self.speech_token_distance.score(gt_x, pred_x) + scores["speech_token_distance"] = score + except Exception as e: + logger.warning(f"Could not calculate SpeechTokenDistance score: {e}") + scores["speech_token_distance"] = 0.0 + + return scores + + def get_metadata(self) -> MetricMetadata: + """Return Discrete Speech metric metadata.""" + return MetricMetadata( + name="discrete_speech", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["discrete_speech_metrics", "librosa", "numpy"], + description="Discrete speech metrics including SpeechBERT, SpeechBLEU, and SpeechTokenDistance for audio evaluation", + paper_reference="https://github.com/ftshijt/discrete_speech_metrics", + implementation_source="https://github.com/ftshijt/discrete_speech_metrics", + ) + + +def register_discrete_speech_metric(registry): + """Register Discrete Speech metric with the registry.""" + metric_metadata = MetricMetadata( + name="discrete_speech", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["discrete_speech_metrics", "librosa", "numpy"], + description="Discrete speech metrics including SpeechBERT, SpeechBLEU, and SpeechTokenDistance for audio evaluation", + paper_reference="https://github.com/ftshijt/discrete_speech_metrics", + implementation_source="https://github.com/ftshijt/discrete_speech_metrics", + ) + registry.register( + DiscreteSpeechMetric, + metric_metadata, + aliases=["DiscreteSpeech", "discrete_speech"], + ) if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - predictor = discrete_speech_setup() - print(discrete_speech_metric(predictor, a, b, 16000)) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = DiscreteSpeechMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/dpam_distance.py b/versa/utterance_metrics/dpam_distance.py index 19e5ac5..07b3b1b 100644 --- a/versa/utterance_metrics/dpam_distance.py +++ b/versa/utterance_metrics/dpam_distance.py @@ -1,14 +1,24 @@ -import torch -import torch.nn as nn -import librosa -import numpy as np -import urllib.request +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for DPAM distance metrics.""" + import logging +import urllib.request import filelock from pathlib import Path +from typing import Dict, Any, Optional, Union -TARGET_FS = 22050 -MODEL_URL = "https://raw.githubusercontent.com/adrienchaton/PerceptualAudio_Pytorch/refs/heads/master/pretrained/dataset_combined_linear_tshrink.pth" +import numpy as np +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType class lossnet(nn.Module): @@ -77,37 +87,134 @@ def forward(self, xref, xper): return dist -def dpam_model_setup(cache_dir="versa_cache", use_gpu=False): - device = "cpu" if not use_gpu else "cuda" - model_path = Path(cache_dir) / "dpam" / "dataset_combined_linear.pth" - model_path.parent.mkdir(parents=True, exist_ok=True) - with filelock.FileLock(model_path.with_suffix(".lock")): - if not model_path.exists(): - logging.info(f"Downloading model to {model_path}...") - urllib.request.urlretrieve(MODEL_URL, model_path) - logging.info("Download complete.") - state = torch.load(model_path, map_location="cpu", weights_only=False)["state"] - prefix = "model_dist." - state = {k[len(prefix) :]: v for k, v in state.items() if k.startswith(prefix)} - model = lossnet(nconv=14, nchan=16, dp=0, dist_act="tshrink") - model.load_state_dict(state) - model.to(device) - model.eval() - return model - - -def dpam_metric(model, pred_x, gt_x, fs): - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=TARGET_FS) - pred_x = torch.from_numpy(pred_x).unsqueeze(0).float() - gt_x = torch.from_numpy(gt_x).unsqueeze(0).float() - dist = model(gt_x, pred_x) - return {"dpam_distance": dist.detach().cpu().numpy().item()} +class DpamDistanceMetric(BaseMetric): + """DPAM distance metric.""" + + TARGET_FS = 22050 + MODEL_URL = "https://raw.githubusercontent.com/adrienchaton/PerceptualAudio_Pytorch/refs/heads/master/pretrained/dataset_combined_linear_tshrink.pth" + + def _setup(self): + """Initialize DPAM-specific components.""" + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("cache_dir", "versa_cache") + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize DPAM model: {str(e)}") from e + + def _setup_model(self): + """Setup the DPAM model.""" + device = "cpu" if not self.use_gpu else "cuda" + model_path = Path(self.cache_dir) / "dpam" / "dataset_combined_linear.pth" + model_path.parent.mkdir(parents=True, exist_ok=True) + + with filelock.FileLock(model_path.with_suffix(".lock")): + if not model_path.exists(): + logger.info(f"Downloading model to {model_path}...") + urllib.request.urlretrieve(self.MODEL_URL, model_path) + logger.info("Download complete.") + + # Suppress PyTorch config registration warnings during model loading + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) + + state = checkpoint["state"] + prefix = "model_dist." + state = {k[len(prefix) :]: v for k, v in state.items() if k.startswith(prefix)} + model = lossnet(nconv=14, nchan=16, dp=0, dist_act="tshrink") + model.load_state_dict(state) + model.to(device) + model.eval() + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate DPAM distance between two audio samples. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the DPAM distance score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 22050) if metadata else 22050 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + if fs != self.TARGET_FS: + pred_x = resample_audio(pred_x, fs, self.TARGET_FS) + gt_x = resample_audio(gt_x, fs, self.TARGET_FS) + + pred_x = torch.from_numpy(pred_x).unsqueeze(0).float() + gt_x = torch.from_numpy(gt_x).unsqueeze(0).float() + dist = self.model(gt_x, pred_x) + + return {"dpam_distance": dist.detach().cpu().numpy().item()} + + def get_metadata(self) -> MetricMetadata: + """Return DPAM distance metric metadata.""" + return MetricMetadata( + name="dpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy", "filelock"], + description="DPAM distance between audio samples", + paper_reference="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + implementation_source="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + ) + + +def register_dpam_distance_metric(registry): + """Register DPAM distance metric with the registry.""" + metric_metadata = MetricMetadata( + name="dpam_distance", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy", "filelock"], + description="DPAM distance between audio samples", + paper_reference="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + implementation_source="https://github.com/adrienchaton/PerceptualAudio_Pytorch", + ) + registry.register( + DpamDistanceMetric, + metric_metadata, + aliases=["DpamDistance", "dpam_distance", "dpam"], + ) if __name__ == "__main__": a = np.random.random(22050) b = np.random.random(22050) - model = dpam_model_setup() - print("metrics: {}".format(dpam_metric(model, a, b, 22050))) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = DpamDistanceMetric(config) + metadata = {"sample_rate": 22050} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/emo_similarity.py b/versa/utterance_metrics/emo_similarity.py new file mode 100644 index 0000000..f91ffff --- /dev/null +++ b/versa/utterance_metrics/emo_similarity.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for emotion similarity metrics using EMO2VEC.""" + +import logging +import os +from pathlib import Path +from typing import Dict, Any, Optional, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +# Handle optional emo2vec dependency +try: + import emo2vec_versa + from emo2vec_versa.emo2vec_class import EMO2VEC + + EMO2VEC_AVAILABLE = True +except ImportError: + logger.info( + "emo2vec is not installed. Please install the package via " + "`tools/install_emo2vec.sh`" + ) + EMO2VEC = None + EMO2VEC_AVAILABLE = False + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class Emo2vecNotAvailableError(RuntimeError): + """Exception raised when emo2vec is required but not available.""" + + pass + + +def is_emo2vec_available(): + """ + Check if the emo2vec package is available. + + Returns: + bool: True if emo2vec is available, False otherwise. + """ + return EMO2VEC_AVAILABLE + + +class Emo2vecMetric(BaseMetric): + """Emotion similarity metric using EMO2VEC.""" + + def _setup(self): + """Initialize Emotion-specific components.""" + if not EMO2VEC_AVAILABLE: + raise ImportError( + "emo2vec_versa not found. Please install from tools/installers" + ) + + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path", None) + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize Emotion model: {str(e)}") from e + + def _setup_model(self): + """Setup the Emotion model.""" + if self.model_path is not None: + model = EMO2VEC(self.model_path, use_gpu=self.use_gpu) + else: + if self.model_tag == "default" or self.model_tag == "base": + model_path = ( + Path(os.path.abspath(emo2vec_versa.__file__)).parent + / "emotion2vec_base.pt" + ) + else: + raise ValueError(f"Unknown model_tag for emo2vec: {self.model_tag}") + + # check if model exists + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + model = EMO2VEC(checkpoint_dir=str(model_path), use_gpu=self.use_gpu) + + return model + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate emotion similarity between two audio samples. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the emotion similarity score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # NOTE(jiatong): only work for 16000 Hz + if fs != 16000: + gt_x = resample_audio(gt_x, fs, 16000) + pred_x = resample_audio(pred_x, fs, 16000) + + embedding_gen = self.model.extract_feature(pred_x, fs=16000) + embedding_gt = self.model.extract_feature(gt_x, fs=16000) + similarity = np.dot(embedding_gen, embedding_gt) / ( + np.linalg.norm(embedding_gen) * np.linalg.norm(embedding_gt) + ) + + return {"emotion_similarity": float(similarity)} + + def get_metadata(self) -> MetricMetadata: + """Return Emotion metric metadata.""" + return MetricMetadata( + name="emotion", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["emo2vec_versa", "librosa", "numpy"], + description="Emotion similarity between audio samples using EMO2VEC", + paper_reference="https://github.com/ddlBoJack/emotion2vec", + implementation_source="https://github.com/ddlBoJack/emotion2vec", + ) + + +def register_emo2vec_metric(registry): + """Register Emotion metric with the registry.""" + metric_metadata = MetricMetadata( + name="emotion", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["emo2vec_versa", "librosa", "numpy"], + description="Emotion similarity between audio samples using EMO2VEC", + paper_reference="https://github.com/ddlBoJack/emotion2vec", + implementation_source="https://github.com/ddlBoJack/emotion2vec", + ) + registry.register( + Emo2vecMetric, + metric_metadata, + aliases=[ + "Emotion", + "emotion", + "emo2vec", + "emo2vec_similarity", + "emo_similarity", + "emotion_similarity", + ], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + b = np.random.random(16000) + + # Test the new class-based metric + config = {"use_gpu": False} + metric = Emo2vecMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, b, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/emo_vad.py b/versa/utterance_metrics/emo_vad.py index 48def8f..a7c5d35 100644 --- a/versa/utterance_metrics/emo_vad.py +++ b/versa/utterance_metrics/emo_vad.py @@ -8,19 +8,51 @@ import logging import os from pathlib import Path +from typing import Dict, Any, Optional, Union -import librosa import numpy as np +import torch +import torch.nn as nn logger = logging.getLogger(__name__) -import torch -import torch.nn as nn -from transformers import Wav2Vec2Processor -from transformers.models.wav2vec2.modeling_wav2vec2 import ( - Wav2Vec2Model, - Wav2Vec2PreTrainedModel, -) +# Handle optional transformers dependency +try: + from transformers import Wav2Vec2Processor + from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, + ) + + TRANSFORMERS_AVAILABLE = True +except ImportError: + logger.warning( + "transformers is not properly installed. " + "Please install transformers and retry" + ) + Wav2Vec2Processor = None + Wav2Vec2Model = None + Wav2Vec2PreTrainedModel = None + TRANSFORMERS_AVAILABLE = False + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class TransformersNotAvailableError(RuntimeError): + """Exception raised when transformers is required but not available.""" + + pass + + +def is_transformers_available(): + """ + Check if the transformers package is available. + + Returns: + bool: True if transformers is available, False otherwise. + """ + return TRANSFORMERS_AVAILABLE class RegressionHead(nn.Module): @@ -54,6 +86,7 @@ def __init__(self, config): super().__init__(config) self.config = config + self.all_tied_weights_keys = {} self.wav2vec2 = Wav2Vec2Model(config) self.classifier = RegressionHead(config) self.init_weights() @@ -71,53 +104,145 @@ def forward( return hidden_states, logits -def w2v2_emo_dim_setup( - model_tag="default", model_path=None, model_config=None, use_gpu=False -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if model_path is not None and model_config is not None: - model = EmotionModel.from_pretrained( - pretrained_model_name_or_path=model_path, config=model_config - ).to(device) - else: - if model_tag == "default": - model_tag = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" - model = EmotionModel.from_pretrained(model_tag).to(device) - processor = Wav2Vec2Processor.from_pretrained( - "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" +class EmoVadMetric(BaseMetric): + """Dimensional emotion prediction metric using w2v2-how-to.""" + + def _setup(self): + """Initialize EmoVad-specific components.""" + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not properly installed. " + "Please install transformers and retry" + ) + + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path", None) + self.model_config = self.config.get("model_config", None) + self.processor_path = self.config.get("processor_path", None) + self.cache_dir = self.config.get("cache_dir", "versa_cache/huggingface") + self.use_gpu = self.config.get("use_gpu", False) + + self.device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" + + try: + self.model, self.processor = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize EmoVad model: {str(e)}") from e + + def _setup_model(self): + """Setup the EmoVad model.""" + if self.model_path is not None and self.model_config is not None: + model = EmotionModel.from_pretrained( + pretrained_model_name_or_path=self.model_path, + config=self.model_config, + cache_dir=self.cache_dir, + ).to(self.device) + else: + if self.model_tag == "default": + model_tag = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + else: + model_tag = self.model_tag + model = EmotionModel.from_pretrained( + model_tag, cache_dir=self.cache_dir + ).to(self.device) + + processor_path = ( + self.processor_path + or self.model_path + or "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + ) + processor = Wav2Vec2Processor.from_pretrained( + processor_path, cache_dir=self.cache_dir + ) + + return model, processor + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate dimensional emotion (arousal, dominance, valence) of input audio samples. + + Args: + predictions: Audio signal to evaluate. + references: Not used for this metric. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing the dimensional emotion predictions. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate input + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # NOTE(jiatong): only work for 16000 Hz + if fs != 16000: + pred_x = resample_audio(pred_x, fs, 16000) + + pred_x = self.processor(pred_x, sampling_rate=16000) + pred_x = pred_x["input_values"][0] + pred_x = pred_x.reshape(1, -1) + pred_x = torch.from_numpy(pred_x).to(self.device) + + with torch.no_grad(): + avd_emo = self.model(pred_x)[1].squeeze(0).cpu().numpy() + + arousal, dominance, valence = avd_emo + arousal = arousal.item() + dominance = dominance.item() + valence = valence.item() + + return { + "arousal_emo_vad": arousal, + "valence_emo_vad": valence, + "dominance_emo_vad": dominance, + } + + def get_metadata(self) -> MetricMetadata: + """Return EmoVad metric metadata.""" + return MetricMetadata( + name="emo_vad", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["transformers", "torch", "librosa", "numpy"], + description="Dimensional emotion prediction (arousal, valence, dominance) using w2v2-how-to", + paper_reference="https://github.com/audeering/w2v2-how-to", + implementation_source="https://github.com/audeering/w2v2-how-to", + ) + + +def register_emo_vad_metric(registry): + """Register EmoVad metric with the registry.""" + metric_metadata = MetricMetadata( + name="emo_vad", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["transformers", "torch", "librosa", "numpy"], + description="Dimensional emotion prediction (arousal, valence, dominance) using w2v2-how-to", + paper_reference="https://github.com/audeering/w2v2-how-to", + implementation_source="https://github.com/audeering/w2v2-how-to", ) - emo_utils = {"model": model, "processor": processor, "device": device} - return emo_utils - - -def dim_emo_pred(emo_utils, pred_x, fs): - """Calculate dimensional emotion (arousal, dominance, valence) of input audio samples. - - Args: - model (w2v2-how-to): The loaded EMO2VEC model. - pred_x (np.ndarray): Predicted audio signal. - fs (int): Sampling rate. - - Returns: - dict: Dictionary containing the dimensional emotion predictions. - """ - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - pred_x = emo_utils["processor"](pred_x, sampling_rate=16000) - pred_x = pred_x["input_values"][0] - pred_x = pred_x.reshape(1, -1) - pred_x = torch.from_numpy(pred_x).to(emo_utils["device"]) - with torch.no_grad(): - avd_emo = emo_utils["model"](pred_x)[1].squeeze(0).cpu().numpy() - - return {"aro_val_dom_emo": avd_emo} + registry.register(EmoVadMetric, metric_metadata, aliases=["EmoVad", "emo_vad"]) if __name__ == "__main__": a = np.random.random(16000) - emo_utils = w2v2_emo_dim_setup() - print(f"metrics: {dim_emo_pred(emo_utils, a, 16000)}") + + # Test the new class-based metric + config = {"use_gpu": False} + metric = EmoVadMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print(f"metrics: {score}") diff --git a/versa/utterance_metrics/emotion.py b/versa/utterance_metrics/emotion.py deleted file mode 100644 index a4945d3..0000000 --- a/versa/utterance_metrics/emotion.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Module for emotion similarity metrics using EMO2VEC.""" - -import logging -import os -from pathlib import Path - -import librosa -import numpy as np - -logger = logging.getLogger(__name__) - -try: - import emo2vec_versa - from emo2vec_versa.emo2vec_class import EMO2VEC -except ImportError: - logger.info( - "emo2vec is not installed. Please install the package via " - "`tools/install_emo2vec.sh`" - ) - EMO2VEC = None - - -def emo2vec_setup(model_tag="default", model_path=None, use_gpu=False): - """Set up EMO2VEC model for emotion embedding extraction. - - Args: - model_tag (str, optional): Model tag. Defaults to "default". - model_path (str, optional): Path to model weights. Defaults to None. - use_gpu (bool, optional): Whether to use GPU. Defaults to False. - - Returns: - EMO2VEC: The loaded model. - - Raises: - ImportError: If emo2vec_versa is not installed. - ValueError: If model_tag is unknown. - FileNotFoundError: If model file is not found. - """ - if EMO2VEC is None: - raise ImportError( - "emo2vec_versa not found. Please install from tools/installers" - ) - - if model_path is not None: - model = EMO2VEC(model_path, use_gpu=use_gpu) - else: - if model_tag == "default" or model_tag == "base": - model_path = ( - Path(os.path.abspath(emo2vec_versa.__file__)).parent - / "emotion2vec_base.pt" - ) - else: - raise ValueError(f"Unknown model_tag for emo2vec: {model_tag}") - - # check if model exists - if not model_path.exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") - - model = EMO2VEC(checkpoint_dir=str(model_path), use_gpu=use_gpu) - return model - - -def emo_sim(model, pred_x, gt_x, fs): - """Calculate emotion similarity between two audio samples. - - Args: - model (EMO2VEC): The loaded EMO2VEC model. - pred_x (np.ndarray): Predicted audio signal. - gt_x (np.ndarray): Ground truth audio signal. - fs (int): Sampling rate. - - Returns: - dict: Dictionary containing the emotion similarity score. - """ - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - embedding_gen = model.extract_feature(pred_x, fs=16000) - embedding_gt = model.extract_feature(gt_x, fs=16000) - similarity = np.dot(embedding_gen, embedding_gt) / ( - np.linalg.norm(embedding_gen) * np.linalg.norm(embedding_gt) - ) - return {"emotion_similarity": similarity} - - -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - model = emo2vec_setup() - print(f"metrics: {emo_sim(model, a, b, 16000)}") diff --git a/versa/utterance_metrics/log_wmse.py b/versa/utterance_metrics/log_wmse.py index bd21be0..92ee05e 100644 --- a/versa/utterance_metrics/log_wmse.py +++ b/versa/utterance_metrics/log_wmse.py @@ -5,23 +5,46 @@ import logging -import librosa import numpy as np import torch +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + logger = logging.getLogger(__name__) try: - import torchaudio - import torchaudio.functional as F from torch_log_wmse import LogWMSE logger.info("Using the torch-log-wmse package for evaluation") except ImportError: - raise ImportError("Please install torch-log-wmse and retry") + LogWMSE = None + + +def _ensure_log_wmse_available(): + if LogWMSE is None: + raise ImportError("Please install torch-log-wmse and retry") + +def _as_unprocessed_tensor(audio): + if isinstance(audio, torch.Tensor): + tensor = audio.float() + else: + tensor = torch.from_numpy(np.asarray(audio)).float() + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0).unsqueeze(0) + elif tensor.ndim == 2: + tensor = tensor.unsqueeze(0) + return tensor -def log_wmse(unproc_x, proc_x, gt_x, fs): + +def _as_stem_tensor(audio): + tensor = _as_unprocessed_tensor(audio) + if tensor.ndim == 3: + tensor = tensor.unsqueeze(1) + return tensor + + +def log_wmse(unproc_x, proc_x, gt_x, fs, model=None): """Calculate LogWMSE metric between audio samples. Args: @@ -36,18 +59,86 @@ def log_wmse(unproc_x, proc_x, gt_x, fs): Returns: dict: Dictionary containing the LogWMSE score. """ - # Instantiate logWMSE - # Set `return_as_loss=False` to return as a positive metric (Default: True) - # Set `bypass_filter=True` to bypass frequency weighting (Default: False) - inst_log_wmse = LogWMSE( - audio_length=1.0, - sample_rate=44100, - return_as_loss=True, # optional + _ensure_log_wmse_available() + + if model is None: + # Set `return_as_loss=False` to return as a positive metric. + # Set `bypass_filter=True` to bypass frequency weighting. + model = LogWMSE( + audio_length=1.0, + sample_rate=44100, + return_as_loss=True, + ) + + log_wmse_score = model(unproc_x, proc_x, gt_x) + score = log_wmse_score.detach().cpu().numpy() + if np.size(score) == 1: + score = float(np.asarray(score).reshape(-1)[0]) + return {"log_wmse": score} + + +class LogWmseMetric(BaseMetric): + """Log-weighted mean square error.""" + + def _setup(self): + _ensure_log_wmse_available() + self.audio_length = self.config.get("audio_length", 1.0) + self.sample_rate = self.config.get("sample_rate", 44100) + self.return_as_loss = self.config.get("return_as_loss", True) + self.bypass_filter = self.config.get("bypass_filter") + + kwargs = { + "audio_length": self.audio_length, + "sample_rate": self.sample_rate, + "return_as_loss": self.return_as_loss, + } + if self.bypass_filter is not None: + kwargs["bypass_filter"] = self.bypass_filter + self.model = LogWMSE(**kwargs) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + metadata = metadata or {} + unprocessed = metadata.get("unprocessed") + if unprocessed is None: + unprocessed = metadata.get("unproc_x", predictions) + + unproc_x = _as_unprocessed_tensor(unprocessed) + proc_x = _as_stem_tensor(predictions) + gt_x = _as_stem_tensor(references) + fs = metadata.get("sample_rate", 16000) + return log_wmse(unproc_x, proc_x, gt_x, fs, model=self.model) + + def get_metadata(self): + return _log_wmse_metadata() + + +def _log_wmse_metadata(): + return MetricMetadata( + name="log_wmse", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["torch_log_wmse", "torch", "numpy"], + description="Log-weighted mean square error", + implementation_source="https://github.com/nomonosound/log-wmse-audio-quality", ) - log_wmse_score = inst_log_wmse(unproc_x, proc_x, gt_x) - return {"torch log-wmse": log_wmse_score.detach().numpy()} +def register_log_wmse_metric(registry): + """Register log-weighted mean square error with the registry.""" + registry.register( + LogWmseMetric, + _log_wmse_metadata(), + aliases=["log-wmse", "torch_log_wmse"], + ) if __name__ == "__main__": @@ -56,13 +147,13 @@ def log_wmse(unproc_x, proc_x, gt_x, fs): https://github.com/crlandsc/torch-log-wmse Unlike many audio quality metrics, logWMSE accepts a triple of audio inputs: - - unprocessed audio (e.g. a raw, noisy recording) # [batch, audio_channels, sample] - - processed audio (e.g. a denoised recording) # [batch, audio_stems, audio_channels, sample] - - target audio (e.g. a clean reference without noise) # [batch, audio_stems, audio_channels, sample] + - unprocessed audio (raw/noisy recording), shape [batch, channels, sample] + - processed audio (denoised recording), shape [batch, stems, channels, sample] + - target clean reference, shape [batch, stems, channels, sample] * audio_length: length of the audio - sample_rate: 44100 the metric performs an internal resampling to 44.1kHz for consistency + sample_rate: 44100 for the package's internal resampling audio_stems: # of audio stems (e.g. vocals, drums, bass, other) audio_channels: mono=1, stereo=2 batch: batch size diff --git a/versa/utterance_metrics/nisqa.py b/versa/utterance_metrics/nisqa.py index 84fee55..fbe89fb 100644 --- a/versa/utterance_metrics/nisqa.py +++ b/versa/utterance_metrics/nisqa.py @@ -1,178 +1,282 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import librosa -import numpy as np -import torch - -import versa.utterance_metrics.nisqa_utils.nisqa_lib as NL - - -def nisqa_model_setup(nisqa_model_path=None, use_gpu=False): - """ - Setup the NISQA model for evaluation. - Args: - nisqa_model_path (str): Path to the NISQA model checkpoint. - use_gpu (bool): If True, use GPU for computation. Default is False. - - Returns: - model: The loaded NISQA model. - - Raises: - ValueError: If the model path is not provided or the checkpoint is invalid. - """ - - # Check if GPU is available - if use_gpu and not torch.cuda.is_available(): - raise RuntimeError("GPU is not available. Please set use_gpu=False.") - - # Set device - if use_gpu: - device = "cuda" - else: - device = "cpu" - # Check if the model path is provided - if nisqa_model_path is None: - raise ValueError("NISQA model path must be provided.") - - checkpoint = torch.load(nisqa_model_path, map_location="cpu") - args = checkpoint.get("args", None) - if args is None: - raise ValueError( - "Model checkpoint does not contain the required arguments. Might due to a wrong checkpoint." - ) - - if args["model"] == "NISQA_DIM": - args["dim"] = True - args["csv_mos_train"] = None # column names hardcoded for dim models - args["csv_mos_val"] = None - else: - args["dim"] = False - - if args["model"] == "NISQA_DE": - args["double_ended"] = True - else: - args["double_ended"] = False - args["csv_ref"] = None - - # Load Model - model_args = { - "ms_seg_length": args["ms_seg_length"], - "ms_n_mels": args["ms_n_mels"], - "cnn_model": args["cnn_model"], - "cnn_c_out_1": args["cnn_c_out_1"], - "cnn_c_out_2": args["cnn_c_out_2"], - "cnn_c_out_3": args["cnn_c_out_3"], - "cnn_kernel_size": args["cnn_kernel_size"], - "cnn_dropout": args["cnn_dropout"], - "cnn_pool_1": args["cnn_pool_1"], - "cnn_pool_2": args["cnn_pool_2"], - "cnn_pool_3": args["cnn_pool_3"], - "cnn_fc_out_h": args["cnn_fc_out_h"], - "td": args["td"], - "td_sa_d_model": args["td_sa_d_model"], - "td_sa_nhead": args["td_sa_nhead"], - "td_sa_pos_enc": args["td_sa_pos_enc"], - "td_sa_num_layers": args["td_sa_num_layers"], - "td_sa_h": args["td_sa_h"], - "td_sa_dropout": args["td_sa_dropout"], - "td_lstm_h": args["td_lstm_h"], - "td_lstm_num_layers": args["td_lstm_num_layers"], - "td_lstm_dropout": args["td_lstm_dropout"], - "td_lstm_bidirectional": args["td_lstm_bidirectional"], - "td_2": args["td_2"], - "td_2_sa_d_model": args["td_2_sa_d_model"], - "td_2_sa_nhead": args["td_2_sa_nhead"], - "td_2_sa_pos_enc": args["td_2_sa_pos_enc"], - "td_2_sa_num_layers": args["td_2_sa_num_layers"], - "td_2_sa_h": args["td_2_sa_h"], - "td_2_sa_dropout": args["td_2_sa_dropout"], - "td_2_lstm_h": args["td_2_lstm_h"], - "td_2_lstm_num_layers": args["td_2_lstm_num_layers"], - "td_2_lstm_dropout": args["td_2_lstm_dropout"], - "td_2_lstm_bidirectional": args["td_2_lstm_bidirectional"], - "pool": args["pool"], - "pool_att_h": args["pool_att_h"], - "pool_att_dropout": args["pool_att_dropout"], - } - - if args["double_ended"]: - model_args.update( - { - "de_align": args["de_align"], - "de_align_apply": args["de_align_apply"], - "de_fuse_dim": args["de_fuse_dim"], - "de_fuse": args["de_fuse"], - } - ) - - if args["model"] == "NISQA": - model = NL.NISQA(**model_args) - elif args["model"] == "NISQA_DIM": - model = NL.NISQA_DIM(**model_args) - elif args["model"] == "NISQA_DE": - model = NL.NISQA_DE(**model_args) - else: - raise NotImplementedError("Model not available") - - # Load weights - missing_keys, unexpected_keys = model.load_state_dict( - checkpoint["model_state_dict"], strict=True - ) - if missing_keys: - print("[NISQA] missing_keys:") - print(missing_keys) - if unexpected_keys: - print("[NISQA] unexpected_keys:") - print(unexpected_keys) - model.args = args - model.device = device - return model - - -def nisqa_metric(nisqa_model, pred_x, fs): - """ - Calculate the NISQA score for a given audio signal. - - Args: - nisqa_model: The NISQA model to use for evaluation. - pred_x (np.ndarray): The audio signal to be evaluated (1D array). - fs (int): The sampling rate of the audio signal in Hz. - - Returns: - dict: A dictionary containing the NISQA score and other metrics. - """ - model_sr = 48e3 # NISQA model's expected sampling rate - if fs != model_sr: - # Resample the audio signal to the model's expected sampling rate - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=model_sr) - fs = model_sr - - # Evaluate the NISQA score - with torch.no_grad(): - metrics = NL.versa_eval_mos( - [pred_x], nisqa_model, 1, nisqa_model.device, num_workers=0 - ) - - final_result = {} - for metrics_key in metrics.keys(): - # Check if the metric is a list and take the first element for batch=1 - final_result["nisqa_" + metrics_key] = metrics[metrics_key][0][0] - - return final_result - - -if __name__ == "__main__": - a = np.random.random(16000) - fs = 16000 - try: - nisqa_model = nisqa_model_setup( - nisqa_model_path="/home/jiatong/projects/espnet/tools/versa/tools/NISQA/weights/nisqa.tar", - use_gpu=True, - ) - score = nisqa_metric(nisqa_model, a, fs) - print("NISQA Score: {}".format(score)) - except NotImplementedError as e: - print(e) +#!/usr/bin/env python3 + +# Copyright 2025 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Module for NISQA speech quality assessment metrics.""" + +import logging +import warnings +from typing import Dict, Any, Optional, Union + +import numpy as np +import torch + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + +try: + import versa.utterance_metrics.nisqa_utils.nisqa_lib as NL + + NISQA_LIB_AVAILABLE = True +except ImportError: + from types import SimpleNamespace + + NL = SimpleNamespace( + NISQA=None, + NISQA_DIM=None, + NISQA_DE=None, + versa_eval_mos=None, + ) + NISQA_LIB_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +def _nisqa_lib(): + if not NISQA_LIB_AVAILABLE and NL.NISQA is None and NL.versa_eval_mos is None: + raise ImportError( + "NISQA dependencies are not available. Please run tools/setup_nisqa.sh " + "and install optional dependencies such as matplotlib." + ) + return NL + + +class NisqaMetric(BaseMetric): + """NISQA speech quality assessment metric.""" + + TARGET_FS = 48000 # NISQA model's expected sampling rate + + def _setup(self): + """Initialize NISQA-specific components.""" + self.nisqa_model_path = self.config.get("nisqa_model_path") + self.use_gpu = self.config.get("use_gpu", False) + + if not self.nisqa_model_path: + raise ValueError("NISQA model path must be provided in config") + + try: + self.model = self._setup_model() + except Exception as e: + self.model = None + self._setup_error = e + + def _setup_model(self): + """Setup the NISQA model.""" + nisqa_lib = _nisqa_lib() + # Check if GPU is available + if self.use_gpu and not torch.cuda.is_available(): + raise RuntimeError("GPU is not available. Please set use_gpu=False.") + + # Set device + device = "cuda" if self.use_gpu else "cpu" + + # Suppress PyTorch config registration warnings during model loading + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + checkpoint = torch.load(self.nisqa_model_path, map_location="cpu") + + args = checkpoint.get("args", None) + if args is None: + raise ValueError( + "Model checkpoint does not contain the required arguments. Might due to a wrong checkpoint." + ) + + if args["model"] == "NISQA_DIM": + args["dim"] = True + args["csv_mos_train"] = None # column names hardcoded for dim models + args["csv_mos_val"] = None + else: + args["dim"] = False + + if args["model"] == "NISQA_DE": + args["double_ended"] = True + else: + args["double_ended"] = False + args["csv_ref"] = None + + # Load Model + model_args = { + "ms_seg_length": args["ms_seg_length"], + "ms_n_mels": args["ms_n_mels"], + "cnn_model": args["cnn_model"], + "cnn_c_out_1": args["cnn_c_out_1"], + "cnn_c_out_2": args["cnn_c_out_2"], + "cnn_c_out_3": args["cnn_c_out_3"], + "cnn_kernel_size": args["cnn_kernel_size"], + "cnn_dropout": args["cnn_dropout"], + "cnn_pool_1": args["cnn_pool_1"], + "cnn_pool_2": args["cnn_pool_2"], + "cnn_pool_3": args["cnn_pool_3"], + "cnn_fc_out_h": args["cnn_fc_out_h"], + "td": args["td"], + "td_sa_d_model": args["td_sa_d_model"], + "td_sa_nhead": args["td_sa_nhead"], + "td_sa_pos_enc": args["td_sa_pos_enc"], + "td_sa_num_layers": args["td_sa_num_layers"], + "td_sa_h": args["td_sa_h"], + "td_sa_dropout": args["td_sa_dropout"], + "td_lstm_h": args["td_lstm_h"], + "td_lstm_num_layers": args["td_lstm_num_layers"], + "td_lstm_dropout": args["td_lstm_dropout"], + "td_lstm_bidirectional": args["td_lstm_bidirectional"], + "td_2": args["td_2"], + "td_2_sa_d_model": args["td_2_sa_d_model"], + "td_2_sa_nhead": args["td_2_sa_nhead"], + "td_2_sa_pos_enc": args["td_2_sa_pos_enc"], + "td_2_sa_num_layers": args["td_2_sa_num_layers"], + "td_2_sa_h": args["td_2_sa_h"], + "td_2_sa_dropout": args["td_2_sa_dropout"], + "td_2_lstm_h": args["td_2_lstm_h"], + "td_2_lstm_num_layers": args["td_2_lstm_num_layers"], + "td_2_lstm_dropout": args["td_2_lstm_dropout"], + "td_2_lstm_bidirectional": args["td_2_lstm_bidirectional"], + "pool": args["pool"], + "pool_att_h": args["pool_att_h"], + "pool_att_dropout": args["pool_att_dropout"], + } + + if args["double_ended"]: + model_args.update( + { + "de_align": args["de_align"], + "de_align_apply": args["de_align_apply"], + "de_fuse_dim": args["de_fuse_dim"], + "de_fuse": args["de_fuse"], + } + ) + + if args["model"] == "NISQA": + model = nisqa_lib.NISQA(**model_args) + elif args["model"] == "NISQA_DIM": + model = nisqa_lib.NISQA_DIM(**model_args) + elif args["model"] == "NISQA_DE": + model = nisqa_lib.NISQA_DE(**model_args) + else: + raise NotImplementedError("Model not available") + + # Load weights + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model_state_dict"], strict=True + ) + if missing_keys: + logger.warning("[NISQA] missing_keys: %s", missing_keys) + if unexpected_keys: + logger.warning("[NISQA] unexpected_keys: %s", unexpected_keys) + + model.args = args + model.device = device + return model + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate NISQA scores for speech quality assessment. + + Args: + predictions: Audio signal to be evaluated. + references: Not used for NISQA (single-ended metric). + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing NISQA scores. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + if self.model is None: + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize NISQA model: {str(e)}") from e + + pred_x = np.asarray(pred_x) + + # Resample if necessary + if fs != self.TARGET_FS: + pred_x = resample_audio(pred_x, fs, self.TARGET_FS) + fs = self.TARGET_FS + + # Evaluate the NISQA score + with torch.no_grad(): + metrics = _nisqa_lib().versa_eval_mos( + [pred_x], self.model, 1, self.model.device, num_workers=0 + ) + + final_result = {} + for metrics_key in metrics.keys(): + # Check if the metric is a list and take the first element for batch=1 + final_result["nisqa_" + metrics_key] = metrics[metrics_key][0][0] + + return final_result + + def get_metadata(self) -> MetricMetadata: + """Return NISQA metric metadata.""" + return MetricMetadata( + name="nisqa", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="NISQA speech quality assessment metric", + paper_reference="https://github.com/gabrielmittag/NISQA", + implementation_source="https://github.com/gabrielmittag/NISQA", + ) + + +def register_nisqa_metric(registry): + """Register NISQA metric with the registry.""" + metric_metadata = MetricMetadata( + name="nisqa", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="NISQA speech quality assessment metric", + paper_reference="https://github.com/gabrielmittag/NISQA", + implementation_source="https://github.com/gabrielmittag/NISQA", + ) + registry.register( + NisqaMetric, + metric_metadata, + aliases=["Nisqa", "nisqa"], + ) + + +def nisqa_model_setup(nisqa_model_path, use_gpu=False): + """Legacy function API for setting up NISQA.""" + return NisqaMetric({"nisqa_model_path": nisqa_model_path, "use_gpu": use_gpu}).model + + +def nisqa_metric(model, pred_x, fs): + """Legacy function API for computing NISQA.""" + metric = object.__new__(NisqaMetric) + metric.config = {} + metric.model = model + return NisqaMetric.compute(metric, pred_x, metadata={"sample_rate": fs}) + + +if __name__ == "__main__": + a = np.random.random(16000) + fs = 16000 + try: + nisqa_model = nisqa_model_setup( + nisqa_model_path="versa_cache/nisqa/nisqa.tar", + use_gpu=True, + ) + score = nisqa_metric(nisqa_model, a, fs) + print("NISQA Score: {}".format(score)) + except NotImplementedError as e: + print(e) diff --git a/versa/utterance_metrics/nisqa_utils/nisqa_lib.py b/versa/utterance_metrics/nisqa_utils/nisqa_lib.py index c1fbec4..f2e69c3 100644 --- a/versa/utterance_metrics/nisqa_utils/nisqa_lib.py +++ b/versa/utterance_metrics/nisqa_utils/nisqa_lib.py @@ -2,6 +2,7 @@ """ @author: Gabriel Mittag, TU-Berlin """ + import copy import math import multiprocessing @@ -383,8 +384,8 @@ def __init__( ) def _split_ref_deg(self, x, n_wins): - (x, y) = torch.chunk(x, 2, dim=2) - (n_wins_x, n_wins_y) = torch.chunk(n_wins, 2, dim=1) + x, y = torch.chunk(x, 2, dim=2) + n_wins_x, n_wins_y = torch.chunk(n_wins, 2, dim=1) n_wins_x = n_wins_x.view(-1) n_wins_y = n_wins_y.view(-1) return x, y, n_wins_x, n_wins_y @@ -484,7 +485,7 @@ def __init__( raise NotImplementedError("Framwise model not available") def forward(self, x, n_wins): - (bs, length, channels, height, width) = x.shape + bs, length, channels, height, width = x.shape x_packed = pack_padded_sequence( x, n_wins.cpu(), batch_first=True, enforce_sorted=False ) diff --git a/versa/utterance_metrics/nomad.py b/versa/utterance_metrics/nomad.py index 3033c6b..1997aa3 100644 --- a/versa/utterance_metrics/nomad.py +++ b/versa/utterance_metrics/nomad.py @@ -1,63 +1,150 @@ #!/usr/bin/env python3 +# Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging +"""Module for NOMAD speech quality assessment metrics.""" -logger = logging.getLogger(__name__) +import logging +from typing import Dict, Any, Optional, Union -import librosa import numpy as np import torch +logger = logging.getLogger(__name__) + +# Handle optional nomad dependency try: from nomad_versa import Nomad + + NOMAD_AVAILABLE = True except ImportError: - logger.info( + logger.warning( "nomad is not installed. Please use `tools/install_nomad.sh` to install" ) Nomad = None + NOMAD_AVAILABLE = False +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType -def nomad_setup(use_gpu=False, cache_dir="versa_cache/nomad_pt-models"): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if Nomad is None: - raise ModuleNotFoundError( - "nomad is not installed. Please use `tools/install_nomad.sh` to install" - ) +class NomadNotAvailableError(RuntimeError): + """Exception raised when nomad is required but not available.""" - return Nomad(device=device, cache_dir=cache_dir) + pass -def nomad(model, pred_x, gt_x, fs): +def is_nomad_available(): """ - Reference: - A. Ragano, J. Skoglund and A. Hines, - "NOMAD: Unsupervised Learning of Perceptual Embeddings For Speech Enhancement and Non-Matching Reference Audio Quality Assessment," - ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Seoul, Korea, Republic of, 2024, pp. 1011-1015 - Codebase: - https://github.com/alessandroragano/nomad + Check if the nomad package is available. + Returns: + bool: True if nomad is available, False otherwise. """ - - # NOTE(hyejin): current model only have 16k options - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - return { - "nomad": model.predict(nmr=gt_x, deg=pred_x), - } + return NOMAD_AVAILABLE + + +class NomadMetric(BaseMetric): + """NOMAD speech quality assessment metric.""" + + TARGET_FS = 16000 # NOMAD model's expected sampling rate + + def _setup(self): + """Initialize NOMAD-specific components.""" + if not NOMAD_AVAILABLE and Nomad is None: + raise ImportError( + "nomad is not installed. Please use `tools/install_nomad.sh` to install" + ) + + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("model_cache", "versa_cache/nomad_pt-models") + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize NOMAD model: {str(e)}") from e + + def _setup_model(self): + """Setup the NOMAD model.""" + device = "cuda" if self.use_gpu else "cpu" + + if Nomad is None: + raise ModuleNotFoundError( + "nomad is not installed. Please use `tools/install_nomad.sh` to install" + ) + + return Nomad(device=device, cache_dir=self.cache_dir) + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate NOMAD score for speech quality assessment. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing NOMAD score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # Resample if necessary (NOMAD only supports 16kHz) + if fs != self.TARGET_FS: + gt_x = resample_audio(gt_x, fs, self.TARGET_FS) + pred_x = resample_audio(pred_x, fs, self.TARGET_FS) + + return { + "nomad": self.model.predict(nmr=gt_x, deg=pred_x), + } + + def get_metadata(self) -> MetricMetadata: + """Return NOMAD metric metadata.""" + return MetricMetadata( + name="nomad", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["nomad_versa", "torch", "librosa", "numpy"], + description="NOMAD: Unsupervised Learning of Perceptual Embeddings For Speech Enhancement and Non-Matching Reference Audio Quality Assessment", + paper_reference="https://ieeexplore.ieee.org/document/10447047", + implementation_source="https://github.com/alessandroragano/nomad", + ) -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - nomad_model = nomad_setup(use_gpu=True) - fs = 16000 - nomad_score = nomad(nomad_model, a, b, fs) - print(nomad_score) +def register_nomad_metric(registry): + """Register NOMAD metric with the registry.""" + metric_metadata = MetricMetadata( + name="nomad", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["nomad_versa", "torch", "librosa", "numpy"], + description="NOMAD: Unsupervised Learning of Perceptual Embeddings For Speech Enhancement and Non-Matching Reference Audio Quality Assessment", + paper_reference="https://ieeexplore.ieee.org/document/10447047", + implementation_source="https://github.com/alessandroragano/nomad", + ) + registry.register( + NomadMetric, + metric_metadata, + aliases=["Nomad", "nomad"], + ) diff --git a/versa/utterance_metrics/noresqa.py b/versa/utterance_metrics/noresqa.py index b9391c0..639cae1 100644 --- a/versa/utterance_metrics/noresqa.py +++ b/versa/utterance_metrics/noresqa.py @@ -3,129 +3,290 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Module for NORESQA speech quality assessment metrics.""" import logging import os -import sys +import warnings +from typing import Dict, Any, Union -import librosa import numpy as np import torch - -logger = logging.getLogger(__name__) - from urllib.request import urlretrieve -import torch.nn as nn - -base_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../tools/Noresqa") -) -sys.path.insert(0, base_path) +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType +logger = logging.getLogger(__name__) +# Handle optional dependencies try: import fairseq + + FAIRSEQ_AVAILABLE = True except ImportError: - logger.info( + logger.warning( "fairseq is not installed. Please use `tools/install_fairseq.sh` to install" ) + fairseq = None + FAIRSEQ_AVAILABLE = False try: - from model import NORESQA - from utils import ( + from versa.utterance_metrics.noresqa_utils.noresqa_model import NORESQA + from versa.utterance_metrics.noresqa_utils.noresqa_utils import ( feats_loading, model_prediction_noresqa, model_prediction_noresqa_mos, ) + NORESQA_AVAILABLE = True except ImportError: - logger.info( - "noresqa is not installed. Please use `tools/install_noresqa.sh` to install" + logger.warning( + "NORESQA dependencies are not available. " + "Please use `tools/install_noresqa.sh` to install model checkpoints" ) - Noresqa = None - - -def noresqa_model_setup( - model_tag="default", - metric_type=0, - cache_dir="versa_cache/noresqa_model", - use_gpu=False, -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - - if model_tag == "default": - - if not os.path.isdir(cache_dir): - print("Creating checkpoints directory") - os.makedirs(cache_dir) - - url_w2v = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt" - w2v_path = os.path.join(cache_dir, "wav2vec_small.pt") - if not os.path.isfile(w2v_path): - print("Downloading wav2vec 2.0 started") - urlretrieve(url_w2v, w2v_path) - print("wav2vec 2.0 download completed") - - model = NORESQA( - output=40, output2=40, metric_type=metric_type, config_path=w2v_path - ) + NORESQA = None + feats_loading = None + model_prediction_noresqa = None + model_prediction_noresqa_mos = None + NORESQA_AVAILABLE = False - if metric_type == 0: - model_checkpoint_path = "{}/models/model_noresqa.pth".format(base_path) - state = torch.load(model_checkpoint_path, map_location="cpu")["state_base"] - elif metric_type == 1: - model_checkpoint_path = "{}/models/model_noresqa_mos.pth".format(base_path) - state = torch.load(model_checkpoint_path, map_location="cpu")["state_dict"] +class NoresqaNotAvailableError(RuntimeError): + """Exception raised when noresqa is required but not available.""" - pretrained_dict = {} - for k, v in state.items(): - if "module" in k: - pretrained_dict[k.replace("module.", "")] = v - else: - pretrained_dict[k] = v - model_dict = model.state_dict() - model_dict.update(pretrained_dict) - model.load_state_dict(pretrained_dict) + pass - # change device as needed - model.to(device) - model.device = device - model.eval() - sfmax = nn.Softmax(dim=1) +def is_noresqa_available(): + """ + Check if the noresqa package is available. - else: - raise NotImplementedError + Returns: + bool: True if noresqa is available, False otherwise. + """ + return NORESQA_AVAILABLE and FAIRSEQ_AVAILABLE - return model +class NoresqaMetric(BaseMetric): + """NORESQA speech quality assessment metric.""" -def noresqa_metric(model, gt_x, pred_x, fs, metric_type=1): - # NOTE(hyejin): only work for 16000 Hz - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - nmr_feat, test_feat = feats_loading(pred_x, gt_x, noresqa_or_noresqaMOS=metric_type) - test_feat = torch.from_numpy(test_feat).float().to(model.device).unsqueeze(0) - nmr_feat = torch.from_numpy(nmr_feat).float().to(model.device).unsqueeze(0) + TARGET_FS = 16000 # NORESQA model's expected sampling rate + DEFAULT_CACHE_DIR = "versa_cache/noresqa_model" - with torch.no_grad(): - if metric_type == 0: - noresqa_pout, noresqa_qout = model_prediction_noresqa( - test_feat, nmr_feat, model + def _setup(self): + """Initialize NORESQA-specific components.""" + if not NORESQA_AVAILABLE: + raise ImportError( + "NORESQA dependencies are not available. " + "Please use `tools/install_noresqa.sh` to install model checkpoints" + ) + if not FAIRSEQ_AVAILABLE: + raise ImportError( + "fairseq is not installed. " + "Please use `tools/install_fairseq.sh` to install" ) - return {"noresqa_score": noresqa_pout} - elif metric_type == 1: - mos_score = model_prediction_noresqa_mos(test_feat, nmr_feat, model) - return {"noresqa_score": mos_score} + self.model_tag = self.config.get("model_tag", "default") + self.metric_type = self.config.get( + "metric_type", 1 + ) # 0: NORESQA-score, 1: NORESQA-MOS + if self.metric_type not in (0, 1): + raise RuntimeError(f"Invalid metric_type: {self.metric_type}") + + self.cache_dir = self.config.get("cache_dir", self.DEFAULT_CACHE_DIR) + self.use_gpu = self.config.get("use_gpu", False) + + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize NORESQA model: {str(e)}") from e + + def _setup_model(self): + """Setup the NORESQA model.""" + device = "cuda" if self.use_gpu else "cpu" + + if self.model_tag == "default": + if not os.path.isdir(self.cache_dir): + logger.info("Creating checkpoints directory") + os.makedirs(self.cache_dir, exist_ok=True) + + url_w2v = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt" + try: + w2v_path = self._checkpoint_path("wav2vec_small.pt") + except FileNotFoundError: + w2v_path = os.path.join(self.cache_dir, "wav2vec_small.pt") + logger.info("Downloading wav2vec 2.0 started") + urlretrieve(url_w2v, w2v_path) + logger.info("wav2vec 2.0 download completed") + + model = NORESQA( + output=40, + output2=40, + metric_type=self.metric_type, + config_path=w2v_path, + ) + + if self.metric_type == 0: + model_checkpoint_path = self._checkpoint_path("model_noresqa.pth") + # Suppress PyTorch config registration warnings during model loading + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + state = torch.load(model_checkpoint_path, map_location="cpu")[ + "state_base" + ] + elif self.metric_type == 1: + model_checkpoint_path = self._checkpoint_path("model_noresqa_mos.pth") + # Suppress PyTorch config registration warnings during model loading + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Skipping config registration for" + ) + state = torch.load(model_checkpoint_path, map_location="cpu")[ + "state_dict" + ] + else: + raise ValueError(f"Invalid metric_type: {self.metric_type}") + + pretrained_dict = {} + for k, v in state.items(): + if "module" in k: + pretrained_dict[k.replace("module.", "")] = v + else: + pretrained_dict[k] = v + model_dict = model.state_dict() + model_dict.update(pretrained_dict) + model.load_state_dict(pretrained_dict) + + # change device as needed + model.to(device) + model.device = device + model.eval() + + else: + raise NotImplementedError(f"Model tag '{self.model_tag}' not implemented") -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - model = noresqa_model_setup(use_gpu=True) - print("metrics: {}".format(noresqa_metric(model, a, b, 16000))) + return model + + def _checkpoint_path(self, filename): + candidates = [ + os.path.join(self.cache_dir, filename), + os.path.join(self.DEFAULT_CACHE_DIR, filename), + ] + for path in candidates: + if os.path.isfile(path): + return path + raise FileNotFoundError( + f"Missing NORESQA model file '{filename}'. " + "Please run `tools/install_noresqa.sh` first." + ) + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate NORESQA score for speech quality assessment. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing NORESQA score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + # Resample to 16kHz (NORESQA only works with 16kHz) + if fs != self.TARGET_FS: + gt_x = resample_audio(gt_x, fs, self.TARGET_FS) + pred_x = resample_audio(pred_x, fs, self.TARGET_FS) + + nmr_feat, test_feat = feats_loading( + pred_x, gt_x, noresqa_or_noresqaMOS=self.metric_type + ) + test_feat = ( + torch.from_numpy(test_feat).float().to(self.model.device).unsqueeze(0) + ) + nmr_feat = torch.from_numpy(nmr_feat).float().to(self.model.device).unsqueeze(0) + + with torch.no_grad(): + if self.metric_type == 0: + noresqa_pout, noresqa_qout = model_prediction_noresqa( + test_feat, nmr_feat, self.model + ) + return {"noresqa_score": noresqa_pout} + elif self.metric_type == 1: + mos_score = model_prediction_noresqa_mos( + test_feat, nmr_feat, self.model + ) + return {"noresqa_mos": mos_score} + else: + raise ValueError(f"Invalid metric_type: {self.metric_type}") + + def get_metadata(self) -> MetricMetadata: + """Return NORESQA metric metadata.""" + metric_name = "noresqa_mos" if self.metric_type == 1 else "noresqa_score" + description = "NORESQA-MOS" if self.metric_type == 1 else "NORESQA-score" + + return MetricMetadata( + name=metric_name, + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["fairseq", "torch", "librosa", "numpy"], + description=( + f"{description}: Non-matching reference based speech quality " + "assessment" + ), + paper_reference="https://arxiv.org/abs/2104.09411", + implementation_source="https://github.com/facebookresearch/NORESQA", + ) + + +def register_noresqa_metric(registry): + """Register NORESQA metric with the registry.""" + # Register both metric types + for metric_type, metric_name in [(0, "noresqa_score"), (1, "noresqa_mos")]: + description = "NORESQA-MOS" if metric_type == 1 else "NORESQA-score" + + metric_metadata = MetricMetadata( + name=metric_name, + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["fairseq", "torch", "librosa", "numpy"], + description=( + f"{description}: Non-matching reference based speech quality " + "assessment" + ), + paper_reference="https://arxiv.org/abs/2104.09411", + implementation_source="https://github.com/facebookresearch/NORESQA", + ) + registry.register( + NoresqaMetric, + metric_metadata, + aliases=[ + f"Noresqa{metric_type}", + "noresqa" if metric_type == 1 else f"noresqa_type_{metric_type}", + metric_name, + ], + ) diff --git a/versa/utterance_metrics/noresqa_utils/LICENSE b/versa/utterance_metrics/noresqa_utils/LICENSE new file mode 100644 index 0000000..d1bbe80 --- /dev/null +++ b/versa/utterance_metrics/noresqa_utils/LICENSE @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/versa/utterance_metrics/noresqa_utils/__init__.py b/versa/utterance_metrics/noresqa_utils/__init__.py new file mode 100644 index 0000000..34d6ba7 --- /dev/null +++ b/versa/utterance_metrics/noresqa_utils/__init__.py @@ -0,0 +1 @@ +"""Vendored NORESQA model components.""" diff --git a/versa/utterance_metrics/noresqa_utils/noresqa_model.py b/versa/utterance_metrics/noresqa_utils/noresqa_model.py new file mode 100644 index 0000000..d0f6d7a --- /dev/null +++ b/versa/utterance_metrics/noresqa_utils/noresqa_model.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import numpy as np +from librosa.filters import mel as librosa_mel_fn +from torch.nn import Parameter +from functools import wraps +import fairseq +from fairseq import tasks +import pickle + + +class model_dimred(nn.Module): + + def __init__( + self, + in_channel=64, + conv1x1=16, + reduce3x3=24, + conv3x3=32, + reduce5x5=16, + conv5x5=8, + pool_proj=8, + pool=2, + ): + super(model_dimred, self).__init__() + + self.modules1 = nn.ModuleList() + self.modules1.append(nn.Conv2d(in_channel, conv1x1, 1, (1, 1), 0)) + self.modules1.append(nn.Conv2d(in_channel, reduce3x3, 1, 1, 0)) + self.modules1.append(nn.Conv2d(reduce3x3, conv3x3, 3, (1, 1), 1)) + self.modules1.append(nn.Conv2d(in_channel, reduce5x5, 1, 1, 0)) + self.modules1.append(nn.Conv2d(reduce5x5, conv5x5, 5, (1, 1), 2)) + self.modules1.append(nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1))) + self.modules1.append(nn.Conv2d(in_channel, pool_proj, 1, 1, 0)) + self.modules1.append(nn.MaxPool2d((1, pool))) + + def forward(self, x): + + a = F.relu(self.modules1[0](x)) + b = F.relu(self.modules1[2]((F.relu(self.modules1[1](x))))) + c = F.relu(self.modules1[4]((F.relu(self.modules1[3](x))))) + d = F.relu(self.modules1[5](x)) + d = F.relu(self.modules1[6](d)) + x1 = torch.cat((a, b, c, d), axis=1) + x2 = F.relu(self.modules1[7](x1)) + return x2 + + +class base_encoder(nn.Module): + def __init__(self, dev=torch.device("cpu")): + super(base_encoder, self).__init__() + self.dev = dev + + self.modelA = model_dimred(in_channel=2, pool=4) + self.modelB = model_dimred(in_channel=64, pool=4) + self.modelC = model_dimred(in_channel=64, pool=4) + self.modelD = model_dimred(in_channel=64, pool=2) + + def forward(self, x): + x = self.modelD(self.modelC(self.modelB(self.modelA(x)))) + return x + + +class which_clean(nn.Module): + def __init__(self): + super(which_clean, self).__init__() + n_layers = 2 + + self.encoder = nn.ModuleList() + self.ebatch = nn.ModuleList() + self.dp = nn.ModuleList() + filter_size = 5 + dp_num = 0.50 + self.encoder.append(nn.Conv1d(128, 32, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(32)) + self.dp.append(nn.Dropout(p=dp_num)) + self.encoder.append(nn.Conv1d(32, 8, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(8)) + self.dp.append(nn.Dropout(p=dp_num)) + self.encoder.append(nn.Conv1d(8, 2, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(2)) + self.dp.append(nn.Dropout(p=dp_num)) + + def forward(self, x): + + for i in range(3): + x = self.encoder[i](x) + x = self.ebatch[i](x) + if i != 2: + x = F.leaky_relu(x, 0.1) + x = self.dp[i](x) + return x + + +class how_snr(nn.Module): + def __init__(self, dim_emb=32, output=50): + super(how_snr, self).__init__() + n_layers = 2 + + self.encoder = nn.ModuleList() + self.ebatch = nn.ModuleList() + self.dp = nn.ModuleList() + filter_size = 5 + dp_num = 0.50 + self.encoder.append(nn.Conv1d(128, 64, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(64)) + self.dp.append(nn.Dropout(p=dp_num)) + self.encoder.append(nn.Conv1d(64, 32, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(32)) + self.dp.append(nn.Dropout(p=dp_num)) + self.encoder.append( + nn.Conv1d(32, output, filter_size, padding=filter_size // 2) + ) + self.ebatch.append(nn.BatchNorm1d(output)) + self.dp.append(nn.Dropout(p=dp_num)) + + def forward(self, x): + + for i in range(3): + x = self.encoder[i](x) + x = self.ebatch[i](x) + if i != 2: + x = F.leaky_relu(x, 0.1) + x = self.dp[i](x) + return x + + +class how_snr_snr(nn.Module): + def __init__(self, dim_emb=32, output=50): + super(how_snr_snr, self).__init__() + n_layers = 2 + + self.encoder = nn.ModuleList() + self.ebatch = nn.ModuleList() + self.dp = nn.ModuleList() + filter_size = 5 + dp_num = 0.50 + self.encoder.append(nn.Conv1d(128, 64, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(64)) + self.dp.append(nn.Dropout(p=dp_num)) + self.encoder.append(nn.Conv1d(64, 32, filter_size, padding=filter_size // 2)) + self.ebatch.append(nn.BatchNorm1d(32)) + self.dp.append(nn.Dropout(p=dp_num)) + self.encoder.append( + nn.Conv1d(32, output, filter_size, padding=filter_size // 2) + ) + self.ebatch.append(nn.BatchNorm1d(output)) + self.dp.append(nn.Dropout(p=dp_num)) + + def forward(self, x): + + for i in range(3): + x = self.encoder[i](x) + x = self.ebatch[i](x) + if i != 2: + x = F.leaky_relu(x, 0.1) + x = self.dp[i](x) + return x + + +class NORESQA(nn.Module): + + def __init__( + self, + dev=torch.device("cpu"), + minit=1, + output=20, + output2=16, + metric_type=0, + config_path="models/wav2vec_small.pt", + ): + super(NORESQA, self).__init__() + + self.metric_type = metric_type + if metric_type == 0: + self.base_encoder = base_encoder() + self.base_encoder_2 = TemporalConvNet( + num_inputs=128, num_channels=[32, 64, 128, 64], kernel_size=3 + ) + + self.which_clean = which_clean() + self.how_snr_sdr = how_snr(output=output) + self.how_snr_snr = how_snr_snr(output=output2) + if minit == 1: + self.base_encoder.apply(weights_init) + self.which_clean.apply(weights_init) + self.how_snr_sdr.apply(weights_init) + self.how_snr_snr.apply(weights_init) + self.CE = nn.CrossEntropyLoss(reduction="mean") + + elif metric_type == 1: + + SSL_OUT_DIM = 768 + ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [config_path] + ) + + ssl_model = ssl_model[0] + + ssl_model.remove_pretraining_modules() + self.main_model = MosPredictor(ssl_model, SSL_OUT_DIM) + self.linear_layer = nn.Linear(SSL_OUT_DIM, 32) + + self.quantification = PoolAtt(d_input=64, output_size=5) + self.preference = PoolAtt(d_input=64, output_size=2) + + def forward(self, x1, x2=None): + + if self.metric_type == 0: + x1 = self.base_encoder.forward(x1) + x2 = self.base_encoder.forward(x2) + x1 = self.base_encoder_2(x1) + x2 = self.base_encoder_2(x2) + + concat = torch.cat((x1, x2), 1) + + which_closer = self.which_clean.forward(concat) + sdr_diff = self.how_snr_sdr.forward(concat) + snr_diff = self.how_snr_snr.forward(concat) + + return which_closer, sdr_diff, snr_diff + + elif self.metric_type == 1: + + x1 = self.linear_layer(self.main_model(x1)).permute(0, 2, 1) + y1 = self.linear_layer(self.main_model(x2)).permute(0, 2, 1) + concat = torch.cat((x1, y1), 1) + + n_wins = concat.shape[2] + B = [n_wins for n in range(concat.shape[0])] + n_wins_tensor = torch.from_numpy(np.asarray(B)).to(concat.device) + + pref = self.preference(concat.permute(0, 2, 1), n_wins_tensor) + quantf = self.quantification(concat.permute(0, 2, 1), n_wins_tensor) + + att = F.softmax(quantf, dim=1) + B = torch.linspace(0, 4, steps=5).to(concat.device) + C = (att * B).sum(axis=1) + return C + + +def weights_init(m): + classname = m.__class__.__name__ + if ( + classname.find("Conv") != -1 + or classname.find("BatchNorm") != -1 + or classname.find("Linear") != -1 + ): + torch.nn.init.normal_(m.weight) + try: + torch.nn.init.constant_(m.bias, 0.01) + except: + pass + + +class TemporalBlock(nn.Module): + def __init__( + self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2 + ): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm( + nn.Conv1d( + n_inputs, + n_outputs, + kernel_size, + stride=stride, + padding=dilation, + dilation=dilation, + ) + ) + + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm( + nn.Conv1d( + n_outputs, + n_outputs, + kernel_size, + stride=stride, + padding=dilation, + dilation=dilation, + ) + ) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential( + self.conv1, self.relu1, self.dropout1, self.conv2, self.relu2, self.dropout2 + ) + self.downsample = ( + nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + ) + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2**i + in_channels = num_inputs if i == 0 else num_channels[i - 1] + out_channels = num_channels[i] + layers += [ + TemporalBlock( + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=dilation_size, + padding=(kernel_size - 1) * dilation_size, + dropout=dropout, + ) + ] + + self.network = nn.Sequential(*layers) + + def forward(self, x1): + + x1 = x1.reshape(x1.shape[0], -1, x1.shape[2]) + x = self.network(x1) + return x + + +class MosPredictor(nn.Module): + def __init__(self, ssl_model, ssl_out_dim): + super(MosPredictor, self).__init__() + self.ssl_model = ssl_model + self.ssl_features = ssl_out_dim + + def forward(self, wav): + wav = wav.squeeze(1) ## [batches, audio_len] + res = self.ssl_model(wav, mask=False, features_only=True) + x = res["x"] + + return x + + +class PoolAtt(torch.nn.Module): + """ + PoolAtt: Attention-Pooling module. + """ + + def __init__(self, d_input, output_size): + super().__init__() + + self.linear1 = nn.Linear(d_input, 1) + self.linear2 = nn.Linear(d_input, output_size) + + def forward(self, x, n_wins): + + att = self.linear1(x) # B X T X C + + att = att.transpose(2, 1) # B X 1 X T + mask = torch.arange(att.shape[2])[None, :] < n_wins[:, None].to("cpu").to( + torch.long + ) + att[~mask.unsqueeze(1)] = float("-Inf") + att = F.softmax(att, dim=2) + x = torch.bmm(att, x) + x = x.squeeze(1) + x = self.linear2(x) + + return x diff --git a/versa/utterance_metrics/noresqa_utils/noresqa_utils.py b/versa/utterance_metrics/noresqa_utils/noresqa_utils.py new file mode 100644 index 0000000..39b5ef8 --- /dev/null +++ b/versa/utterance_metrics/noresqa_utils/noresqa_utils.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import argparse +import librosa as librosa +import torch.nn.functional as F +import numpy as np +import torch.nn as nn +from versa.utterance_metrics.noresqa_utils.noresqa_model import NORESQA +from scipy import signal + + +def argument_parser(): + """ + Get an argument parser. + """ + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--metric_type", help="NORESQA->0, NORESQA-MOS->1", default=1, type=int + ) + parser.add_argument( + "--GPU_id", help="GPU Id to use (-1 for cpu)", default=-1, type=int + ) + parser.add_argument( + "--mode", + choices=["file", "list"], + help="predict noresqa for test file with another file (mode = file) as NMR or, with a database given as list of files (mode=list) as NMRs", + default="file", + type=str, + ) + parser.add_argument( + "--test_file", + help="test speech file", + required=False, + type=str, + default="sample_clips/noisy.wav", + ) + parser.add_argument( + "--nmr", + help="for mode=file, path of nmr speech file. for mode=list, path of text file which contains list of nmr paths", + required=False, + type=str, + default="sample_clips/clean.wav", + ) + return parser + + +# function extraction stft +def extract_stft(audio, sampling_rate=16000): + + fx, tx, stft_out = signal.stft( + audio, sampling_rate, window="hann", nperseg=512, noverlap=256, nfft=512 + ) + stft_out = stft_out[:256, :] + feat = np.concatenate( + ( + np.abs(stft_out).reshape([stft_out.shape[0], stft_out.shape[1], 1]), + np.angle(stft_out).reshape([stft_out.shape[0], stft_out.shape[1], 1]), + ), + axis=2, + ) + return feat + + +# noresqa and noresqa-mos prediction calls +def model_prediction_noresqa(test_feat, nmr_feat, model): + + intervals_sdr = np.arange(0.5, 40, 1) + + with torch.no_grad(): + ranking_frame, sdr_frame, snr_frame = model( + test_feat.permute(0, 3, 2, 1), nmr_feat.permute(0, 3, 2, 1) + ) + sfmax = nn.Softmax(dim=1) + # preference task prediction + ranking = sfmax(ranking_frame).mean(2).detach().cpu().numpy() + pout = ranking[0][0] + # quantification task + sdr = intervals_sdr * (sfmax(sdr_frame).mean(2).detach().cpu().numpy()) + qout = sdr.sum() + + return pout, qout + + +def model_prediction_noresqa_mos(test_feat, nmr_feat, model): + + with torch.no_grad(): + score = model(nmr_feat, test_feat).detach().cpu().numpy()[0] + + return score + + +# reading audio clips +def audio_loading(path, sampling_rate=16000): + + audio, fs = librosa.load(path, sr=None) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + + if fs != sampling_rate: + audio = librosa.resample(audio, fs, sampling_rate) + + return audio + + +# function checking if the size of the inputs are same. If not, then the reference audio's size is adjusted +def check_size(audio_ref, audio_test): + + if len(audio_ref) > len(audio_test): + print("Durations dont match. Adjusting duration of reference.") + audio_ref = audio_ref[: len(audio_test)] + + elif len(audio_ref) < len(audio_test): + print("Durations dont match. Adjusting duration of reference.") + while len(audio_test) > len(audio_ref): + audio_ref = np.append(audio_ref, audio_ref) + audio_ref = audio_ref[: len(audio_test)] + + return audio_ref, audio_test + + +# audio loading and feature extraction +def feats_loading(test_path, ref_path=None, noresqa_or_noresqaMOS=0): + + if noresqa_or_noresqaMOS == 0 or noresqa_or_noresqaMOS == 1: + + # audio_ref = audio_loading(ref_path) + # audio_test = audio_loading(test_path) + audio_ref, audio_test = ref_path, test_path + audio_ref, audio_test = check_size(audio_ref, audio_test) + + if noresqa_or_noresqaMOS == 0: + ref_feat = extract_stft(audio_ref) + test_feat = extract_stft(audio_test) + return ref_feat, test_feat + else: + return audio_ref, audio_test + + +if __name__ == "__main__": + args = argument_parser().parse_args() + CONFIG_PATH = "models/wav2vec_small.pt" + + # Noresqa model + model = NORESQA( + output=40, output2=40, metric_type=args.metric_type, config_path=CONFIG_PATH + ) + + # Loading checkpoint + if args.metric_type == 0: + model_checkpoint_path = "models/model_noresqa.pth" + state = torch.load(model_checkpoint_path, map_location="cpu")["state_base"] + elif args.metric_type == 1: + model_checkpoint_path = "models/model_noresqa_mos.pth" + state = torch.load(model_checkpoint_path, map_location="cpu")["state_dict"] + + pretrained_dict = {} + for k, v in state.items(): + if "module" in k: + pretrained_dict[k.replace("module.", "")] = v + else: + pretrained_dict[k] = v + model_dict = model.state_dict() + model_dict.update(pretrained_dict) + model.load_state_dict(pretrained_dict) + + # change device as needed + # device + if args.GPU_id >= 0 and torch.cuda.is_available(): + device = torch.device("cuda:{}".format(args.GPU_id)) + else: + device = torch.device("cpu") + + model.to(device) + model.eval() + + sfmax = nn.Softmax(dim=1) + + if args.mode == "file": + + nmr_feat, test_feat = feats_loading( + args.test_file, args.nmr, noresqa_or_noresqaMOS=args.metric_type + ) + test_feat = torch.from_numpy(test_feat).float().to(device).unsqueeze(0) + nmr_feat = torch.from_numpy(nmr_feat).float().to(device).unsqueeze(0) + + if args.metric_type == 0: + noresqa_pout, noresqa_qout = model_prediction_noresqa(test_feat, nmr_feat) + print( + "Probaility of the test speech cleaner than the given NMR =", + noresqa_pout, + ) + print( + "NORESQA score of the test speech with respect to the given NMR =", + noresqa_qout, + ) + + elif args.metric_type == 1: + mos_score = model_prediction_noresqa_mos(test_feat, nmr_feat) + print( + "MOS score of the test speech (assuming NMR is clean) =", + str(5.0 - mos_score), + ) + + elif args.mode == "list": + + with open(args.nmr) as f: + for ln in f: + + nmr_feat, test_feat = feats_loading( + args.test_file, ln.strip(), noresqa_or_noresqaMOS=args.metric_type + ) + test_feat = torch.from_numpy(test_feat).float().to(device).unsqueeze(0) + nmr_feat = torch.from_numpy(nmr_feat).float().to(device).unsqueeze(0) + + if args.metric_type == 0: + pout, qout = model_prediction_noresqa(test_feat, nmr_feat) + print( + f"Prob. of test cleaner than {ln.strip()} = {pout}. Noresqa score = {qout}" + ) + + elif args.metric_type == 1: + score = model_prediction_noresqa_mos(test_feat, nmr_feat) + print(f"MOS of test with respect to clean {ln.strip()} = {5-score}") diff --git a/versa/utterance_metrics/owsm_lid.py b/versa/utterance_metrics/owsm_lid.py index b9286b6..4648ff3 100644 --- a/versa/utterance_metrics/owsm_lid.py +++ b/versa/utterance_metrics/owsm_lid.py @@ -3,39 +3,160 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import os +"""Module for OWSM Language Identification (LID) metrics.""" + +import logging +from typing import Dict, Any, Optional, Union -import librosa import numpy as np -from espnet2.bin.s2t_inference_language import Speech2Language - - -def owsm_lid_model_setup(model_tag="default", nbest=3, use_gpu=False): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if model_tag == "default": - model_tag = "espnet/owsm_v3.1_ebf" - model = Speech2Language.from_pretrained( - model_tag=model_tag, - device=device, - nbest=nbest, + +logger = logging.getLogger(__name__) + +# Handle optional espnet2 dependency +try: + from espnet2.bin.s2t_inference_language import Speech2Language + + ESPNET2_AVAILABLE = True +except ImportError: + logger.warning( + "espnet2 is not properly installed. " "Please install espnet2 and retry" ) + Speech2Language = None + ESPNET2_AVAILABLE = False + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class Espnet2NotAvailableError(RuntimeError): + """Exception raised when espnet2 is required but not available.""" + + pass + + +def is_espnet2_available(): + """ + Check if the espnet2 package is available. + + Returns: + bool: True if espnet2 is available, False otherwise. + """ + return ESPNET2_AVAILABLE + + +class OwsmLidMetric(BaseMetric): + """OWSM Language Identification (LID) metric.""" + + TARGET_FS = 16000 # OWSM model's expected sampling rate + + def _setup(self): + """Initialize OWSM LID-specific components.""" + if not ESPNET2_AVAILABLE: + raise ImportError( + "espnet2 is not properly installed. Please install espnet2 and retry" + ) - return model + self.model_tag = self.config.get("model_tag", "default") + self.nbest = self.config.get("nbest", 3) + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("cache_dir", "versa_cache/espnet_model_zoo") + try: + self.model = self._setup_model() + except Exception as e: + raise RuntimeError(f"Failed to initialize OWSM LID model: {str(e)}") from e -def language_id(model, pred_x, fs): - # NOTE(jiatong): only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + def _setup_model(self): + """Setup the OWSM LID model.""" + device = "cuda" if self.use_gpu else "cpu" - result = model(pred_x) - return {"language": result} + if self.model_tag == "default": + model_tag = "espnet/owsm_v3.1_ebf" + else: + model_tag = self.model_tag + + try: + from espnet_model_zoo.downloader import ModelDownloader + except ImportError: + raise ImportError( + "owsm_lid requires espnet_model_zoo. Please install it and retry" + ) + model_kwargs = ModelDownloader(cachedir=self.cache_dir).download_and_unpack( + model_tag + ) + model = Speech2Language(device=device, nbest=self.nbest, **model_kwargs) + + return model + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate language identification for speech. + + Args: + predictions: Audio signal to be evaluated. + references: Not used for LID (single-ended metric). + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing language identification result. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + + pred_x = np.asarray(pred_x) + + # Resample if necessary (OWSM only works with 16kHz) + if fs != self.TARGET_FS: + pred_x = resample_audio(pred_x, fs, self.TARGET_FS) + + result = self.model(pred_x) + return {"language": result} + + def get_metadata(self) -> MetricMetadata: + """Return OWSM LID metric metadata.""" + return MetricMetadata( + name="lid", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.LIST, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "librosa", "numpy"], + description="OWSM Language Identification (LID) for speech", + paper_reference="https://arxiv.org/abs/2309.16588", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_owsm_lid_metric(registry): + """Register OWSM LID metric with the registry.""" + metric_metadata = MetricMetadata( + name="lid", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.LIST, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "librosa", "numpy"], + description="OWSM Language Identification (LID) for speech", + paper_reference="https://arxiv.org/abs/2309.16588", + implementation_source="https://github.com/espnet/espnet", + ) + registry.register( + OwsmLidMetric, + metric_metadata, + aliases=["OwsmLid", "owsm_lid", "lid", "language_id"], + ) if __name__ == "__main__": a = np.random.random(16000) - model = owsm_lid_model_setup() - print("metrics: {}".format(language_id(model, a, 16000))) + model = OwsmLidMetric() + print("metrics: {}".format(model.compute(a, None, {"sample_rate": 16000}))) diff --git a/versa/utterance_metrics/pam.py b/versa/utterance_metrics/pam.py index c02eb0c..256b8af 100644 --- a/versa/utterance_metrics/pam.py +++ b/versa/utterance_metrics/pam.py @@ -8,6 +8,7 @@ warnings.filterwarnings("ignore") import argparse +import logging as py_logging import os import re import sys @@ -20,17 +21,42 @@ from huggingface_hub.file_download import hf_hub_download from transformers import AutoTokenizer, logging -from versa.utterance_metrics.pam_utils.clap import CLAP - +logger = py_logging.getLogger(__name__) logging.set_verbosity_error() import collections import numpy as np import torch.nn.functional as F +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + +# Handle optional dependencies +try: + from versa.utterance_metrics.pam_utils.clap import CLAP + + PAM_AVAILABLE = True +except ImportError: + logger.warning( + "PAM dependencies are not installed. Please install required dependencies" + ) + CLAP = None + PAM_AVAILABLE = False + # Constants HF_REPO = "microsoft/msclap" CLAP_VERSION = "CLAP_weights_2023.pth" + + +def is_pam_available(): + """ + Check if the PAM dependencies are available. + + Returns: + bool: True if PAM dependencies are available, False otherwise. + """ + return PAM_AVAILABLE + + PAM_PROMPTS = [ "the sound is clear and clean.", "the sound is noisy and with artifacts.", @@ -154,11 +180,12 @@ def _preprocess_text(self, text_queries: List[str]) -> Dict[str, torch.Tensor]: if "gpt" in self.args.text_model: text = text + " <|endoftext|>" - tok = self.tokenizer.encode_plus( - text=text, + tok = self.tokenizer( + text, add_special_tokens=True, max_length=self.args.text_len, padding="max_length", + truncation=True, return_tensors="pt", ) @@ -240,140 +267,197 @@ def evaluate(self, audio_tensor: torch.Tensor) -> float: return pam_score -def load_audio( - audio_file: Union[str, torch.Tensor], sample_rate: int, repro: bool = True -) -> torch.Tensor: - """ - Load and preprocess audio file. +class PamMetric(BaseMetric): + """PAM (Perceptual Audio Metric) for audio quality assessment.""" - Args: - audio_file: Path to audio file or audio tensor - sample_rate: Sample rate of the input audio - repro: If True, use reproducible processing (taking first 7 seconds) + TARGET_FS = 44100 # PAM model's expected sampling rate - Returns: - Processed audio tensor - """ - # Load audio file if path is provided - if isinstance(audio_file, str): - audio, sample_rate = torchaudio.load(audio_file) - else: - audio = audio_file.clone() # Create a copy to avoid modifying the original - - # Ensure audio is a FloatTensor - audio = torch.FloatTensor(audio) - - # Resample audio if needed - if sample_rate != RESAMPLE_RATE: - resampler = T.Resample(sample_rate, RESAMPLE_RATE) - audio = resampler(audio) - - # Convert to mono if stereo - if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - # Reshape to 1D - audio = audio.reshape(-1) - - # Process audio to be exactly AUDIO_DURATION seconds - if SAMPLES >= audio.shape[0]: - # Audio is shorter than required duration, repeat to match - repeat_factor = int(np.ceil(SAMPLES / audio.shape[0])) - audio = audio.repeat(repeat_factor) - # Trim to exact length - audio = audio[:SAMPLES] - else: - # Audio is longer than required duration - if repro: - # Take first AUDIO_DURATION seconds - audio = audio[:SAMPLES] - else: - # Take chunks of AUDIO_DURATION seconds plus remaining portion - cutoff = int(np.floor(audio.shape[0] / SAMPLES)) - initial_audio = audio[: cutoff * SAMPLES] - - remaining = audio[cutoff * SAMPLES :] - if remaining.shape[0] > 0: - # If remaining is non-empty, take the last AUDIO_DURATION seconds - remaining = ( - audio[-SAMPLES:] - if remaining.shape[0] <= SAMPLES - else remaining[:SAMPLES] - ) - audio = torch.cat([initial_audio, remaining]) - else: - audio = initial_audio + def _setup(self): + """Initialize PAM-specific components.""" + if not PAM_AVAILABLE: + raise ImportError( + "PAM dependencies are not installed. Please install required dependencies" + ) - return audio + self.repro = self.config.get("repro", True) + self.cache_dir = self.config.get("cache_dir", "versa_cache/pam") + self.use_gpu = self.config.get("use_gpu", False) + + # Extract model configuration from config + model_config = { + "text_model": self.config.get("text_model", "gpt2"), + "text_len": self.config.get("text_len", 77), + "transformer_embed_dim": self.config.get("transformer_embed_dim", 768), + "audioenc_name": self.config.get("audioenc_name", "HTSAT"), + "out_emb": self.config.get("out_emb", 768), + "sampling_rate": self.config.get("sampling_rate", 44100), + "duration": self.config.get("duration", 7), + "fmin": self.config.get("fmin", 50), + "fmax": self.config.get("fmax", 8000), + "n_fft": self.config.get("n_fft", 1024), + "hop_size": self.config.get("hop_size", 320), + "mel_bins": self.config.get("mel_bins", 64), + "window_size": self.config.get("window_size", 1024), + "d_proj": self.config.get("d_proj", 1024), + "temperature": self.config.get("temperature", 0.003), + "num_classes": self.config.get("num_classes", 527), + "batch_size": self.config.get("batch_size", 1024), + "demo": self.config.get("demo", False), + } + try: + self.model = PAM(model_config=model_config, use_cuda=self.use_gpu) + except Exception as e: + raise RuntimeError(f"Failed to initialize PAM model: {str(e)}") from e -def pam_model_setup(model_config: Dict[str, Any], use_gpu: bool = False) -> PAM: - """ - Initialize PAM model with given configuration. + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate PAM score for audio quality assessment. - Args: - model_config: Model configuration dictionary - use_gpu: Whether to use GPU for computation + Args: + predictions: Audio signal to be evaluated. + references: Not used for PAM (single-ended metric). + metadata: Optional metadata containing sample_rate. - Returns: - Initialized PAM model - """ - model = PAM(model_config=model_config, use_cuda=use_gpu) - return model + Returns: + dict: Dictionary containing PAM score. + """ + pred_x = predictions + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") -def pam_metric( - model: PAM, - pred_x: Union[str, torch.Tensor, np.ndarray], - gt_x: Optional[Union[str, torch.Tensor, np.ndarray]] = None, - fs: int = 16000, -) -> Dict[str, float]: - """ - Compute PAM metric for given audio. + # Convert numpy array to tensor if needed + if isinstance(pred_x, np.ndarray): + pred_x = torch.FloatTensor(pred_x) - Args: - model: PAM model - pred_x: Predicted audio (file path or tensor) - gt_x: Ground truth audio (unused, kept for API compatibility) - fs: Sample rate of the input audio + # Load and preprocess audio + audio = self._load_audio(pred_x, fs) - Returns: - Dictionary containing PAM score - """ - # Convert numpy array to tensor if needed - if isinstance(pred_x, np.ndarray): - pred_x = torch.FloatTensor(pred_x) - - # Load and preprocess audio - audio = load_audio(pred_x, fs, repro=True) - - # Ensure audio has batch dimension - if len(audio.shape) < 2: - audio = audio.unsqueeze(0) - - # Compute PAM score - pam_score = model.evaluate(audio) - - return {"pam_score": pam_score} - - -if __name__ == "__main__": - # Example usage - a = np.random.random(16000) - - # Load configuration from YAML file - try: - with open("egs/separate_metrics/pam.yaml", "r", encoding="utf-8") as f: - config = yaml.safe_load(f)[0] - except (FileNotFoundError, yaml.YAMLError) as e: - print(f"Error loading configuration: {e}") - sys.exit(1) - - # Initialize model and compute metric - try: - model = pam_model_setup(config, use_gpu=torch.cuda.is_available()) - result = pam_metric(model, a, fs=16000) - print(f"PAM score: {result['pam_score']:.4f}") - except Exception as e: - print(f"Error computing PAM metric: {e}") - sys.exit(1) + # Ensure audio has batch dimension + if len(audio.shape) < 2: + audio = audio.unsqueeze(0) + + # Compute PAM score + pam_score = self.model.evaluate(audio) + + return {"pam_score": pam_score} + + def _load_audio( + self, audio_file: Union[str, torch.Tensor], sample_rate: int + ) -> torch.Tensor: + """ + Load and preprocess audio file. + + Args: + audio_file: Path to audio file or audio tensor + sample_rate: Sample rate of the input audio + + Returns: + Processed audio tensor + """ + # Load audio file if path is provided + if isinstance(audio_file, str): + audio, sample_rate = torchaudio.load(audio_file) + else: + audio = audio_file.clone() # Create a copy to avoid modifying the original + + # Ensure audio is a FloatTensor + audio = torch.FloatTensor(audio) + + # Resample audio if needed + if sample_rate != self.TARGET_FS: + resampler = T.Resample(sample_rate, self.TARGET_FS) + audio = resampler(audio) + + # Convert to mono if stereo + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + # Reshape to 1D + audio = audio.reshape(-1) + + # Process audio to be exactly AUDIO_DURATION seconds + samples = self.TARGET_FS * AUDIO_DURATION + if samples >= audio.shape[0]: + # Audio is shorter than required duration, repeat to match + repeat_factor = int(np.ceil(samples / audio.shape[0])) + audio = audio.repeat(repeat_factor) + # Trim to exact length + audio = audio[:samples] + else: + # Audio is longer than required duration + if self.repro: + # Take first AUDIO_DURATION seconds + audio = audio[:samples] + else: + # Take chunks of AUDIO_DURATION seconds plus remaining portion + cutoff = int(np.floor(audio.shape[0] / samples)) + initial_audio = audio[: cutoff * samples] + + remaining = audio[cutoff * samples :] + if remaining.shape[0] > 0: + # If remaining is non-empty, take the last AUDIO_DURATION seconds + remaining = ( + audio[-samples:] + if remaining.shape[0] <= samples + else remaining[:samples] + ) + audio = torch.cat([initial_audio, remaining]) + else: + audio = initial_audio + + return audio + + def get_metadata(self) -> MetricMetadata: + """Return PAM metric metadata.""" + return MetricMetadata( + name="pam", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=[ + "torch", + "torchaudio", + "transformers", + "huggingface_hub", + "numpy", + ], + description="PAM: Prompting Audio-Language Models for Audio Quality Assessment", + paper_reference="https://arxiv.org/abs/2309.07317", + implementation_source="https://github.com/soham97/PAM", + ) + + +def register_pam_metric(registry): + """Register PAM metric with the registry.""" + metric_metadata = MetricMetadata( + name="pam", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=[ + "torch", + "torchaudio", + "transformers", + "huggingface_hub", + "numpy", + ], + description="PAM: Prompting Audio-Language Models for Audio Quality Assessment", + paper_reference="https://arxiv.org/abs/2309.07317", + implementation_source="https://github.com/soham97/PAM", + ) + registry.register( + PamMetric, + metric_metadata, + aliases=["Pam", "pam", "perceptual_audio_metric"], + ) diff --git a/versa/utterance_metrics/pam_utils/clap.py b/versa/utterance_metrics/pam_utils/clap.py index 9e6de52..1fe590b 100644 --- a/versa/utterance_metrics/pam_utils/clap.py +++ b/versa/utterance_metrics/pam_utils/clap.py @@ -57,7 +57,7 @@ def interpolate(x, ratio): Returns: upsampled: (batch_size, time_steps * ratio, classes_num) """ - (batch_size, time_steps, classes_num) = x.shape + batch_size, time_steps, classes_num = x.shape upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) return upsampled diff --git a/versa/utterance_metrics/pesq_score.py b/versa/utterance_metrics/pesq_score.py index b6062b5..310c05d 100644 --- a/versa/utterance_metrics/pesq_score.py +++ b/versa/utterance_metrics/pesq_score.py @@ -3,46 +3,148 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging +"""Module for PESQ (Perceptual Evaluation of Speech Quality) metrics.""" -logger = logging.getLogger(__name__) +import logging +from typing import Dict, Any, Optional, Union -import librosa import numpy as np -from pesq import pesq +logger = logging.getLogger(__name__) + +# Handle optional pesq dependency try: from pesq import pesq + + PESQ_AVAILABLE = True except ImportError: - raise ImportError("Please install pesq and retry: pip install pesq") - - -def pesq_metric(pred_x, gt_x, fs): - try: - if fs == 8000: - pesq_value = pesq(8000, gt_x, pred_x, "nb") - elif fs < 16000: - logging.info("not support fs {}, resample to 8khz".format(fs)) - new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=8000) - new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=8000) - pesq_value = pesq(16000, new_gt_x, new_pred_x, "nb") - elif fs == 16000: - pesq_value = pesq(16000, gt_x, pred_x, "wb") - else: - logging.info("not support fs {}, resample to 16khz".format(fs)) - new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - pesq_value = pesq(16000, new_gt_x, new_pred_x, "wb") - except BaseException: - logging.warning( - "Error from pesq calculation. Please check the audio (likely due to silence)" + logger.warning( + "pesq is not properly installed. Please install pesq and retry: pip install pesq" + ) + pesq = None + PESQ_AVAILABLE = False + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType + + +class PesqNotAvailableError(RuntimeError): + """Exception raised when pesq is required but not available.""" + + pass + + +def is_pesq_available(): + """ + Check if the pesq package is available. + + Returns: + bool: True if pesq is available, False otherwise. + """ + return PESQ_AVAILABLE + + +class PesqMetric(BaseMetric): + """PESQ (Perceptual Evaluation of Speech Quality) metric.""" + + def _setup(self): + """Initialize PESQ-specific components.""" + if not PESQ_AVAILABLE: + raise ImportError( + "pesq is not properly installed. Please install pesq and retry: pip install pesq" + ) + + def compute( + self, predictions: Any, references: Any, metadata: Dict[str, Any] = None + ) -> Dict[str, Union[float, str]]: + """Calculate PESQ score for speech quality assessment. + + Args: + predictions: Predicted audio signal. + references: Ground truth audio signal. + metadata: Optional metadata containing sample_rate. + + Returns: + dict: Dictionary containing PESQ score. + """ + pred_x = predictions + gt_x = references + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + + # Validate inputs + if pred_x is None: + raise ValueError("Predicted signal must be provided") + if gt_x is None: + raise ValueError("Reference signal must be provided") + + pred_x = np.asarray(pred_x) + gt_x = np.asarray(gt_x) + + try: + if fs == 8000: + pesq_value = pesq(8000, gt_x, pred_x, "nb") + elif fs < 16000: + logger.info("not support fs {}, resample to 8khz".format(fs)) + new_gt_x = resample_audio(gt_x, fs, 8000) + new_pred_x = resample_audio(pred_x, fs, 8000) + pesq_value = pesq(8000, new_gt_x, new_pred_x, "nb") + elif fs == 16000: + pesq_value = pesq(16000, gt_x, pred_x, "wb") + else: + logger.info("not support fs {}, resample to 16khz".format(fs)) + new_gt_x = resample_audio(gt_x, fs, 16000) + new_pred_x = resample_audio(pred_x, fs, 16000) + pesq_value = pesq(16000, new_gt_x, new_pred_x, "wb") + except BaseException: + logger.warning( + "Error from pesq calculation. Please check the audio (likely due to silence)" + ) + pesq_value = 0.0 + + return {"pesq": pesq_value} + + def get_metadata(self) -> MetricMetadata: + """Return PESQ metric metadata.""" + return MetricMetadata( + name="pesq", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pesq", "librosa", "numpy"], + description="PESQ: Perceptual Evaluation of Speech Quality", + paper_reference="https://www.itu.int/rec/T-REC-P.862", + implementation_source="https://github.com/ludlows/python-pesq", ) - pesq_value = 0.0 - return {"pesq": pesq_value} + + +def register_pesq_metric(registry): + """Register PESQ metric with the registry.""" + metric_metadata = MetricMetadata( + name="pesq", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pesq", "librosa", "numpy"], + description="PESQ: Perceptual Evaluation of Speech Quality", + paper_reference="https://www.itu.int/rec/T-REC-P.862", + implementation_source="https://github.com/ludlows/python-pesq", + ) + registry.register( + PesqMetric, + metric_metadata, + aliases=["Pesq", "pesq", "perceptual_evaluation_speech_quality"], + ) if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - scores = pesq_metric(a, b, 16000) + metric = PesqMetric() + scores = metric.compute(a, b, metadata={"sample_rate": 16000}) print(scores) diff --git a/versa/utterance_metrics/pseudo_mos.py b/versa/utterance_metrics/pseudo_mos.py index 7402cb1..cb946e7 100644 --- a/versa/utterance_metrics/pseudo_mos.py +++ b/versa/utterance_metrics/pseudo_mos.py @@ -5,10 +5,9 @@ # Copyright 2025 Jionghao Han # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging - -logger = logging.getLogger(__name__) +# flake8: noqa: E501 +import logging import librosa import numpy as np import torch @@ -16,6 +15,11 @@ from pathlib import Path from typing import Optional +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +logger = logging.getLogger(__name__) + try: import utmosv2 from utmosv2.dataset.multi_spec import process_audio_only_versa @@ -41,7 +45,9 @@ def pseudo_mos_setup( # first import utmos to resolve cross-import from the same model if "utmos" in predictor_types: torch.hub.set_dir(cache_dir) - utmos = torch.hub.load("ftshijt/SpeechMOS:main", "utmos22_strong").to(device) + utmos = torch.hub.load( + "ftshijt/SpeechMOS:main", "utmos22_strong", trust_repo=True + ).to(device) predictor_dict["utmos"] = utmos.float() predictor_fs["utmos"] = 16000 if "utmosv2" in predictor_types: @@ -67,7 +73,7 @@ def pseudo_mos_setup( or "plcmos" in predictor_types ): try: - import onnxruntime # NOTE(jiatong): a requirement of aecmos but not in requirements + import onnxruntime as _onnxruntime # noqa: F401 from speechmos import dnsmos, plcmos except ImportError: raise ImportError( @@ -89,6 +95,13 @@ def pseudo_mos_setup( predictor_fs["plcmos"] = predictor_args["plcmos"]["fs"] elif predictor == "utmos" or predictor == "utmosv2": continue # already initialized + elif predictor == "singmos": + torch.hub.set_dir(cache_dir) + singmos = torch.hub.load( + "South-Twilight/SingMOS:v0.2.0", "singing_ssl_mos", trust_repo=True + ).to(device) + predictor_dict["singmos"] = singmos + predictor_fs["singmos"] = 16000 elif predictor == "singmos_v1": torch.hub.set_dir(cache_dir) singmos = torch.hub.load( @@ -128,9 +141,7 @@ def pseudo_mos_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): for predictor in predictor_dict.keys(): if predictor == "utmos": if fs != predictor_fs["utmos"]: - pred_utmos = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs["utmos"] - ) + pred_utmos = resample_audio(pred, fs, predictor_fs["utmos"]) else: pred_utmos = pred pred_tensor = torch.from_numpy(pred_utmos).unsqueeze(0) @@ -143,9 +154,7 @@ def pseudo_mos_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): elif predictor == "utmosv2": if fs != predictor_fs["utmosv2"]: - pred_utmosv2 = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs["utmosv2"] - ) + pred_utmosv2 = resample_audio(pred, fs, predictor_fs["utmosv2"]) else: pred_utmosv2 = pred @@ -186,9 +195,7 @@ def pseudo_mos_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): elif predictor == "dnsmos": if fs != predictor_fs["dnsmos"]: - pred_dnsmos = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs["dnsmos"] - ) + pred_dnsmos = resample_audio(pred, fs, predictor_fs["dnsmos"]) fs = predictor_fs["dnsmos"] else: pred_dnsmos = pred @@ -199,9 +206,7 @@ def pseudo_mos_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): scores.update(dns_overall=score["ovrl_mos"], dns_p808=score["p808_mos"]) elif predictor == "plcmos": if fs != predictor_fs["plcmos"]: - pred_plcmos = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs["plcmos"] - ) + pred_plcmos = resample_audio(pred, fs, predictor_fs["plcmos"]) fs = predictor_fs["plcmos"] else: pred_plcmos = pred @@ -209,27 +214,9 @@ def pseudo_mos_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): max_val = np.max(np.abs(pred_plcmos)) score = predictor_dict["plcmos"].run(pred_plcmos / max_val, sr=fs) scores.update(plcmos=score["plcmos"]) - elif predictor == "singmos_v1": - if fs != predictor_fs["singmos_v1"]: - pred_singmos = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs["singmos_v1"] - ) - else: - pred_singmos = pred - pred_tensor = torch.from_numpy(pred_singmos).unsqueeze(0) - length_tensor = torch.tensor([pred_tensor.size(1)]).int() - if use_gpu: - pred_tensor = pred_tensor.to("cuda") - length_tensor = length_tensor.to("cuda") - score = predictor_dict["singmos_v1"](pred_tensor.float(), length_tensor)[ - 0 - ].item() - scores.update(singmos_v1=score) - elif predictor == "singmos_pro": - if fs != predictor_fs["singmos_pro"]: - pred_singmos = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs["singmos_pro"] - ) + elif predictor in ("singmos", "singmos_v1", "singmos_pro"): + if fs != predictor_fs[predictor]: + pred_singmos = resample_audio(pred, fs, predictor_fs[predictor]) else: pred_singmos = pred pred_tensor = torch.from_numpy(pred_singmos).unsqueeze(0) @@ -237,15 +224,13 @@ def pseudo_mos_metric(pred, fs, predictor_dict, predictor_fs, use_gpu=False): if use_gpu: pred_tensor = pred_tensor.to("cuda") length_tensor = length_tensor.to("cuda") - score = predictor_dict["singmos_pro"](pred_tensor.float(), length_tensor)[ + score = predictor_dict[predictor](pred_tensor.float(), length_tensor)[ 0 ].item() - scores.update(singmos_pro=score) + scores[predictor] = score elif predictor.startswith("dnsmos_pro_"): if fs != predictor_fs[predictor]: - pred_dnsmos_pro = librosa.resample( - pred, orig_sr=fs, target_sr=predictor_fs[predictor] - ) + pred_dnsmos_pro = resample_audio(pred, fs, predictor_fs[predictor]) else: pred_dnsmos_pro = pred @@ -286,8 +271,6 @@ def stft( return spec spec = torch.FloatTensor(stft(pred_dnsmos_pro)) - if use_gpu: - spec = spec.to("cuda") with torch.no_grad(): prediction = predictor_dict[predictor](spec[None, None, ...]) scores[predictor] = prediction[0, 0].item() @@ -297,6 +280,76 @@ def stft( return scores +class PseudoMosMetric(BaseMetric): + """Pseudo-subjective MOS predictors.""" + + def _setup(self): + self.predictor_types = self.config.get( + "predictor_types", ["utmos", "dnsmos", "plcmos"] + ) + self.predictor_args = self.config.get("predictor_args", {}) + self.cache_dir = self.config.get("cache_dir", "versa_cache") + self.use_gpu = self.config.get("use_gpu", False) + self.predictor_dict, self.predictor_fs = pseudo_mos_setup( + self.predictor_types, + self.predictor_args, + cache_dir=self.cache_dir, + use_gpu=self.use_gpu, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return pseudo_mos_metric( + np.asarray(predictions), + fs=fs, + predictor_dict=self.predictor_dict, + predictor_fs=self.predictor_fs, + use_gpu=self.use_gpu, + ) + + def get_metadata(self): + return _pseudo_mos_metadata() + + +def _pseudo_mos_metadata(): + return MetricMetadata( + name="pseudo_mos", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy", "requests"], + description=( + "Pseudo-subjective MOS predictors including UTMOS, DNSMOS, " + "PLCMOS, SingMOS, and DNSMOS Pro" + ), + implementation_source="https://github.com/tarepan/SpeechMOS", + ) + + +def register_pseudo_mos_metric(registry): + """Register pseudo MOS metric suite with the registry.""" + registry.register( + PseudoMosMetric, + _pseudo_mos_metadata(), + aliases=[ + "utmos", + "dnsmos", + "plcmos", + "singmos", + "singmos_v1", + "singmos_pro", + "utmosv2", + "dnsmos_pro", + ], + ) + + if __name__ == "__main__": a = np.random.random(16000) print(a) @@ -305,8 +358,7 @@ def stft( "utmos", "dnsmos", "plcmos", - "singmos_v1", - "singmos_pro", + "singmos", "dnsmos_pro_bvcc", "dnsmos_pro_nisqa", "dnsmos_pro_vcc2018", diff --git a/versa/utterance_metrics/pysepm.py b/versa/utterance_metrics/pysepm.py index 9343fc7..65a47b1 100644 --- a/versa/utterance_metrics/pysepm.py +++ b/versa/utterance_metrics/pysepm.py @@ -1,12 +1,14 @@ #!/usr/bin/env python3 # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import librosa import logging -logger = logging.getLogger(__name__) import numpy as np +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +logger = logging.getLogger(__name__) try: import pysepm # Import the pysepm package for speech quality metrics @@ -17,9 +19,13 @@ pysepm = None +def is_pysepm_available(): + return pysepm is not None + + def fwsegsnr(pred_x, gt_x, fs, frame_len=0.03, overlap=0.75): """ - Compute the Frequency-Weighted Segmental SNR (fwsegSNR) between predicted and ground truth signals. + Compute Frequency-Weighted Segmental SNR. Args: pred_x (np.array): Audio signal to be evaluated signal. @@ -64,7 +70,7 @@ def llr(pred_x, gt_x, fs, frame_len=0.03, overlap=0.75): def wss(pred_x, gt_x, fs, frame_len=0.03, overlap=0.75): """ - Compute the Weighted Spectral Slope (WSS) measure between predicted and ground truth signals. + Compute the Weighted Spectral Slope (WSS) measure. Args: pred_x (np.array): Audio signal to be evaluated signal. @@ -99,14 +105,18 @@ def cd(pred_x, gt_x, fs): float: Cepstral Distance score. """ cep_dist_score = pysepm.cepstrum_distance( - clean_speech=gt_x, processed_speech=pred_x, fs=fs, frameLen=0.03, overlap=0.75 + clean_speech=gt_x, + processed_speech=pred_x, + fs=fs, + frameLen=0.03, + overlap=0.75, ) return cep_dist_score def composite(pred_x, gt_x, fs): """ - Compute the composite objective measure scores (c_sig, c_bak, c_ovl) for speech quality. + Compute composite objective speech quality scores. Args: pred_x (np.array): Audio signal to be evaluated signal. @@ -126,7 +136,7 @@ def composite(pred_x, gt_x, fs): def csii(pred_x, gt_x, fs): """ - Compute the Coherence Speech Intelligibility Index (CSII) between predicted and ground truth signals. + Compute the Coherence Speech Intelligibility Index (CSII). Args: pred_x (np.array): Audio signal to be evaluated signal. @@ -146,7 +156,7 @@ def csii(pred_x, gt_x, fs): def ncm(pred_x, gt_x, fs): """ - Compute the Normalized Covariance Measure (NCM) between predicted and ground truth signals. + Compute the Normalized Covariance Measure (NCM). Args: pred_x (np.array): Audio signal to be evaluated signal. @@ -167,7 +177,7 @@ def ncm(pred_x, gt_x, fs): def pysepm_metric(pred_x, gt_x, fs, frame_len=0.03, overlap=0.75): if pysepm is None: raise ImportError( - # Error message if pysepm is not installed + "pysepm is not installed. Please use `tools/install_pysepm.sh` to install" ) fwsegsnr_score = fwsegsnr(pred_x, gt_x, fs, frame_len, overlap) llr_score = llr(pred_x, gt_x, fs, frame_len, overlap) @@ -179,19 +189,19 @@ def pysepm_metric(pred_x, gt_x, fs, frame_len=0.03, overlap=0.75): ncm_score = ncm(pred_x, gt_x, 8000) elif fs < 16000: logging.info("not support fs {}, resample to 8khz".format(fs)) - new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=8000) - new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=8000) - composite_score = composite(pred_x, gt_x, 8000) - ncm_score = ncm(pred_x, gt_x, 8000) + new_gt_x = resample_audio(gt_x, fs, 8000) + new_pred_x = resample_audio(pred_x, fs, 8000) + composite_score = composite(new_pred_x, new_gt_x, 8000) + ncm_score = ncm(new_pred_x, new_gt_x, 8000) elif fs == 16000: composite_score = composite(pred_x, gt_x, 16000) ncm_score = ncm(pred_x, gt_x, 16000) else: logging.info("not support fs {}, resample to 16khz".format(fs)) - new_gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - new_pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - composite_score = composite(pred_x, gt_x, 16000) - ncm_score = ncm(pred_x, gt_x, 16000) + new_gt_x = resample_audio(gt_x, fs, 16000) + new_pred_x = resample_audio(pred_x, fs, 16000) + composite_score = composite(new_pred_x, new_gt_x, 16000) + ncm_score = ncm(new_pred_x, new_gt_x, 16000) csii_score = csii(pred_x, gt_x, fs) @@ -210,9 +220,64 @@ def pysepm_metric(pred_x, gt_x, fs, frame_len=0.03, overlap=0.75): } +class PysepmMetric(BaseMetric): + """Composite pysepm reference-based speech quality metrics.""" + + def _setup(self): + if pysepm is None: + raise ImportError( + "pysepm is not installed. " + "Please use `tools/install_pysepm.sh` to install" + ) + self.frame_len = self.config.get("frame_len", 0.03) + self.overlap = self.config.get("overlap", 0.75) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return pysepm_metric( + np.asarray(predictions), + np.asarray(references), + fs, + frame_len=self.frame_len, + overlap=self.overlap, + ) + + def get_metadata(self): + return _pysepm_metadata() + + +def _pysepm_metadata(): + return MetricMetadata( + name="pysepm", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.DICT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pysepm", "librosa", "numpy"], + description="pysepm composite reference-based speech quality metrics", + implementation_source="https://github.com/schmiph2/pysepm", + ) + + +def register_pysepm_metric(registry): + """Register pysepm metrics with the registry.""" + registry.register( + PysepmMetric, + _pysepm_metadata(), + aliases=["pysepm_metric"], + ) + + if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - score = pysepm_metric(a, b, 16000) + metric = PysepmMetric() + score = metric.compute(a, b, metadata={"sample_rate": 16000}) print(score) diff --git a/versa/utterance_metrics/qwen2_audio.py b/versa/utterance_metrics/qwen2_audio.py index 1e1fc51..3e317ed 100644 --- a/versa/utterance_metrics/qwen2_audio.py +++ b/versa/utterance_metrics/qwen2_audio.py @@ -1,452 +1,547 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -""" -Speech Properties for Metadata Modeling - -This module provides functions for extracting various speech properties -from audio using Qwen2-Audio. The properties are organized into the -following categories: - -1. Speaker Characteristics - - qwen2_speaker_count_metric: Number of distinct speakers - - qwen2_speaker_gender_metric: Gender of speaker(s) - - qwen2_speaker_age_metric: Age group of speaker(s) - - qwen2_speech_impairment_metric: Presence and type of speech disorders - -2. Voice Properties - - qwen2_voice_pitch_metric: Overall pitch level - - qwen2_pitch_range_metric: Variation in intonation - - qwen2_voice_type_metric: Voice texture characteristics - - qwen2_speech_volume_level_metric: Loudness of speech - -3. Speech Content - - qwen2_language_metric: Language(s) being spoken - - qwen2_speech_register_metric: Level of formality in speech - - qwen2_vocabulary_complexity_metric: Sophistication of word choice - - qwen2_speech_purpose_metric: Communicative goal of speech - -4. Speech Delivery - - qwen2_speech_emotion_metric: Emotional state conveyed - - qwen2_speech_clarity_metric: Intelligibility of speech - - qwen2_speech_rate_metric: Speed of delivery - - qwen2_speaking_style_metric: Overall presentation manner - - qwen2_laughter_crying_metric: Presence of emotional vocalizations - -5. Interaction Patterns - - qwen2_overlapping_speech_metric: Degree of simultaneous speech - -6. Recording Environment - - qwen2_speech_background_environment_metric: Setting where recorded - - qwen2_recording_quality_metric: Technical quality of recording - - qwen2_channel_type_metric: Equipment used for recording - -7. Vocal Evaluation - - qwen2_singing_technique_metric: Singing Techniques (styles) - -Each function follows the same signature pattern: - qwen_utils: Dictionary containing model, processor, and conversation - pred_x: Audio signal as numpy array - fs: Sampling rate in Hz (default 16000) - custom_prompt: Optional custom prompt to override default - -Each function returns a dictionary with a single key-value pair where -the key is the metric name prefixed with "qwen_" and the value is the -model's response. -""" - -import copy -import logging -from typing import Dict, Optional, Any, Union - -import librosa -import numpy as np - -try: - from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration -except ImportError: - logging.warning( - "If Qwen2Audio is not found with key error, please install the latest version of transformers and retry." - ) - Qwen2AudioForConditionalGeneration, AutoProcessor = None, None - - -# Default prompts for different metrics -DEFAULT_PROMPTS = { - # Speaker Characteristics - "speaker_count": """Analyze the audio and determine the number of distinct speakers present. -Provide your answer as a single number between 1-10. -Examples: -- For a monologue: 1 -- For an interview with host and guest: 2 -- For a panel discussion with a moderator and three panelists: 4""", - "speaker_gender": """Identify the perceived gender of the speaker(s). -If multiple speakers, list each speaker with their perceived gender. -Choose from: -- Male -- Female -- Non-binary/unclear -- Multiple speakers with mixed genders""", - "speaker_age": """Identify the age group of the speaker. -Choose exactly one label from the following categories: -- Child: under 13 years -- Teen: 13-19 years -- Young adult: 20-35 years -- Middle-aged adult: 36-55 years -- Senior: over 55 years""", - "speech_impairment": """Assess whether there are any noticeable speech impairments or disorders in the speaker's voice. -Choose exactly one category: -- No apparent impairment: typical speech patterns -- Stuttering/disfluency: repetitions, blocks, or prolongations of sounds -- Articulation disorder: difficulty with specific speech sounds -- Voice disorder: abnormal pitch, loudness, or quality -- Fluency disorder: atypical rhythm, rate, or flow of speech -- Foreign accent: non-native pronunciation patterns -- Dysarthria: slurred or unclear speech from muscle weakness -- Apraxia: difficulty with motor planning for speech -- Other impairment: speech pattern that suggests a different disorder""", - # Voice Properties - "voice_pitch": """Analyze the voice pitch/tone of the speaker. -Choose exactly one category from the following: -- Very high: significantly higher than average for their perceived gender -- High: noticeably above average pitch -- Medium: average pitch range -- Low: noticeably below average pitch -- Very low: significantly lower than average for their perceived gender""", - "pitch_range": """Assess the pitch variation/intonation range in the speaker's voice. -Choose exactly one category: -- Wide range: highly expressive with significant variation between high and low tones -- Moderate range: normal variation in pitch during speech -- Narrow range: minimal pitch variation, relatively monotone delivery -- Monotone: almost no pitch variation""", - "voice_type": """Identify the dominant voice quality-related characteristic of the speaker. -Choose exactly one category: -- Clear: clean vocal production without noticeable texture issues -- Breathy: voice has audible breath sounds, less vocal cord closure -- Creaky/vocal fry: low-frequency rattling sound, especially at ends of phrases -- Hoarse: rough, raspy quality indicating vocal strain -- Nasal: voice resonates primarily through the nose -- Pressed/tense: strained quality from excessive vocal cord pressure -- Resonant: rich, vibrant voice with good projection -- Whispered: intentionally quiet with minimal vocal cord vibration -- Tremulous: shaky or quivery voice quality""", - "speech_volume_level": """Assess the overall volume or loudness level of the speaker. -Choose exactly one category: -- Very quiet: barely audible, whispering or very soft-spoken -- Quiet: below average volume, soft-spoken -- Moderate: normal conversational volume -- Loud: above average volume, projecting voice -- Very loud: shouting or extremely high volume -- Variable: significant changes in volume throughout the recording""", - # Speech Content - "language": """Identify all languages spoken in the audio. -List languages using their English names. -Choose from common languages: -- English -- Spanish -- Mandarin Chinese -- Hindi -- Arabic -- French -- Russian -- Portuguese -- German -- Japanese -- Other (specify if possible)""", - "speech_register": """Determine the speech register used by the speaker. -Choose exactly one category: -- Formal register: careful pronunciation, complex grammar, specialized vocabulary -- Standard register: proper grammar and pronunciation for professional or educational contexts -- Consultative register: mixture of formal and casual for everyday professional interactions -- Casual register: relaxed grammar, contractions, colloquialisms for friends/family -- Intimate register: highly familiar language used with close relations -- Technical register: specialized terminology for a specific field or profession -- Slang register: highly informal with group-specific vocabulary""", - "vocabulary_complexity": """Evaluate the vocabulary complexity level in the speech. -Choose exactly one category: -- Basic: simple, everyday vocabulary, mostly high-frequency words -- General: standard vocabulary for common topics, occasional advanced words -- Advanced: sophisticated vocabulary with specific terminology -- Technical: specialized/domain-specific terminology -- Academic: scholarly vocabulary with abstract concepts""", - "speech_purpose": """Identify the primary purpose of the speech. -Choose one category: -- Informative: primarily explains or educates -- Persuasive: attempts to convince or change opinions -- Entertainment: primarily aims to amuse or entertain -- Narrative: tells a story or relates events -- Conversational: casual exchange of information -- Instructional: provides specific directions or guidance -- Emotional expression: primarily conveys feelings or emotional state""", - # Speech Delivery - "speech_emotion": """Identify the dominant emotion expressed in this speech. -Choose exactly one label from the following categories: -- Neutral: even-toned, matter-of-fact delivery with minimal emotional expression -- Happy: upbeat, positive, enthusiastic tone -- Sad: downcast, melancholic, somber tone -- Angry: irritated, frustrated, hostile tone -- Fearful: anxious, worried, frightened tone -- Surprised: astonished, shocked tone -- Disgusted: repulsed, revolted tone -- Other: other emotion that cannot be classified by above classes""", - "speech_clarity": """Rate the overall clarity and intelligibility of the speech. -Choose one category: -- High clarity: perfectly intelligible, professional quality -- Medium clarity: generally understandable with occasional unclear segments -- Low clarity: difficult to understand, frequent unclear segments -- Very low clarity: mostly unintelligible""", - "speech_rate": """Assess the rate of speech in the audio. -Choose one category: -- Very slow: deliberate, significantly slower than average speech -- Slow: relaxed pace, slower than conversational speech -- Medium: average conversational pace -- Fast: quicker than average conversational speech -- Very fast: rapid delivery, difficult to follow""", - "speaking_style": """Identify the predominant speaking style of the speaker. -Choose exactly one category: -- Formal: structured, proper, adherence to linguistic conventions -- Professional: clear, efficient communication focused on task/topic -- Casual/conversational: relaxed, everyday speech -- Animated/enthusiastic: highly energetic, expressive speech -- Deliberate: careful, measured delivery -- Dramatic: theatrical, performance-oriented speech -- Authoritative: commanding, confident tone -- Hesitant: uncertain, tentative speech with pauses""", - "laughter_crying": """Identify if there is laughter, crying, or other emotional vocalizations in the audio. -Choose exactly one category: -- No laughter or crying: speech only -- Contains laughter: audible laughter is present -- Contains crying: audible crying or sobbing is present -- Contains both: both laughter and crying are present -- Contains other emotional sounds: sighs, gasps, etc. -- Contains multiple emotional vocalizations: combination of various emotional sounds""", - # Interaction Patterns - "overlapping_speech": """Determine if there is overlapping speech in the audio (people talking simultaneously). -Choose exactly one category: -- No overlap: clean turn-taking with no simultaneous speech -- Minimal overlap: occasional brief instances of overlapping speech -- Moderate overlap: noticeable instances where speakers talk over each other -- Significant overlap: frequent overlapping speech, making it difficult to follow -- Constant overlap: multiple speakers talking simultaneously throughout most of the audio""", - # Recording Environment - "speech_background_environment": """Identify the dominant background environment or setting. -Choose one category: -- Quiet indoor: minimal background noise, likely studio environment -- Noisy indoor: indoor setting with noticeable background sounds (cafe, office) -- Outdoor urban: city sounds, traffic -- Outdoor natural: nature sounds, birds, wind, water -- Event/crowd: audience sounds, applause, crowd noise -- Music background: music playing behind speech -- Multiple environments: changes throughout recording""", - "recording_quality": """Assess the technical quality of the audio recording. -Choose one category: -- Professional: studio-quality, broadcast standard -- Good: clear recording with minimal issues -- Fair: noticeable recording artifacts but generally clear -- Poor: significant recording issues affecting comprehension -- Very poor: severe technical problems making content difficult to understand""", - "channel_type": """Identify the likely recording channel or device type used to record this audio. -Choose exactly one category: -- Professional microphone: high-quality, full-range audio -- Consumer microphone: decent quality but less clarity than professional -- Smartphone: typical mobile phone recording quality -- Telephone/VoIP: limited frequency range, compression artifacts -- Webcam/computer mic: variable quality, often with computer fan noise -- Headset microphone: close to mouth, may have breathing sounds -- Distant microphone: recorded from a distance, may have room echo -- Radio/broadcast: compressed audio with limited frequency range -- Surveillance/hidden mic: typically lower quality with background noise""", - # Vocal Evaluation - "singing_technique": """You are an expert in vocal performance and singing technique. -Given the following audio clip of a singing voice, your task is to identify the predominant singing style used. -Choose one of the following seven styles based on the vocal characteristics: - -Breathy: Light, airy voice with noticeable breathiness. -Falsetto: High-pitched, flute-like sound, especially for male voices. -Mixed Voice: A blend of chest and head voice, balanced resonance. -Pharyngeal: Focused, twangy tone with forward placement in the pharynx. -Glissando: Smooth, sliding transitions between notes. -Vibrato: Regular, pulsating pitch variation while sustaining a note. -Control: A neutral, well-supported tone without stylistic effects. - -Carefully listen to the tone quality, pitch control, resonance, and transitions in the audio. -Then, output only the predicted singing style from the list above. -""", -} - - -def qwen2_model_setup( - model_tag: str = "Qwen/Qwen2-Audio-7B-Instruct", - start_prompt: str = "The following is a conversation with an AI assistant. The assistant is helpful, honest, and harmless.", -) -> Dict[str, Any]: - """Set up the Qwen2-Audio model for speech analysis. - - Args: - model_tag: Model identifier for Qwen2-Audio, defaults to Qwen2-Audio-7B-Instruct - start_prompt: Initial system prompt for the model conversation - - Returns: - Dictionary containing model, processor, and conversation starter - """ - if model_tag == "default": - model_tag = "Qwen/Qwen2-Audio-7B-Instruct" - if Qwen2AudioForConditionalGeneration is None or AutoProcessor is None: - raise RuntimeError( - "Qwen2Audio is used for evaluation while transformers is not installed (could be a version issue)." - ) - processor = AutoProcessor.from_pretrained(model_tag) - model = Qwen2AudioForConditionalGeneration.from_pretrained( - model_tag, device_map="auto" - ) - - start_conversation = [ - {"role": "system", "content": start_prompt}, - ] - return { - "model": model, - "processor": processor, - "start_conversation": start_conversation, - } - - -def qwen2_base_metric( - qwen_utils: Dict[str, Any], - pred_x: np.ndarray, - fs: int = 16000, - custom_prompt: Optional[str] = None, - max_length: int = 1000, -) -> str: - """Calculate the base metric from Qwen2Audio results. - - Args: - qwen_utils: A utility dict for Qwen2Audio calculation. - including: Qwen2Audio model ("model"), processor ("processor"), and start conversation ("start_conversation") - pred_x: Test signal (time,) - fs: Sampling rate in Hz - custom_prompt: Custom prompt for the model - max_length: Maximum length for model generation - - Returns: - Model's response as a string - """ - if custom_prompt is None: - raise ValueError("Custom prompt must be provided for the qwen2-audio model.") - - conversation = copy.deepcopy(qwen_utils["start_conversation"]) - processor = qwen_utils["processor"] - model = qwen_utils["model"] - - conversation.append( - { - "role": "user", - "content": [ - {"type": "audio", "audio_url": None}, - {"type": "text", "text": custom_prompt}, - ], - } - ) - - text = processor.apply_chat_template( - conversation, add_generation_prompt=True, tokenize=False - ) - audio = [ - librosa.resample( - pred_x, orig_sr=fs, target_sr=processor.feature_extractor.sampling_rate - ) - ] - - inputs = processor(text=text, audios=audio, return_tensors="pt", padding=True) - for key in inputs.keys(): - inputs[key] = inputs[key].to(model.device) - - generate_ids = model.generate(**inputs, max_length=max_length) - generate_ids = generate_ids[:, inputs["input_ids"].size(1) :] - response = processor.batch_decode( - generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True - )[0] - return response - - -def create_metric_fn(metric_name: str) -> callable: - """Factory function to create metric functions. - - Args: - metric_name: Name of the metric to create a function for - - Returns: - Function that calculates the specified metric - """ - - def metric_fn( - qwen_utils: Dict[str, Any], - pred_x: np.ndarray, - fs: int = 16000, - custom_prompt: Optional[str] = None, - ) -> Dict[str, str]: - """Calculate the specified metric from Qwen2Audio results. - - Args: - qwen_utils: A utility dict for Qwen2Audio calculation - pred_x: Test signal (time,) - fs: Sampling rate in Hz - custom_prompt: Custom prompt for the model - - Returns: - Dictionary containing the metric result - """ - if custom_prompt is None: - custom_prompt = DEFAULT_PROMPTS.get(metric_name) - if custom_prompt is None: - raise ValueError(f"No default prompt found for metric: {metric_name}") - - response = qwen2_base_metric(qwen_utils, pred_x, fs, custom_prompt) - return {f"qwen_{metric_name}": response} - - return metric_fn - - -# Create metric functions for all categories -# 1. Speaker Characteristics -qwen2_speaker_count_metric = create_metric_fn("speaker_count") -qwen2_speaker_gender_metric = create_metric_fn("speaker_gender") -qwen2_speaker_age_metric = create_metric_fn("speaker_age") -qwen2_speech_impairment_metric = create_metric_fn("speech_impairment") - -# 2. Voice Properties -qwen2_voice_pitch_metric = create_metric_fn("voice_pitch") -qwen2_pitch_range_metric = create_metric_fn("pitch_range") -qwen2_voice_type_metric = create_metric_fn("voice_type") -qwen2_speech_volume_level_metric = create_metric_fn("speech_volume_level") - -# 3. Speech Content -qwen2_language_metric = create_metric_fn("language") -qwen2_speech_register_metric = create_metric_fn("speech_register") -qwen2_vocabulary_complexity_metric = create_metric_fn("vocabulary_complexity") -qwen2_speech_purpose_metric = create_metric_fn("speech_purpose") - -# 4. Speech Delivery -qwen2_speech_emotion_metric = create_metric_fn("speech_emotion") -qwen2_speech_clarity_metric = create_metric_fn("speech_clarity") -qwen2_speech_rate_metric = create_metric_fn("speech_rate") -qwen2_speaking_style_metric = create_metric_fn("speaking_style") -qwen2_laughter_crying_metric = create_metric_fn("laughter_crying") - -# 5. Interaction Patterns -qwen2_overlapping_speech_metric = create_metric_fn("overlapping_speech") - -# 6. Recording Environment -qwen2_speech_background_environment_metric = create_metric_fn( - "speech_background_environment" -) -qwen2_recording_quality_metric = create_metric_fn("recording_quality") -qwen2_channel_type_metric = create_metric_fn("channel_type") - -qwen2_singing_technique_metric = create_metric_fn("singing_technique") - -if __name__ == "__main__": - a = np.random.random(16000) - qwen_utils = qwen2_model_setup() - # print("metrics: {}".format(qwen2_speaker_age_metric(qwen_utils, a, 16000))) - print("metrics: {}".format(qwen2_speech_emotion_metric(qwen_utils, a, 16000))) +#!/usr/bin/env python3 + +# Copyright 2025 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +# flake8: noqa: E501 + +""" +Speech Properties for Metadata Modeling + +This module provides functions for extracting various speech properties +from audio using Qwen2-Audio. The properties are organized into the +following categories: + +1. Speaker Characteristics + - qwen2_speaker_count_metric: Number of distinct speakers + - qwen2_speaker_gender_metric: Gender of speaker(s) + - qwen2_speaker_age_metric: Age group of speaker(s) + - qwen2_speech_impairment_metric: Presence and type of speech disorders + +2. Voice Properties + - qwen2_voice_pitch_metric: Overall pitch level + - qwen2_pitch_range_metric: Variation in intonation + - qwen2_voice_type_metric: Voice texture characteristics + - qwen2_speech_volume_level_metric: Loudness of speech + +3. Speech Content + - qwen2_language_metric: Language(s) being spoken + - qwen2_speech_register_metric: Level of formality in speech + - qwen2_vocabulary_complexity_metric: Sophistication of word choice + - qwen2_speech_purpose_metric: Communicative goal of speech + +4. Speech Delivery + - qwen2_speech_emotion_metric: Emotional state conveyed + - qwen2_speech_clarity_metric: Intelligibility of speech + - qwen2_speech_rate_metric: Speed of delivery + - qwen2_speaking_style_metric: Overall presentation manner + - qwen2_laughter_crying_metric: Presence of emotional vocalizations + +5. Interaction Patterns + - qwen2_overlapping_speech_metric: Degree of simultaneous speech + +6. Recording Environment + - qwen2_speech_background_environment_metric: Setting where recorded + - qwen2_recording_quality_metric: Technical quality of recording + - qwen2_channel_type_metric: Equipment used for recording + +7. Vocal Evaluation + - qwen2_singing_technique_metric: Singing Techniques (styles) + +Each function follows the same signature pattern: + qwen_utils: Dictionary containing model, processor, and conversation + pred_x: Audio signal as numpy array + fs: Sampling rate in Hz (default 16000) + custom_prompt: Optional custom prompt to override default + +Each function returns a dictionary with a single key-value pair where +the key is the metric name prefixed with "qwen_" and the value is the +model's response. +""" + +import copy +import logging +from typing import Dict, Optional, Any + +import numpy as np + +from versa.audio_utils import resample_audio + +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +try: + from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration +except ImportError: + logging.warning( + "If Qwen2Audio is not found with key error, please install the latest version of transformers and retry." + ) + Qwen2AudioForConditionalGeneration, AutoProcessor = None, None + + +# Default prompts for different metrics +DEFAULT_PROMPTS = { + # Speaker Characteristics + "speaker_count": """Analyze the audio and determine the number of distinct speakers present. +Provide your answer as a single number between 1-10. +Examples: +- For a monologue: 1 +- For an interview with host and guest: 2 +- For a panel discussion with a moderator and three panelists: 4""", + "speaker_gender": """Identify the perceived gender of the speaker(s). +If multiple speakers, list each speaker with their perceived gender. +Choose from: +- Male +- Female +- Non-binary/unclear +- Multiple speakers with mixed genders""", + "speaker_age": """Identify the age group of the speaker. +Choose exactly one label from the following categories: +- Child: under 13 years +- Teen: 13-19 years +- Young adult: 20-35 years +- Middle-aged adult: 36-55 years +- Senior: over 55 years""", + "speech_impairment": """Assess whether there are any noticeable speech impairments or disorders in the speaker's voice. +Choose exactly one category: +- No apparent impairment: typical speech patterns +- Stuttering/disfluency: repetitions, blocks, or prolongations of sounds +- Articulation disorder: difficulty with specific speech sounds +- Voice disorder: abnormal pitch, loudness, or quality +- Fluency disorder: atypical rhythm, rate, or flow of speech +- Foreign accent: non-native pronunciation patterns +- Dysarthria: slurred or unclear speech from muscle weakness +- Apraxia: difficulty with motor planning for speech +- Other impairment: speech pattern that suggests a different disorder""", + # Voice Properties + "voice_pitch": """Analyze the voice pitch/tone of the speaker. +Choose exactly one category from the following: +- Very high: significantly higher than average for their perceived gender +- High: noticeably above average pitch +- Medium: average pitch range +- Low: noticeably below average pitch +- Very low: significantly lower than average for their perceived gender""", + "pitch_range": """Assess the pitch variation/intonation range in the speaker's voice. +Choose exactly one category: +- Wide range: highly expressive with significant variation between high and low tones +- Moderate range: normal variation in pitch during speech +- Narrow range: minimal pitch variation, relatively monotone delivery +- Monotone: almost no pitch variation""", + "voice_type": """Identify the dominant voice quality-related characteristic of the speaker. +Choose exactly one category: +- Clear: clean vocal production without noticeable texture issues +- Breathy: voice has audible breath sounds, less vocal cord closure +- Creaky/vocal fry: low-frequency rattling sound, especially at ends of phrases +- Hoarse: rough, raspy quality indicating vocal strain +- Nasal: voice resonates primarily through the nose +- Pressed/tense: strained quality from excessive vocal cord pressure +- Resonant: rich, vibrant voice with good projection +- Whispered: intentionally quiet with minimal vocal cord vibration +- Tremulous: shaky or quivery voice quality""", + "speech_volume_level": """Assess the overall volume or loudness level of the speaker. +Choose exactly one category: +- Very quiet: barely audible, whispering or very soft-spoken +- Quiet: below average volume, soft-spoken +- Moderate: normal conversational volume +- Loud: above average volume, projecting voice +- Very loud: shouting or extremely high volume +- Variable: significant changes in volume throughout the recording""", + # Speech Content + "language": """Identify all languages spoken in the audio. +List languages using their English names. +Choose from common languages: +- English +- Spanish +- Mandarin Chinese +- Hindi +- Arabic +- French +- Russian +- Portuguese +- German +- Japanese +- Other (specify if possible)""", + "speech_register": """Determine the speech register used by the speaker. +Choose exactly one category: +- Formal register: careful pronunciation, complex grammar, specialized vocabulary +- Standard register: proper grammar and pronunciation for professional or educational contexts +- Consultative register: mixture of formal and casual for everyday professional interactions +- Casual register: relaxed grammar, contractions, colloquialisms for friends/family +- Intimate register: highly familiar language used with close relations +- Technical register: specialized terminology for a specific field or profession +- Slang register: highly informal with group-specific vocabulary""", + "vocabulary_complexity": """Evaluate the vocabulary complexity level in the speech. +Choose exactly one category: +- Basic: simple, everyday vocabulary, mostly high-frequency words +- General: standard vocabulary for common topics, occasional advanced words +- Advanced: sophisticated vocabulary with specific terminology +- Technical: specialized/domain-specific terminology +- Academic: scholarly vocabulary with abstract concepts""", + "speech_purpose": """Identify the primary purpose of the speech. +Choose one category: +- Informative: primarily explains or educates +- Persuasive: attempts to convince or change opinions +- Entertainment: primarily aims to amuse or entertain +- Narrative: tells a story or relates events +- Conversational: casual exchange of information +- Instructional: provides specific directions or guidance +- Emotional expression: primarily conveys feelings or emotional state""", + # Speech Delivery + "speech_emotion": """Identify the dominant emotion expressed in this speech. +Choose exactly one label from the following categories: +- Neutral: even-toned, matter-of-fact delivery with minimal emotional expression +- Happy: upbeat, positive, enthusiastic tone +- Sad: downcast, melancholic, somber tone +- Angry: irritated, frustrated, hostile tone +- Fearful: anxious, worried, frightened tone +- Surprised: astonished, shocked tone +- Disgusted: repulsed, revolted tone +- Other: other emotion that cannot be classified by above classes""", + "speech_clarity": """Rate the overall clarity and intelligibility of the speech. +Choose one category: +- High clarity: perfectly intelligible, professional quality +- Medium clarity: generally understandable with occasional unclear segments +- Low clarity: difficult to understand, frequent unclear segments +- Very low clarity: mostly unintelligible""", + "speech_rate": """Assess the rate of speech in the audio. +Choose one category: +- Very slow: deliberate, significantly slower than average speech +- Slow: relaxed pace, slower than conversational speech +- Medium: average conversational pace +- Fast: quicker than average conversational speech +- Very fast: rapid delivery, difficult to follow""", + "speaking_style": """Identify the predominant speaking style of the speaker. +Choose exactly one category: +- Formal: structured, proper, adherence to linguistic conventions +- Professional: clear, efficient communication focused on task/topic +- Casual/conversational: relaxed, everyday speech +- Animated/enthusiastic: highly energetic, expressive speech +- Deliberate: careful, measured delivery +- Dramatic: theatrical, performance-oriented speech +- Authoritative: commanding, confident tone +- Hesitant: uncertain, tentative speech with pauses""", + "laughter_crying": """Identify if there is laughter, crying, or other emotional vocalizations in the audio. +Choose exactly one category: +- No laughter or crying: speech only +- Contains laughter: audible laughter is present +- Contains crying: audible crying or sobbing is present +- Contains both: both laughter and crying are present +- Contains other emotional sounds: sighs, gasps, etc. +- Contains multiple emotional vocalizations: combination of various emotional sounds""", + # Interaction Patterns + "overlapping_speech": """Determine if there is overlapping speech in the audio (people talking simultaneously). +Choose exactly one category: +- No overlap: clean turn-taking with no simultaneous speech +- Minimal overlap: occasional brief instances of overlapping speech +- Moderate overlap: noticeable instances where speakers talk over each other +- Significant overlap: frequent overlapping speech, making it difficult to follow +- Constant overlap: multiple speakers talking simultaneously throughout most of the audio""", + # Recording Environment + "speech_background_environment": """Identify the dominant background environment or setting. +Choose one category: +- Quiet indoor: minimal background noise, likely studio environment +- Noisy indoor: indoor setting with noticeable background sounds (cafe, office) +- Outdoor urban: city sounds, traffic +- Outdoor natural: nature sounds, birds, wind, water +- Event/crowd: audience sounds, applause, crowd noise +- Music background: music playing behind speech +- Multiple environments: changes throughout recording""", + "recording_quality": """Assess the technical quality of the audio recording. +Choose one category: +- Professional: studio-quality, broadcast standard +- Good: clear recording with minimal issues +- Fair: noticeable recording artifacts but generally clear +- Poor: significant recording issues affecting comprehension +- Very poor: severe technical problems making content difficult to understand""", + "channel_type": """Identify the likely recording channel or device type used to record this audio. +Choose exactly one category: +- Professional microphone: high-quality, full-range audio +- Consumer microphone: decent quality but less clarity than professional +- Smartphone: typical mobile phone recording quality +- Telephone/VoIP: limited frequency range, compression artifacts +- Webcam/computer mic: variable quality, often with computer fan noise +- Headset microphone: close to mouth, may have breathing sounds +- Distant microphone: recorded from a distance, may have room echo +- Radio/broadcast: compressed audio with limited frequency range +- Surveillance/hidden mic: typically lower quality with background noise""", + # Vocal Evaluation + "singing_technique": """You are an expert in vocal performance and singing technique. +Given the following audio clip of a singing voice, your task is to identify the predominant singing style used. +Choose one of the following seven styles based on the vocal characteristics: + +Breathy: Light, airy voice with noticeable breathiness. +Falsetto: High-pitched, flute-like sound, especially for male voices. +Mixed Voice: A blend of chest and head voice, balanced resonance. +Pharyngeal: Focused, twangy tone with forward placement in the pharynx. +Glissando: Smooth, sliding transitions between notes. +Vibrato: Regular, pulsating pitch variation while sustaining a note. +Control: A neutral, well-supported tone without stylistic effects. + +Carefully listen to the tone quality, pitch control, resonance, and transitions in the audio. +Then, output only the predicted singing style from the list above. +""", +} + + +def qwen2_model_setup( + model_tag: str = "Qwen/Qwen2-Audio-7B-Instruct", + start_prompt: str = "The following is a conversation with an AI assistant. The assistant is helpful, honest, and harmless.", +) -> Dict[str, Any]: + """Set up the Qwen2-Audio model for speech analysis. + + Args: + model_tag: Model identifier for Qwen2-Audio, defaults to Qwen2-Audio-7B-Instruct + start_prompt: Initial system prompt for the model conversation + + Returns: + Dictionary containing model, processor, and conversation starter + """ + if model_tag == "default": + model_tag = "Qwen/Qwen2-Audio-7B-Instruct" + if Qwen2AudioForConditionalGeneration is None or AutoProcessor is None: + raise RuntimeError( + "Qwen2Audio is used for evaluation while transformers is not installed (could be a version issue)." + ) + processor = AutoProcessor.from_pretrained(model_tag) + model = Qwen2AudioForConditionalGeneration.from_pretrained( + model_tag, device_map="auto" + ) + + start_conversation = [ + {"role": "system", "content": start_prompt}, + ] + return { + "model": model, + "processor": processor, + "start_conversation": start_conversation, + } + + +def qwen2_base_metric( + qwen_utils: Dict[str, Any], + pred_x: np.ndarray, + fs: int = 16000, + custom_prompt: Optional[str] = None, + max_length: int = 1000, +) -> str: + """Calculate the base metric from Qwen2Audio results. + + Args: + qwen_utils: A utility dict for Qwen2Audio calculation. + including: Qwen2Audio model ("model"), processor ("processor"), and start conversation ("start_conversation") + pred_x: Test signal (time,) + fs: Sampling rate in Hz + custom_prompt: Custom prompt for the model + max_length: Maximum length for model generation + + Returns: + Model's response as a string + """ + if custom_prompt is None: + raise ValueError("Custom prompt must be provided for the qwen2-audio model.") + + conversation = copy.deepcopy(qwen_utils["start_conversation"]) + processor = qwen_utils["processor"] + model = qwen_utils["model"] + + conversation.append( + { + "role": "user", + "content": [ + {"type": "audio", "audio_url": None}, + {"type": "text", "text": custom_prompt}, + ], + } + ) + + text = processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + audio = [resample_audio(pred_x, fs, processor.feature_extractor.sampling_rate)] + + inputs = processor(text=text, audios=audio, return_tensors="pt", padding=True) + for key in inputs.keys(): + inputs[key] = inputs[key].to(model.device) + + generate_ids = model.generate(**inputs, max_length=max_length) + generate_ids = generate_ids[:, inputs["input_ids"].size(1) :] + response = processor.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + )[0] + return response + + +def create_metric_fn(metric_name: str) -> callable: + """Factory function to create metric functions. + + Args: + metric_name: Name of the metric to create a function for + + Returns: + Function that calculates the specified metric + """ + + def metric_fn( + qwen_utils: Dict[str, Any], + pred_x: np.ndarray, + fs: int = 16000, + custom_prompt: Optional[str] = None, + ) -> Dict[str, str]: + """Calculate the specified metric from Qwen2Audio results. + + Args: + qwen_utils: A utility dict for Qwen2Audio calculation + pred_x: Test signal (time,) + fs: Sampling rate in Hz + custom_prompt: Custom prompt for the model + + Returns: + Dictionary containing the metric result + """ + if custom_prompt is None: + custom_prompt = DEFAULT_PROMPTS.get(metric_name) + if custom_prompt is None: + raise ValueError(f"No default prompt found for metric: {metric_name}") + + response = qwen2_base_metric(qwen_utils, pred_x, fs, custom_prompt) + return {f"qwen_{metric_name}": response} + + return metric_fn + + +# Create metric functions for all categories +# 1. Speaker Characteristics +qwen2_speaker_count_metric = create_metric_fn("speaker_count") +qwen2_speaker_gender_metric = create_metric_fn("speaker_gender") +qwen2_speaker_age_metric = create_metric_fn("speaker_age") +qwen2_speech_impairment_metric = create_metric_fn("speech_impairment") + +# 2. Voice Properties +qwen2_voice_pitch_metric = create_metric_fn("voice_pitch") +qwen2_pitch_range_metric = create_metric_fn("pitch_range") +qwen2_voice_type_metric = create_metric_fn("voice_type") +qwen2_speech_volume_level_metric = create_metric_fn("speech_volume_level") + +# 3. Speech Content +qwen2_language_metric = create_metric_fn("language") +qwen2_speech_register_metric = create_metric_fn("speech_register") +qwen2_vocabulary_complexity_metric = create_metric_fn("vocabulary_complexity") +qwen2_speech_purpose_metric = create_metric_fn("speech_purpose") + +# 4. Speech Delivery +qwen2_speech_emotion_metric = create_metric_fn("speech_emotion") +qwen2_speech_clarity_metric = create_metric_fn("speech_clarity") +qwen2_speech_rate_metric = create_metric_fn("speech_rate") +qwen2_speaking_style_metric = create_metric_fn("speaking_style") +qwen2_laughter_crying_metric = create_metric_fn("laughter_crying") + +# 5. Interaction Patterns +qwen2_overlapping_speech_metric = create_metric_fn("overlapping_speech") + +# 6. Recording Environment +qwen2_speech_background_environment_metric = create_metric_fn( + "speech_background_environment" +) +qwen2_recording_quality_metric = create_metric_fn("recording_quality") +qwen2_channel_type_metric = create_metric_fn("channel_type") + +qwen2_singing_technique_metric = create_metric_fn("singing_technique") + + +class Qwen2AudioMetric(BaseMetric): + """Speech property extraction with Qwen2-Audio.""" + + metric_name = None + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.start_prompt = self.config.get( + "start_prompt", + ( + "The following is a conversation with an AI assistant. " + "The assistant is helpful, honest, and harmless." + ), + ) + self.metric_name = self.config.get("metric_name", self.metric_name) + if self.metric_name is None: + raise ValueError("metric_name must be provided") + self.prompt = self.config.get("prompt") + self.max_length = self.config.get("max_length", 1000) + self.qwen_utils = qwen2_model_setup( + model_tag=self.model_tag, + start_prompt=self.start_prompt, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + prompt = self.prompt or DEFAULT_PROMPTS.get(self.metric_name) + response = qwen2_base_metric( + self.qwen_utils, + np.asarray(predictions), + fs=fs, + custom_prompt=prompt, + max_length=self.max_length, + ) + return {f"qwen_{self.metric_name}": response} + + def get_metadata(self): + return _qwen2_audio_metadata(self.registry_name()) + + @classmethod + def registry_name(cls): + return f"qwen2_audio_{cls.metric_name}" if cls.metric_name else "qwen2_audio" + + +def _qwen2_audio_metadata(name): + return MetricMetadata( + name=name, + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.STRING, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["transformers", "librosa", "numpy"], + description="Speech property extraction with Qwen2-Audio", + paper_reference="https://arxiv.org/abs/2407.10759", + implementation_source="https://github.com/QwenLM/Qwen2-Audio", + ) + + +def _make_qwen2_metric_class(metric_name): + class _SpecificQwen2AudioMetric(Qwen2AudioMetric): + pass + + _SpecificQwen2AudioMetric.metric_name = metric_name + class_name = "".join(part.title() for part in metric_name.split("_")) + _SpecificQwen2AudioMetric.__name__ = f"Qwen2Audio{class_name}Metric" + return _SpecificQwen2AudioMetric + + +QWEN2_AUDIO_METRIC_CLASSES = { + metric_name: _make_qwen2_metric_class(metric_name) + for metric_name in DEFAULT_PROMPTS.keys() +} + + +def register_qwen2_audio_metric(registry): + """Register Qwen2-Audio speech property metrics with the registry.""" + for metric_name, metric_class in QWEN2_AUDIO_METRIC_CLASSES.items(): + registry_name = f"qwen2_audio_{metric_name}" + registry.register( + metric_class, + _qwen2_audio_metadata(registry_name), + aliases=[ + f"qwen2_{metric_name}_metric", + f"qwen_{metric_name}", + ], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + qwen_utils = qwen2_model_setup() + # print("metrics: {}".format(qwen2_speaker_age_metric(qwen_utils, a, 16000))) + print("metrics: {}".format(qwen2_speech_emotion_metric(qwen_utils, a, 16000))) diff --git a/versa/utterance_metrics/qwen_omni.py b/versa/utterance_metrics/qwen_omni.py index 67d2dc5..42d724f 100644 --- a/versa/utterance_metrics/qwen_omni.py +++ b/versa/utterance_metrics/qwen_omni.py @@ -3,6 +3,8 @@ # Copyright 2025 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# flake8: noqa: E501 + """ Speech Properties for Metadata Modeling @@ -59,11 +61,12 @@ import copy import logging -import torch -from typing import Dict, Optional, Any, Union +from typing import Dict, Optional, Any -import librosa import numpy as np +import torch + +from versa.audio_utils import resample_audio try: from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor @@ -73,12 +76,15 @@ ) Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor = None, None +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType from versa.utterance_metrics.qwen2_audio import DEFAULT_PROMPTS def qwen_omni_model_setup( model_tag: str = "Qwen/Qwen2-Audio-7B-Instruct", start_prompt: str = "The following is a conversation with an AI assistant. The assistant is helpful, honest, and harmless.", + use_gpu: bool = True, + device_map: Optional[str] = None, ) -> Dict[str, Any]: """Set up the Qwen2-Audio model for speech analysis. @@ -95,14 +101,17 @@ def qwen_omni_model_setup( raise RuntimeError( "qwen2_5_omni is used for evaluation while transformers is not installed (could be a version issue)." ) + target_device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" + device_map = device_map or target_device processor = Qwen2_5OmniProcessor.from_pretrained(model_tag) model = Qwen2_5OmniForConditionalGeneration.from_pretrained( model_tag, torch_dtype="auto", - device_map="cuda", + device_map=device_map, # attn_implementation="flash_attention_2", NOTE(jiatong): to add ) - model.to("cuda") + if device_map in {None, "cpu", "cuda"}: + model.to(target_device) start_conversation = [ {"role": "system", "content": [{"type": "text", "text": start_prompt}]} ] @@ -153,11 +162,7 @@ def qwen_omni_base_metric( text = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False ) - audio = [ - librosa.resample( - pred_x, orig_sr=fs, target_sr=processor.feature_extractor.sampling_rate - ) - ] + audio = [resample_audio(pred_x, fs, processor.feature_extractor.sampling_rate)] inputs = processor(text=text, audio=audio, return_tensors="pt", padding=True) inputs = inputs.to(model.device).to(model.dtype) @@ -253,6 +258,101 @@ def metric_fn( qwen_omni_singing_technique_metric = create_metric_fn("singing_technique") + +class QwenOmniMetric(BaseMetric): + """Speech property extraction with Qwen2.5-Omni.""" + + metric_name = None + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.start_prompt = self.config.get( + "start_prompt", + ( + "The following is a conversation with an AI assistant. " + "The assistant is helpful, honest, and harmless." + ), + ) + self.metric_name = self.config.get("metric_name", self.metric_name) + if self.metric_name is None: + raise ValueError("metric_name must be provided") + self.prompt = self.config.get("prompt") + self.max_length = self.config.get("max_length", 500) + self.use_gpu = self.config.get("use_gpu", True) + self.device_map = self.config.get("device_map") + self.qwen_utils = qwen_omni_model_setup( + model_tag=self.model_tag, + start_prompt=self.start_prompt, + use_gpu=self.use_gpu, + device_map=self.device_map, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + prompt = self.prompt or DEFAULT_PROMPTS.get(self.metric_name) + response = qwen_omni_base_metric( + self.qwen_utils, + np.asarray(predictions), + fs=fs, + custom_prompt=prompt, + max_length=self.max_length, + ) + return {f"qwen_omni_{self.metric_name}": response} + + def get_metadata(self): + return _qwen_omni_metadata(self.registry_name()) + + @classmethod + def registry_name(cls): + return f"qwen_omni_{cls.metric_name}" if cls.metric_name else "qwen_omni" + + +def _qwen_omni_metadata(name): + return MetricMetadata( + name=name, + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.STRING, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["transformers", "librosa", "numpy", "torch"], + description="Speech property extraction with Qwen2.5-Omni", + paper_reference="https://arxiv.org/abs/2503.20215", + implementation_source="https://github.com/QwenLM/Qwen2.5-Omni", + ) + + +def _make_qwen_omni_metric_class(metric_name): + class _SpecificQwenOmniMetric(QwenOmniMetric): + pass + + _SpecificQwenOmniMetric.metric_name = metric_name + class_name = "".join(part.title() for part in metric_name.split("_")) + _SpecificQwenOmniMetric.__name__ = f"QwenOmni{class_name}Metric" + return _SpecificQwenOmniMetric + + +QWEN_OMNI_METRIC_CLASSES = { + metric_name: _make_qwen_omni_metric_class(metric_name) + for metric_name in DEFAULT_PROMPTS.keys() +} + + +def register_qwen_omni_metric(registry): + """Register Qwen2.5-Omni speech property metrics with the registry.""" + for metric_name, metric_class in QWEN_OMNI_METRIC_CLASSES.items(): + registry_name = f"qwen_omni_{metric_name}" + registry.register( + metric_class, + _qwen_omni_metadata(registry_name), + aliases=[f"qwen_omni_{metric_name}_metric"], + ) + + if __name__ == "__main__": a = np.random.random(16000) qwen_utils = qwen_omni_model_setup() diff --git a/versa/utterance_metrics/scoreq.py b/versa/utterance_metrics/scoreq.py index e55fc99..e7cddcb 100644 --- a/versa/utterance_metrics/scoreq.py +++ b/versa/utterance_metrics/scoreq.py @@ -1,89 +1,239 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging - -logger = logging.getLogger(__name__) - -import librosa -import numpy as np -import torch - -try: - from scoreq_versa import Scoreq -except ImportError: - logger.info( - "scoreq is not installed. Please use `tools/install_scoreq.sh` to install" - ) - Scoreq = None - - -def scoreq_nr_setup( - data_domain="synthetic", - cache_dir="versa_cache/scoreq_pt-models", - use_gpu=False, -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - - if Scoreq is None: - raise ModuleNotFoundError( - "scoreq is not installed. Please use `tools/install_scoreq.sh` to install" - ) - - return Scoreq( - data_domain=data_domain, mode="nr", cache_dir=cache_dir, device=device - ) - - -def scoreq_ref_setup( - data_domain="synthetic", - cache_dir="./scoreq_pt-models", - use_gpu=False, -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - - if Scoreq is None: - raise ModuleNotFoundError( - "scoreq is not installed. Please use `tools/install_scoreq.sh` to install" - ) - - return Scoreq( - data_domain=data_domain, mode="ref", cache_dir=cache_dir, device=device - ) - - -def scoreq_nr(model, pred_x, fs): - # NOTE(jiatong): current model only have 16k options - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - return {"scoreq_nr": model.predict(test_path=pred_x, ref_path=None)} - - -def scoreq_ref(model, pred_x, gt_x, fs): - # NOTE(jiatong): current model only have 16k options - if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - return {"scoreq_ref": model.predict(test_path=pred_x, ref_path=gt_x)} - - -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - nr_model = scoreq_nr_setup(use_gpu=True) - ref_model = scoreq_ref_setup(use_gpu=True) - fs = 16000 - metric_nr = scoreq_nr(nr_model, a, fs) - metric_ref = scoreq_ref(ref_model, a, b, fs) - print(metric_nr) - print(metric_ref) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +import sys +import ast + +import numpy as np +from omegaconf import OmegaConf + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +logger = logging.getLogger(__name__) + +try: + import fairseq.logging.meters as fairseq_meters + import fairseq.checkpoint_utils as fairseq_checkpoint_utils + import fairseq.dataclass.utils as fairseq_dataclass_utils + + sys.modules.setdefault("fairseq.meters", fairseq_meters) + + def _legacy_fairseq_args_to_cfg(args): + values = dict(vars(args)) + for key in ("latent_temp",): + value = values.get(key) + if isinstance(value, str): + try: + parsed = ast.literal_eval(value) + except (SyntaxError, ValueError): + continue + values[key] = list(parsed) if isinstance(parsed, tuple) else parsed + + generation = dict(values) + generation.setdefault("print_alignment", None) + + def section(name, source_key=None): + data = dict(values) + data["_name"] = values.get(source_key or name) + return data + + return OmegaConf.create( + { + "common": dict(values), + "common_eval": dict(values), + "distributed_training": dict(values), + "dataset": dict(values), + "optimization": dict(values), + "checkpoint": dict(values), + "bmuf": dict(values), + "generation": generation, + "eval_lm": dict(values), + "interactive": dict(values), + "ema": dict(values), + "task": section("task"), + "model": section("model", "arch"), + "optimizer": section("optimizer"), + "lr_scheduler": section("lr_scheduler"), + "criterion": section("criterion"), + } + ) + + fairseq_dataclass_utils.convert_namespace_to_omegaconf = _legacy_fairseq_args_to_cfg + fairseq_checkpoint_utils.convert_namespace_to_omegaconf = ( + _legacy_fairseq_args_to_cfg + ) + from scoreq_versa import Scoreq +except ImportError: + logger.info( + "scoreq is not installed. Please use `tools/install_scoreq.sh` to install" + ) + Scoreq = None + + +def scoreq_nr_setup( + data_domain="synthetic", + cache_dir="versa_cache/scoreq_pt-models", + use_gpu=False, +): + if use_gpu: + device = "cuda" + else: + device = "cpu" + + if Scoreq is None: + raise ModuleNotFoundError( + "scoreq is not installed. Please use `tools/install_scoreq.sh` to install" + ) + + return Scoreq( + data_domain=data_domain, mode="nr", cache_dir=cache_dir, device=device + ) + + +def scoreq_ref_setup( + data_domain="synthetic", + cache_dir="./scoreq_pt-models", + use_gpu=False, +): + if use_gpu: + device = "cuda" + else: + device = "cpu" + + if Scoreq is None: + raise ModuleNotFoundError( + "scoreq is not installed. Please use `tools/install_scoreq.sh` to install" + ) + + return Scoreq( + data_domain=data_domain, mode="ref", cache_dir=cache_dir, device=device + ) + + +def scoreq_nr(model, pred_x, fs): + # NOTE(jiatong): current model only have 16k options + if fs != 16000: + pred_x = resample_audio(pred_x, fs, 16000) + + return {"scoreq_nr": model.predict(test_path=pred_x, ref_path=None)} + + +def scoreq_ref(model, pred_x, gt_x, fs): + # NOTE(jiatong): current model only have 16k options + if fs != 16000: + gt_x = resample_audio(gt_x, fs, 16000) + pred_x = resample_audio(pred_x, fs, 16000) + + return {"scoreq_ref": model.predict(test_path=pred_x, ref_path=gt_x)} + + +class ScoreqMetric(BaseMetric): + """ScoreQ speech quality metric.""" + + def _setup(self): + self.mode = self.config.get("mode", "nr") + if self.mode not in {"nr", "ref"}: + raise ValueError(f"Invalid ScoreQ mode: {self.mode}") + + self.data_domain = self.config.get("data_domain", "synthetic") + self.cache_dir = self.config.get( + "cache_dir", self.config.get("model_cache", "versa_cache/scoreq_pt-models") + ) + self.use_gpu = self.config.get("use_gpu", False) + + if self.mode == "ref": + self.model = scoreq_ref_setup( + data_domain=self.data_domain, + cache_dir=self.cache_dir, + use_gpu=self.use_gpu, + ) + else: + self.model = scoreq_nr_setup( + data_domain=self.data_domain, + cache_dir=self.cache_dir, + use_gpu=self.use_gpu, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if self.mode == "ref" and references is None: + raise ValueError("Reference signal must be provided for ScoreQ ref mode") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + pred_x = np.asarray(predictions) + if self.mode == "ref": + return scoreq_ref(self.model, pred_x, np.asarray(references), fs) + return scoreq_nr(self.model, pred_x, fs) + + def get_metadata(self): + return _scoreq_metadata(f"scoreq_{self.mode}", self.mode) + + +class ScoreqNrMetric(ScoreqMetric): + """Reference-less ScoreQ speech quality metric.""" + + def _setup(self): + self.config = {**self.config, "mode": self.config.get("mode", "nr")} + super()._setup() + + +class ScoreqRefMetric(ScoreqMetric): + """Reference-based ScoreQ speech quality metric.""" + + def _setup(self): + self.config = {**self.config, "mode": self.config.get("mode", "ref")} + super()._setup() + + +def _scoreq_metadata(name, mode): + requires_reference = mode == "ref" + description = ( + "ScoreQ reference-based speech quality assessment" + if requires_reference + else "ScoreQ reference-less speech quality assessment" + ) + return MetricMetadata( + name=name, + category=( + MetricCategory.DEPENDENT + if requires_reference + else MetricCategory.INDEPENDENT + ), + metric_type=MetricType.FLOAT, + requires_reference=requires_reference, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["scoreq_versa", "torch", "librosa", "numpy"], + description=description, + paper_reference="https://arxiv.org/pdf/2410.06675", + implementation_source="https://github.com/ftshijt/scoreq", + ) + + +def register_scoreq_metric(registry): + """Register ScoreQ reference-less and reference-based metrics.""" + registry.register( + ScoreqNrMetric, + _scoreq_metadata("scoreq_nr", "nr"), + aliases=["scoreq", "scoreq_metric", "scoreq_no_ref"], + ) + registry.register( + ScoreqRefMetric, + _scoreq_metadata("scoreq_ref", "ref"), + aliases=["scoreq_reference"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + b = np.random.random(16000) + metric_nr = ScoreqNrMetric({"use_gpu": True}) + metric_ref = ScoreqRefMetric({"use_gpu": True}) + print(metric_nr.compute(a, metadata={"sample_rate": 16000})) + print(metric_ref.compute(a, b, metadata={"sample_rate": 16000})) diff --git a/versa/utterance_metrics/se_snr.py b/versa/utterance_metrics/se_snr.py index e902668..10f0a2b 100644 --- a/versa/utterance_metrics/se_snr.py +++ b/versa/utterance_metrics/se_snr.py @@ -1,49 +1,131 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import os - -import numpy as np -from espnet2.bin.enh_inference import SeparateSpeech - -from versa.sequence_metrics.signal_metric import signal_metric - - -def se_snr_setup( - model_tag="default", model_path=None, model_config=None, use_gpu=False -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - if model_path is not None and model_config is not None: - model = SeparateSpeech.from_pretrained( - model_file=model_path, - train_config=model_config, - normalize_output_wav=True, - device=device, - ) - else: - if model_tag == "default": - model_tag = "wyz/tfgridnet_for_urgent24" - model = SeparateSpeech.from_pretrained( - model_tag=model_tag, normalize_output_wav=True, device=device - ) - return model - - -def se_snr(model, pred_x, fs): - enhanced_x = model(pred_x[None, :], fs=fs)[0] - signal_metrics = signal_metric(pred_x, enhanced_x) - updated_metrics = {f"se_{key}": value for key, value in signal_metrics.items()} - updated_metrics.pop("se_sir") - return updated_metrics - - -if __name__ == "__main__": - a = np.random.random(16000) - b = np.random.random(16000) - model = se_snr_setup() - print("metrics: {}".format(se_snr(model, a, 16000))) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import numpy as np + +try: + from espnet2.bin.enh_inference import SeparateSpeech +except ImportError: + SeparateSpeech = None + +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType +from versa.sequence_metrics.signal_metric import signal_metric + + +def se_snr_setup( + model_tag="default", + model_path=None, + model_config=None, + use_gpu=False, + cache_dir=None, +): + if SeparateSpeech is None: + raise ImportError("se_snr requires espnet. Please install espnet and retry") + + if use_gpu: + device = "cuda" + else: + device = "cpu" + if model_path is not None and model_config is not None: + model = SeparateSpeech.from_pretrained( + model_file=model_path, + train_config=model_config, + normalize_output_wav=True, + device=device, + ) + else: + if model_tag == "default": + model_tag = "wyz/tfgridnet_for_urgent24" + if cache_dir is None: + model = SeparateSpeech.from_pretrained( + model_tag=model_tag, normalize_output_wav=True, device=device + ) + else: + try: + from espnet_model_zoo.downloader import ModelDownloader + except ImportError: + raise ImportError( + "se_snr requires espnet_model_zoo. Please install it and retry" + ) + model_kwargs = ModelDownloader(cachedir=cache_dir).download_and_unpack( + model_tag + ) + model = SeparateSpeech( + normalize_output_wav=True, device=device, **model_kwargs + ) + return model + + +def se_snr(model, pred_x, fs): + enhanced_x = model(pred_x[None, :], fs=fs)[0] + signal_metrics = signal_metric(pred_x, enhanced_x) + updated_metrics = {f"se_{key}": value for key, value in signal_metrics.items()} + updated_metrics.pop("se_sir") + return updated_metrics + + +class SeSnrMetric(BaseMetric): + """Speech enhancement-based signal quality metrics.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path") + self.model_config = self.config.get("model_config") + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("cache_dir", "versa_cache/espnet_model_zoo") + self.model = se_snr_setup( + model_tag=self.model_tag, + model_path=self.model_path, + model_config=self.model_config, + use_gpu=self.use_gpu, + cache_dir=self.cache_dir, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return se_snr(self.model, np.asarray(predictions), fs) + + def get_metadata(self): + return _se_snr_metadata() + + +def _se_snr_metadata(): + return MetricMetadata( + name="se_snr", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=[ + "espnet2", + "ci_sdr", + "fast_bss_eval", + "mir_eval", + "numpy", + "torch", + ], + description="Speech enhancement-based SDR, SAR, SI-SNR, and CI-SDR metrics", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_se_snr_metric(registry): + """Register speech enhancement-based signal metrics with the registry.""" + registry.register( + SeSnrMetric, + _se_snr_metadata(), + aliases=["se_snr_metric", "speech_enhancement_snr"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = SeSnrMetric() + print("metrics: {}".format(metric.compute(a, metadata={"sample_rate": 16000}))) diff --git a/versa/utterance_metrics/sheet_ssqa.py b/versa/utterance_metrics/sheet_ssqa.py index 7e63edc..67a14e9 100644 --- a/versa/utterance_metrics/sheet_ssqa.py +++ b/versa/utterance_metrics/sheet_ssqa.py @@ -1,54 +1,109 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import librosa -import numpy as np -import torch - - -def sheet_ssqa_setup( - model_tag="default", - model_path=None, - model_config=None, - cache_dir="versa_cache", - use_gpu=False, -): - if use_gpu: - device = "cuda" - else: - device = "cpu" - - if model_path is not None and model_config is not None: - raise NotImplementedError( - "Pending implementation for customized setup (Jiatong)" - ) - else: - if model_tag == "default": - model_tag = "unilight/sheet:v0.1.0" - torch.hub.set_dir(cache_dir) - model = torch.hub.load( - "unilight/sheet:v0.1.0", "default", trust_repo=True, force_reload=False - ) - - model.model.to(device) - return model - - -def sheet_ssqa(model, pred_x, fs, use_gpu=False): - # NOTE(jiatong): current model only work for 16000 Hz - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - pred_x = torch.tensor(pred_x).float() - if use_gpu: - pred_x = pred_x.to("cuda") - return {"sheet_ssqa": model.predict(wav=pred_x)} - - -if __name__ == "__main__": - a = np.random.random(16000) - model = sheet_ssqa_setup() - print("metrics: {}".format(sheet_ssqa(model, a, 16000))) +#!/usr/bin/env python3 + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import numpy as np +import torch + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + + +def sheet_ssqa_setup( + model_tag="default", + model_path=None, + model_config=None, + cache_dir="versa_cache", + use_gpu=False, +): + if use_gpu: + device = "cuda" + else: + device = "cpu" + + if model_path is not None and model_config is not None: + raise NotImplementedError( + "Pending implementation for customized setup (Jiatong)" + ) + else: + if model_tag == "default": + model_tag = "unilight/sheet:v0.1.0" + torch.hub.set_dir(cache_dir) + model = torch.hub.load( + "unilight/sheet:v0.1.0", "default", trust_repo=True, force_reload=False + ) + + model.model.to(device) + return model + + +def sheet_ssqa(model, pred_x, fs, use_gpu=False): + # NOTE(jiatong): current model only work for 16000 Hz + if fs != 16000: + pred_x = resample_audio(pred_x, fs, 16000) + pred_x = torch.tensor(pred_x).float() + if use_gpu: + pred_x = pred_x.to("cuda") + return {"sheet_ssqa": model.predict(wav=pred_x)} + + +class SheetSsqaMetric(BaseMetric): + """Sheet SSQA MOS prediction metric.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path") + self.model_config = self.config.get("model_config") + self.cache_dir = self.config.get("cache_dir", "versa_cache") + self.use_gpu = self.config.get("use_gpu", False) + self.model = sheet_ssqa_setup( + model_tag=self.model_tag, + model_path=self.model_path, + model_config=self.model_config, + cache_dir=self.cache_dir, + use_gpu=self.use_gpu, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return sheet_ssqa(self.model, np.asarray(predictions), fs, use_gpu=self.use_gpu) + + def get_metadata(self): + return _sheet_ssqa_metadata() + + +def _sheet_ssqa_metadata(): + return MetricMetadata( + name="sheet_ssqa", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="Sheet SSQA MOS prediction metric", + paper_reference="https://arxiv.org/abs/2411.03715", + implementation_source="https://github.com/unilight/sheet", + ) + + +def register_sheet_ssqa_metric(registry): + """Register Sheet SSQA with the registry.""" + registry.register( + SheetSsqaMetric, + _sheet_ssqa_metadata(), + aliases=["sheet", "sheet_ssqa_metric"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = SheetSsqaMetric() + print("metrics: {}".format(metric.compute(a, metadata={"sample_rate": 16000}))) diff --git a/versa/utterance_metrics/sigmos.py b/versa/utterance_metrics/sigmos.py index 725dbe8..59cfb8e 100644 --- a/versa/utterance_metrics/sigmos.py +++ b/versa/utterance_metrics/sigmos.py @@ -1,15 +1,21 @@ import os -import scipy -import librosa +import logging +from enum import Enum +import librosa import numpy as np import onnxruntime as ort -from enum import Enum -import logging - +import scipy +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType __all__ = ["SigMOS", "Version"] +SIGMOS_MODEL_FILENAME = "model-sigmos_1697718653_41d092e8-epo-200.onnx" +SIGMOS_MODEL_URL = ( + "https://github.com/microsoft/SIG-Challenge/raw/refs/heads/main/" + "ICASSP2024/sigmos/model-sigmos_1697718653_41d092e8-epo-200.onnx" +) + class Version(Enum): V1 = "v1" # 15.10.2023 @@ -26,9 +32,7 @@ def __init__(self, model_dir, model_version=Version.V1): assert model_version in [v for v in Version] model_path_history = { - Version.V1: os.path.join( - model_dir, "model-sigmos_1697718653_41d092e8-epo-200.onnx" - ) + Version.V1: os.path.join(model_dir, SIGMOS_MODEL_FILENAME) } self.sampling_rate = 48_000 @@ -113,8 +117,6 @@ def sigmos_setup(model_dir=None): # Get the absolute path to the current file (this script) script_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(script_dir, "..", "..", "versa_cache", "sigmos_model") - # Permalink to the onnx file: - # https://github.com/microsoft/SIG-Challenge/raw/refs/heads/main/ICASSP2024/sigmos/model-sigmos_1697718653_41d092e8-epo-200.onnx # Check if the model directory exists if not os.path.exists(model_dir): @@ -123,26 +125,22 @@ def sigmos_setup(model_dir=None): # Check if the model file already exists # If it does not exist, download it - if not os.path.exists( - os.path.join(model_dir, "model-sigmos_1697718653_41d092e8-epo-200.onnx") - ): + if not os.path.exists(os.path.join(model_dir, SIGMOS_MODEL_FILENAME)): logging.info( f"Model file not found in {model_dir}. Downloading the model file..." ) # Download the model file from the web url import requests - model_url = "https://github.com/microsoft/SIG-Challenge/raw/refs/heads/main/ICASSP2024/sigmos/model-sigmos_1697718653_41d092e8-epo-200.onnx" - model_file = os.path.join( - model_dir, "model-sigmos_1697718653_41d092e8-epo-200.onnx" - ) - response = requests.get(model_url) + model_file = os.path.join(model_dir, SIGMOS_MODEL_FILENAME) + response = requests.get(SIGMOS_MODEL_URL) if response.status_code == 200: with open(model_file, "wb") as f: f.write(response.content) else: raise RuntimeError( - f"Failed to download the model file from {model_url}. Status code: {response.status_code}" + "Failed to download the model file from " + f"{SIGMOS_MODEL_URL}. Status code: {response.status_code}" ) sigmos_estimator = SigMOS(model_dir=model_dir) @@ -161,13 +159,56 @@ def sigmos_calculate(model, pred_x, gen_sr): return result +class SigmosMetric(BaseMetric): + """SIG-MOS metric using the ICASSP 2024 SIG Challenge ONNX model.""" + + def _setup(self): + model_dir = self.config.get("model_dir", self.config.get("cache_dir")) + self.model = sigmos_setup(model_dir=model_dir) + + def compute(self, predictions, references=None, metadata=None): + metadata = metadata or {} + sample_rate = metadata.get("sample_rate", 48000) + return sigmos_calculate(self.model, predictions, sample_rate) + + def get_metadata(self): + return _sigmos_metadata() + + +def _sigmos_metadata(): + return MetricMetadata( + name="sigmos", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["onnxruntime", "librosa", "scipy", "numpy"], + description="SIG-MOS P.804 speech quality estimator", + paper_reference="https://arxiv.org/pdf/2309.07385", + implementation_source=( + "https://github.com/microsoft/SIG-Challenge/tree/main/" "ICASSP2024/sigmos" + ), + ) + + +def register_sigmos_metric(registry): + """Register SIG-MOS with the metric registry.""" + registry.register( + SigmosMetric, + _sigmos_metadata(), + aliases=["Sigmos", "sigmos", "sig_mos"], + ) + + if __name__ == "__main__": """ Sample code to run the SigMOS estimator. V1 (current model) is an alpha version and should be used in accordance. """ # model_dir = r"." - sigmos_estimator = SigMOS() + sigmos_estimator = sigmos_setup() # input data must have sr=48kHz, otherwise please specify the sr (it will be resampled to 48kHz internally) sampling_rate = 48_000 diff --git a/versa/utterance_metrics/singer.py b/versa/utterance_metrics/singer.py index f9ee5e2..f610665 100644 --- a/versa/utterance_metrics/singer.py +++ b/versa/utterance_metrics/singer.py @@ -3,24 +3,32 @@ # Adapted from speaker similarity code for singer identity # Uses SSL singer identity models from SonyCSLParis/ssl-singer-identity -import os -import librosa import numpy as np import torch +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + def singer_model_setup( - model_name="byol", model_path=None, use_gpu=False, input_sr=44100, torchscript=False + model_name="byol", + model_path=None, + use_gpu=False, + input_sr=44100, + torchscript=False, + cache_dir="versa_cache/singer_identity", ): """ Setup singer identity model Args: - model_name (str): Name of the pretrained model ('byol', 'contrastive', 'contrastive-vc', 'uniformity', 'vicreg') + model_name (str): Pretrained model name such as 'byol', + 'contrastive', 'contrastive-vc', 'uniformity', or 'vicreg'. model_path (str): Path to local model (if None, downloads from HuggingFace) use_gpu (bool): Whether to use GPU input_sr (int): Input sample rate (will be upsampled to 44.1kHz if different) torchscript (bool): Whether to load torchscript version + cache_dir (str): Directory for downloaded pretrained model files Returns: model: Loaded singer identity model @@ -38,11 +46,20 @@ def singer_model_setup( if model_path is not None: # Load from local path model = load_model( - model_path, source=model_path, input_sr=input_sr, torchscript=torchscript + model_path, + source=model_path, + input_sr=input_sr, + torchscript=torchscript, + savedir=cache_dir, ) else: # Load from HuggingFace Hub - model = load_model(model_name, input_sr=input_sr, torchscript=torchscript) + model = load_model( + model_name, + input_sr=input_sr, + torchscript=torchscript, + savedir=cache_dir, + ) model = model.to(device) model.eval() @@ -65,8 +82,8 @@ def singer_metric(model, pred_x, gt_x, fs, target_sr=44100): """ # Resample to target sample rate if needed (singer models expect 44.1kHz) if fs != target_sr: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=target_sr) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=target_sr) + gt_x = resample_audio(gt_x, fs, target_sr) + pred_x = resample_audio(pred_x, fs, target_sr) # Convert to torch tensors and add batch dimension device = next(model.parameters()).device @@ -110,9 +127,7 @@ def singer_metric_batch(model, audio_batch, fs, target_sr=44100): if fs != target_sr: resampled_batch = [] for i in range(audio_batch.shape[0]): - resampled = librosa.resample( - audio_batch[i], orig_sr=fs, target_sr=target_sr - ) + resampled = resample_audio(audio_batch[i], fs, target_sr) resampled_batch.append(resampled) audio_batch = np.array(resampled_batch) @@ -146,6 +161,70 @@ def compute_similarity_matrix(embeddings): return similarity_matrix +class SingerMetric(BaseMetric): + """Singer identity embedding cosine similarity.""" + + def _setup(self): + self.model_name = self.config.get("model_name", "byol") + self.model_path = self.config.get("model_path") + self.use_gpu = self.config.get("use_gpu", False) + self.input_sr = self.config.get("input_sr", 44100) + self.torchscript = self.config.get("torchscript", False) + self.target_sr = self.config.get("target_sr", 44100) + self.cache_dir = self.config.get("cache_dir", "versa_cache/singer_identity") + self.model = singer_model_setup( + model_name=self.model_name, + model_path=self.model_path, + use_gpu=self.use_gpu, + input_sr=self.input_sr, + torchscript=self.torchscript, + cache_dir=self.cache_dir, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return singer_metric( + self.model, + np.asarray(predictions), + np.asarray(references), + fs, + target_sr=self.target_sr, + ) + + def get_metadata(self): + return _singer_metadata() + + +def _singer_metadata(): + return MetricMetadata( + name="singer", + category=MetricCategory.NON_MATCH, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["singer_identity", "librosa", "numpy", "torch"], + description="Singer identity embedding cosine similarity", + paper_reference="https://hal.science/hal-04186048v1", + implementation_source="https://github.com/SonyCSLParis/ssl-singer-identity", + ) + + +def register_singer_metric(registry): + """Register singer similarity with the registry.""" + registry.register( + SingerMetric, + _singer_metadata(), + aliases=["singer_similarity", "singer_identity"], + ) + + if __name__ == "__main__": # Example usage @@ -162,12 +241,12 @@ def compute_similarity_matrix(embeddings): # Setup model (will download from HuggingFace on first use) try: - model = singer_model_setup(model_name="byol", use_gpu=False) + model = SingerMetric({"model_name": "byol", "use_gpu": False}) print("Model loaded successfully!") # Compute similarity between two audio signals print("Computing singer similarity...") - result = singer_metric(model, audio_a, audio_b, sample_rate) + result = model.compute(audio_a, audio_b, metadata={"sample_rate": sample_rate}) print(f"Singer similarity: {result['singer_similarity']:.4f}") # Example of batch processing diff --git a/versa/utterance_metrics/speaker.py b/versa/utterance_metrics/speaker.py index 321b1ca..56497b3 100644 --- a/versa/utterance_metrics/speaker.py +++ b/versa/utterance_metrics/speaker.py @@ -3,16 +3,28 @@ # Copyright 2024 Jiatong Shi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import os - -import librosa import numpy as np -from espnet2.bin.spk_inference import Speech2Embedding + +from versa.audio_utils import resample_audio + +try: + from espnet2.bin.spk_inference import Speech2Embedding +except ImportError: + Speech2Embedding = None + +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType def speaker_model_setup( - model_tag="default", model_path=None, model_config=None, use_gpu=False + model_tag="default", + model_path=None, + model_config=None, + use_gpu=False, + cache_dir=None, ): + if Speech2Embedding is None: + raise ImportError("speaker requires espnet. Please install espnet and retry") + if use_gpu: device = "cuda" else: @@ -24,15 +36,27 @@ def speaker_model_setup( else: if model_tag == "default": model_tag = "espnet/voxcelebs12_rawnet3" - model = Speech2Embedding.from_pretrained(model_tag=model_tag, device=device) + if cache_dir is None: + model = Speech2Embedding.from_pretrained(model_tag=model_tag, device=device) + else: + try: + from espnet_model_zoo.downloader import ModelDownloader + except ImportError: + raise ImportError( + "speaker requires espnet_model_zoo. Please install it and retry" + ) + model_kwargs = ModelDownloader(cachedir=cache_dir).download_and_unpack( + model_tag + ) + model = Speech2Embedding(device=device, **model_kwargs) return model def speaker_metric(model, pred_x, gt_x, fs): # NOTE(jiatong): only work for 16000 Hz if fs != 16000: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=16000) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) + gt_x = resample_audio(gt_x, fs, 16000) + pred_x = resample_audio(pred_x, fs, 16000) embedding_gen = model(pred_x).squeeze(0).cpu().numpy() embedding_gt = model(gt_x).squeeze(0).cpu().numpy() @@ -42,8 +66,65 @@ def speaker_metric(model, pred_x, gt_x, fs): return {"spk_similarity": similarity} +class SpeakerMetric(BaseMetric): + """Speaker embedding cosine similarity.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.model_path = self.config.get("model_path") + self.model_config = self.config.get("model_config") + self.use_gpu = self.config.get("use_gpu", False) + self.cache_dir = self.config.get("cache_dir", "versa_cache/espnet_model_zoo") + self.model = speaker_model_setup( + model_tag=self.model_tag, + model_path=self.model_path, + model_config=self.model_config, + use_gpu=self.use_gpu, + cache_dir=self.cache_dir, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return speaker_metric( + self.model, np.asarray(predictions), np.asarray(references), fs + ) + + def get_metadata(self): + return _speaker_metadata() + + +def _speaker_metadata(): + return MetricMetadata( + name="speaker", + category=MetricCategory.NON_MATCH, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "librosa", "numpy"], + description="Speaker embedding cosine similarity", + paper_reference="https://arxiv.org/abs/2401.17230", + implementation_source="https://github.com/espnet/espnet", + ) + + +def register_speaker_metric(registry): + """Register speaker similarity with the registry.""" + registry.register( + SpeakerMetric, + _speaker_metadata(), + aliases=["spk_similarity", "speaker_similarity"], + ) + + if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - model = speaker_model_setup() - print("metrics: {}".format(speaker_metric(model, a, b, 16000))) + metric = SpeakerMetric() + print("metrics: {}".format(metric.compute(a, b, metadata={"sample_rate": 16000}))) diff --git a/versa/utterance_metrics/speaking_rate.py b/versa/utterance_metrics/speaking_rate.py index 9868e61..3d69a91 100644 --- a/versa/utterance_metrics/speaking_rate.py +++ b/versa/utterance_metrics/speaking_rate.py @@ -1,82 +1,159 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging - -logger = logging.getLogger(__name__) - -import librosa -import numpy as np -import torch - -try: - import whisper -except ImportError: - logger.info( - "Whisper is not properly installed. Please install following https://github.com/openai/whisper" - ) - whisper = None - -from espnet2.text.cleaner import TextCleaner - -TARGET_FS = 16000 -CHUNK_SIZE = 30 # seconds - - -def speaking_rate_model_setup( - model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True -): - if model_tag == "default": - model_tag = "large" - device = "cuda" if use_gpu else "cpu" - if whisper is None: - raise RuntimeError( - "Whisper WER is used for evaluation while openai-whisper is not installed" - ) - model = whisper.load_model(model_tag, device=device) - textcleaner = TextCleaner(text_cleaner) - wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} - return wer_utils - - -def speaking_rate_metric(wer_utils, pred_x, cache_text=None, fs=16000, use_char=False): - """Calculate the speaking rate from ASR results. - - Args: - wer_utils (dict): a utility dict for WER calculation. - including: whisper model ("model"), text cleaner ("textcleaner"), and - beam size ("beam size") - pred_x (np.ndarray): test signal (time,) - cache_text (string): transcription from cache (previous modules) - fs (int): sampling rate in Hz - use_char (bool): whether to use character-level speaking rate - Returns: - ret (dict): ditionary containing the speaking word rate - """ - if cache_text is not None: - inf_text = cache_text - else: - if fs != TARGET_FS: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=TARGET_FS) - fs = TARGET_FS - with torch.no_grad(): - inf_text = wer_utils["model"].transcribe( - torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] - )["text"] - - if use_char: - length = len(inf_text) - else: - length = len(inf_text.split()) - return { - "speaking_rate": length / (len(pred_x) / fs), - "whisper_hyp_text": inf_text, - } - - -if __name__ == "__main__": - a = np.random.random(16000) - wer_utils = speaking_rate_model_setup() - print("metrics: {}".format(speaking_rate_metric(wer_utils, a, None, 16000))) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging + +import numpy as np +import torch + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +logger = logging.getLogger(__name__) + +try: + import whisper +except ImportError: + logger.info( + "Whisper is not properly installed. Please install following " + "https://github.com/openai/whisper" + ) + whisper = None + +try: + from espnet2.text.cleaner import TextCleaner +except ImportError: + logger.info("ESPnet is not properly installed. Please install espnet and retry") + TextCleaner = None + +TARGET_FS = 16000 +CHUNK_SIZE = 30 # seconds + + +def speaking_rate_model_setup( + model_tag="default", beam_size=5, text_cleaner="whisper_basic", use_gpu=True +): + if model_tag == "default": + model_tag = "large" + device = "cuda" if use_gpu else "cpu" + if whisper is None: + raise ImportError( + "speaking_rate requires openai-whisper. " + "Please install following https://github.com/openai/whisper" + ) + if TextCleaner is None: + raise ImportError( + "speaking_rate requires espnet TextCleaner. Please install espnet" + ) + model = whisper.load_model(model_tag, device=device) + textcleaner = TextCleaner(text_cleaner) + wer_utils = {"model": model, "cleaner": textcleaner, "beam_size": beam_size} + return wer_utils + + +def speaking_rate_metric(wer_utils, pred_x, cache_text=None, fs=16000, use_char=False): + """Calculate the speaking rate from ASR results. + + Args: + wer_utils (dict): a utility dict for WER calculation. + including: whisper model ("model"), text cleaner ("textcleaner"), and + beam size ("beam size") + pred_x (np.ndarray): test signal (time,) + cache_text (string): transcription from cache (previous modules) + fs (int): sampling rate in Hz + use_char (bool): whether to use character-level speaking rate + Returns: + ret (dict): ditionary containing the speaking word rate + """ + if cache_text is not None: + inf_text = cache_text + else: + if fs != TARGET_FS: + pred_x = resample_audio(pred_x, fs, TARGET_FS) + fs = TARGET_FS + with torch.no_grad(): + inf_text = wer_utils["model"].transcribe( + torch.tensor(pred_x).float(), beam_size=wer_utils["beam_size"] + )["text"] + + if use_char: + length = len(inf_text) + else: + length = len(inf_text.split()) + return { + "speaking_rate": length / (len(pred_x) / fs), + "whisper_hyp_text": inf_text, + } + + +class SpeakingRateMetric(BaseMetric): + """Speaking word or character rate estimated from Whisper ASR output.""" + + def _setup(self): + self.model_tag = self.config.get("model_tag", "default") + self.beam_size = self.config.get("beam_size", 5) + self.text_cleaner = self.config.get("text_cleaner", "whisper_basic") + self.use_gpu = self.config.get("use_gpu", True) + self.use_char = self.config.get("use_char", False) + self.wer_utils = speaking_rate_model_setup( + model_tag=self.model_tag, + beam_size=self.beam_size, + text_cleaner=self.text_cleaner, + use_gpu=self.use_gpu, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + metadata = metadata or {} + cache_text = metadata.get("whisper_hyp_text") + general_cache = metadata.get("general_cache") + if cache_text is None and general_cache: + cache_text = general_cache.get("whisper_hyp_text") + + fs = metadata.get("sample_rate", 16000) + pred_x = np.asarray(predictions) + return speaking_rate_metric( + self.wer_utils, + pred_x, + cache_text=cache_text, + fs=fs, + use_char=self.use_char, + ) + + def get_metadata(self): + return _speaking_rate_metadata() + + +def _speaking_rate_metadata(): + return MetricMetadata( + name="speaking_rate", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["whisper", "espnet2", "librosa", "torch", "numpy"], + description="Speaking word or character rate estimated from Whisper ASR", + paper_reference="https://github.com/openai/whisper", + implementation_source="https://github.com/openai/whisper", + ) + + +def register_speaking_rate_metric(registry): + """Register speaking_rate with the registry.""" + registry.register( + SpeakingRateMetric, + _speaking_rate_metadata(), + aliases=["speaking_rate_metric", "swr"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = SpeakingRateMetric() + print("metrics: {}".format(metric.compute(a, metadata={"sample_rate": 16000}))) diff --git a/versa/utterance_metrics/squim.py b/versa/utterance_metrics/squim.py index d7b8c9a..1bfccd2 100644 --- a/versa/utterance_metrics/squim.py +++ b/versa/utterance_metrics/squim.py @@ -9,20 +9,30 @@ import torch try: - import torchaudio import torchaudio.functional as F from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE except ImportError: logging.warning( "Import error. Please install pesq, pystoi, torchaudio for torch squim" ) + F = None + SQUIM_OBJECTIVE = None + SQUIM_SUBJECTIVE = None + +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +SQUIM_AVAILABLE = SQUIM_OBJECTIVE is not None and SQUIM_SUBJECTIVE is not None + + +def is_squim_available(): + return SQUIM_AVAILABLE def squim_metric(pred_x, gt_x, fs): """ Reference: - Kumar, Anurag, et al. “TorchAudio-Squim: Reference-less Speech Quality and Intelligibility measures in TorchAudio.”, - ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2023 + Kumar et al., "TorchAudio-Squim: Reference-less Speech Quality and + Intelligibility measures in TorchAudio", ICASSP 2023. https://pytorch.org/audio/main/tutorials/squim_tutorial.html """ @@ -45,8 +55,8 @@ def squim_metric(pred_x, gt_x, fs): def squim_metric_no_ref(pred_x, fs): """ Reference: - Kumar, Anurag, et al. “TorchAudio-Squim: Reference-less Speech Quality and Intelligibility measures in TorchAudio.”, - ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2023 + Kumar et al., "TorchAudio-Squim: Reference-less Speech Quality and + Intelligibility measures in TorchAudio", ICASSP 2023. https://pytorch.org/audio/main/tutorials/squim_tutorial.html """ @@ -66,8 +76,114 @@ def squim_metric_no_ref(pred_x, fs): } +class SquimMetric(BaseMetric): + """TorchAudio-SQUIM speech quality metric.""" + + def _setup(self): + if not SQUIM_AVAILABLE: + raise ImportError( + "SQUIM is not available. Please install pesq, pystoi, and torchaudio" + ) + self.mode = self.config.get("mode", "no_ref") + if self.mode not in {"ref", "no_ref"}: + raise ValueError(f"Invalid SQUIM mode: {self.mode}") + if self.mode == "ref": + self.model = SQUIM_SUBJECTIVE.get_model() + else: + self.model = SQUIM_OBJECTIVE.get_model() + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if self.mode == "ref" and references is None: + raise ValueError("Reference signal must be provided for SQUIM ref mode") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + pred_x = torch.from_numpy(np.asarray(predictions)) + if fs != 16000: + pred_x = F.resample(pred_x, fs, 16000) + pred_x = pred_x.unsqueeze(0).float() + + if self.mode == "ref": + gt_x = torch.from_numpy(np.asarray(references)) + if fs != 16000: + gt_x = F.resample(gt_x, fs, 16000) + gt_x = gt_x.unsqueeze(0).float() + torch_squim_mos = self.model(pred_x, gt_x) + return {"torch_squim_mos": torch_squim_mos.detach().numpy()[0]} + + torch_squim_stoi, torch_squim_pesq, torch_squim_si_sdr = self.model(pred_x) + return { + "torch_squim_stoi": torch_squim_stoi.detach().numpy()[0], + "torch_squim_pesq": torch_squim_pesq.detach().numpy()[0], + "torch_squim_si_sdr": torch_squim_si_sdr.detach().numpy()[0], + } + + def get_metadata(self): + return _squim_metadata(f"squim_{self.mode}", self.mode) + + +class SquimRefMetric(SquimMetric): + """Reference-based TorchAudio-SQUIM MOS metric.""" + + def _setup(self): + self.config = {**self.config, "mode": self.config.get("mode", "ref")} + super()._setup() + + +class SquimNoRefMetric(SquimMetric): + """Reference-less TorchAudio-SQUIM objective metrics.""" + + def _setup(self): + self.config = {**self.config, "mode": self.config.get("mode", "no_ref")} + super()._setup() + + +def _squim_metadata(name, mode): + requires_reference = mode == "ref" + description = ( + "TorchAudio-SQUIM subjective MOS metric" + if requires_reference + else "TorchAudio-SQUIM reference-less PESQ, STOI, and SI-SDR metrics" + ) + return MetricMetadata( + name=name, + category=( + MetricCategory.DEPENDENT + if requires_reference + else MetricCategory.INDEPENDENT + ), + metric_type=MetricType.DICT, + requires_reference=requires_reference, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["torch", "torchaudio"], + description=description, + paper_reference="https://arxiv.org/abs/2302.01147", + implementation_source=( + "https://pytorch.org/audio/main/tutorials/squim_tutorial.html" + ), + ) + + +def register_squim_metric(registry): + """Register TorchAudio-SQUIM metrics with the registry.""" + registry.register( + SquimRefMetric, + _squim_metadata("squim_ref", "ref"), + aliases=["torch_squim_mos"], + ) + registry.register( + SquimNoRefMetric, + _squim_metadata("squim_no_ref", "no_ref"), + aliases=["squim", "torch_squim_objective"], + ) + + if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - scores = squim_metric(a, b, 16000) + metric = SquimRefMetric() + scores = metric.compute(a, b, metadata={"sample_rate": 16000}) print(scores) diff --git a/versa/utterance_metrics/srmr.py b/versa/utterance_metrics/srmr.py index 7543d7a..c48392d 100644 --- a/versa/utterance_metrics/srmr.py +++ b/versa/utterance_metrics/srmr.py @@ -2,6 +2,7 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import logging +from typing import Dict, Any, Optional logger = logging.getLogger(__name__) @@ -13,39 +14,98 @@ logger.info("srmr is not installed. Please use `tools/install_srmr.sh` to install") srmr = None +from versa.definition import BaseMetric, MetricMetadata, MetricCategory, MetricType -def srmr_metric( - pred_x, - fs, - n_cochlear_filters=23, - low_freq=125, - min_cf=4, - max_cf=128, - fast=True, - norm=False, -): - if srmr is None: - raise ImportError( - # Error message if SRMRpy is not installed + +class SRMRMetric(BaseMetric): + """Speech-to-Reverberation Modulation energy Ratio (SRMR) metric.""" + + def _setup(self): + """Initialize SRMR-specific components.""" + if srmr is None: + raise ImportError( + "srmr is not installed. Please use `tools/install_srmr.sh` to install" + ) + + # Set default parameters from config + self.n_cochlear_filters = self.config.get("n_cochlear_filters", 23) + self.low_freq = self.config.get("low_freq", 125) + self.min_cf = self.config.get("min_cf", 4) + self.max_cf = self.config.get("max_cf", 128) + self.fast = self.config.get("fast", True) + self.norm = self.config.get("norm", False) + + def compute( + self, predictions: Any, references: Any = None, metadata: Dict[str, Any] = None + ) -> Dict[str, float]: + """Compute the SRMR score.""" + pred_x = predictions + sample_rate = metadata.get("sample_rate", 16000) if metadata else 16000 + + srmr_score = srmr( + pred_x, + sample_rate, + n_cochlear_filters=self.n_cochlear_filters, + low_freq=self.low_freq, + min_cf=self.min_cf, + max_cf=self.max_cf, + fast=self.fast, + norm=self.norm, ) - srmr_score = srmr( - pred_x, - fs, - n_cochlear_filters=n_cochlear_filters, - low_freq=low_freq, - min_cf=min_cf, - max_cf=max_cf, - fast=fast, - norm=norm, - ) - return { - "srmr": srmr_score, - } + return { + "srmr": srmr_score, + } + def get_metadata(self) -> MetricMetadata: + """Return SRMR metric metadata.""" + return MetricMetadata( + name="srmr", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["srmrpy"], + description="Speech-to-Reverberation Modulation energy Ratio (SRMR) for speech quality assessment", + paper_reference="http://www.individual.utoronto.ca/falkt/falk/pdf/FalkChan_TASLP2010.pdf", + implementation_source="https://github.com/shimhz/SRMRpy.git", + ) + + +# Auto-registration function +def register_srmr_metric(registry): + """Register SRMR metric with the registry.""" + metric_metadata = MetricMetadata( + name="srmr", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["srmrpy"], + description="Speech-to-Reverberation Modulation energy Ratio (SRMR) for speech quality assessment", + paper_reference="http://www.individual.utoronto.ca/falkt/falk/pdf/FalkChan_TASLP2010.pdf", + implementation_source="https://github.com/shimhz/SRMRpy.git", + ) + registry.register(SRMRMetric, metric_metadata, aliases=["SRMR"]) -if __name__ == "__main__": +if __name__ == "__main__": a = np.random.random(16000) - score = srmr_metric(a, 16000) - print(score) + + # Test the new class-based metric + config = { + "n_cochlear_filters": 23, + "low_freq": 125, + "min_cf": 4, + "max_cf": 128, + "fast": True, + "norm": False, + } + metric = SRMRMetric(config) + metadata = {"sample_rate": 16000} + score = metric.compute(a, metadata=metadata) + print("SRMR", score) diff --git a/versa/utterance_metrics/stoi.py b/versa/utterance_metrics/stoi.py index 7e5f060..c7ea047 100644 --- a/versa/utterance_metrics/stoi.py +++ b/versa/utterance_metrics/stoi.py @@ -10,6 +10,8 @@ except ImportError: raise ImportError("Please install pystoi and retry: pip install stoi") +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + def stoi_metric(pred_x, gt_x, fs): if pred_x.shape[0] != gt_x.shape[0]: @@ -29,8 +31,78 @@ def estoi_metric(pred_x, gt_x, fs): return {"estoi": score} +class StoiMetric(BaseMetric): + """Short-Time Objective Intelligibility metric.""" + + def _setup(self): + self.extended = self.config.get("extended", False) + self.output_key = "estoi" if self.extended else "stoi" + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + pred_x = np.asarray(predictions) + gt_x = np.asarray(references) + + if self.extended: + return estoi_metric(pred_x, gt_x, fs) + return stoi_metric(pred_x, gt_x, fs) + + def get_metadata(self): + return _stoi_metadata(self.output_key, self.extended) + + +class EstoiMetric(StoiMetric): + """Extended Short-Time Objective Intelligibility metric.""" + + def _setup(self): + self.extended = self.config.get("extended", True) + self.output_key = "estoi" if self.extended else "stoi" + + +def _stoi_metadata(name, extended): + label = "ESTOI" if extended else "STOI" + description = ( + "Extended Short-Time Objective Intelligibility" + if extended + else "Short-Time Objective Intelligibility" + ) + return MetricMetadata( + name=name, + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["pystoi", "numpy"], + description=f"{label}: {description}", + paper_reference="https://doi.org/10.1109/TASL.2010.2045551", + implementation_source="https://github.com/mpariente/pystoi", + ) + + +def register_stoi_metric(registry): + """Register STOI and ESTOI metrics with the registry.""" + registry.register( + StoiMetric, + _stoi_metadata("stoi", extended=False), + aliases=["STOI", "stoi_metric"], + ) + registry.register( + EstoiMetric, + _stoi_metadata("estoi", extended=True), + aliases=["ESTOI", "estoi_metric"], + ) + + if __name__ == "__main__": a = np.random.random(16000) b = np.random.random(16000) - scores = stoi_metric(a, b, 16000) + metric = StoiMetric() + scores = metric.compute(a, b, metadata={"sample_rate": 16000}) print(scores) diff --git a/versa/utterance_metrics/universa.py b/versa/utterance_metrics/universa.py index 24dd111..285482b 100644 --- a/versa/utterance_metrics/universa.py +++ b/versa/utterance_metrics/universa.py @@ -1,258 +1,486 @@ #!/usr/bin/env python3 -# Copyright 2024 Jiatong Shi +# Copyright 2025 Jiatong Shi +# Mainly adapted from ESPnet-SE (https://github.com/espnet/espnet.git) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import numpy as np import torch -import librosa -import soundfile as sf +import soundfile +from versa.audio_utils import resample_audio -def universa_model_setup(model_tag="default", use_gpu=False): + +def _ensure_torchaudio_legacy_backend_api(): + try: + import torchaudio + except ImportError: + return + + if not hasattr(torchaudio, "set_audio_backend"): + torchaudio.set_audio_backend = lambda *args, **kwargs: None + + +_ensure_torchaudio_legacy_backend_api() + +try: + from espnet2.bin.universa_inference import UniversaInference +except ImportError: + UniversaInference = None + +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +# Global model instances to avoid reloading +_universa_models = {} + + +def get_universa_model(model_type="noref", cache_dir=None): """ - Setup Universa model for inference. + Get or load Universa model instance. Args: - model_tag (str): Model tag to use. Options: - - "default": espnet/universa-wavlm_base_urgent24_multi-metric_noref - - "audioref": espnet/universa-wavlm_base_urgent24_multi-metric_audioref - - "textref": espnet/universa-wavlm_base_urgent24_multi-metric_textref - - "fullref": espnet/universa-wavlm_base_urgent24_multi-metric_fullref - use_gpu (bool): Whether to use GPU for inference + model_type (str): One of "noref", "audioref", "textref", "fullref" Returns: - UniversaInference: Loaded model for inference + UniversaInference: Loaded model instance """ - try: - from espnet2.bin.universa_inference import UniversaInference - except ImportError: - raise ImportError( - "Please install espnet and espnet_model_zoo to use Universa metric: " - "pip install espnet espnet_model_zoo" - ) - - # Map model tags to actual model names model_mapping = { - "default": "espnet/universa-wavlm_base_urgent24_multi-metric_noref", "noref": "espnet/universa-wavlm_base_urgent24_multi-metric_noref", "audioref": "espnet/universa-wavlm_base_urgent24_multi-metric_audioref", "textref": "espnet/universa-wavlm_base_urgent24_multi-metric_textref", "fullref": "espnet/universa-wavlm_base_urgent24_multi-metric_fullref", } - if model_tag not in model_mapping: - raise ValueError( - f"Unknown model_tag: {model_tag}. Available options: {list(model_mapping.keys())}" - ) + if UniversaInference is None: + raise ImportError("universa requires espnet. Please install espnet and retry") + + cache_key = (model_type, cache_dir) + if cache_key not in _universa_models: + if model_type not in model_mapping: + raise ValueError( + f"Unknown model_type: {model_type}. " + f"Choose from {list(model_mapping.keys())}" + ) + + print(f"Loading Universa model: {model_mapping[model_type]}") + if cache_dir is None: + _universa_models[cache_key] = UniversaInference.from_pretrained( + model_mapping[model_type] + ) + else: + try: + from espnet_model_zoo.downloader import ModelDownloader + except ImportError: + raise ImportError( + "universa requires espnet_model_zoo. Please install it and retry" + ) + model_kwargs = ModelDownloader(cachedir=cache_dir).download_and_unpack( + model_mapping[model_type] + ) + _universa_models[cache_key] = UniversaInference(**model_kwargs) + + return _universa_models[cache_key] + + +def audio_preprocess(audio_data, original_sr=None, target_sr=16000): + """ + Preprocess audio data for Universa inference. + + Args: + audio_data: numpy array or file path + original_sr: original sample rate (if audio_data is numpy array) + target_sr: target sample rate - model_name = model_mapping[model_tag] + Returns: + tuple: (audio_tensor, audio_lengths_tensor) + """ + if isinstance(audio_data, str): + # File path + audio, sr = soundfile.read(audio_data) + else: + # Numpy array + audio = audio_data + sr = original_sr or target_sr - # Load the model - model = UniversaInference.from_pretrained(model_name) + # Ensure audio is 1D + if audio.ndim > 1: + audio = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] - # Set device - if use_gpu and torch.cuda.is_available(): - model = model.cuda() + # Resample if needed + if sr != target_sr: + audio = resample_audio(audio, sr, target_sr) + + # Convert to float32 and create tensor + audio = audio.astype(np.float32) + audio_tensor = torch.from_numpy(audio).unsqueeze(0) + audio_lengths = torch.tensor([len(audio_tensor[0])]) - return model + return audio_tensor, audio_lengths -def audio_preprocess(audio, fs, target_fs=16000): +def universa_metric_noref(audio_data, original_sr=None, cache_dir=None): """ - Preprocess audio for Universa model. + Universa no-reference quality assessment. Args: - audio (np.ndarray): Input audio signal - fs (int): Original sampling rate - target_fs (int): Target sampling rate (default: 16000) + audio_data: numpy array or file path + original_sr: original sample rate (if audio_data is numpy array) Returns: - tuple: (processed_audio_tensor, audio_lengths_tensor) + dict: Universa quality metrics with float values and 'universa_' prefix """ - # Resample to target sampling rate - if fs != target_fs: - audio = librosa.resample(audio, orig_sr=fs, target_sr=target_fs) + model = get_universa_model("noref", cache_dir=cache_dir) + audio, audio_lengths = audio_preprocess(audio_data, original_sr) - # Convert to float32 - audio = audio.astype(np.float32) + with torch.no_grad(): + result = model(audio.float(), audio_lengths) - # Convert to tensor and add batch dimension - audio_tensor = torch.from_numpy(audio).unsqueeze(0) - audio_lengths = torch.tensor([len(audio)]) + # Convert to float values with universa_ prefix + formatted_result = {} + for key, value in result.items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + formatted_result[f"universa_{key}"] = float( + value.item() if hasattr(value, "item") else value.flatten()[0] + ) + else: + formatted_result[f"universa_{key}"] = float(value) - return audio_tensor, audio_lengths + return formatted_result -def universa_metric(model, pred_x, fs, gt_x=None, ref_text=None): +def universa_metric_audioref( + audio_data, ref_audio_data, original_sr=None, ref_sr=None, cache_dir=None +): """ - Compute Universa metrics for audio evaluation. + Universa inference with audio reference. Args: - model: Universa model for inference - pred_x (np.ndarray): Audio signal to be evaluated - fs (int): Sampling rate of pred_x - gt_x (np.ndarray, optional): Reference audio signal - ref_text (str, optional): Reference text + audio_data: numpy array or file path (test audio) + ref_audio_data: numpy array or file path (reference audio) + original_sr: original sample rate for test audio + ref_sr: original sample rate for reference audio Returns: - dict: Dictionary containing Universa metric scores + dict: Universa quality metrics with float values and 'universa_' prefix """ - # Preprocess the prediction audio - pred_audio, pred_lengths = audio_preprocess(pred_x, fs) - - # Move to same device as model - if next(model.model.parameters()).is_cuda: - pred_audio = pred_audio.cuda() - pred_lengths = pred_lengths.cuda() - - # Prepare reference audio if provided - ref_audio = None - ref_lengths = None - if gt_x is not None: - ref_audio, ref_lengths = audio_preprocess(gt_x, fs) - if next(model.model.parameters()).is_cuda: - ref_audio = ref_audio.cuda() - ref_lengths = ref_lengths.cuda() - - # Run inference based on available references - if ref_audio is not None and ref_text is not None: - # Both audio and text reference - results = model( - pred_audio.float(), - pred_lengths, - ref_audio=ref_audio.float(), - ref_audio_lengths=ref_lengths, - ref_text=ref_text, - ) - elif ref_audio is not None: - # Only audio reference - results = model( - pred_audio.float(), - pred_lengths, + model = get_universa_model("audioref", cache_dir=cache_dir) + audio, audio_lengths = audio_preprocess(audio_data, original_sr) + ref_audio, ref_audio_lengths = audio_preprocess(ref_audio_data, ref_sr) + + with torch.no_grad(): + result = model( + audio.float(), + audio_lengths, ref_audio=ref_audio.float(), - ref_audio_lengths=ref_lengths, + ref_audio_lengths=ref_audio_lengths, ) - elif ref_text is not None: - # Only text reference - results = model(pred_audio.float(), pred_lengths, ref_text=ref_text) - else: - # No reference - results = model(pred_audio.float(), pred_lengths) - # Convert results to dictionary format - if isinstance(results, dict): - # If results is already a dictionary, use it directly - universa_scores = { - "universa_" + key: value[0][0] for key, value in results.items() - } - else: - # If results is a tensor or other format, convert to dictionary - # This might need adjustment based on actual Universa output format - universa_scores = {"universa_score": float(results.cpu().numpy())} + # Convert to float values with universa_ prefix + formatted_result = {} + for key, value in result.items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + formatted_result[f"universa_{key}"] = float( + value.item() if hasattr(value, "item") else value.flatten()[0] + ) + else: + formatted_result[f"universa_{key}"] = float(value) - return universa_scores + return formatted_result -def universa_noref_metric(model, pred_x, fs): +def universa_metric_textref(audio_data, ref_text, original_sr=None, cache_dir=None): """ - Compute Universa metrics without reference. + Universa inference with text reference. Args: - model: Universa model for inference - pred_x (np.ndarray): Audio signal to be evaluated - fs (int): Sampling rate of pred_x + audio_data: numpy array or file path + ref_text: reference text string + original_sr: original sample rate (if audio_data is numpy array) Returns: - dict: Dictionary containing Universa metric scores + dict: Universa quality metrics with float values and 'universa_' prefix """ - return universa_metric(model, pred_x, fs) - - -def universa_audioref_metric(model, pred_x, fs, gt_x): + model = get_universa_model("textref", cache_dir=cache_dir) + audio, audio_lengths = audio_preprocess(audio_data, original_sr) + + with torch.no_grad(): + result = model(audio.float(), audio_lengths, ref_text=ref_text) + + # Convert to float values with universa_ prefix + formatted_result = {} + for key, value in result.items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + formatted_result[f"universa_{key}"] = float( + value.item() if hasattr(value, "item") else value.flatten()[0] + ) + else: + formatted_result[f"universa_{key}"] = float(value) + + return formatted_result + + +def universa_metric_fullref( + audio_data, + ref_audio_data, + ref_text, + original_sr=None, + ref_sr=None, + cache_dir=None, +): """ - Compute Universa metrics with audio reference. + Universa inference with both audio and text reference. Args: - model: Universa model for inference - pred_x (np.ndarray): Audio signal to be evaluated - fs (int): Sampling rate of pred_x - gt_x (np.ndarray): Reference audio signal + audio_data: numpy array or file path (test audio) + ref_audio_data: numpy array or file path (reference audio) + ref_text: reference text string + original_sr: original sample rate for test audio + ref_sr: original sample rate for reference audio Returns: - dict: Dictionary containing Universa metric scores + dict: Universa quality metrics with float values and 'universa_' prefix """ - return universa_metric(model, pred_x, fs, gt_x=gt_x) - + model = get_universa_model("fullref", cache_dir=cache_dir) + audio, audio_lengths = audio_preprocess(audio_data, original_sr) + ref_audio, ref_audio_lengths = audio_preprocess(ref_audio_data, ref_sr) + + with torch.no_grad(): + result = model( + audio.float(), + audio_lengths, + ref_audio=ref_audio.float(), + ref_audio_lengths=ref_audio_lengths, + ref_text=ref_text, + ) -def universa_textref_metric(model, pred_x, fs, ref_text): + # Convert to float values with universa_ prefix + formatted_result = {} + for key, value in result.items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + formatted_result[f"universa_{key}"] = float( + value.item() if hasattr(value, "item") else value.flatten()[0] + ) + else: + formatted_result[f"universa_{key}"] = float(value) + + return formatted_result + + +def universa_metric( + audio_data, + ref_audio=None, + ref_text=None, + original_sr=16000, + ref_sr=None, + cache_dir=None, +): """ - Compute Universa metrics with text reference. + Universal Universa metric function that automatically selects the appropriate model + based on available references. Args: - model: Universa model for inference - pred_x (np.ndarray): Audio signal to be evaluated - fs (int): Sampling rate of pred_x - ref_text (str): Reference text + audio_data: numpy array or file path (test audio) + ref_audio: numpy array or file path (reference audio, optional) + ref_text: reference text string (optional) + original_sr: original sample rate for test audio + ref_sr: original sample rate for reference audio Returns: - dict: Dictionary containing Universa metric scores + dict: Universa quality metrics """ - return universa_metric(model, pred_x, fs, ref_text=ref_text) + if ref_audio is not None and ref_text is not None: + # Full reference (both audio and text) + return universa_metric_fullref( + audio_data, ref_audio, ref_text, original_sr, ref_sr, cache_dir=cache_dir + ) + elif ref_audio is not None: + # Audio reference only + return universa_metric_audioref( + audio_data, ref_audio, original_sr, ref_sr, cache_dir=cache_dir + ) + elif ref_text is not None: + # Text reference only + return universa_metric_textref( + audio_data, ref_text, original_sr, cache_dir=cache_dir + ) + else: + # No reference + return universa_metric_noref(audio_data, original_sr, cache_dir=cache_dir) -def universa_fullref_metric(model, pred_x, fs, gt_x, ref_text): - """ - Compute Universa metrics with both audio and text reference. +class UniversaMetric(BaseMetric): + """Uni-VERSA speech assessment metric.""" - Args: - model: Universa model for inference - pred_x (np.ndarray): Audio signal to be evaluated - fs (int): Sampling rate of pred_x - gt_x (np.ndarray): Reference audio signal - ref_text (str): Reference text + def _setup(self): + self.model_type = self.config.get( + "model_type", self.config.get("model_tag", "auto") + ) + if self.model_type == "default": + self.model_type = "noref" + self.cache_dir = self.config.get("cache_dir") + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + metadata = metadata or {} + fs = metadata.get("sample_rate", 16000) + ref_sr = metadata.get("reference_sample_rate", fs) + ref_text = metadata.get("text") + if isinstance(references, str): + ref_text = references + ref_audio = None + else: + ref_audio = references + + model_type = self.model_type + if model_type == "noref": + return universa_metric_noref(predictions, fs, cache_dir=self.cache_dir) + if model_type == "audioref": + if ref_audio is None: + raise ValueError("Audio reference must be provided") + return universa_metric_audioref( + predictions, ref_audio, fs, ref_sr, cache_dir=self.cache_dir + ) + if model_type == "textref": + if ref_text is None: + raise ValueError("Text reference must be provided") + return universa_metric_textref( + predictions, ref_text, fs, cache_dir=self.cache_dir + ) + if model_type == "fullref": + if ref_audio is None: + raise ValueError("Audio reference must be provided") + if ref_text is None: + raise ValueError("Text reference must be provided") + return universa_metric_fullref( + predictions, + ref_audio, + ref_text, + fs, + ref_sr, + cache_dir=self.cache_dir, + ) + + metric_kwargs = { + "ref_audio": ref_audio, + "ref_text": ref_text, + "original_sr": fs, + "ref_sr": ref_sr, + } + if self.cache_dir is not None: + metric_kwargs["cache_dir"] = self.cache_dir + return universa_metric(predictions, **metric_kwargs) + + def get_metadata(self): + return _universa_metadata() + + +def _universa_metadata(): + return MetricMetadata( + name="universa", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["espnet2", "torch", "librosa", "numpy", "soundfile"], + description="Uni-VERSA speech assessment metrics", + paper_reference="https://arxiv.org/abs/2505.20741", + implementation_source=( + "https://huggingface.co/collections/espnet/" + "universa-6834e7c0a28225bffb6e2526" + ), + ) + + +def register_universa_metric(registry): + """Register Uni-VERSA with the registry.""" + registry.register( + UniversaMetric, + _universa_metadata(), + aliases=[ + "uni_versa", + "universal_speech_assessment", + "universa_noref", + "universa_audioref", + "universa_textref", + "universa_fullref", + ], + ) + + +# Debug code +if __name__ == "__main__": + # Generate test audio + test_audio = np.random.random(16000) + ref_audio = np.random.random(16000) + ref_text = "This is a test reference text" - Returns: - dict: Dictionary containing Universa metric scores - """ - return universa_metric(model, pred_x, fs, gt_x=gt_x, ref_text=ref_text) + print("=== Universa Metrics Tests ===") + # Test no-reference + try: + print("\n1. Testing no-reference Universa...") + noref_result = universa_metric_noref(test_audio, 16000) + print("No-ref result:", noref_result) + except Exception as e: + print(f"No-ref test failed: {e}") -if __name__ == "__main__": - # Test the implementation - print("Testing Universa metric implementation...") + # Test with audio reference + try: + print("\n2. Testing audio-reference Universa...") + audioref_result = universa_metric_audioref(test_audio, ref_audio, 16000, 16000) + print("Audio-ref result:", audioref_result) + except Exception as e: + print(f"Audio-ref test failed: {e}") - # Generate test audio - fs = 16000 - duration = 2.0 - t = np.linspace(0, duration, int(fs * duration)) - test_audio = np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave + # Test with text reference + try: + print("\n3. Testing text-reference Universa...") + textref_result = universa_metric_textref(test_audio, ref_text, 16000) + print("Text-ref result:", textref_result) + except Exception as e: + print(f"Text-ref test failed: {e}") + # Test with full reference try: - # Test no-reference model - print("Testing no-reference model...") - model = universa_model_setup(model_tag="noref", use_gpu=False) - scores = universa_noref_metric(model, test_audio, fs) - print(f"No-reference scores: {scores}") - - # Test with audio reference - print("Testing audio reference model...") - model_audioref = universa_model_setup(model_tag="audioref", use_gpu=False) - scores_audioref = universa_audioref_metric( - model_audioref, test_audio, fs, test_audio + print("\n4. Testing full-reference Universa...") + fullref_result = universa_metric_fullref( + test_audio, ref_audio, ref_text, 16000, 16000 ) - print(f"Audio reference scores: {scores_audioref}") + print("Full-ref result:", fullref_result) + except Exception as e: + print(f"Full-ref test failed: {e}") + + # Test universal function + try: + print("\n5. Testing universal Universa function...") - # Test with text reference - print("Testing text reference model...") - model_textref = universa_model_setup(model_tag="textref", use_gpu=False) - scores_textref = universa_textref_metric( - model_textref, test_audio, fs, "test text" + # Auto-select no-ref + auto_noref = universa_metric(test_audio, original_sr=16000) + print("Auto no-ref:", auto_noref) + + # Auto-select audio-ref + auto_audioref = universa_metric( + test_audio, ref_audio=ref_audio, original_sr=16000, ref_sr=16000 ) - print(f"Text reference scores: {scores_textref}") + print("Auto audio-ref:", auto_audioref) + + # Auto-select text-ref + auto_textref = universa_metric(test_audio, ref_text=ref_text, original_sr=16000) + print("Auto text-ref:", auto_textref) - print("All tests completed successfully!") + # Auto-select full-ref + auto_fullref = universa_metric( + test_audio, + ref_audio=ref_audio, + ref_text=ref_text, + original_sr=16000, + ref_sr=16000, + ) + print("Auto full-ref:", auto_fullref) except Exception as e: - print(f"Test failed: {e}") - print("This is expected if espnet is not installed.") + print(f"Universal function test failed: {e}") diff --git a/versa/utterance_metrics/vad.py b/versa/utterance_metrics/vad.py index 1902138..864ac12 100644 --- a/versa/utterance_metrics/vad.py +++ b/versa/utterance_metrics/vad.py @@ -1,65 +1,134 @@ -#!/usr/bin/env python3 - -# Copyright 2024 Jiatong Shi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import os - -import librosa -import numpy as np -import torch - - -def vad_model_setup( - threshold=0.5, - min_speech_duration_ms=250, - max_speech_duration_s=float("inf"), - min_silence_duration_ms=100, - speech_pad_ms=30, -): - - model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") - (get_speech_ts, _, _, _, *_) = utils - return { - "module": model, - "util": get_speech_ts, - "threshold": threshold, - "min_speech_duration_ms": min_speech_duration_ms, - "max_speech_duration_s": max_speech_duration_s, - "min_silence_duration_ms": min_silence_duration_ms, - "speech_pad_ms": speech_pad_ms, - } - - -def vad_metric(model_info, pred_x, fs): - model = model_info["module"] - get_speech_ts = model_info["util"] - # NOTE(jiatong): only work for 16000 Hz - if fs > 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - fs = 16000 - elif fs < 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=8000) - fs = 8000 - - speech_timestamps = get_speech_ts( - pred_x, - model, - sampling_rate=fs, - return_seconds=True, - threshold=model_info["threshold"], - min_speech_duration_ms=model_info["min_speech_duration_ms"], - max_speech_duration_s=model_info["max_speech_duration_s"], - min_silence_duration_ms=model_info["min_silence_duration_ms"], - speech_pad_ms=model_info["speech_pad_ms"], - ) - return {"vad_info": speech_timestamps} - - -if __name__ == "__main__": - torch.hub.download_url_to_file( - "https://models.silero.ai/vad_models/en.wav", "en_example.wav" - ) - a, fs = librosa.load("en_example.wav", sr=None) - model_info = vad_model_setup() - print("metrics: {}".format(vad_metric(model_info, a, 16000))) +#!/usr/bin/env python3 + +# Copyright 2024 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import librosa +import numpy as np +import torch + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + + +def vad_model_setup( + threshold=0.5, + min_speech_duration_ms=250, + max_speech_duration_s=float("inf"), + min_silence_duration_ms=100, + speech_pad_ms=30, + trust_repo=True, + force_reload=False, +): + + hub_kwargs = { + "repo_or_dir": "snakers4/silero-vad", + "model": "silero_vad", + "force_reload": force_reload, + } + if trust_repo is not None: + hub_kwargs["trust_repo"] = trust_repo + model, utils = torch.hub.load(**hub_kwargs) + get_speech_ts, _, _, _, *_ = utils + return { + "module": model, + "util": get_speech_ts, + "threshold": threshold, + "min_speech_duration_ms": min_speech_duration_ms, + "max_speech_duration_s": max_speech_duration_s, + "min_silence_duration_ms": min_silence_duration_ms, + "speech_pad_ms": speech_pad_ms, + } + + +def vad_metric(model_info, pred_x, fs): + model = model_info["module"] + get_speech_ts = model_info["util"] + # NOTE(jiatong): only work for 16000 Hz + if fs > 16000: + pred_x = resample_audio(pred_x, fs, 16000) + fs = 16000 + elif fs < 16000: + pred_x = resample_audio(pred_x, fs, 8000) + fs = 8000 + + speech_timestamps = get_speech_ts( + pred_x, + model, + sampling_rate=fs, + return_seconds=True, + threshold=model_info["threshold"], + min_speech_duration_ms=model_info["min_speech_duration_ms"], + max_speech_duration_s=model_info["max_speech_duration_s"], + min_silence_duration_ms=model_info["min_silence_duration_ms"], + speech_pad_ms=model_info["speech_pad_ms"], + ) + return {"vad_info": speech_timestamps} + + +class VadMetric(BaseMetric): + """Voice activity detection using Silero VAD.""" + + def _setup(self): + self.threshold = self.config.get("threshold", 0.5) + self.min_speech_duration_ms = self.config.get("min_speech_duration_ms", 250) + self.max_speech_duration_s = self.config.get( + "max_speech_duration_s", float("inf") + ) + self.min_silence_duration_ms = self.config.get("min_silence_duration_ms", 100) + self.speech_pad_ms = self.config.get("speech_pad_ms", 30) + self.trust_repo = self.config.get("trust_repo", True) + self.force_reload = self.config.get("force_reload", False) + self.model_info = vad_model_setup( + threshold=self.threshold, + min_speech_duration_ms=self.min_speech_duration_ms, + max_speech_duration_s=self.max_speech_duration_s, + min_silence_duration_ms=self.min_silence_duration_ms, + speech_pad_ms=self.speech_pad_ms, + trust_repo=self.trust_repo, + force_reload=self.force_reload, + ) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return vad_metric(self.model_info, np.asarray(predictions), fs) + + def get_metadata(self): + return _vad_metadata() + + +def _vad_metadata(): + return MetricMetadata( + name="vad", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.DICT, + requires_reference=False, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["torch", "librosa", "numpy"], + description="Voice activity detection timestamps from Silero VAD", + paper_reference="https://arxiv.org/abs/2111.14467", + implementation_source="https://github.com/snakers4/silero-vad", + ) + + +def register_vad_metric(registry): + """Register VAD with the registry.""" + registry.register( + VadMetric, + _vad_metadata(), + aliases=["vad_metric", "silero_vad"], + ) + + +if __name__ == "__main__": + torch.hub.download_url_to_file( + "https://models.silero.ai/vad_models/en.wav", "en_example.wav" + ) + a, fs = librosa.load("en_example.wav", sr=None) + metric = VadMetric() + print("metrics: {}".format(metric.compute(a, metadata={"sample_rate": fs}))) diff --git a/versa/utterance_metrics/visqol_score.py b/versa/utterance_metrics/visqol_score.py index 6e6148d..e93c2f5 100644 --- a/versa/utterance_metrics/visqol_score.py +++ b/versa/utterance_metrics/visqol_score.py @@ -5,16 +5,27 @@ import os -import librosa import numpy as np -import visqol -from visqol import visqol_lib_py -from visqol.pb2 import similarity_result_pb2, visqol_config_pb2 + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +try: + from visqol import visqol_lib_py + from visqol.pb2 import visqol_config_pb2 +except ImportError: + visqol_lib_py = None + visqol_config_pb2 = None def visqol_setup(model): # model name related to # https://github.com/google/visqol/tree/master/model + if visqol_lib_py is None or visqol_config_pb2 is None: + raise ImportError( + "visqol is not installed. Please install visqol following " + "https://github.com/google/visqol and retry" + ) config = visqol_config_pb2.VisqolConfig() config.audio.sample_rate = 48000 @@ -33,7 +44,8 @@ def visqol_setup(model): config.audio.sample_rate = 16000 else: raise NotImplementedError( - "Not a valid tag for model, check https://github.com/google/visqol/tree/master/model for details" + "Not a valid tag for model, check " + "https://github.com/google/visqol/tree/master/model for details" ) config.options.svr_model_path = os.path.join( @@ -49,16 +61,67 @@ def visqol_setup(model): def visqol_metric(api, api_fs, pred_x, gt_x, fs): if api_fs != fs: - gt_x = librosa.resample(gt_x, orig_sr=fs, target_sr=api_fs) - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=api_fs) + gt_x = resample_audio(gt_x, fs, api_fs) + pred_x = resample_audio(pred_x, fs, api_fs) similarity_result = api.Measure(gt_x, pred_x) return {"visqol": similarity_result.moslqo} +class VisqolMetric(BaseMetric): + """Virtual Speech Quality Objective Listener metric.""" + + def _setup(self): + self.model = self.config.get("model", "default") + self.api, self.api_fs = visqol_setup(self.model) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + if references is None: + raise ValueError("Reference signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return visqol_metric( + self.api, + self.api_fs, + np.asarray(predictions), + np.asarray(references), + fs, + ) + + def get_metadata(self): + return _visqol_metadata() + + +def _visqol_metadata(): + return MetricMetadata( + name="visqol", + category=MetricCategory.DEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=True, + requires_text=False, + gpu_compatible=False, + auto_install=False, + dependencies=["visqol", "librosa", "numpy"], + description="Virtual Speech Quality Objective Listener MOS-LQO metric", + paper_reference="https://arxiv.org/abs/2004.09584", + implementation_source="https://github.com/google/visqol", + ) + + +def register_visqol_metric(registry): + """Register VISQOL with the registry.""" + registry.register( + VisqolMetric, + _visqol_metadata(), + aliases=["visqol_metric", "VISQOL"], + ) + + if __name__ == "__main__": a = np.random.random(int(16000 * 1)) b = np.random.random(int(16000 * 1)) - predictor, fs = visqol_setup("default") - print(visqol_metric(predictor, fs, a, b, 16000)) + metric = VisqolMetric() + print(metric.compute(a, b, metadata={"sample_rate": 16000})) diff --git a/versa/utterance_metrics/vqscore.py b/versa/utterance_metrics/vqscore.py index ed71b6c..338dcb1 100644 --- a/versa/utterance_metrics/vqscore.py +++ b/versa/utterance_metrics/vqscore.py @@ -1,133 +1,184 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Wangyou Zhang -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import logging -from pathlib import Path -import sys -import yaml - -logger = logging.getLogger(__name__) - -import librosa -import numpy as np -import torch - - -vqscore_dir = str(Path(__file__).parent / "VQscore") -sys.path.append(vqscore_dir) -try: - from models.VQVAE_models import VQVAE_QE -except ImportError: - logger.info( - "After cloning this repository, please run the following command to" - "initialize the submodule 'VQscore':" - "```bash" - "git submodule update --init --recursive" - "```" - ) - VQVAE_QE = None - - -def vqscore_setup(use_gpu=False): - if use_gpu: - device = "cuda" - else: - device = "cpu" - - if VQVAE_QE is None: - raise ModuleNotFoundError( - "After cloning this repository, please run the following command to" - "initialize the submodule 'VQscore':" - "```bash" - "git submodule update --init --recursive" - "```" - ) - - vqscore_conf = str( - Path(vqscore_dir) - / "config/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml" - ) - vqscore_model = str( - Path(vqscore_dir) - / "exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/checkpoint-dnsmos_ovr_CC=0.835.pkl" - ) - - with open(vqscore_conf, "r") as f: - config = yaml.load(f, Loader=yaml.FullLoader) - if device.startswith("cuda"): - torch.backends.cudnn.benchmark = True - - model = VQVAE_QE(**config["VQVAE_params"]).to(device=device).eval() - model.load_state_dict( - torch.load(vqscore_model, map_location=device)["model"]["VQVAE"] - ) - model.input_transform = config["input_transform"] - model.device = device - return model - - -# ported from VQscore/inference.py -def stft_magnitude(x, hop_size, fft_size=512, win_length=512): - if x.is_cuda: - x_stft = torch.stft( - x, - fft_size, - hop_size, - win_length, - window=torch.hann_window(win_length).to("cuda"), - return_complex=False, - ) - else: - x_stft = torch.stft( - x, - fft_size, - hop_size, - win_length, - window=torch.hann_window(win_length), - return_complex=False, - ) - real = x_stft[..., 0] - imag = x_stft[..., 1] - - return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) - - -def cos_similarity(SP_noisy, SP_y_noisy, eps=1e-5): - SP_noisy_norm = torch.norm(SP_noisy, p=2, dim=-1, keepdim=True) + eps - SP_y_noisy_norm = torch.norm(SP_y_noisy, p=2, dim=-1, keepdim=True) + eps - Cos_frame = torch.sum( - SP_noisy / SP_noisy_norm * SP_y_noisy / SP_y_noisy_norm, dim=-1 - ) # torch.Size([B, T, 1] - - return torch.mean(Cos_frame) - - -def vqscore_metric(model, pred_x, fs): - # NOTE(wangyou): current model only have 16k options - if fs != 16000: - pred_x = librosa.resample(pred_x, orig_sr=fs, target_sr=16000) - - with torch.no_grad(): - audio = torch.as_tensor( - pred_x, dtype=torch.float32, device=model.device - ).unsqueeze(0) - SP_input = stft_magnitude(audio, hop_size=256) - if model.input_transform == "log1p": - SP_input = torch.log1p(SP_input) - z = model.CNN_1D_encoder(SP_input) - zq, indices, vqloss, distance = model.quantizer( - z, stochastic=False, update=False - ) - VQScore_cos_z = cos_similarity(z.transpose(2, 1).cpu(), zq.cpu()).numpy() - - return {"vqscore": float(VQScore_cos_z)} - - -if __name__ == "__main__": - a = np.random.random(16000) - qe_model = vqscore_setup(use_gpu=False) - fs = 16000 - metric_qe = vqscore_metric(qe_model, a, fs) - print(metric_qe) +#!/usr/bin/env python3 + +# Copyright 2025 Wangyou Zhang +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +from pathlib import Path +import sys +import yaml + +import numpy as np +import torch + +from versa.audio_utils import resample_audio +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType + +logger = logging.getLogger(__name__) + + +vqscore_dir = str(Path(__file__).parent / "VQscore") +sys.path.append(vqscore_dir) +try: + from models.VQVAE_models import VQVAE_QE +except ImportError: + logger.info( + "After cloning this repository, please run the following command to" + "initialize the submodule 'VQscore':" + "```bash" + "git submodule update --init --recursive" + "```" + ) + VQVAE_QE = None + + +def vqscore_setup(use_gpu=False): + if use_gpu: + device = "cuda" + else: + device = "cpu" + + if VQVAE_QE is None: + raise ModuleNotFoundError( + "After cloning this repository, please run the following command to" + "initialize the submodule 'VQscore':" + "```bash" + "./tools/install_vqscore.sh" + "```" + ) + + vqscore_conf = str( + Path(vqscore_dir) + / ( + "config/" + "QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml" + ) + ) + vqscore_model = str( + Path(vqscore_dir) + / ( + "exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/" + "checkpoint-dnsmos_ovr_CC=0.835.pkl" + ) + ) + + with open(vqscore_conf, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + if device.startswith("cuda"): + torch.backends.cudnn.benchmark = True + + model = VQVAE_QE(**config["VQVAE_params"]).to(device=device).eval() + model.load_state_dict( + torch.load(vqscore_model, map_location=device)["model"]["VQVAE"] + ) + model.input_transform = config["input_transform"] + model.device = device + return model + + +# ported from VQscore/inference.py +def stft_magnitude(x, hop_size, fft_size=512, win_length=512): + if x.is_cuda: + x_stft = torch.stft( + x, + fft_size, + hop_size, + win_length, + window=torch.hann_window(win_length).to("cuda"), + return_complex=False, + ) + else: + x_stft = torch.stft( + x, + fft_size, + hop_size, + win_length, + window=torch.hann_window(win_length), + return_complex=False, + ) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) + + +def cos_similarity(SP_noisy, SP_y_noisy, eps=1e-5): + SP_noisy_norm = torch.norm(SP_noisy, p=2, dim=-1, keepdim=True) + eps + SP_y_noisy_norm = torch.norm(SP_y_noisy, p=2, dim=-1, keepdim=True) + eps + Cos_frame = torch.sum( + SP_noisy / SP_noisy_norm * SP_y_noisy / SP_y_noisy_norm, dim=-1 + ) # torch.Size([B, T, 1] + + return torch.mean(Cos_frame) + + +def vqscore_metric(model, pred_x, fs): + # NOTE(wangyou): current model only have 16k options + if fs != 16000: + pred_x = resample_audio(pred_x, fs, 16000) + + with torch.no_grad(): + audio = torch.as_tensor( + pred_x, dtype=torch.float32, device=model.device + ).unsqueeze(0) + SP_input = stft_magnitude(audio, hop_size=256) + if model.input_transform == "log1p": + SP_input = torch.log1p(SP_input) + z = model.CNN_1D_encoder(SP_input) + zq, indices, vqloss, distance = model.quantizer( + z, stochastic=False, update=False + ) + VQScore_cos_z = cos_similarity(z.transpose(2, 1).cpu(), zq.cpu()).numpy() + + return {"vqscore": float(VQScore_cos_z)} + + +class VqscoreMetric(BaseMetric): + """VQScore speech quality assessment metric.""" + + def _setup(self): + self.use_gpu = self.config.get("use_gpu", False) + self.model = vqscore_setup(use_gpu=self.use_gpu) + + def compute(self, predictions, references=None, metadata=None): + if predictions is None: + raise ValueError("Predicted signal must be provided") + + fs = metadata.get("sample_rate", 16000) if metadata else 16000 + return vqscore_metric(self.model, np.asarray(predictions), fs) + + def get_metadata(self): + return _vqscore_metadata() + + +def _vqscore_metadata(): + return MetricMetadata( + name="vqscore", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["torch", "librosa", "numpy", "yaml"], + description=( + "VQScore speech quality assessment from encoded and quantized features" + ), + paper_reference="https://arxiv.org/abs/2402.16321", + implementation_source="https://github.com/JasonSWFu/VQscore", + ) + + +def register_vqscore_metric(registry): + """Register VQScore with the registry.""" + registry.register( + VqscoreMetric, + _vqscore_metadata(), + aliases=["vq_score", "vqscore_metric"], + ) + + +if __name__ == "__main__": + a = np.random.random(16000) + metric = VqscoreMetric() + print(metric.compute(a, metadata={"sample_rate": 16000})) diff --git a/versa/utterance_metrics/wvmos.py b/versa/utterance_metrics/wvmos.py index ccbfb58..1fb7601 100644 --- a/versa/utterance_metrics/wvmos.py +++ b/versa/utterance_metrics/wvmos.py @@ -7,23 +7,17 @@ logger = logging.getLogger(__name__) import librosa -import numpy as np import torch - -try: - from wvmos import get_wvmos -except ImportError: - logger.info( - "WVMOS is not installed. Please use `tools/install_wvmos.sh` to install" - ) - get_wvmos = None +from versa.definition import BaseMetric, MetricCategory, MetricMetadata, MetricType def wvmos_setup(use_gpu=False): - if get_wvmos is None: + try: + from wvmos import get_wvmos + except ImportError as e: raise ModuleNotFoundError( "WVMOS is not installed. Please use `tools/install_wvmos.sh` to install" - ) + ) from e model = get_wvmos(cuda=use_gpu) @@ -51,3 +45,44 @@ def wvmos_calculate(model, pred_x, gen_sr): x = x.cuda() res = model.forward(x).mean() return {"wvmos": res.cpu().item()} + + +class WvmosMetric(BaseMetric): + """WV-MOS metric using a fine-tuned wav2vec2 model.""" + + def _setup(self): + self.use_gpu = self.config.get("use_gpu", False) + self.model = wvmos_setup(use_gpu=self.use_gpu) + + def compute(self, predictions, references=None, metadata=None): + metadata = metadata or {} + sample_rate = metadata.get("sample_rate", 16000) + return wvmos_calculate(self.model, predictions, sample_rate) + + def get_metadata(self): + return _wvmos_metadata() + + +def _wvmos_metadata(): + return MetricMetadata( + name="wvmos", + category=MetricCategory.INDEPENDENT, + metric_type=MetricType.FLOAT, + requires_reference=False, + requires_text=False, + gpu_compatible=True, + auto_install=False, + dependencies=["librosa", "torch", "transformers"], + description="WV-MOS score prediction using a fine-tuned wav2vec2 model", + paper_reference="https://arxiv.org/abs/2203.13086", + implementation_source="https://github.com/AndreevP/wvmos", + ) + + +def register_wvmos_metric(registry): + """Register WV-MOS with the metric registry.""" + registry.register( + WvmosMetric, + _wvmos_metadata(), + aliases=["Wvmos", "wvmos", "wv_mos"], + )