From fd8f1003cf6589f0962ec0ed7bdc709e1ec4d9a2 Mon Sep 17 00:00:00 2001
From: Max Bain <maxbain@robots.ox.ac.uk>
Date: Sat, 13 May 2023 12:14:06 +0100
Subject: [PATCH] add translate, fix word_timestamp error

---
 README.md              | 8 ++++----
 setup.py               | 2 +-
 whisperx/alignment.py  | 4 ++++
 whisperx/asr.py        | 4 ++--
 whisperx/transcribe.py | 7 ++++++-
 5 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/README.md b/README.md
index bccbca8b..d76a0701 100644
--- a/README.md
+++ b/README.md
@@ -32,12 +32,12 @@
 <!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
 
 
-This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
+This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
 
 - ⚡ī¸ Batched inference for 70x realtime transcription using whisper large-v2
 - đŸĒļ [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
 - đŸŽ¯ Accurate word-level timestamps using wav2vec2 alignment
-- đŸ‘¯â€â™‚ī¸ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID) 
+- đŸ‘¯â€â™‚ī¸ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels) 
 - đŸ—Ŗī¸ VAD preprocessing, reduces hallucination & batching with no WER degradation
 
 
@@ -75,9 +75,9 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
 
 ### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
 
-`pip3 install torch torchvision torchaudio`
+`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia`
 
-See other methods [here.](https://pytorch.org/get-started/locally/)
+See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
 
 ### 3. Install this repo
 
diff --git a/setup.py b/setup.py
index eea26adf..1e1e6f5e 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@
 setup(
     name="whisperx",
     py_modules=["whisperx"],
-    version="3.1.0",
+    version="3.1.1",
     description="Time-Accurate Automatic Speech Recognition using Whisper.",
     readme="README.md",
     python_requires=">=3.8",
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index b8734753..1e22a7b1 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -259,6 +259,10 @@ def align(
                 word_text = "".join(word_chars["char"].tolist()).strip()
                 if len(word_text) == 0:
                     continue
+
+                # dont use space character for alignment
+                word_chars = word_chars[word_chars["char"] != " "]
+
                 word_start = word_chars["start"].min()
                 word_end = word_chars["end"].max()
                 word_score = round(word_chars["score"].mean(), 3)
diff --git a/whisperx/asr.py b/whisperx/asr.py
index f2c54f6c..66b58ad3 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -14,7 +14,7 @@
 
 
 def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
-               vad_options=None, model=None):
+               vad_options=None, model=None, task="transcribe"):
     '''Load a Whisper model for inference.
     Args:
         whisper_arch: str - The name of the Whisper model to load.
@@ -31,7 +31,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
 
     model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
     if language is not None:
-        tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
+        tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
     else:
         print("No language specified, language will be first be detected for each audio file (increases inference time).")
         tokenizer = None
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index d09c5f66..3edc746d 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -86,6 +86,11 @@ def cli():
     align_model: str = args.pop("align_model")
     interpolate_method: str = args.pop("interpolate_method")
     no_align: bool = args.pop("no_align")
+    task : str = args.pop("task")
+    if task == "translate":
+        # translation cannot be aligned
+        no_align = True
+
     return_char_alignments: bool = args.pop("return_char_alignments")
 
     hf_token: str = args.pop("hf_token")
@@ -139,7 +144,7 @@ def cli():
     results = []
     tmp_results = []
     # model = load_model(model_name, device=device, download_root=model_dir)
-    model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
+    model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
 
     for audio_path in args.pop("audio"):
         audio = load_audio(audio_path)