Skip to content

Commit f1dd038

Browse files
committed
Remove obvious comments
Signed-off-by: Vladimir Suvorov <[email protected]>
1 parent f3a0e45 commit f1dd038

File tree

3 files changed

+13
-114
lines changed

3 files changed

+13
-114
lines changed

bonsai/models/whisper/tests/run_model.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,14 @@ def load_audio_file(audio_path: str, sample_rate: int = 16000) -> np.ndarray:
3434
return audio
3535
except ImportError:
3636
print("librosa not available, using dummy audio data")
37-
# Generate dummy audio for testing
38-
return np.random.randn(sample_rate * 3) # 3 seconds of random audio
37+
return np.random.randn(sample_rate * 3)
3938

4039

4140
def extract_mel_features_whisper(audio: np.ndarray, sample_rate: int = 16000) -> np.ndarray:
4241
"""Extract mel spectrogram features from audio using Whisper's exact preprocessing."""
4342
try:
4443
import librosa
4544

46-
# Whisper's exact mel spectrogram parameters
4745
mel_spec = librosa.feature.melspectrogram(
4846
y=audio,
4947
sr=sample_rate,
@@ -56,21 +54,17 @@ def extract_mel_features_whisper(audio: np.ndarray, sample_rate: int = 16000) ->
5654
power=2.0
5755
)
5856

59-
# Convert to log10 scale (Whisper's approach)
6057
log_spec = np.log10(mel_spec + 1e-10)
6158

62-
# Clip values (Whisper's approach)
6359
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
6460

65-
# Normalize (Whisper's approach)
6661
log_spec = (log_spec + 4.0) / 4.0
6762

68-
return log_spec.T # Transpose to (time, n_mels)
63+
return log_spec.T
6964

7065
except ImportError:
7166
print("librosa not available, using dummy mel features")
72-
# Generate dummy mel features for testing
73-
time_steps = len(audio) // 160 # Approximate time steps
67+
time_steps = len(audio) // 160
7468
return np.random.randn(time_steps, 80)
7569

7670

@@ -80,64 +74,50 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
8074
if MODEL_CP_PATH is None:
8175
MODEL_CP_PATH = "/tmp/models-bonsai/" + model_name.split("/")[1]
8276

83-
# Download model if not present
8477
if not os.path.isdir(MODEL_CP_PATH):
8578
print(f"Downloading {model_name} to {MODEL_CP_PATH}...")
8679
snapshot_download(model_name, local_dir=MODEL_CP_PATH)
8780

88-
# Load audio
8981
if audio_path is None:
9082
print("No audio file provided, using test speech sample")
91-
# Use the Bush speech sample
9283
audio_path = Path(__file__).parent / "audio_samples" / "bush_speech.wav"
9384
print(f"Using default audio: {audio_path}")
9485

9586
print(f"Loading audio from {audio_path}")
9687
audio = load_audio_file(str(audio_path))
9788

98-
# Extract mel features using Whisper's exact preprocessing
9989
mel_features = extract_mel_features_whisper(audio)
100-
# Convert to JAX array after mel extraction
10190
mel_features = jnp.array(mel_features)
10291
print(f"Mel features shape: {mel_features.shape}")
10392

104-
# Pad or truncate to expected length (Whisper uses 3000 for full context)
105-
max_time_steps = 3000 # Whisper's full audio context length
93+
max_time_steps = 3000
10694
if mel_features.shape[0] > max_time_steps:
10795
mel_features = mel_features[:max_time_steps]
10896
else:
109-
# Pad with zeros using JAX
11097
padding = jnp.zeros((max_time_steps - mel_features.shape[0], mel_features.shape[1]))
11198
mel_features = jnp.concatenate([mel_features, padding], axis=0)
11299

113-
# Add batch dimension and transpose to (batch, n_mels, time)
114-
mel_features = mel_features[None, ...].transpose(0, 2, 1) # (1, n_mels, time)
100+
mel_features = mel_features[None, ...].transpose(0, 2, 1)
115101
mel_features = jnp.array(mel_features)
116102

117-
# Create model from pretrained weights
118103
config = modeling.WhisperConfig.whisper_tiny()
119104
model = params.create_model_from_safe_tensors(MODEL_CP_PATH, config)
120105
print("Loaded pretrained Whisper weights")
121106

122-
# Create dummy tokens for testing (BOS token)
123-
tokens = jnp.array([[50258]]) # BOS token for Whisper
107+
tokens = jnp.array([[50258]])
124108

125109
print("Running Whisper model...")
126110

127-
# Time the forward pass
128111
start_time = time.perf_counter()
129112

130-
# Run forward pass
131113
logits = model(mel_features, tokens)
132114

133-
# Block until computation is complete
134115
jax.block_until_ready(logits)
135116

136117
end_time = time.perf_counter()
137118
print(f"Forward pass completed in {end_time - start_time:.4f} seconds")
138119
print(f"Output logits shape: {logits.shape}")
139120

140-
# Test generation
141121
print("Testing text generation...")
142122
start_time = time.perf_counter()
143123

@@ -147,21 +127,17 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
147127
end_time = time.perf_counter()
148128
print(f"Generation completed in {end_time - start_time:.4f} seconds")
149129

150-
# Simple token-to-text decoding (using Whisper vocabulary)
151130
print("Decoding transcription...")
152131
try:
153-
# Load vocabulary from the model files
154132
import json
155133
vocab_path = os.path.join(MODEL_CP_PATH, "tokenizer.json")
156134
if os.path.exists(vocab_path):
157135
with open(vocab_path, 'r') as f:
158136
tokenizer_data = json.load(f)
159137
vocab_data = tokenizer_data['model']['vocab']
160138

161-
# Create reverse mapping: token_id -> text
162139
vocab = {int(token_id): text for text, token_id in vocab_data.items()}
163140

164-
# Add special tokens
165141
special_tokens = {
166142
50258: "<|startoftranscript|>",
167143
50259: "<|en|>",
@@ -171,19 +147,15 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
171147
}
172148
vocab.update(special_tokens)
173149

174-
# Decode tokens to text
175150
decoded_text = ""
176151
for token in generated_tokens[0]:
177152
token_id = int(token)
178153
if token_id in vocab:
179154
text = vocab[token_id]
180-
# Skip special tokens for clean output
181155
if not (text.startswith("<|") and text.endswith("|>")):
182-
# Replace BPE space marker with actual space
183156
text = text.replace("Ġ", " ")
184157
decoded_text += text
185158
else:
186-
# For unknown tokens, show the token ID
187159
decoded_text += f"[{token_id}]"
188160

189161
print(f"Transcription: {decoded_text.strip()}")
@@ -195,11 +167,9 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
195167
print(f"Could not decode tokens: {e}")
196168
print(f"Generated tokens: {generated_tokens[0]}")
197169

198-
# Test with JAX profiling
199170
print("Running with JAX profiling...")
200171
jax.profiler.start_trace("/tmp/profile-data")
201172

202-
# Run a few iterations for profiling
203173
for i in range(5):
204174
logits = model(mel_features, tokens)
205175
jax.block_until_ready(logits)

bonsai/models/whisper/tests/test_structure.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def test_imports():
1212
"""Test if all required modules can be imported."""
1313
print("Testing imports...")
1414

15-
# Test basic Python modules
1615
try:
1716
import numpy as np
1817
print("✓ numpy imported successfully")
@@ -27,7 +26,6 @@ def test_imports():
2726
print(f"✗ transformers import failed: {e}")
2827
return False
2928

30-
# Test JAX-related modules (optional)
3129
try:
3230
import jax
3331
print("✓ jax imported successfully")
@@ -74,11 +72,9 @@ def test_modeling_structure():
7472
print("\nTesting modeling.py structure...")
7573

7674
try:
77-
# Try to read the file and check for key components
7875
with open("bonsai/models/whisper/modeling.py", "r") as f:
7976
content = f.read()
8077

81-
# Check for key classes and functions
8278
key_components = [
8379
"class WhisperConfig",
8480
"class MultiHeadAttention",
@@ -184,7 +180,6 @@ def main():
184180
print(f"✗ Test failed with error: {e}")
185181
results.append((test_name, False))
186182

187-
# Summary
188183
print("\n" + "=" * 60)
189184
print("SUMMARY")
190185
print("=" * 60)

0 commit comments

Comments
 (0)