Skip to content

Commit

Permalink
Add large (v1 and v2) Whisper models (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Dec 8, 2022
1 parent 195bcae commit d95db97
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,9 @@ notarize_log:

codesign_verify:
codesign --verify --deep --strict --verbose=2 dist/Buzz.app

VENV_PATH := $(shell poetry env info -p)

# Make GGML model from whisper. Example: make ggml model_path=/Users/chidiwilliams/.cache/whisper/medium.pt
ggml:
python3 ./whisper.cpp/models/convert-pt-to-ggml.py ${model_path} $(VENV_PATH)/src/whisper dist
14 changes: 9 additions & 5 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class Quality(enum.Enum):
LOW = 'low'
MEDIUM = 'medium'
HIGH = 'high'
VERY_HIGH = 'very high'


class QualityComboBox(QComboBox):
Expand Down Expand Up @@ -333,11 +334,12 @@ def on_next_interval(self, stopped=False):

def get_model_name(quality: Quality) -> str:
return {
Quality.VERY_LOW: ('tiny', 'tiny.en'),
Quality.LOW: ('base', 'base.en'),
Quality.MEDIUM: ('small', 'small.en'),
Quality.HIGH: ('medium', 'medium.en'),
}[quality][0]
Quality.VERY_LOW: 'tiny',
Quality.LOW: 'base',
Quality.MEDIUM: 'small',
Quality.HIGH: 'medium',
Quality.VERY_HIGH: 'large',
}[quality]


def show_model_download_error_dialog(parent: QWidget, error: str):
Expand Down Expand Up @@ -707,6 +709,8 @@ def on_download_model_progress(self, progress: Tuple[int, int]):
def on_download_model_error(self, error: str):
show_model_download_error_dialog(self, error)
self.stop_recording()
self.record_button.force_stop()
self.record_button.setDisabled(False)

def on_transcriber_event_changed(self, event: RecordingTranscriber.Event):
if isinstance(event, RecordingTranscriber.TranscribedNextChunkEvent):
Expand Down
8 changes: 5 additions & 3 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
'medium': '6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208',
}


Expand Down Expand Up @@ -57,7 +58,8 @@ def run(self):
"/")[-2]
if os.path.isfile(model_path):
model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
self.signals.completed.emit(model_path)
return
else:
Expand Down Expand Up @@ -88,10 +90,10 @@ def run(self):
self.signals.error.emit(str(exc))
logging.exception('')
except requests.RequestException:
self.signals.error.emit('A connection error occurred.')
self.signals.error.emit('A connection error occurred')
logging.exception('')
except Exception:
self.signals.error.emit('An unknown error occurred.')
self.signals.error.emit('An unknown error occurred')
logging.exception('')

def stop(self):
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 5 additions & 13 deletions tests/gui_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pathlib
from unittest.mock import Mock, patch
from pytestqt.qtbot import QtBot


import sounddevice
from PyQt6.QtCore import QCoreApplication, Qt, pyqtBoundSignal
Expand Down Expand Up @@ -172,32 +174,22 @@ def test_should_init(self):


class TestFileTranscriberWidget:
def test_should_transcribe(self, qtbot, tmp_path: pathlib.Path):
def test_should_transcribe(self, qtbot: QtBot, tmp_path: pathlib.Path):
widget = FileTranscriberWidget(
file_path='testdata/whisper-french.mp3', parent=None)
qtbot.addWidget(widget)

output_file_path = tmp_path / 'whisper.txt'

with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock:
with (patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock,
qtbot.wait_signal(widget.transcribed, timeout=30*1000)):
save_file_name_mock.return_value = (output_file_path, '')
widget.run_button.click()

wait_signal_while_processing(widget.transcribed)

output_file = open(output_file_path, 'r', encoding='utf-8')
assert 'Bienvenue dans Passe-Relle, un podcast' in output_file.read()


def wait_signal_while_processing(signal: pyqtBoundSignal):
mock = Mock()
signal.connect(mock)
while True:
QCoreApplication.processEvents()
if mock.call_count > 0:
break


class TestSettings:
def test_should_enable_ggml_inference(self):
settings = Settings()
Expand Down

0 comments on commit d95db97

Please sign in to comment.