-
Notifications
You must be signed in to change notification settings - Fork 16
Whisper preliminary implementation #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
bonsai/models/whisper/README.md
Outdated
|
||
Whisper is a general-purpose speech recognition model that can transcribe audio in multiple languages and perform various speech recognition tasks including transcription, translation, and language identification. | ||
|
||
**🚀 Status: Production Ready** - The model successfully loads pretrained weights and performs accurate speech transcription with no HuggingFace dependencies during inference. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please omit this statement, Bonsai is intended to be academia / general audience targeting model repository.
A production ready model must be proven to be extremely performant in various large cluster HW settings, which this model is not proven to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
bonsai/models/whisper/README.md
Outdated
- **Robust**: Handles various accents, background noise, and technical language | ||
- **Flexible**: Can perform transcription, translation, and language identification | ||
- **Efficient**: Optimized for JAX with JIT compilation | ||
- **Production Ready**: Successfully loads pretrained weights and generates accurate transcriptions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please omit, bonsai model implementations are focused on feature parity with original model implementations and the exact capabilities of models need not be highlighted here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
@@ -0,0 +1,46 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file will be helpful if we were to do quality checks with golden logits comparison between bonsai impl vs. original HF impl, but should not be necessary otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -0,0 +1,148 @@ | |||
# Whisper Model Installation Guide |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this file is needed, and if we have JAX-specific debugging tips it should be at the root bonsai documents.
Individual library installation specifications can be in the root pyproject.toml file.
Much of these should be self-explanatory from the code. (ex: The first time you run the model, it will automatically download the Whisper weights from HuggingFace.
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -0,0 +1,3 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for 3 blank lines, please keep empty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used for actual imports
bonsai/models/whisper/modeling.py
Outdated
# x shape: (batch, n_mels, time) -> (batch, time, n_mels) for Conv1D | ||
x = x.transpose(0, 2, 1) | ||
|
||
# Conv stack with GELU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please omit self evident comments and unnecessary spaces, we want concise, readable code with shorter loc as much as possible
bonsai/models/whisper/modeling.py
Outdated
return model(mel_features, tokens, mask) | ||
|
||
|
||
def generate(model: WhisperModel, mel_features: Array, max_length: int = 448, temperature: float = 0.0) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please keep consistency (ex: naming, style, structure) with qwen example.
ex: this functionally should be forward function.
|
||
return model(mel_features, tokens, mask) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please apply jax.jit here, you may need to separate out the state as the qwen example rather than passing the model as a whole.
bonsai/models/whisper/modeling.py
Outdated
repetition_window = 10 | ||
|
||
for step in range(max_length - len(prompt_tokens[0])): | ||
# Create causal mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for removing all unnecessary comments and newlines
bonsai/models/whisper/modeling.py
Outdated
|
||
def __call__(self, mel_features: Array, tokens: Array, mask: Optional[Array] = None) -> Array: | ||
# Encode audio | ||
xa = self.encoder(mel_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please avoid abbreviations for non obvious names
bonsai/models/whisper/modeling.py
Outdated
if len(recent_tokens) >= 6: | ||
last_3 = recent_tokens[-3:] | ||
prev_3 = recent_tokens[-6:-3] | ||
if jnp.array_equal(last_3, prev_3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this stopping logic part of the original whisper code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. Changed to original greedy + stop token
6554628
to
51cdea5
Compare
bonsai/models/whisper/tokenizer.py
Outdated
# Use HuggingFace tokenizer instead of tiktoken | ||
from transformers import WhisperTokenizer | ||
|
||
LANGUAGES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? We can just call whichever language directly via WhisperTokenizer(language=...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
# Generate tokens | ||
# Initialize HuggingFace tokenizer | ||
print(f"\n🔤 Initializing HuggingFace tokenizer...") | ||
tokenizer_instance = get_tokenizer(multilingual=True, language="en", task="transcribe") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just directly use WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="english", task="transcribe")
here.
If we're only using certain tokenizer imports (ex: tokenizer.sot, tokenizer.to_language_token("en"), tokenizer.transcribe, tokenizer.no_timestamps) once, can we just hardcode the variables (ex: <|startoftext|>
) in this file? We want to make these implementations as simple and hackable as possible with minimal dependencies that are not directly used.
https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperTokenizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes called it directly
bonsai/models/whisper/tokenizer.py
Outdated
@@ -0,0 +1,381 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file a modue or a test file? I don't think we should have the test logic and function class together in the same file.
Also, I don't think most of the functions here are used (ex: decode_with_timestamps
, etc).
Could we remove this file and call WhisperTokenizer
directly? We'd like to not add an unnecessary abstraction, if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this completely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be still here, could we remove this?
We are using bush_speech as an example for decoding. It's CC labeled. |
Can you also add a quality logits checker (ex: test_outputs.py) to verify that the resulting logits are correct? Also, it seems some comments are yet to be addressed (ex: unnecessary key mapping in tokenizer, need to remove data files and instead refer via wget) -- please let us know when it is ready to be reviewed so that we can get whisper checked in :) |
8609f32
to
3724566
Compare
e2106d3
to
d6c5e91
Compare
- Implemented complete JAX Whisper model with real PyTorch weights - Added audio processing with mel spectrogram computation - Created clean modeling.py with WhisperConfig class methods - Added comprehensive test suite with JAX vs PyTorch comparison - Supports tiny, base, small, medium, large model sizes - Includes real audio transcription example with Bush Moscow speech - All tests pass with 81%+ similarity to PyTorch implementation
d6c5e91
to
ff950b1
Compare
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Trying to port Whisper model to NNX
Includes the example for running the Whisper model for the G. Bush speech transcription.