From fd8f1003cf6589f0962ec0ed7bdc709e1ec4d9a2 Mon Sep 17 00:00:00 2001 From: Max Bain <maxbain@robots.ox.ac.uk> Date: Sat, 13 May 2023 12:14:06 +0100 Subject: [PATCH] add translate, fix word_timestamp error --- README.md | 8 ++++---- setup.py | 2 +- whisperx/alignment.py | 4 ++++ whisperx/asr.py | 4 ++-- whisperx/transcribe.py | 7 ++++++- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index bccbca8b..d76a0701 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,12 @@ <!-- <h2 align="left", id="what-is-it">What is it đ</h2> --> -This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization. +This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization. - âĄī¸ Batched inference for 70x realtime transcription using whisper large-v2 - đĒļ [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5 - đ¯ Accurate word-level timestamps using wav2vec2 alignment -- đ¯ââī¸ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID) +- đ¯ââī¸ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels) - đŖī¸ VAD preprocessing, reduces hallucination & batching with no WER degradation @@ -75,9 +75,9 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst ### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7: -`pip3 install torch torchvision torchaudio` +`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia` -See other methods [here.](https://pytorch.org/get-started/locally/) +See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200) ### 3. Install this repo diff --git a/setup.py b/setup.py index eea26adf..1e1e6f5e 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="whisperx", py_modules=["whisperx"], - version="3.1.0", + version="3.1.1", description="Time-Accurate Automatic Speech Recognition using Whisper.", readme="README.md", python_requires=">=3.8", diff --git a/whisperx/alignment.py b/whisperx/alignment.py index b8734753..1e22a7b1 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -259,6 +259,10 @@ def align( word_text = "".join(word_chars["char"].tolist()).strip() if len(word_text) == 0: continue + + # dont use space character for alignment + word_chars = word_chars[word_chars["char"] != " "] + word_start = word_chars["start"].min() word_end = word_chars["end"].max() word_score = round(word_chars["score"].mean(), 3) diff --git a/whisperx/asr.py b/whisperx/asr.py index f2c54f6c..66b58ad3 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -14,7 +14,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, - vad_options=None, model=None): + vad_options=None, model=None, task="transcribe"): '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -31,7 +31,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) if language is not None: - tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language) + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: print("No language specified, language will be first be detected for each audio file (increases inference time).") tokenizer = None diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index d09c5f66..3edc746d 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -86,6 +86,11 @@ def cli(): align_model: str = args.pop("align_model") interpolate_method: str = args.pop("interpolate_method") no_align: bool = args.pop("no_align") + task : str = args.pop("task") + if task == "translate": + # translation cannot be aligned + no_align = True + return_char_alignments: bool = args.pop("return_char_alignments") hf_token: str = args.pop("hf_token") @@ -139,7 +144,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},) + model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) for audio_path in args.pop("audio"): audio = load_audio(audio_path)