Skip to content

Commit 0de15c9

Browse files
Fix Audio Classification Pipeline top_k Documentation Mismatch and Bug #35736 (#35771)
* added condition for top_k Doc mismatch fix * initilation of test file for top_k changes * added test for returning all labels * added test for few labels * tests/test_audio_classification_top_k.py * final fix * ruff fix --------- Co-authored-by: sambhavnoobcoder <[email protected]>
1 parent 694aaa7 commit 0de15c9

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

src/transformers/pipelines/audio_classification.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,11 @@ class AudioClassificationPipeline(Pipeline):
9191
"""
9292

9393
def __init__(self, *args, **kwargs):
94-
# Default, might be overriden by the model.config.
95-
kwargs["top_k"] = kwargs.get("top_k", 5)
94+
# Only set default top_k if explicitly provided
95+
if "top_k" in kwargs and kwargs["top_k"] is None:
96+
kwargs["top_k"] = None
97+
elif "top_k" not in kwargs:
98+
kwargs["top_k"] = 5
9699
super().__init__(*args, **kwargs)
97100

98101
if self.framework != "pt":
@@ -141,12 +144,16 @@ def __call__(
141144
return super().__call__(inputs, **kwargs)
142145

143146
def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
144-
# No parameters on this pipeline right now
145147
postprocess_params = {}
146-
if top_k is not None:
148+
149+
# If top_k is None, use all labels
150+
if top_k is None:
151+
postprocess_params["top_k"] = self.model.config.num_labels
152+
else:
147153
if top_k > self.model.config.num_labels:
148154
top_k = self.model.config.num_labels
149155
postprocess_params["top_k"] = top_k
156+
150157
if function_to_apply is not None:
151158
if function_to_apply not in ["softmax", "sigmoid", "none"]:
152159
raise ValueError(
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
from transformers import pipeline
6+
from transformers.testing_utils import require_torch
7+
8+
9+
@require_torch
10+
class AudioClassificationTopKTest(unittest.TestCase):
11+
def test_top_k_none_returns_all_labels(self):
12+
model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels
13+
classification_pipeline = pipeline(
14+
"audio-classification",
15+
model=model_name,
16+
top_k=None,
17+
)
18+
19+
# Create dummy input
20+
sampling_rate = 16000
21+
signal = np.zeros((sampling_rate,), dtype=np.float32)
22+
23+
result = classification_pipeline(signal)
24+
num_labels = classification_pipeline.model.config.num_labels
25+
26+
self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None")
27+
28+
def test_top_k_none_with_few_labels(self):
29+
model_name = "superb/hubert-base-superb-er" # model with fewer labels
30+
classification_pipeline = pipeline(
31+
"audio-classification",
32+
model=model_name,
33+
top_k=None,
34+
)
35+
36+
# Create dummy input
37+
sampling_rate = 16000
38+
signal = np.zeros((sampling_rate,), dtype=np.float32)
39+
40+
result = classification_pipeline(signal)
41+
num_labels = classification_pipeline.model.config.num_labels
42+
43+
self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly")
44+
45+
def test_top_k_greater_than_labels(self):
46+
model_name = "superb/hubert-base-superb-er"
47+
classification_pipeline = pipeline(
48+
"audio-classification",
49+
model=model_name,
50+
top_k=100, # intentionally large number
51+
)
52+
53+
# Create dummy input
54+
sampling_rate = 16000
55+
signal = np.zeros((sampling_rate,), dtype=np.float32)
56+
57+
result = classification_pipeline(signal)
58+
num_labels = classification_pipeline.model.config.num_labels
59+
60+
self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")

0 commit comments

Comments
 (0)