-
Notifications
You must be signed in to change notification settings - Fork 2
added regression tests for official models - fix for #164 #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nkundiushuti
wants to merge
2
commits into
main
Choose a base branch
from
marius/regression-tests-checkpoints-load
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
193 changes: 193 additions & 0 deletions
193
scripts/regenerate_official_model_output_fingerprints.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| """Regenerate expected output fingerprints for official ESP HF models. | ||
|
|
||
| This utility builds the same deterministic labeled mini-batch used by | ||
| `tests/unittests/test_official_models_output_regression.py`, runs all official | ||
| HF-backed models in feature mode, and prints a Python dictionary literal with | ||
| updated SHA-256 fingerprints. | ||
|
|
||
| Usage: | ||
| uv run python scripts/regenerate_official_model_output_fingerprints.py | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import hashlib | ||
| import json | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from avex import load_model | ||
| from avex.models.utils.registry import get_checkpoint_path, list_models | ||
|
|
||
| _HF_PREFIX = "hf://" | ||
|
|
||
|
|
||
| def _official_hf_model_names() -> list[str]: | ||
| """Return official ESP model names with HF-backed checkpoints. | ||
|
|
||
| Returns: | ||
| Sorted official model names. | ||
| """ | ||
| names: list[str] = [] | ||
| for model_name in list_models().keys(): | ||
| if not model_name.startswith("esp_"): | ||
| continue | ||
| checkpoint_path = get_checkpoint_path(model_name) | ||
| if checkpoint_path is not None and checkpoint_path.startswith(_HF_PREFIX): | ||
| names.append(model_name) | ||
| return sorted(names) | ||
|
|
||
|
|
||
| def _build_labeled_audio_batch(seed: int) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Build deterministic labeled mini-batch with three synthetic classes. | ||
|
|
||
| Args: | ||
| seed: Torch seed used for deterministic setup. | ||
|
|
||
| Returns: | ||
| Tuple of `(audio, labels)` with shapes `(6, 16000)` and `(6,)`. | ||
| """ | ||
| torch.manual_seed(seed) | ||
| sample_rate = 16_000 | ||
| duration_seconds = 1 | ||
| t = torch.linspace(0.0, float(duration_seconds), steps=sample_rate, dtype=torch.float32) | ||
| freqs = (220.0, 440.0, 880.0) | ||
|
|
||
| clips: list[torch.Tensor] = [] | ||
| labels: list[int] = [] | ||
| for class_index, freq in enumerate(freqs): | ||
| base = torch.sin(2.0 * torch.pi * freq * t) | ||
| for amplitude in (0.8, 0.9): | ||
| clips.append((amplitude * base).to(torch.float32)) | ||
| labels.append(class_index) | ||
|
|
||
| return torch.stack(clips, dim=0), torch.tensor(labels, dtype=torch.long) | ||
|
|
||
|
|
||
| def _pool_output(output: torch.Tensor, model_name: str) -> torch.Tensor: | ||
| """Pool model outputs to clip-level shape `(B, D)`. | ||
|
|
||
| Args: | ||
| output: Raw model output tensor. | ||
| model_name: Name used for error messages. | ||
|
|
||
| Returns: | ||
| Pooled output tensor. | ||
|
|
||
| Raises: | ||
| ValueError: If ``output`` rank is not 2, 3, or 4. | ||
| """ | ||
| if output.dim() == 2: | ||
| return output | ||
| if output.dim() == 3: | ||
| return output.mean(dim=1) | ||
| if output.dim() == 4: | ||
| return output.mean(dim=(2, 3)) | ||
| raise ValueError(f"Unsupported output rank for {model_name}: shape={tuple(output.shape)}") | ||
|
|
||
|
|
||
| def _fingerprint_output(output: torch.Tensor, decimals: int) -> str: | ||
| """Compute SHA-256 fingerprint from rounded float output. | ||
|
|
||
| Args: | ||
| output: Pooled output tensor. | ||
| decimals: Number of decimal places for rounding. | ||
|
|
||
| Returns: | ||
| Hex SHA-256 digest string. | ||
| """ | ||
| array = output.detach().cpu().to(torch.float32).numpy() | ||
| rounded = np.round(array, decimals=decimals) | ||
| return hashlib.sha256(rounded.tobytes()).hexdigest() | ||
|
|
||
|
|
||
| def _compute_fingerprints( | ||
| model_names: list[str], | ||
| audio: torch.Tensor, | ||
| decimals: int, | ||
| ) -> tuple[dict[str, str], dict[str, str]]: | ||
| """Compute per-model fingerprints and capture load/run errors. | ||
|
|
||
| Args: | ||
| model_names: Official model names to evaluate. | ||
| audio: Input audio batch. | ||
| decimals: Number of decimal places for rounding. | ||
|
|
||
| Returns: | ||
| Tuple of `(fingerprints, errors)`. | ||
| """ | ||
| fingerprints: dict[str, str] = {} | ||
| errors: dict[str, str] = {} | ||
|
|
||
| for model_name in model_names: | ||
| try: | ||
| model = load_model(model_name, device="cpu", return_features_only=True) | ||
| model.eval() | ||
| with torch.no_grad(): | ||
| output = model(audio) | ||
| pooled = _pool_output(output, model_name=model_name) | ||
| fingerprints[model_name] = _fingerprint_output(pooled, decimals=decimals) | ||
| except Exception as exc: # pragma: no cover - depends on environment/network | ||
| errors[model_name] = str(exc) | ||
| return fingerprints, errors | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| """Parse command-line arguments. | ||
|
|
||
| Returns: | ||
| Parsed arguments namespace. | ||
| """ | ||
| parser = argparse.ArgumentParser(description="Regenerate official model output fingerprints.") | ||
| parser.add_argument( | ||
| "--decimals", | ||
| type=int, | ||
| default=4, | ||
| help="Rounding decimals before hashing (default: 4).", | ||
| ) | ||
| parser.add_argument( | ||
| "--json", | ||
| action="store_true", | ||
| help="Print JSON output instead of Python dict literal.", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main() -> int: | ||
| """Run fingerprint generation and print output. | ||
|
|
||
| Returns: | ||
| ``0`` if every model produced a fingerprint, ``1`` if any model failed | ||
| and was recorded in the errors map. | ||
|
|
||
| Raises: | ||
| ValueError: If the labeled audio batch and labels have mismatched lengths. | ||
| """ | ||
| args = parse_args() | ||
| model_names = _official_hf_model_names() | ||
| audio, labels = _build_labeled_audio_batch(seed=7) | ||
| if labels.shape[0] != audio.shape[0]: | ||
| raise ValueError("Labeled batch mismatch between audio and labels.") | ||
|
|
||
| fingerprints, errors = _compute_fingerprints(model_names, audio=audio, decimals=args.decimals) | ||
|
|
||
| if args.json: | ||
| print(json.dumps(fingerprints, indent=2, sort_keys=True)) | ||
| else: | ||
| print("OFFICIAL_MODEL_OUTPUT_FINGERPRINTS: dict[str, str] = {") | ||
| for name in sorted(fingerprints.keys()): | ||
| print(f' "{name}": "{fingerprints[name]}",') | ||
| print("}") | ||
|
|
||
| if errors: | ||
| print("\nErrors (models skipped):") | ||
| for name in sorted(errors.keys()): | ||
| print(f"- {name}: {errors[name]}") | ||
| return 1 | ||
| return 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| raise SystemExit(main()) | ||
146 changes: 146 additions & 0 deletions
146
tests/integration/test_official_models_output_regression.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """Regression tests for official model checkpoint outputs on labeled audio. | ||
|
|
||
| This test complements checksum verification by asserting that official checkpoints | ||
| also produce stable numerical outputs on a deterministic labeled mini-batch. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
|
|
||
| from avex import load_model | ||
| from avex.models.utils.registry import get_checkpoint_path, list_models | ||
|
|
||
| # Expected pooled-output fingerprints from deterministic labeled mini-batch. | ||
| # Fingerprint is SHA-256 of np.round(output, 4).tobytes(). | ||
| OFFICIAL_MODEL_OUTPUT_FINGERPRINTS: dict[str, str] = { | ||
| "esp_aves2_eat_all": "d5d462c560352c1c3c9f498a0951f56ec9924e50f8fe1f0f0a4d285e316c17c8", | ||
| "esp_aves2_eat_bio": "d5d462c560352c1c3c9f498a0951f56ec9924e50f8fe1f0f0a4d285e316c17c8", | ||
| "esp_aves2_effnetb0_all": "7f1e8cc046287f79a3a2b7413042ff121a3f32c115cf3a487d2b5348e09a4931", | ||
| "esp_aves2_effnetb0_audioset": "8ba36f99b5e8245d7b61fc472339f5760fabca19d63a51e835309c11a379eab6", | ||
| "esp_aves2_effnetb0_bio": "c91dde6bee57788951a0fb9044703d301cb295e83fdc5e064874b63c99c70493", | ||
| "esp_aves2_naturelm_audio_v1_beats": "c1689532213d32cc16b0f7eb1774239c4d4bbd91a0500b551d4468acf52cb9d1", | ||
| "esp_aves2_sl_beats_all": "b6231fdcb855734ebfddf26e793a46d8e4b3bf61ee950273fdd85affcf85eefe", | ||
| "esp_aves2_sl_beats_bio": "1ad22272d36f3e74d64c5fb98ec31810c9281c1c32e9a2178f10c08004c8bcd6", | ||
| "esp_aves2_sl_eat_all_ssl_all": "0832f0c78523167e0a5439b9a4e96caf115131118549ff9161a01bd6d03a5b2e", | ||
| "esp_aves2_sl_eat_bio_ssl_all": "a9302a12a55bb6c1379b2dc42a22c15150eab12d039f7ad8c8d793a5dc31af70", | ||
| } | ||
|
|
||
| _HF_PREFIX = "hf://" | ||
|
|
||
|
|
||
| def _official_hf_model_names() -> list[str]: | ||
| """Return official ESP model names with HF-backed checkpoints. | ||
|
|
||
| Returns: | ||
| Sorted list of registry model names whose checkpoint paths start with | ||
| the Hugging Face URI prefix. | ||
| """ | ||
| names: list[str] = [] | ||
| for model_name in list_models().keys(): | ||
| if not model_name.startswith("esp_"): | ||
| continue | ||
| checkpoint_path = get_checkpoint_path(model_name) | ||
| if checkpoint_path is not None and checkpoint_path.startswith(_HF_PREFIX): | ||
| names.append(model_name) | ||
| return sorted(names) | ||
|
|
||
|
|
||
| def _build_labeled_audio_batch() -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Build deterministic labeled mini-batch with three synthetic classes. | ||
|
|
||
| Returns: | ||
| Tuple of `(audio, labels)` where audio has shape `(6, 16000)` and labels | ||
| has shape `(6,)`. | ||
| """ | ||
| torch.manual_seed(7) | ||
nkundiushuti marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| sample_rate = 16_000 | ||
| duration_seconds = 1 | ||
| t = torch.linspace(0.0, float(duration_seconds), steps=sample_rate, dtype=torch.float32) | ||
nkundiushuti marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| freqs = (220.0, 440.0, 880.0) | ||
|
|
||
| clips: list[torch.Tensor] = [] | ||
| labels: list[int] = [] | ||
| for class_index, freq in enumerate(freqs): | ||
| base = torch.sin(2.0 * torch.pi * freq * t) | ||
| for amplitude in (0.8, 0.9): | ||
| clips.append((amplitude * base).to(torch.float32)) | ||
| labels.append(class_index) | ||
|
|
||
| return torch.stack(clips, dim=0), torch.tensor(labels, dtype=torch.long) | ||
|
|
||
|
|
||
| def _pooled_model_output(model_name: str, audio: torch.Tensor) -> torch.Tensor: | ||
| """Load model and produce pooled clip-level outputs. | ||
|
|
||
| Args: | ||
| model_name: Official model registry key. | ||
| audio: Input batch shaped `(B, T)`. | ||
|
|
||
| Returns: | ||
| Tensor shaped `(B, D)` after temporal/spatial pooling when needed. | ||
|
|
||
| Raises: | ||
| ValueError: If model output tensor rank is not 2, 3, or 4. | ||
| """ | ||
| model = load_model(model_name, device="cpu", return_features_only=True) | ||
| model.eval() | ||
|
|
||
| with torch.no_grad(): | ||
| output = model(audio) | ||
|
|
||
| if output.dim() == 2: | ||
| return output | ||
| if output.dim() == 3: | ||
| return output.mean(dim=1) | ||
| if output.dim() == 4: | ||
| return output.mean(dim=(2, 3)) | ||
| raise ValueError(f"Unsupported output rank for {model_name}: shape={tuple(output.shape)}") | ||
|
|
||
|
|
||
| @pytest.mark.slow | ||
| class TestOfficialModelsOutputRegression: | ||
| """Regression tests for official model numerical output stability.""" | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _ensure_registry_initialized(self) -> None: | ||
| """Ensure model registry is populated before each test.""" | ||
| from avex.models.utils import registry | ||
|
|
||
| registry.initialize_registry() | ||
|
|
||
| def test_reference_table_covers_all_official_hf_models(self) -> None: | ||
| """Ensure every official HF model has an expected output fingerprint.""" | ||
| official = set(_official_hf_model_names()) | ||
| expected = set(OFFICIAL_MODEL_OUTPUT_FINGERPRINTS.keys()) | ||
| assert expected == official, ( | ||
| "Fingerprint table mismatch. Update OFFICIAL_MODEL_OUTPUT_FINGERPRINTS to " | ||
| f"exactly match official HF models.\nExpected-only: {sorted(expected - official)}\n" | ||
| f"Official-only: {sorted(official - expected)}" | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("model_name", sorted(OFFICIAL_MODEL_OUTPUT_FINGERPRINTS.keys())) | ||
| def test_official_model_output_matches_expected_fingerprint(self, model_name: str) -> None: | ||
| """Assert model output fingerprint matches expected reference value.""" | ||
| audio, labels = _build_labeled_audio_batch() | ||
| assert labels.shape[0] == audio.shape[0], "Labeled batch must align audio and labels." | ||
nkundiushuti marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| try: | ||
| pooled = _pooled_model_output(model_name, audio) | ||
| except Exception as exc: # pragma: no cover - network/model availability | ||
nkundiushuti marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pytest.skip(f"Unable to load/run model {model_name!r}: {exc}") | ||
|
|
||
| pooled_np = pooled.detach().cpu().to(torch.float32).numpy() | ||
| rounded = np.round(pooled_np, 4) | ||
| digest = hashlib.sha256(rounded.tobytes()).hexdigest() | ||
| expected_digest = OFFICIAL_MODEL_OUTPUT_FINGERPRINTS[model_name] | ||
|
|
||
| assert digest == expected_digest, ( | ||
| f"Output fingerprint mismatch for {model_name}. " | ||
| f"Expected {expected_digest}, got {digest}. " | ||
| "If checkpoint changed intentionally, regenerate and update reference." | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.