Skip to content

Commit

Permalink
add translate, fix word_timestamp error
Browse files Browse the repository at this point in the history
m-bain committed May 13, 2023
1 parent 4603f01 commit fd8f100
Showing 5 changed files with 17 additions and 8 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -32,12 +32,12 @@
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->


This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.

- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
- 👯‍♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID)
- 👯‍♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation


@@ -75,9 +75,9 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst

### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:

`pip3 install torch torchvision torchaudio`
`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia`

See other methods [here.](https://pytorch.org/get-started/locally/)
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)

### 3. Install this repo

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.1.0",
version="3.1.1",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
4 changes: 4 additions & 0 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
@@ -259,6 +259,10 @@ def align(
word_text = "".join(word_chars["char"].tolist()).strip()
if len(word_text) == 0:
continue

# dont use space character for alignment
word_chars = word_chars[word_chars["char"] != " "]

word_start = word_chars["start"].min()
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
4 changes: 2 additions & 2 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@


def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None):
vad_options=None, model=None, task="transcribe"):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
@@ -31,7 +31,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l

model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
7 changes: 6 additions & 1 deletion whisperx/transcribe.py
Original file line number Diff line number Diff line change
@@ -86,6 +86,11 @@ def cli():
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task : str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True

return_char_alignments: bool = args.pop("return_char_alignments")

hf_token: str = args.pop("hf_token")
@@ -139,7 +144,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)

for audio_path in args.pop("audio"):
audio = load_audio(audio_path)

0 comments on commit fd8f100

Please sign in to comment.