diff --git a/README.md b/README.md index 25a3f9c2..ae3d5bdf 100644 --- a/README.md +++ b/README.md @@ -13,36 +13,36 @@ GitHub license + + ArXiv paper + Twitter

-

- What is it • - Setup • - Usage • - Multilingual • - Contribute • - More examples • - Paper -

- whisperx-arch -

Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment. + + + + -

+This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization. -

What is it 🔎

+- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2 +- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5 +- 🎯 Accurate word-level timestamps using wav2vec2 alignment +- 👯‍♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID) +- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation -This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case. -**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. +**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching. **Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self). @@ -50,7 +50,7 @@ This repository refines the timestamps of openAI's Whisper model via forced alig **Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech. -

New🚨

+**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker. - v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*! - v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper. @@ -59,15 +59,39 @@ This repository refines the timestamps of openAI's Whisper model via forced alig - Character level timestamps (see `*.char.ass` file output) - Diarization (still in beta, add `--diarize`) +

New🚨

+ +- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization +- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend! +- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper. +- Paper drop🎓👨‍🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.

Setup ⚙️

-Install this package using +Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!) + +GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). + + +### 1. Create Python3.10 environment + +`conda create --name whisperx python=3.10` -`pip install git+https://github.com/m-bain/whisperx.git` +`conda activate whisperx` + + +### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7: + +`pip3 install torch torchvision torchaudio` + +See other methods [here.](https://pytorch.org/get-started/locally/) + +### 3. Install this repo + +`pip install git+https://github.com/m-bain/whisperx.git@v3` If already installed, update package to most recent commit -`pip install git+https://github.com/m-bain/whisperx.git --upgrade` +`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade` If wishing to modify this package, clone and install in editable mode: ``` @@ -78,23 +102,6 @@ $ pip install -e . You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. -### Docker -Alternatively, you can use the docker image provided in this repo. To build the image, run the following command from the root of this repo: -1. In this image you can find jupyter notebook where you can easily run and debug the code. -```bash -docker build -t whisperx . -``` -2. To run the image, run the following command: -```bash -docker run --gpus=all -it -v :/workspace -p 8888:8888 whisperx -``` -### Setup not working??? -Safest to use install pytorch as follows (for gpu) - -`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -c pytorch -` - - ### Speaker Diarization To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization) @@ -103,15 +110,11 @@ To **enable Speaker. Diarization**, include your Hugging Face access token that ### English -Run whisper on example segment (using default params) +Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file. whisperx examples/sample01.wav -For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g. - - whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H - Result using *WhisperX* with forced alignment to wav2vec2.0 large: https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4 @@ -120,6 +123,16 @@ Compare this to original whisper out the box, where many transcriptions are out https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov + +For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g. + + whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4 + + +To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`): + + whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True + ### Other languages The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22). @@ -129,7 +142,7 @@ Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}` #### E.g. German - whisperx --model large --language de examples/sample_de_01.wav + whisperx --model large-v2 --language de examples/sample_de_01.wav https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov @@ -140,70 +153,108 @@ See more examples in other languages [here](EXAMPLES.md). ```python import whisperx -import whisper +import gc device = "cuda" audio_file = "audio.mp3" +batch_size = 16 # reduce if low on GPU mem +compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy) -# transcribe with original whisper -model = whisper.load_model("large", device) -result = model.transcribe(audio_file) +# 1. Transcribe with original whisper (batched) +model = whisperx.load_model("large-v2", device, compute_type=compute_type) +audio = whisperx.load_audio(audio_file) +result = model.transcribe(audio, batch_size=batch_size) print(result["segments"]) # before alignment -# load alignment model and metadata +# delete model if low on GPU resources +# import gc; gc.collect(); torch.cuda.empty_cache(); del model + +# 2. Align whisper output model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) +result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) + +print(result["segments"]) # after alignment + +# delete model if low on GPU resources +# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a + +# 3. Assign speaker labels +diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) -# align whisper output -result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device) +# add min/max number of speakers if known +diarize_segments = diarize_model(input_audio_path) +# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) -print(result_aligned["segments"]) # after alignment -print(result_aligned["word_segments"]) # after alignment +result = assign_word_speakers(diarize_segments, result) +print(diarize_segments) +print(result["segments"]) # segments are now assigned speaker IDs ``` -

Whisper Modifications

+

Technical Details 👷‍♂️

-In addition to forced alignment, the following two modifications have been made to the whisper transcription method: +For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf). -1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination) +To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality): +1. reduce batch size, e.g. `--batch_size 4` +2. use a smaller ASR model `--model base` +3. Use lighter compute type `--compute_type int8` + +Transcription differences from openai's whisper: +1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output. +2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference +3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)

Limitations ⚠️

-- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers. -- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors) +- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing. - Overlapping speech is not handled particularly well by whisper nor whisperx -- Diariazation is far from perfect. +- Diarization is far from perfect (working on this with custom model v4 -- see contact me). +- Language specific wav2vec2 model is needed

Contribute 🧑‍🏫

-If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a merge request and some examples showing its success. +If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success. Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope. -

Coming Soon 🗓

+

TODO 🗓

* [x] Multilingual init -* [x] Subtitle .ass output - * [x] Automatic align model selection based on language detection * [x] Python usage -* [x] Character level timestamps - * [x] Incorporating speaker diarization * [x] Model flush, for low gpu mem resources -* [ ] Improve diarization (word level). *Harder than first thought... see #below* +* [x] Faster-whisper backend + +* [x] Add max-line etc. see (openai's whisper utils.py) + +* [x] Sentence-level segments (nltk toolbox) + +* [x] Improve alignment logic + +* [ ] update examples with diarization and word highlighting + +* [ ] Subtitle .ass output <- bring this back (removed in v3) + +* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation) + +* [ ] Allow silero-vad as alternative VAD option + +* [ ] Improve diarization (word level). *Harder than first thought...*

Contact/Support 📇

-Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with *with siginificantly improved diarization*. To support v4 and get early access, get in touch. + +Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with with siginificantly improved diarization. To support v4 and get early access, get in touch. Buy Me A Coffee @@ -212,13 +263,18 @@ Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford. - - Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper). Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio +Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio] + +Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2) + +Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏 + +Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.

Citation

If you use this in your research, please cite the paper: diff --git a/requirements.txt b/requirements.txt index 139ee56d..ec90a07f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ -numpy -pandas -torch >=1.9 -torchaudio >=0.10,<1.0 -tqdm -more-itertools -transformers>=4.19.0 +torch==2.0.0 +torchaudio==2.0.1 +faster-whisper +transformers ffmpeg-python==0.2.0 -pyannote.audio -openai-whisper==20230314 +pandas +setuptools==65.6.3 +nltk \ No newline at end of file diff --git a/setup.py b/setup.py index 33db0e16..c63f6534 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,8 @@ setup( name="whisperx", py_modules=["whisperx"], - version="2.0.1", - description="Time-Accurate Automatic Speech Recognition using Whisper.", + version="3.1.0", + description="Time-Accurate Automatic Speech Recognition.", readme="README.md", python_requires=">=3.8", author="Max Bain", @@ -19,7 +19,7 @@ for r in pkg_resources.parse_requirements( open(os.path.join(os.path.dirname(__file__), "requirements.txt")) ) - ], + ] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"], entry_points = { 'console_scripts': ['whisperx=whisperx.transcribe:cli'], }, diff --git a/whisperx/__init__.py b/whisperx/__init__.py index 985ed328..20abaaed 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -1,3 +1,4 @@ -from .transcribe import transcribe, transcribe_with_vad +from .transcribe import load_model from .alignment import load_align_model, align -from .vad import load_vad_model \ No newline at end of file +from .audio import load_audio +from .diarize import assign_word_speakers, DiarizationPipeline \ No newline at end of file diff --git a/whisperx/alignment.py b/whisperx/alignment.py index c15310b0..b8734753 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,16 +2,18 @@ Forced Alignment with Whisper C. Max Bain """ +from dataclasses import dataclass +from typing import Iterator, Union + import numpy as np import pandas as pd -from typing import List, Union, Iterator, TYPE_CHECKING -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor -import torchaudio import torch -from dataclasses import dataclass -from whisper.audio import SAMPLE_RATE, load_audio -from .utils import interpolate_nans +import torchaudio +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor +from .audio import SAMPLE_RATE, load_audio +from .utils import interpolate_nans +import nltk LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -37,6 +39,7 @@ "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", + "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", } @@ -82,354 +85,227 @@ def align( align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, - extend_duration: float = 0.0, - start_from_previous: bool = True, interpolate_method: str = "nearest", + return_char_alignments: bool = False, ): """ - Force align phoneme recognition predictions to known transcription - - Parameters - ---------- - transcript: Iterator[dict] - The Whisper model instance - - model: torch.nn.Module - Alignment model (wav2vec2) - - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform - - device: str - cuda device - - diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]} - diarization segments with speaker labels. - - extend_duration: float - Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds - - If the gzip compression ratio is above this value, treat as failed - - interpolate_method: str ["nearest", "linear", "ignore"] - Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary. - "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output. - - Returns - ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. + Align phoneme recognition predictions to known transcription. """ + if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) audio = torch.from_numpy(audio) if len(audio.shape) == 1: audio = audio.unsqueeze(0) - + MAX_DURATION = audio.shape[1] / SAMPLE_RATE model_dictionary = align_model_metadata["dictionary"] model_lang = align_model_metadata["language"] model_type = align_model_metadata["type"] - aligned_segments = [] - - prev_t2 = 0 - - char_segments_arr = { - "segment-idx": [], - "subsegment-idx": [], - "word-idx": [], - "char": [], - "start": [], - "end": [], - "score": [], - } - + # 1. Preprocess to keep only characters in dictionary for sdx, segment in enumerate(transcript): - while True: - segment_align_success = False - - # strip spaces at beginning / end, but keep track of the amount. - num_leading = len(segment["text"]) - len(segment["text"].lstrip()) - num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) - transcription = segment["text"] - - # TODO: convert number tokenizer / symbols to phonetic words for alignment. - # e.g. "$300" -> "three hundred dollars" - # currently "$300" is ignored since no characters present in the phonetic dictionary + # strip spaces at beginning / end, but keep track of the amount. + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) + num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) + text = segment["text"] + + # split into words + if model_lang not in LANGUAGES_WITHOUT_SPACES: + per_word = text.split(" ") + else: + per_word = text - # split into words + clean_char, clean_cdx = [], [] + for cdx, char in enumerate(text): + char_ = char.lower() + # wav2vec2 models use "|" character to represent spaces if model_lang not in LANGUAGES_WITHOUT_SPACES: - per_word = transcription.split(" ") - else: - per_word = transcription - - # first check that characters in transcription can be aligned (they are contained in align model"s dictionary) - clean_char, clean_cdx = [], [] - for cdx, char in enumerate(transcription): - char_ = char.lower() - # wav2vec2 models use "|" character to represent spaces - if model_lang not in LANGUAGES_WITHOUT_SPACES: - char_ = char_.replace(" ", "|") - - # ignore whitespace at beginning and end of transcript - if cdx < num_leading: - pass - elif cdx > len(transcription) - num_trailing - 1: - pass - elif char_ in model_dictionary.keys(): - clean_char.append(char_) - clean_cdx.append(cdx) - - clean_wdx = [] - for wdx, wrd in enumerate(per_word): - if any([c in model_dictionary.keys() for c in wrd]): - clean_wdx.append(wdx) - - # if no characters are in the dictionary, then we skip this segment... - if len(clean_char) == 0: - print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') - break - - transcription_cleaned = "".join(clean_char) - tokens = [model_dictionary[c] for c in transcription_cleaned] - - # we only pad if not using VAD filtering - if "seg_text" not in segment: - # pad according original timestamps - t1 = max(segment["start"] - extend_duration, 0) - t2 = min(segment["end"] + extend_duration, MAX_DURATION) - - # use prev_t2 as current t1 if it"s later - if start_from_previous and t1 < prev_t2: - t1 = prev_t2 - - # check if timestamp range is still valid - if t1 >= MAX_DURATION: - print("Failed to align segment: original start time longer than audio duration, skipping...") - break - if t2 - t1 < 0.02: - print("Failed to align segment: duration smaller than 0.02s time precision") - break - - f1 = int(t1 * SAMPLE_RATE) - f2 = int(t2 * SAMPLE_RATE) - - waveform_segment = audio[:, f1:f2] - - with torch.inference_mode(): - if model_type == "torchaudio": - emissions, _ = model(waveform_segment.to(device)) - elif model_type == "huggingface": - emissions = model(waveform_segment.to(device)).logits - else: - raise NotImplementedError(f"Align model of type {model_type} not supported.") - emissions = torch.log_softmax(emissions, dim=-1) - - emission = emissions[0].cpu().detach() - - trellis = get_trellis(emission, tokens) - path = backtrack(trellis, emission, tokens) - if path is None: - print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') - break - char_segments = merge_repeats(path, transcription_cleaned) - # word_segments = merge_words(char_segments) + char_ = char_.replace(" ", "|") - - # sub-segments - if "seg-text" not in segment: - segment["seg-text"] = [transcription] - - seg_lens = [0] + [len(x) for x in segment["seg-text"]] - seg_lens_cumsum = list(np.cumsum(seg_lens)) - sub_seg_idx = 0 - - wdx = 0 - duration = t2 - t1 - ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) - for cdx, char in enumerate(transcription + " "): - is_last = False - if cdx == len(transcription): - break - elif cdx+1 == len(transcription): - is_last = True - - - start, end, score = None, None, None - if cdx in clean_cdx: - char_seg = char_segments[clean_cdx.index(cdx)] - start = char_seg.start * ratio + t1 - end = char_seg.end * ratio + t1 - score = char_seg.score - - char_segments_arr["char"].append(char) - char_segments_arr["start"].append(start) - char_segments_arr["end"].append(end) - char_segments_arr["score"].append(score) - char_segments_arr["word-idx"].append(wdx) - char_segments_arr["segment-idx"].append(sdx) - char_segments_arr["subsegment-idx"].append(sub_seg_idx) - - # word-level info - if model_lang in LANGUAGES_WITHOUT_SPACES: - # character == word - wdx += 1 - elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: - wdx += 1 - - if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: - wdx = 0 - sub_seg_idx += 1 - - prev_t2 = segment["end"] - - segment_align_success = True - # end while True loop - break - - # reset prev_t2 due to drifting issues - if not segment_align_success: - prev_t2 = 0 - - char_segments_arr = pd.DataFrame(char_segments_arr) - not_space = char_segments_arr["char"] != " " - - per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False) - char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index() - per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) - per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"]) - per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"]) - char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount() - per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup - - word_segments_arr = {} - - # start of word is first char with a timestamp - word_segments_arr["start"] = per_word_grp["start"].min().values - # end of word is last char with a timestamp - word_segments_arr["end"] = per_word_grp["end"].max().values - # score of word is mean (excluding nan) - word_segments_arr["score"] = per_word_grp["score"].mean().values - - word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values - word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1 - word_segments_arr = pd.DataFrame(word_segments_arr) - - word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int) - segments_arr = {} - segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"] - segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"] - segments_arr = pd.DataFrame(segments_arr) - segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]] - segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1 - - # interpolate missing words / sub-segments - if interpolate_method != "ignore": - wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False) - wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False) - # we still know which word timestamps are interpolated because their score == nan - word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - - word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - - sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False) - segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - - # merge words & subsegments which are missing times - word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"]) - - word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min) - word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max) - word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True) - - seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"]) - segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min) - segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max) - segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True) - else: - word_segments_arr.dropna(inplace=True) - segments_arr.dropna(inplace=True) - - # if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps... - segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True) - segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True) - segments_arr['subsegment-idx-start'].fillna(0, inplace=True) - segments_arr['subsegment-idx-end'].fillna(1, inplace=True) - - + # ignore whitespace at beginning and end of transcript + if cdx < num_leading: + pass + elif cdx > len(text) - num_trailing - 1: + pass + elif char_ in model_dictionary.keys(): + clean_char.append(char_) + clean_cdx.append(cdx) + + clean_wdx = [] + for wdx, wrd in enumerate(per_word): + if any([c in model_dictionary.keys() for c in wrd]): + clean_wdx.append(wdx) + + sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text)) + + segment["clean_char"] = clean_char + segment["clean_cdx"] = clean_cdx + segment["clean_wdx"] = clean_wdx + segment["sentence_spans"] = sentence_spans + aligned_segments = [] - aligned_segments_word = [] - - word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True) - char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True) - - for sdx, srow in segments_arr.iterrows(): - - seg_idx = int(srow["segment-idx"]) - sub_start = int(srow["subsegment-idx-start"]) - sub_end = int(srow["subsegment-idx-end"]) - - seg = transcript[seg_idx] - text = "".join(seg["seg-text"][sub_start:sub_end]) - - wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] - wseg["start"].fillna(srow["start"], inplace=True) - wseg["end"].fillna(srow["end"], inplace=True) - wseg["segment-text-start"].fillna(0, inplace=True) - wseg["segment-text-end"].fillna(len(text)-1, inplace=True) - - cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] - # fixes bug for single segment in transcript - cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0 - cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1 - if 'level_1' in cseg: del cseg['level_1'] - if 'level_0' in cseg: del cseg['level_0'] - cseg.reset_index(inplace=True) - aligned_segments.append( - { - "start": srow["start"], - "end": srow["end"], - "text": text, - "word-segments": wseg, - "char-segments": cseg - } - ) - def get_raw_text(word_row): - return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1] - - wdx = 0 - curr_text = get_raw_text(wseg.iloc[wdx]) - if len(wseg) > 1: - for _, wrow in wseg.iloc[1:].iterrows(): - if wrow['start'] != wseg.iloc[wdx]['start']: - aligned_segments_word.append( - { - "text": curr_text.strip(), - "start": wseg.iloc[wdx]["start"], - "end": wseg.iloc[wdx]["end"], - } - ) - curr_text = "" - curr_text += " " + get_raw_text(wrow) - wdx += 1 - aligned_segments_word.append( - { - "text": curr_text.strip(), - "start": wseg.iloc[wdx]["start"], - "end": wseg.iloc[wdx]["end"] - } - ) - - - return {"segments": aligned_segments, "word_segments": aligned_segments_word} + # 2. Get prediction matrix from alignment model & align + for sdx, segment in enumerate(transcript): + t1 = segment["start"] + t2 = segment["end"] + text = segment["text"] + + aligned_seg = { + "start": t1, + "end": t2, + "text": text, + "words": [], + } + + if return_char_alignments: + aligned_seg["chars"] = [] + + # check we can align + if len(segment["clean_char"]) == 0: + print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') + aligned_segments.append(aligned_seg) + continue + + if t1 >= MAX_DURATION or t2 - t1 < 0.02: + print("Failed to align segment: original start time longer than audio duration, skipping...") + aligned_segments.append(aligned_seg) + continue + + text_clean = "".join(segment["clean_char"]) + tokens = [model_dictionary[c] for c in text_clean] + + f1 = int(t1 * SAMPLE_RATE) + f2 = int(t2 * SAMPLE_RATE) + + # TODO: Probably can get some speedup gain with batched inference here + waveform_segment = audio[:, f1:f2] + + with torch.inference_mode(): + if model_type == "torchaudio": + emissions, _ = model(waveform_segment.to(device)) + elif model_type == "huggingface": + emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError(f"Align model of type {model_type} not supported.") + emissions = torch.log_softmax(emissions, dim=-1) + + emission = emissions[0].cpu().detach() + + blank_id = 0 + for char, code in model_dictionary.items(): + if char == '[pad]' or char == '': + blank_id = code + + trellis = get_trellis(emission, tokens, blank_id) + path = backtrack(trellis, emission, tokens, blank_id) + + if path is None: + print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') + aligned_segments.append(aligned_seg) + continue + + char_segments = merge_repeats(path, text_clean) + + duration = t2 -t1 + ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) + + # assign timestamps to aligned characters + char_segments_arr = [] + word_idx = 0 + for cdx, char in enumerate(text): + start, end, score = None, None, None + if cdx in segment["clean_cdx"]: + char_seg = char_segments[segment["clean_cdx"].index(cdx)] + start = round(char_seg.start * ratio + t1, 3) + end = round(char_seg.end * ratio + t1, 3) + score = round(char_seg.score, 3) + + char_segments_arr.append( + { + "char": char, + "start": start, + "end": end, + "score": score, + "word-idx": word_idx, + } + ) + # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... + if model_lang in LANGUAGES_WITHOUT_SPACES: + word_idx += 1 + elif cdx == len(text) - 1 or text[cdx+1] == " ": + word_idx += 1 + + char_segments_arr = pd.DataFrame(char_segments_arr) + + aligned_subsegments = [] + # assign sentence_idx to each character index + char_segments_arr["sentence-idx"] = None + for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] + char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx + + sentence_text = text[sstart:send] + sentence_start = curr_chars["start"].min() + sentence_end = curr_chars["end"].max() + sentence_words = [] + + for word_idx in curr_chars["word-idx"].unique(): + word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] + word_text = "".join(word_chars["char"].tolist()).strip() + if len(word_text) == 0: + continue + word_start = word_chars["start"].min() + word_end = word_chars["end"].max() + word_score = round(word_chars["score"].mean(), 3) + + # -1 indicates unalignable + word_segment = {"word": word_text} + + if not np.isnan(word_start): + word_segment["start"] = word_start + if not np.isnan(word_end): + word_segment["end"] = word_end + if not np.isnan(word_score): + word_segment["score"] = word_score + + sentence_words.append(word_segment) + + aligned_subsegments.append({ + "text": sentence_text, + "start": sentence_start, + "end": sentence_end, + "words": sentence_words, + }) + + if return_char_alignments: + curr_chars = curr_chars[["char", "start", "end", "score"]] + curr_chars.fillna(-1, inplace=True) + curr_chars = curr_chars.to_dict("records") + curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] + aligned_subsegments[-1]["chars"] = curr_chars + + aligned_subsegments = pd.DataFrame(aligned_subsegments) + aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) + aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) + # concatenate sentences with same timestamps + agg_dict = {"text": " ".join, "words": "sum"} + if return_char_alignments: + agg_dict["chars"] = "sum" + aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) + aligned_subsegments = aligned_subsegments.to_dict('records') + aligned_segments += aligned_subsegments + + # create word_segments list + word_segments = [] + for segment in aligned_segments: + word_segments += segment["words"] + + return {"segments": aligned_segments, "word_segments": word_segments} """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html diff --git a/whisperx/asr.py b/whisperx/asr.py index e78d77cb..f2c54f6c 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,433 +1,270 @@ +import os import warnings -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import List, Union + +import ctranslate2 +import faster_whisper import numpy as np import torch -import tqdm -import ffmpeg -from whisper.audio import ( - FRAMES_PER_SECOND, - HOP_LENGTH, - N_FRAMES, - N_SAMPLES, - SAMPLE_RATE, - CHUNK_LENGTH, - log_mel_spectrogram, - pad_or_trim, - load_audio -) -from whisper.decoding import DecodingOptions, DecodingResult -from whisper.timing import add_word_timestamps -from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from whisper.utils import ( - exact_div, - format_timestamp, - make_safe, -) - -if TYPE_CHECKING: - from whisper.model import Whisper - -from .vad import merge_chunks - -def transcribe( - model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor] = None, - mel: np.ndarray = None, - verbose: Optional[bool] = None, - temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - **decode_options, -): - """ - Transcribe an audio file using Whisper. - We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio) - - Parameters - ---------- - model: Whisper - The Whisper model instance - - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform - - mel: np.ndarray - Mel spectrogram of audio segment. - - verbose: bool - Whether to display the text being decoded to the console. If True, displays all the details, - If False, displays minimal details. If None, does not display anything - - temperature: Union[float, Tuple[float, ...]] - Temperature for sampling. It can be a tuple of temperatures, which will be successively used - upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. - - compression_ratio_threshold: float - If the gzip compression ratio is above this value, treat as failed - - logprob_threshold: float - If the average log probability over sampled tokens is below this value, treat as failed - - no_speech_threshold: float - If the no_speech probability is higher than this value AND the average log probability - over sampled tokens is below `logprob_threshold`, consider the segment as silent - - condition_on_previous_text: bool - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - word_timestamps: bool - Extract word-level timestamps using the cross-attention pattern and dynamic time warping, - and include the timestamps for each word in each segment. - - prepend_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the next word - - append_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the previous word - - initial_prompt: Optional[str] - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. +from transformers import Pipeline +from transformers.pipelines.pt_utils import PipelineIterator + +from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram +from .vad import load_vad_model, merge_chunks + + +def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, + vad_options=None, model=None): + '''Load a Whisper model for inference. + Args: + whisper_arch: str - The name of the Whisper model to load. + device: str - The device to load the model on. + compute_type: str - The compute type to use for the model. + options: dict - A dictionary of options to use for the model. + language: str - The language of the model. (use English for now) + Returns: + A Whisper pipeline. + ''' + + if whisper_arch.endswith(".en"): + language = "en" + + model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) + if language is not None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language) + else: + print("No language specified, language will be first be detected for each audio file (increases inference time).") + tokenizer = None + + default_asr_options = { + "beam_size": 5, + "best_of": 5, + "patience": 1, + "length_penalty": 1, + "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": False, + "initial_prompt": None, + "prefix": None, + "suppress_blank": True, + "suppress_tokens": [-1], + "without_timestamps": True, + "max_initial_timestamp": 0.0, + "word_timestamps": False, + "prepend_punctuations": "\"'“¿([{-", + "append_punctuations": "\"'.。,,!!??::”)]}、" + } + + if asr_options is not None: + default_asr_options.update(asr_options) + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + + default_vad_options = { + "vad_onset": 0.500, + "vad_offset": 0.363 + } + + if vad_options is not None: + default_vad_options.update(vad_options) + + vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + + return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer) + + + +class WhisperModel(faster_whisper.WhisperModel): + ''' + FasterWhisperModel provides batched inference for faster-whisper. + Currently only works in non-timestamp mode and fixed prompt for all samples in batch. + ''' + + def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): + batch_size = features.shape[0] + all_tokens = [] + prompt_reset_since = 0 + if options.initial_prompt is not None: + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix, + ) - decode_options: dict - Keyword arguments to construct `DecodingOptions` instances + encoder_output = self.encode(features) - Returns - ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. - """ - dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 - if model.device == torch.device("cpu"): - if torch.cuda.is_available(): - warnings.warn("Performing inference on CPU when CUDA is available") - if dtype == torch.float16: - warnings.warn("FP16 is not supported on CPU; using FP32 instead") - dtype = torch.float32 - - if dtype == torch.float32: - decode_options["fp16"] = False - - # Pad 30-seconds of silence to the input audio, for slicing - if mel is None: - if audio is None: - raise ValueError("Transcribe needs either audio or mel as input, currently both are none.") - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) - content_frames = mel.shape[-1] - N_FRAMES - - if decode_options.get("language", None) is None: - if not model.is_multilingual: - decode_options["language"] = "en" - else: - if verbose: - print( - "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" - ) - mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(mel_segment) - decode_options["language"] = max(probs, key=probs.get) - if verbose is not None: - print( - f"Detected language: {LANGUAGES[decode_options['language']].title()}" - ) - - language: str = decode_options["language"] - task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) - - if word_timestamps and task == "translate": - warnings.warn("Word-level timestamps on translations may not be reliable.") - - def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: - temperatures = ( - [temperature] if isinstance(temperature, (int, float)) else temperature + max_initial_timestamp_index = int( + round(options.max_initial_timestamp / self.time_precision) ) - decode_result = None - - for t in temperatures: - kwargs = {**decode_options} - if t > 0: - # disable beam_size and patience when t > 0 - kwargs.pop("beam_size", None) - kwargs.pop("patience", None) - else: - # disable best_of when t == 0 - kwargs.pop("best_of", None) - - options = DecodingOptions(**kwargs, temperature=t) - decode_result = model.decode(segment, options) - - needs_fallback = False - if ( - compression_ratio_threshold is not None - and decode_result.compression_ratio > compression_ratio_threshold - ): - needs_fallback = True # too repetitive - if ( - logprob_threshold is not None - and decode_result.avg_logprob < logprob_threshold - ): - needs_fallback = True # average log probability is too low - - if not needs_fallback: - break - - return decode_result - - seek = 0 - input_stride = exact_div( - N_FRAMES, model.dims.n_audio_ctx - ) # mel frames per output token: 2 - time_precision = ( - input_stride * HOP_LENGTH / SAMPLE_RATE - ) # time per output token: 0.02 (seconds) - all_tokens = [] - all_segments = [] - prompt_reset_since = 0 - - if initial_prompt is not None: - initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) - all_tokens.extend(initial_prompt_tokens) - else: - initial_prompt_tokens = [] - def new_segment( - *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult - ): - tokens = tokens.tolist() - text_tokens = [token for token in tokens if token < tokenizer.eot] - return { - "seek": seek, - "start": start, - "end": end, - "text": tokenizer.decode(text_tokens), - "tokens": tokens, - "temperature": result.temperature, - "avg_logprob": result.avg_logprob, - "compression_ratio": result.compression_ratio, - "no_speech_prob": result.no_speech_prob, - } - - - # show the progress bar when verbose is False (if True, transcribed text will be printed) - with tqdm.tqdm( - total=content_frames, unit="frames", disable=verbose is not False - ) as pbar: - while seek < content_frames: - time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - mel_segment = mel[:, seek : seek + N_FRAMES] - segment_size = min(N_FRAMES, content_frames - seek) - segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE - mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(mel_segment) - tokens = torch.tensor(result.tokens) - if no_speech_threshold is not None: - # no voice activity check - should_skip = result.no_speech_prob > no_speech_threshold - if ( - logprob_threshold is not None - and result.avg_logprob > logprob_threshold - ): - # don't skip if the logprob is high enough, despite the no_speech_prob - should_skip = False - - if should_skip: - seek += segment_size # fast-forward to the next segment boundary - continue - - previous_seek = seek - current_segments = [] - - timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] - - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] - consecutive.add_(1) - if len(consecutive) > 0: - # if the output contains two consecutive timestamp tokens - slices = consecutive.tolist() - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_pos = ( - sliced_tokens[0].item() - tokenizer.timestamp_begin - ) - end_timestamp_pos = ( - sliced_tokens[-1].item() - tokenizer.timestamp_begin - ) - - # clamp end-time to at least be 1 frame after start-time - end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision) - - current_segments.append( - new_segment( - start=time_offset + start_timestamp_pos * time_precision, - end=time_offset + end_timestamp_pos * time_precision, - tokens=sliced_tokens, - result=result, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_pos = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin - ) - seek += last_timestamp_pos * input_stride - else: - duration = segment_duration - timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if ( - len(timestamps) > 0 - and timestamps[-1].item() != tokenizer.timestamp_begin - ): - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = ( - timestamps[-1].item() - tokenizer.timestamp_begin - ) - duration = last_timestamp_pos * time_precision - - current_segments.append( - new_segment( - start=time_offset, - end=time_offset + duration, - tokens=tokens, - result=result, - ) - ) - seek += segment_size - - if not condition_on_previous_text or result.temperature > 0.5: - # do not feed the prompt tokens if a high temperature was used - prompt_reset_since = len(all_tokens) - - if word_timestamps: - add_word_timestamps( - segments=current_segments, - model=model, - tokenizer=tokenizer, - mel=mel_segment, - num_frames=segment_size, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - ) - word_end_timestamps = [ - w["end"] for s in current_segments for w in s["words"] - ] - if not single_timestamp_ending and len(word_end_timestamps) > 0: - seek_shift = round( - (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND - ) - if seek_shift > 0: - seek = previous_seek + seek_shift - - if verbose: - for segment in current_segments: - start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" - print(make_safe(line)) - - # if a segment is instantaneous or does not contain text, clear it - for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or segment["text"].strip() == "": - segment["text"] = "" - segment["tokens"] = [] - segment["words"] = [] - - all_segments.extend( - [ - {"id": i, **segment} - for i, segment in enumerate( - current_segments, start=len(all_segments) - ) - ] - ) - all_tokens.extend( - [token for segment in current_segments for token in segment["tokens"]] + result = self.model.generate( + encoder_output, + [prompt] * batch_size, + # length_penalty=options.length_penalty, + # max_length=self.max_length, + # return_scores=True, + # return_no_speech_prob=True, + # suppress_blank=options.suppress_blank, + # suppress_tokens=options.suppress_tokens, + # max_initial_timestamp_index=max_initial_timestamp_index, ) - - # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) - - - return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), - segments=all_segments, - language=language, - ) - - -def transcribe_with_vad( - model: "Whisper", - audio: str, - vad_pipeline, - mel = None, - verbose: Optional[bool] = None, - **kwargs -): + + tokens_batch = [x.sequences_ids[0] for x in result] + + def decode_batch(tokens: List[List[int]]) -> str: + res = [] + for tk in tokens: + res.append([token for token in tk if token < tokenizer.eot]) + # text_tokens = [token for token in tokens if token < self.eot] + return tokenizer.tokenizer.decode_batch(res) + + text = decode_batch(tokens_batch) + + return text + + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + # unsqueeze if batch size = 1 + if len(features.shape) == 2: + features = np.expand_dims(features, 0) + features = faster_whisper.transcribe.get_ctranslate2_storage(features) + + return self.model.encode(features, to_cpu=to_cpu) + +class FasterWhisperPipeline(Pipeline): """ - Transcribe per VAD segment + Huggingface Pipeline wrapper for FasterWhisperModel. """ - - vad_segments = vad_pipeline(audio) - - # if not torch.is_tensor(audio): - # if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) - - prev = 0 - output = {"segments": []} - - # merge segments to approx 30s inputs to make whisper most appropraite - vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH) - if len(vad_segments) == 0: - return output - - print(">>Performing transcription...") - for sdx, seg_t in enumerate(vad_segments): - if verbose: - print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~") - seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE) - local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev - audio = audio[local_f_start:] # seek forward - seg_audio = audio[:local_f_end-local_f_start] # seek forward - prev = seg_f_start - local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES) - # need to pad - - result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs) - seg_t["text"] = result["text"] - output["segments"].append( - { - "start": seg_t["start"], - "end": seg_t["end"], - "language": result["language"], - "text": result["text"], - "seg-text": [x["text"] for x in result["segments"]], - "seg-start": [x["start"] for x in result["segments"]], - "seg-end": [x["end"] for x in result["segments"]], + # TODO: + # - add support for timestamp mode + # - add support for custom inference kwargs + + def __init__( + self, + model, + vad, + options, + tokenizer=None, + device: Union[int, str, "torch.device"] = -1, + framework = "pt", + **kwargs + ): + self.model = model + self.tokenizer = tokenizer + self.options = options + self._batch_size = kwargs.pop("batch_size", None) + self._num_workers = 1 + self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) + self.call_count = 0 + self.framework = framework + if self.framework == "pt": + if isinstance(device, torch.device): + self.device = device + elif isinstance(device, str): + self.device = torch.device(device) + elif device < 0: + self.device = torch.device("cpu") + else: + self.device = torch.device(f"cuda:{device}") + else: + self.device = device + + super(Pipeline, self).__init__() + self.vad_model = vad + + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "tokenizer" in kwargs: + preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] + return preprocess_kwargs, {}, {} + + def preprocess(self, audio): + audio = audio['inputs'] + features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) + return {'inputs': features} + + def _forward(self, model_inputs): + outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) + return {'text': outputs} + + def postprocess(self, model_outputs): + return model_outputs + + def get_iterator( + self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params + ): + dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) + if "TOKENIZERS_PARALLELISM" not in os.environ: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + # TODO hack by collating feature_extractor and image_processor + + def stack(items): + return {'inputs': torch.stack([x['inputs'] for x in items])} + dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) + model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) + final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) + return final_iterator + + def transcribe( + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 + ): + if isinstance(audio, str): + audio = load_audio(audio) + + def data(audio, segments): + for seg in segments: + f1 = int(seg['start'] * SAMPLE_RATE) + f2 = int(seg['end'] * SAMPLE_RATE) + # print(f2-f1) + yield {'inputs': audio[f1:f2]} + + vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = merge_chunks(vad_segments, 30) + + del_tokenizer = False + if self.tokenizer is None: + language = self.detect_language(audio) + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) + del_tokenizer = True + else: + language = self.tokenizer.language_code + + segments = [] + batch_size = batch_size or self._batch_size + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): + text = out['text'] + if batch_size in [0, 1, None]: + text = text[0] + segments.append( + { + "text": out['text'], + "start": round(vad_segments[idx]['start'], 3), + "end": round(vad_segments[idx]['end'], 3) } ) - - output["language"] = output["segments"][0]["language"] - - return output + + if del_tokenizer: + self.tokenizer = None + + return {"segments": segments, "language": language} + + + def detect_language(self, audio: np.ndarray): + if audio.shape[0] < N_SAMPLES: + print("Warning: audio is shorter than 30s, language detection may be inaccurate.") + segment = log_mel_spectrogram(audio[: N_SAMPLES], + padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) + encoder_output = self.model.encode(segment) + results = self.model.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") + return language diff --git a/whisperx/assets/mel_filters.npz b/whisperx/assets/mel_filters.npz new file mode 100644 index 00000000..1a783924 Binary files /dev/null and b/whisperx/assets/mel_filters.npz differ diff --git a/whisperx/audio.py b/whisperx/audio.py new file mode 100644 index 00000000..513ab7c9 --- /dev/null +++ b/whisperx/audio.py @@ -0,0 +1,147 @@ +import os +from functools import lru_cache +from typing import Optional, Union + +import ffmpeg +import numpy as np +import torch +import torch.nn.functional as F + +from .utils import exact_div + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + try: + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = N_MELS, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 34dfc634..320d2a48 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -1,73 +1,63 @@ import numpy as np import pandas as pd from pyannote.audio import Pipeline +from typing import Optional, Union +import torch class DiarizationPipeline: def __init__( self, model_name="pyannote/speaker-diarization@2.1", use_auth_token=None, + device: Optional[Union[str, torch.device]] = "cpu", ): - self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) + if isinstance(device, str): + device = torch.device(device) + self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) def __call__(self, audio, min_speakers=None, max_speakers=None): segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) diarize_df = pd.DataFrame(segments.itertracks(yield_label=True)) diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) + diarize_df.rename(columns={2: "speaker"}, inplace=True) return diarize_df -def assign_word_speakers(diarize_df, result_segments, fill_nearest=False): - for seg in result_segments: - wdf = seg['word-segments'] - if len(wdf['start'].dropna()) == 0: - wdf['start'] = seg['start'] - wdf['end'] = seg['end'] - speakers = [] - for wdx, wrow in wdf.iterrows(): - if not np.isnan(wrow['start']): - diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start']) - diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start']) - # remove no hit - if not fill_nearest: - dia_tmp = diarize_df[diarize_df['intersection'] > 0] - else: - dia_tmp = diarize_df - if len(dia_tmp) == 0: - speaker = None - else: - speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2] - else: - speaker = None - speakers.append(speaker) - seg['word-segments']['speaker'] = speakers - speaker_count = pd.Series(speakers).value_counts() - if len(speaker_count) == 0: - seg["speaker"]= "UNKNOWN" +def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): + transcript_segments = transcript_result["segments"] + for seg in transcript_segments: + # assign speaker to segment (if any) + diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) + # remove no hit, otherwise we look for closest (even negative intersection...) + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] else: - seg["speaker"] = speaker_count.index[0] + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + seg["speaker"] = speaker + + # assign speaker to words + if 'words' in seg: + for word in seg['words']: + if 'start' in word: + diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start']) + # remove no hit + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + word["speaker"] = speaker + + return transcript_result - # create word level segments for .srt - word_seg = [] - for seg in result_segments: - wseg = pd.DataFrame(seg["word-segments"]) - for wdx, wrow in wseg.iterrows(): - if wrow["start"] is not None: - speaker = wrow['speaker'] - if speaker is None or speaker == np.nan: - speaker = "UNKNOWN" - word_seg.append( - { - "start": wrow["start"], - "end": wrow["end"], - "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] - } - ) - - # TODO: create segments but split words on new speaker - - return result_segments, word_seg class Segment: def __init__(self, start, end, speaker=None): diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index ed918e00..d09c5f66 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -1,37 +1,31 @@ import argparse -import os import gc +import os import warnings -from typing import TYPE_CHECKING, Optional, Tuple, Union + import numpy as np import torch -import tempfile -import ffmpeg -from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE -from whisper.audio import SAMPLE_RATE -from whisper.utils import ( - optional_float, - optional_int, - str2bool, -) - -from .alignment import load_align_model, align -from .asr import transcribe, transcribe_with_vad + +from .alignment import align, load_align_model +from .asr import load_model +from .audio import load_audio from .diarize import DiarizationPipeline, assign_word_speakers -from .utils import get_writer -from .vad import load_vad_model +from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float, + optional_int, str2bool) -def cli(): - from whisper import available_models +def cli(): # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") - parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") + parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") + parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") + parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") @@ -39,20 +33,18 @@ def cli(): # alignment params parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") - parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).") - parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)") parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment") + parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file") # vad params - parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747") parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") # diarization params parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word") - parser.add_argument("--min_speakers", default=None, type=int) - parser.add_argument("--max_speakers", default=None, type=int) + parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file") + parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file") parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") @@ -69,37 +61,34 @@ def cli(): parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") - parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") - parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") - parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") + + parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") + parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment") + parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") + parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") + parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") - # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") - parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).") # fmt: on args = parser.parse_args().__dict__ model_name: str = args.pop("model") - model_dir: str = args.pop("model_dir") + batch_size: int = args.pop("batch_size") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") + compute_type: str = args.pop("compute_type") + # model_flush: bool = args.pop("model_flush") os.makedirs(output_dir, exist_ok=True) - tmp_dir: str = args.pop("tmp_dir") - if tmp_dir is not None: - os.makedirs(tmp_dir, exist_ok=True) - align_model: str = args.pop("align_model") - align_extend: float = args.pop("align_extend") - align_from_prev: bool = args.pop("align_from_prev") interpolate_method: str = args.pop("interpolate_method") no_align: bool = args.pop("no_align") + return_char_alignments: bool = args.pop("return_char_alignments") hf_token: str = args.pop("hf_token") - vad_filter: bool = args.pop("vad_filter") vad_onset: float = args.pop("vad_onset") vad_offset: float = args.pop("vad_offset") @@ -107,18 +96,6 @@ def cli(): min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") - if vad_filter: - from pyannote.audio import Pipeline - from pyannote.audio import Model, Pipeline - vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token) - else: - vad_model = None - - # if model_flush: - # print(">>Model flushing activated... Only loading model after ASR stage") - # del align_model - # align_model = "" - if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: @@ -136,39 +113,43 @@ def cli(): if (threads := args.pop("threads")) > 0: torch.set_num_threads(threads) - from whisper import load_model + asr_options = { + "beam_size": args.pop("beam_size"), + "patience": args.pop("patience"), + "length_penalty": args.pop("length_penalty"), + "temperatures": temperature, + "compression_ratio_threshold": args.pop("compression_ratio_threshold"), + "log_prob_threshold": args.pop("logprob_threshold"), + "no_speech_threshold": args.pop("no_speech_threshold"), + "condition_on_previous_text": False, + "initial_prompt": args.pop("initial_prompt"), + } writer = get_writer(output_format, output_dir) - + word_options = ["highlight_words", "max_line_count", "max_line_width"] + if no_align: + for option in word_options: + if args[option]: + parser.error(f"--{option} requires --word_timestamps True") + if args["max_line_count"] and not args["max_line_width"]: + warnings.warn("--max_line_count has no effect without --max_line_width") + writer_args = {arg: args.pop(arg) for arg in word_options} + # Part 1: VAD & ASR Loop results = [] tmp_results = [] - model = load_model(model_name, device=device, download_root=model_dir) - for audio_path in args.pop("audio"): - input_audio_path = audio_path - tfile = None + # model = load_model(model_name, device=device, download_root=model_dir) + model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},) + for audio_path in args.pop("audio"): + audio = load_audio(audio_path) # >> VAD & ASR - if vad_model is not None: - if not audio_path.endswith(".wav"): - print(">>VAD requires .wav format, converting to wav as a tempfile...") - audio_basename = os.path.splitext(os.path.basename(audio_path))[0] - if tmp_dir is not None: - input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav") - else: - input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav") - ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"]) - print(">>Performing VAD...") - result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args) - else: - print(">>Performing transcription...") - result = transcribe(model, input_audio_path, temperature=temperature, **args) - - results.append((result, input_audio_path)) + print(">>Performing transcription...") + result = model.transcribe(audio, batch_size=batch_size) + results.append((result, audio_path)) # Unload Whisper and VAD del model - del vad_model gc.collect() torch.cuda.empty_cache() @@ -178,17 +159,23 @@ def cli(): results = [] align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) - for result, input_audio_path in tmp_results: + for result, audio_path in tmp_results: # >> Align + if len(tmp_results) > 1: + input_audio = audio_path + else: + # lazily load audio from part 1 + input_audio = audio + if align_model is not None and len(result["segments"]) > 0: if result.get("language", "en") != align_metadata["language"]: # load new language print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") align_model, align_metadata = load_align_model(result["language"], device) print(">>Performing alignment...") - result = align(result["segments"], align_model, align_metadata, input_audio_path, device, - extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) - results.append((result, input_audio_path)) + result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments) + + results.append((result, audio_path)) # Unload align model del align_model @@ -200,21 +187,16 @@ def cli(): if hf_token is None: print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") tmp_results = results + print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token) + diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) - results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) - result = {"segments": results_segments, "word_segments": word_segments} + result = assign_word_speakers(diarize_segments, result) results.append((result, input_audio_path)) - # >> Write for result, audio_path in results: - writer(result, audio_path) - - # cleanup - if input_audio_path != audio_path: - os.remove(input_audio_path) + writer(result, audio_path, writer_args) if __name__ == "__main__": cli() \ No newline at end of file diff --git a/whisperx/utils.py b/whisperx/utils.py index 14e298b1..d042bb70 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -1,280 +1,320 @@ +import json import os +import re +import sys import zlib -from typing import Callable, TextIO, Iterator, Tuple -import pandas as pd -import numpy as np - -def interpolate_nans(x, method='nearest'): - if x.notnull().sum() > 1: - return x.interpolate(method=method).ffill().bfill() +from typing import Callable, Optional, TextIO + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + + +system_encoding = sys.getdefaultencoding() + +if system_encoding != "utf-8": + + def make_safe(string): + # replaces any character not representable using the system default encoding with an '?', + # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). + return string.encode(system_encoding, errors="replace").decode(system_encoding) + +else: + + def make_safe(string): + # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding + return string + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] else: - return x.ffill().bfill() - - -def write_txt(transcript: Iterator[dict], file: TextIO): - for segment in transcript: - print(segment['text'].strip(), file=file, flush=True) - - -def write_vtt(transcript: Iterator[dict], file: TextIO): - print("WEBVTT\n", file=file) - for segment in transcript: - print( - f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") -def write_tsv(transcript: Iterator[dict], file: TextIO): - print("start", "end", "text", sep="\t", file=file) - for segment in transcript: - print(segment['start'], file=file, end="\t") - print(segment['end'], file=file, end="\t") - print(segment['text'].strip().replace("\t", " "), file=file, flush=True) +def optional_int(string): + return None if string == "None" else int(string) -def write_srt(transcript: Iterator[dict], file: TextIO): - """ - Write a transcript to a file in SRT format. - Example usage: - from pathlib import Path - from whisper.utils import write_srt +def optional_float(string): + return None if string == "None" else float(string) - result = transcribe(model, audio_path, temperature=temperature, **args) - # save SRT - audio_basename = Path(audio_path).stem - with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: - write_srt(result["segments"], file=srt) - """ - for i, segment in enumerate(transcript, start=1): - # write srt lines - print( - f"{i}\n" - f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) -def write_ass(transcript: Iterator[dict], - file: TextIO, - resolution: str = "word", - color: str = None, underline=True, - prefmt: str = None, suffmt: str = None, - font: str = None, font_size: int = 24, - strip=True, **kwargs): - """ - Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py - Generate Advanced SubStation Alpha (ass) file from results to - display both phrase-level & word-level timestamp simultaneously by: - -using segment-level timestamps display phrases as usual - -using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment - Note: ass file is used in the same way as srt, vtt, etc. - Parameters - ---------- - transcript: dict - results from modified model - file: TextIO - file object to write to - resolution: str - "word" or "char", timestamp resolution to highlight. - color: str - color code for a word at its corresponding timestamp - reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00) - underline: bool - whether to underline a word at its corresponding timestamp - prefmt: str - used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline') - appears as such in the .ass file: - Hi, {}how{} are you? - reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm - suffmt: str - used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline') - appears as such in the .ass file: - Hi, {}how{} are you? - reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm - font: str - word font (default: Arial) - font_size: int - word font size (default: 48) - kwargs: - used for format styles: - 'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold', - 'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline', - 'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding' +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) - """ + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 - fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff', - 'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0', - 'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100', - 'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0', - 'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'} + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 - for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()): - kwargs[k] = f'&H{kwargs[k]}' + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 - fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict) + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) - if font: - fmt_style_dict.update(Fontname=font) - if font_size: - fmt_style_dict.update(Fontsize=font_size) - fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}' +class ResultWriter: + extension: str - styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}' + def __init__(self, output_dir: str): + self.output_dir = output_dir - ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \ - f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \ - f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n' + def __call__(self, result: dict, audio_path: str, options: dict): + audio_basename = os.path.basename(audio_path) + audio_basename = os.path.splitext(audio_basename)[0] + output_path = os.path.join( + self.output_dir, audio_basename + "." + self.extension + ) - if prefmt or suffmt: - if suffmt: - assert prefmt, 'prefmt must be used along with suffmt' + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f, options=options) + + def write_result(self, result: dict, file: TextIO, options: dict): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result(self, result: dict, file: TextIO, options: dict): + for segment in result["segments"]: + print(segment["text"].strip(), file=file, flush=True) + + +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str + + def iterate_result(self, result: dict, options: dict): + raw_max_line_width: Optional[int] = options["max_line_width"] + max_line_count: Optional[int] = options["max_line_count"] + highlight_words: bool = options["highlight_words"] + max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width + preserve_segments = max_line_count is None or raw_max_line_width is None + + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: list[dict] = [] + times = [] + last = result["segments"][0]["start"] + for segment in result["segments"]: + for i, original_timing in enumerate(segment["words"]): + timing = original_timing.copy() + long_pause = not preserve_segments + if "start" in timing: + long_pause = long_pause and timing["start"] - last > 3.0 + else: + long_pause = False + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if line_len > 0 and has_room and not long_pause and not seg_break: + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle, times + subtitle = [] + times = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + times.append((segment["start"], segment["end"], segment.get("speaker"))) + if "start" in timing: + last = timing["start"] + if len(subtitle) > 0: + yield subtitle, times + + if "words" in result["segments"][0]: + for subtitle, _ in iterate_subtitles(): + sstart, ssend, speaker = _[0] + subtitle_start = self.format_timestamp(sstart) + subtitle_end = self.format_timestamp(ssend) + subtitle_text = " ".join([word["word"] for word in subtitle]) + has_timing = any(["start" in word for word in subtitle]) + + # add [$SPEAKER_ID]: to each subtitle if speaker is available + prefix = "" + if speaker is not None: + prefix = f"[{speaker}]: " + + if highlight_words and has_timing: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + if "start" in this_word: + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text + + yield start, end, prefix + " ".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, prefix + subtitle_text else: - suffmt = r'\r' - else: - if not color: - color = 'HFF00' - underline_code = r'\u1' if underline else '' - - prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}' - suffmt = r'{\r}' - - def secs_to_hhmmss(secs: Tuple[float, int]): - mm, ss = divmod(secs, 60) - hh, mm = divmod(mm, 60) - return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}' - - - def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str: - if idx_0 == -1: - text = chars - else: - text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}' - return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \ - f"Default,,0,0,0,,{text.strip() if strip else text}" - - if resolution == "word": - resolution_key = "word-segments" - elif resolution == "char": - resolution_key = "char-segments" - else: - raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution) - - ass_arr = [] - - for segment in transcript: - # if "12" in segment['text']: - # import pdb; pdb.set_trace() - if resolution_key in segment: - res_segs = pd.DataFrame(segment[resolution_key]) - prev = segment['start'] - if "speaker" in segment: - speaker_str = f"[{segment['speaker']}]: " - else: - speaker_str = "" - for cdx, crow in res_segs.iterrows(): - if not np.isnan(crow['start']): - if resolution == "char": - idx_0 = cdx - idx_1 = cdx + 1 - elif resolution == "word": - idx_0 = int(crow["segment-text-start"]) - idx_1 = int(crow["segment-text-end"]) - # fill gap - if crow['start'] > prev: - filler_ts = { - "chars": speaker_str + segment['text'], - "start": prev, - "end": crow['start'], - "idx_0": -1, - "idx_1": -1 - } - - ass_arr.append(filler_ts) - # highlight current word - f_word_ts = { - "chars": speaker_str + segment['text'], - "start": crow['start'], - "end": crow['end'], - "idx_0": idx_0 + len(speaker_str), - "idx_1": idx_1 + len(speaker_str) - } - ass_arr.append(f_word_ts) - prev = crow['end'] - - ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr)) - - file.write(ass_str) - - -from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp - -class WriteASS(ResultWriter): - extension: str = "ass" - - def write_result(self, result: dict, file: TextIO): - write_ass(result["segments"], file, resolution="word") - -class WriteASSchar(ResultWriter): - extension: str = "ass" - - def write_result(self, result: dict, file: TextIO): - write_ass(result["segments"], file, resolution="char") - -class WritePickle(ResultWriter): - extension: str = "ass" - - def write_result(self, result: dict, file: TextIO): - pd.DataFrame(result["segments"]).to_pickle(file) - -class WriteSRTWord(ResultWriter): - extension: str = "word.srt" - always_include_hours: bool = True - decimal_marker: str = "," - - def iterate_result(self, result: dict): - for segment in result["word_segments"]: - segment_start = self.format_timestamp(segment["start"]) - segment_end = self.format_timestamp(segment["end"]) - segment_text = segment["text"].strip().replace("-->", "->") - - if word_timings := segment.get("words", None): - all_words = [timing["word"] for timing in word_timings] - all_words[0] = all_words[0].strip() # remove the leading space, if any - last = segment_start - for i, this_word in enumerate(word_timings): - start = self.format_timestamp(this_word["start"]) - end = self.format_timestamp(this_word["end"]) - if last != start: - yield last, start, segment_text - - yield start, end, "".join( - [ - f"{word}" if j == i else word - for j, word in enumerate(all_words) - ] - ) - last = end - - if last != segment_end: - yield last, segment_end, segment_text - else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") + if "speaker" in segment: + segment_text = f"[{segment['speaker']}]: {segment_text}" yield segment_start, segment_end, segment_text - def write_result(self, result: dict, file: TextIO): - if "word_segments" not in result: - return - for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): - print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) - def format_timestamp(self, seconds: float): return format_timestamp( seconds=seconds, @@ -282,36 +322,81 @@ def format_timestamp(self, seconds: float): decimal_marker=self.decimal_marker, ) -def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: + +class WriteVTT(SubtitlesWriter): + extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = "." + + def write_result(self, result: dict, file: TextIO, options: dict): + print("WEBVTT\n", file=file) + for start, end, text in self.iterate_result(result, options): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteSRT(SubtitlesWriter): + extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = "," + + def write_result(self, result: dict, file: TextIO, options: dict): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options), start=1 + ): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + + extension: str = "tsv" + + def write_result(self, result: dict, file: TextIO, options: dict): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result(self, result: dict, file: TextIO, options: dict): + json.dump(result, file) + + +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: writers = { "txt": WriteTXT, "vtt": WriteVTT, "srt": WriteSRT, "tsv": WriteTSV, - "ass": WriteASS, - "srt-word": WriteSRTWord, - # "ass-char": WriteASSchar, - # "pickle": WritePickle, - # "json": WriteJSON, - } - - writers_other = { - "pkl": WritePickle, - "ass-char": WriteASSchar + "json": WriteJSON, } if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO): + def write_all(result: dict, file: TextIO, options: dict): for writer in all_writers: - writer(result, file) + writer(result, file, options) return write_all - if output_format in writers: - return writers[output_format](output_dir) - elif output_format in writers_other: - return writers_other[output_format](output_dir) + return writers[output_format](output_dir) + +def interpolate_nans(x, method='nearest'): + if x.notnull().sum() > 1: + return x.interpolate(method=method).ffill().bfill() else: - raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}") + return x.ffill().bfill() \ No newline at end of file diff --git a/whisperx/vad.py b/whisperx/vad.py index 933d270e..42b0bfbc 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -1,22 +1,23 @@ +import hashlib import os import urllib -import pandas as pd +from typing import Callable, Optional, Text, Union + import numpy as np +import pandas as pd import torch -import hashlib -from tqdm import tqdm -from typing import Optional, Callable, Union, Text -from pyannote.audio.core.io import AudioFile -from pyannote.core import Annotation, Segment, SlidingWindowFeature -from pyannote.audio.pipelines.utils import PipelineModel from pyannote.audio import Model +from pyannote.audio.core.io import AudioFile from pyannote.audio.pipelines import VoiceActivityDetection +from pyannote.audio.pipelines.utils import PipelineModel +from pyannote.core import Annotation, Segment, SlidingWindowFeature +from tqdm import tqdm + from .diarize import Segment as SegmentX -from typing import List, Tuple, Optional VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" -def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None): +def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): model_dir = torch.hub._get_torch_home() os.makedirs(model_dir, exist_ok = True) if model_fp is None: