Skip to content

Commit

Permalink
[DONE] Add Whisper transcription (#991)
Browse files Browse the repository at this point in the history
* Add Whisper transcription

* Fix typo

* Refactor

* Add dependence

---------

Co-authored-by: Ptitloup <[email protected]>
  • Loading branch information
david42 and ptitloup authored Oct 24, 2023
1 parent 8ebd2a2 commit 8e17f90
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
1 change: 1 addition & 0 deletions pod/video_encode_transcript/transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if (
importlib.util.find_spec("vosk") is not None
or importlib.util.find_spec("stt") is not None
or importlib.util.find_spec("whisper") is not None
):
from .transcript_model import start_transcripting

Expand Down
40 changes: 38 additions & 2 deletions pod/video_encode_transcript/transcript_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from vosk import Model, KaldiRecognizer
elif TRANSCRIPTION_TYPE == "STT":
from stt import Model
elif TRANSCRIPTION_TYPE == "WHISPER":
import whisper

TRANSCRIPTION_NORMALIZE = getattr(settings_local, "TRANSCRIPTION_NORMALIZE", False)
TRANSCRIPTION_NORMALIZE_TARGET_LEVEL = getattr(
Expand Down Expand Up @@ -91,8 +93,13 @@ def start_transcripting(mp3filepath, duration, lang):
"""
if TRANSCRIPTION_NORMALIZE:
mp3filepath = normalize_mp3(mp3filepath)
transript_model = get_model(lang)
msg, webvtt, all_text = start_main_transcript(mp3filepath, duration, transript_model)
if TRANSCRIPTION_TYPE == "WHISPER":
msg, webvtt, all_text = main_whisper_transcript(mp3filepath, lang)
else:
transript_model = get_model(lang)
msg, webvtt, all_text = start_main_transcript(
mp3filepath, duration, transript_model
)
if DEBUG:
print(msg)
print(webvtt)
Expand Down Expand Up @@ -395,6 +402,35 @@ def main_stt_transcript(norm_mp3_file, duration, transript_model):
return msg, webvtt, all_text


def main_whisper_transcript(norm_mp3_file, lang):
"""Whisper transcription."""
msg = ""
all_text = ""
webvtt = WebVTT()
inference_start = timer()
msg += "\nInference start %0.3fs." % inference_start

model = whisper.load_model(
TRANSCRIPTION_MODEL_PARAM[TRANSCRIPTION_TYPE][lang]["model"],
download_root=TRANSCRIPTION_MODEL_PARAM[TRANSCRIPTION_TYPE][lang]["download_root"]
)

transcription = model.transcribe(norm_mp3_file, language=lang)
msg += "\nRunning inference."

for segment in transcription["segments"]:
caption = Caption(
format_time_caption(segment['start']),
format_time_caption(segment['end']),
segment['text'],
)
webvtt.captions.append(caption)

inference_end = timer() - inference_start
msg += "\nInference took %0.3fs." % inference_end
return msg, webvtt, all_text


def change_previous_end_caption(webvtt, start_caption):
"""Change the end time for caption."""
if len(webvtt.captions) > 0:
Expand Down
3 changes: 3 additions & 0 deletions requirements-transcripts.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ stt==1.4.0 # linux only

# Offline open source speech recognition API based on Kaldi and Vosk
vosk==0.3.45

# Whisper
openai-whisper

0 comments on commit 8e17f90

Please sign in to comment.