Skip to content

Commit

Permalink
improve time codes and add optional chunking on generated subtitle se…
Browse files Browse the repository at this point in the history
…gments
  • Loading branch information
baxtree committed Jan 2, 2025
1 parent 59bdd89 commit c06d9e1
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 34 deletions.
3 changes: 2 additions & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ tensorflow = ">=1.15.5,<2.12"
termcolor = "==1.1.0"
toml = "==0.10.0"
toolz = "==0.9.0"
torch = "<2.2.0"
torch = "<2.3.0"
torchaudio = "<2.3.0"
transformers = "<4.27.0"
urllib3 = "~=1.26.5"
wrapt = "==1.14.0"
Expand Down
3 changes: 2 additions & 1 deletion requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
sentencepiece~=0.1.95
torch<2.3.0
torchaudio<2.3.0
transformers<4.37.0
openai-whisper==20240930
openai-whisper==20240930
17 changes: 15 additions & 2 deletions subaligner/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def main():
default=None,
help="Optional text to provide the transcribing context or specific phrases"
)
parser.add_argument(
"-mcl",
"--max_char_length",
type=int,
default=None,
help="Maximum number of characters for each generated subtitle segment"
)
from subaligner.llm import TranslationRecipe
from subaligner.llm import HelsinkiNLPFlavour
parser.add_argument(
Expand Down Expand Up @@ -356,9 +363,15 @@ def main():
from subaligner.transcriber import Transcriber
transcriber = Transcriber(recipe=FLAGS.transcription_recipe, flavour=FLAGS.transcription_flavour)
if "_transcribe_temp" in local_subtitle_path:
subtitle, frame_rate = transcriber.transcribe(video_file_path=local_video_path, language_code=stretch_in_lang, initial_prompt=FLAGS.initial_prompt)
subtitle, frame_rate = transcriber.transcribe(video_file_path=local_video_path,
language_code=stretch_in_lang,
initial_prompt=FLAGS.initial_prompt,
max_char_length=FLAGS.max_char_length)
else:
subtitle, frame_rate = transcriber.transcribe_with_subtitle_as_prompts(video_file_path=local_video_path, subtitle_file_path=local_subtitle_path, language_code=stretch_in_lang)
subtitle, frame_rate = transcriber.transcribe_with_subtitle_as_prompts(video_file_path=local_video_path,
subtitle_file_path=local_subtitle_path,
language_code=stretch_in_lang,
max_char_length=FLAGS.max_char_length)
aligned_subs = subtitle.subs
else:
print("ERROR: Unknown mode {}".format(FLAGS.mode))
Expand Down
160 changes: 139 additions & 21 deletions subaligner/transcriber.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import os
import whisper
import torch
from typing import Tuple, Optional
import concurrent.futures
import math
import multiprocessing as mp
import torchaudio
import numpy as np
from functools import partial
from threading import Lock
from typing import Tuple, Optional, Dict, List
from pysrt import SubRipTime
from whisper import Whisper
from whisper.tokenizer import LANGUAGES
from tqdm import tqdm
from .subtitle import Subtitle
from .media_helper import MediaHelper
from .llm import TranscriptionRecipe, WhisperFlavour
Expand Down Expand Up @@ -38,14 +47,20 @@ def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str
self.__flavour = flavour
self.__media_helper = MediaHelper()
self.__LOGGER = Logger().get_logger(__name__)
self.__lock = Lock()

def transcribe(self, video_file_path: str, language_code: str, initial_prompt: Optional[str] = None) -> Tuple[Subtitle, Optional[float]]:
def transcribe(self,
video_file_path: str,
language_code: str,
initial_prompt: Optional[str] = None,
max_char_length: Optional[int] = None) -> Tuple[Subtitle, Optional[float]]:
"""Transcribe an audiovisual file and generate subtitles.
Arguments:
video_file_path {string} -- The input video file path.
language_code {string} -- An alpha 3 language code derived from ISO 639-3.
initial_prompt {string} -- Optional text to provide the transcribing context or specific phrases.
initial_prompt {string} -- Optional Text to provide the transcribing context or specific phrases.
max_char_length {int} -- Optional Maximum number of characters for each generated subtitle segment.
Returns:
tuple: Generated subtitle after transcription and the detected frame rate
Expand All @@ -64,14 +79,24 @@ def transcribe(self, video_file_path: str, language_code: str, initial_prompt: O
self.__LOGGER.info("Start transcribing the audio...")
verbose = False if Logger.VERBOSE and not Logger.QUIET else None
self.__LOGGER.debug("Prompting with: '%s'" % initial_prompt)
result = self.__model.transcribe(audio, task="transcribe", language=LANGUAGES[lang], verbose=verbose, initial_prompt=initial_prompt)
result = self.__model.transcribe(audio,
task="transcribe",
language=LANGUAGES[lang],
verbose=verbose,
word_timestamps=True,
initial_prompt=initial_prompt)
self.__LOGGER.info("Finished transcribing the audio")
srt_str = ""
for i, segment in enumerate(result["segments"], start=1):
srt_str += f"{i}\n" \
f"{Utils.format_timestamp(segment['start'])} --> {Utils.format_timestamp(segment['end'])}\n" \
f"{segment['text'].strip().replace('-->', '->')}\n" \
"\n"
srt_idx = 1
for segment in result["segments"]:
if max_char_length is not None and len(segment["text"]) > max_char_length:
srt_str, srt_idx = self._chunk_segment(segment, srt_str, srt_idx, max_char_length)
else:
srt_str += f"{srt_idx}\n" \
f"{Utils.format_timestamp(segment['words'][0]['start'])} --> {Utils.format_timestamp(segment['words'][-1]['end'])}\n" \
f"{segment['text'].strip().replace('-->', '->')}\n" \
"\n"
srt_idx += 1
subtitle = Subtitle.load_subrip_str(srt_str)
subtitle, frame_rate = self.__on_frame_timecodes(subtitle, video_file_path)
self.__LOGGER.debug("Generated the raw subtitle")
Expand All @@ -82,13 +107,19 @@ def transcribe(self, video_file_path: str, language_code: str, initial_prompt: O
else:
raise NotImplementedError(f"{self.__recipe} ({self.__flavour}) is not supported")

def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_file_path: str, language_code: str) -> Tuple[Subtitle, Optional[float]]:
"""Transcribe an audiovisual file and generate subtitles using the original subtitle as prompts.
def transcribe_with_subtitle_as_prompts(self,
video_file_path: str,
subtitle_file_path: str,
language_code: str,
max_char_length: Optional[int] = None) -> Tuple[Subtitle, Optional[float]]:
"""Transcribe an audiovisual file and generate subtitles using the original subtitle (with accurate time codes) as prompts.
Arguments:
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The input subtitle file path to provide prompts.
language_code {string} -- An alpha 3 language code derived from ISO 639-3.
max_char_length {int} -- Optional Maximum number of characters for each generated subtitle segment.
Returns:
tuple: Generated subtitle after transcription and the detected frame rate
Expand All @@ -104,27 +135,54 @@ def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_fil
f'"{language_code}" is not supported by {self.__recipe} ({self.__flavour})')
audio_file_path = self.__media_helper.extract_audio(video_file_path, True, 16000)
subtitle = Subtitle.load(subtitle_file_path)
segment_paths = []
segment_paths: List[str] = []
try:
srt_str = ""
srt_idx = 1
self.__LOGGER.info("Start transcribing the audio...")
verbose = False if Logger.VERBOSE and not Logger.QUIET else None
for sub in subtitle.subs:
segment_paths = []
args = []
longest_segment_char_length = 0
for sub in tqdm(subtitle.subs, desc="Extracting audio segments"):
segment_path, _ = self.__media_helper.extract_audio_from_start_to_end(audio_file_path, str(sub.start), str(sub.end))
segment_paths.append(segment_path)
audio = whisper.load_audio(segment_path)
result = self.__model.transcribe(audio, task="transcribe", language=LANGUAGES[lang], verbose=verbose, initial_prompt=sub.text)
args.append((segment_path, sub.text, self.__lock, self.__LOGGER))
if len(sub.text) > longest_segment_char_length:
longest_segment_char_length = len(sub.text)
max_subtitle_char_length = max_char_length or longest_segment_char_length

max_workers = math.ceil(float(os.getenv("MAX_WORKERS", mp.cpu_count() / 2)))
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(partial(self._whisper_transcribe, model=self.__model, lang=lang), args))
for sub, result in zip(subtitle.subs, results):
original_start_in_secs = sub.start.hours * 3600 + sub.start.minutes * 60 + sub.start.seconds + sub.start.milliseconds / 1000.0
original_end_in_secs = sub.end.hours * 3600 + sub.end.minutes * 60 + sub.end.seconds + sub.end.milliseconds / 1000.0
for segment in result["segments"]:
if segment["end"] <= segment["start"]:
continue
if len(result["segments"]) == 0:
srt_str += f"{srt_idx}\n" \
f"{Utils.format_timestamp(original_start_in_secs + segment['start'])} --> {Utils.format_timestamp(min(original_start_in_secs + segment['end'], original_end_in_secs))}\n" \
f"{segment['text'].strip().replace('-->', '->')}\n" \
f"{Utils.format_timestamp(original_start_in_secs)} --> {Utils.format_timestamp(original_end_in_secs)}\n" \
f"{sub.text.strip().replace('-->', '->')}\n" \
"\n"
srt_idx += 1
else:
for segment in result["segments"]:
if segment["end"] <= segment["start"]:
continue
segment_end = min(original_start_in_secs + segment["end"], original_end_in_secs)
if len(segment["text"]) > max_subtitle_char_length:
srt_str, srt_idx = self._chunk_segment(segment,
srt_str,
srt_idx,
max_subtitle_char_length,
original_start_in_secs,
original_end_in_secs)
else:
srt_str += f"{srt_idx}\n" \
f"{Utils.format_timestamp(original_start_in_secs + segment['start'])} --> {Utils.format_timestamp(segment_end)}\n" \
f"{segment['text'].strip().replace('-->', '->')}\n" \
"\n"
srt_idx += 1
if segment_end == original_end_in_secs:
break
self.__LOGGER.info("Finished transcribing the audio")
subtitle = Subtitle.load_subrip_str(srt_str)
subtitle, frame_rate = self.__on_frame_timecodes(subtitle, video_file_path)
Expand All @@ -139,6 +197,66 @@ def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_fil
else:
raise NotImplementedError(f"{self.__recipe} ({self.__flavour}) is not supported")

@staticmethod
def _whisper_transcribe(args: Tuple, model: Whisper, lang: str) -> Dict:
segment_path, sub_text, lock, logger = args
verbose = False if Logger.VERBOSE and not Logger.QUIET else None
try:
waveform, _ = torchaudio.load(segment_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0)
waveform = waveform.numpy().astype(np.float32)
with lock:
result = model.transcribe(waveform,
task="transcribe",
language=LANGUAGES[lang],
verbose=verbose,
initial_prompt=sub_text,
word_timestamps=True)
logger.debug("Segment transcribed : %s", result)
return result
except Exception as e:
logger.warning(f"Error while transcribing segment: {e}")
return {"segments": []}

@staticmethod
def _chunk_segment(segment: Dict,
srt_str: str,
srt_idx: int,
max_subtitle_char_length: int,
start_offset: float = 0.0,
end_ceiling: float = float("inf")) -> Tuple[str, int]:
chunked_text = ""
chunk_start_in_secs = 0.0
chunk_end_in_secs = 0.0
chunk_char_length = 0

for word in segment["words"]:
if chunk_char_length + len(word["word"]) > max_subtitle_char_length and chunked_text.strip() != "":
srt_str += f"{srt_idx}\n" \
f"{Utils.format_timestamp(start_offset + chunk_start_in_secs)} --> {Utils.format_timestamp(min(start_offset + chunk_end_in_secs, end_ceiling))}\n" \
f"{chunked_text.strip().replace('-->', '->')}\n" \
"\n"
srt_idx += 1
chunked_text = word["word"]
chunk_start_in_secs = word["start"]
chunk_char_length = len(word["word"])
else:
if chunk_start_in_secs == 0.0:
chunk_start_in_secs = word["start"]
chunked_text += word["word"]
chunk_char_length += len(word["word"])
chunk_end_in_secs = word["end"]

if len(chunked_text) > 0:
srt_str += f"{srt_idx}\n" \
f"{Utils.format_timestamp(start_offset + chunk_start_in_secs)} --> {Utils.format_timestamp(min(start_offset + chunk_end_in_secs, end_ceiling))}\n" \
f"{chunked_text.strip().replace('-->', '->')}\n" \
"\n"
srt_idx += 1

return srt_str, srt_idx

def __on_frame_timecodes(self, subtitle: Subtitle, video_file_path: str) -> Tuple[Subtitle, Optional[float]]:
frame_rate = None
try:
Expand Down
27 changes: 18 additions & 9 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
[tox]
envlist =
py36,
py37,
py38
py38,
py39,
py310,
py311
skipsdist=True
skip_missing_interpreters = True

[darglint]
ignore=DAR101

[testenv:py36]
basepython = python3.6
[testenv:py38]
basepython = python3.8
whitelist_externals = /bin/bash
commands =
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
bash -c \'cat requirements-dev.txt | xargs -L 1 pip install'
python -m unittest discover

[testenv:py37]
basepython = python3.7
[testenv:py39]
basepython = python3.9
whitelist_externals = /bin/bash
commands =
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
bash -c \'cat requirements-dev.txt | xargs -L 1 pip install'
python -m unittest discover

[testenv:py38]
basepython = python3.8
[testenv:py310]
basepython = python3.10
whitelist_externals = /bin/bash
commands =
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
bash -c \'cat requirements-dev.txt | xargs -L 1 pip install'
python -m unittest discover

[testenv:py311]
basepython = python3.11
whitelist_externals = /bin/bash
commands =
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
Expand Down

0 comments on commit c06d9e1

Please sign in to comment.