@@ -34,16 +34,14 @@ def load_audio_file(audio_path: str, sample_rate: int = 16000) -> np.ndarray:
3434        return  audio 
3535    except  ImportError :
3636        print ("librosa not available, using dummy audio data" )
37-         # Generate dummy audio for testing 
38-         return  np .random .randn (sample_rate  *  3 )  # 3 seconds of random audio 
37+         return  np .random .randn (sample_rate  *  3 )
3938
4039
4140def  extract_mel_features_whisper (audio : np .ndarray , sample_rate : int  =  16000 ) ->  np .ndarray :
4241    """Extract mel spectrogram features from audio using Whisper's exact preprocessing.""" 
4342    try :
4443        import  librosa 
4544
46-         # Whisper's exact mel spectrogram parameters 
4745        mel_spec  =  librosa .feature .melspectrogram (
4846            y = audio ,
4947            sr = sample_rate ,
@@ -56,21 +54,17 @@ def extract_mel_features_whisper(audio: np.ndarray, sample_rate: int = 16000) ->
5654            power = 2.0 
5755        )
5856
59-         # Convert to log10 scale (Whisper's approach) 
6057        log_spec  =  np .log10 (mel_spec  +  1e-10 )
6158
62-         # Clip values (Whisper's approach) 
6359        log_spec  =  np .maximum (log_spec , log_spec .max () -  8.0 )
6460
65-         # Normalize (Whisper's approach) 
6661        log_spec  =  (log_spec  +  4.0 ) /  4.0 
6762
68-         return  log_spec .T    # Transpose to (time, n_mels) 
63+         return  log_spec .T 
6964
7065    except  ImportError :
7166        print ("librosa not available, using dummy mel features" )
72-         # Generate dummy mel features for testing 
73-         time_steps  =  len (audio ) //  160   # Approximate time steps 
67+         time_steps  =  len (audio ) //  160 
7468        return  np .random .randn (time_steps , 80 )
7569
7670
@@ -80,64 +74,50 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
8074    if  MODEL_CP_PATH  is  None :
8175        MODEL_CP_PATH  =  "/tmp/models-bonsai/"  +  model_name .split ("/" )[1 ]
8276
83-     # Download model if not present 
8477    if  not  os .path .isdir (MODEL_CP_PATH ):
8578        print (f"Downloading { model_name }   to { MODEL_CP_PATH }  ..." )
8679        snapshot_download (model_name , local_dir = MODEL_CP_PATH )
8780
88-     # Load audio 
8981    if  audio_path  is  None :
9082        print ("No audio file provided, using test speech sample" )
91-         # Use the Bush speech sample 
9283        audio_path  =  Path (__file__ ).parent  /  "audio_samples"  /  "bush_speech.wav" 
9384        print (f"Using default audio: { audio_path }  " )
9485
9586    print (f"Loading audio from { audio_path }  " )
9687    audio  =  load_audio_file (str (audio_path ))
9788
98-     # Extract mel features using Whisper's exact preprocessing 
9989    mel_features  =  extract_mel_features_whisper (audio )
100-     # Convert to JAX array after mel extraction 
10190    mel_features  =  jnp .array (mel_features )
10291    print (f"Mel features shape: { mel_features .shape }  " )
10392
104-     # Pad or truncate to expected length (Whisper uses 3000 for full context) 
105-     max_time_steps  =  3000   # Whisper's full audio context length 
93+     max_time_steps  =  3000 
10694    if  mel_features .shape [0 ] >  max_time_steps :
10795        mel_features  =  mel_features [:max_time_steps ]
10896    else :
109-         # Pad with zeros using JAX 
11097        padding  =  jnp .zeros ((max_time_steps  -  mel_features .shape [0 ], mel_features .shape [1 ]))
11198        mel_features  =  jnp .concatenate ([mel_features , padding ], axis = 0 )
11299
113-     # Add batch dimension and transpose to (batch, n_mels, time) 
114-     mel_features  =  mel_features [None , ...].transpose (0 , 2 , 1 )  # (1, n_mels, time) 
100+     mel_features  =  mel_features [None , ...].transpose (0 , 2 , 1 )
115101    mel_features  =  jnp .array (mel_features )
116102
117-     # Create model from pretrained weights 
118103    config  =  modeling .WhisperConfig .whisper_tiny ()
119104    model  =  params .create_model_from_safe_tensors (MODEL_CP_PATH , config )
120105    print ("Loaded pretrained Whisper weights" )
121106
122-     # Create dummy tokens for testing (BOS token) 
123-     tokens  =  jnp .array ([[50258 ]])  # BOS token for Whisper 
107+     tokens  =  jnp .array ([[50258 ]])
124108
125109    print ("Running Whisper model..." )
126110
127-     # Time the forward pass 
128111    start_time  =  time .perf_counter ()
129112
130-     # Run forward pass 
131113    logits  =  model (mel_features , tokens )
132114
133-     # Block until computation is complete 
134115    jax .block_until_ready (logits )
135116
136117    end_time  =  time .perf_counter ()
137118    print (f"Forward pass completed in { end_time  -  start_time :.4f}   seconds" )
138119    print (f"Output logits shape: { logits .shape }  " )
139120
140-     # Test generation 
141121    print ("Testing text generation..." )
142122    start_time  =  time .perf_counter ()
143123
@@ -147,21 +127,17 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
147127    end_time  =  time .perf_counter ()
148128    print (f"Generation completed in { end_time  -  start_time :.4f}   seconds" )
149129
150-     # Simple token-to-text decoding (using Whisper vocabulary) 
151130    print ("Decoding transcription..." )
152131    try :
153-         # Load vocabulary from the model files 
154132        import  json 
155133        vocab_path  =  os .path .join (MODEL_CP_PATH , "tokenizer.json" )
156134        if  os .path .exists (vocab_path ):
157135            with  open (vocab_path , 'r' ) as  f :
158136                tokenizer_data  =  json .load (f )
159137                vocab_data  =  tokenizer_data ['model' ]['vocab' ]
160138
161-             # Create reverse mapping: token_id -> text 
162139            vocab  =  {int (token_id ): text  for  text , token_id  in  vocab_data .items ()}
163140
164-             # Add special tokens 
165141            special_tokens  =  {
166142                50258 : "<|startoftranscript|>" ,
167143                50259 : "<|en|>" , 
@@ -171,19 +147,15 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
171147            }
172148            vocab .update (special_tokens )
173149
174-             # Decode tokens to text 
175150            decoded_text  =  "" 
176151            for  token  in  generated_tokens [0 ]:
177152                token_id  =  int (token )
178153                if  token_id  in  vocab :
179154                    text  =  vocab [token_id ]
180-                     # Skip special tokens for clean output 
181155                    if  not  (text .startswith ("<|" ) and  text .endswith ("|>" )):
182-                         # Replace BPE space marker with actual space 
183156                        text  =  text .replace ("Ġ" , " " )
184157                        decoded_text  +=  text 
185158                else :
186-                     # For unknown tokens, show the token ID 
187159                    decoded_text  +=  f"[{ token_id }  ]" 
188160
189161            print (f"Transcription: { decoded_text .strip ()}  " )
@@ -195,11 +167,9 @@ def run_model(MODEL_CP_PATH: Optional[str] = None, audio_path: Optional[str] = N
195167        print (f"Could not decode tokens: { e }  " )
196168        print (f"Generated tokens: { generated_tokens [0 ]}  " )
197169
198-     # Test with JAX profiling 
199170    print ("Running with JAX profiling..." )
200171    jax .profiler .start_trace ("/tmp/profile-data" )
201172
202-     # Run a few iterations for profiling 
203173    for  i  in  range (5 ):
204174        logits  =  model (mel_features , tokens )
205175        jax .block_until_ready (logits )
0 commit comments