Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions backend/python/chatterbox/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS

from chatterbox.mtl_tts import ChatterboxMultilingualTTS
import grpc

def is_float(s):
try:
float(s)
return True
except ValueError:
return False

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

Expand Down Expand Up @@ -47,6 +53,27 @@ def LoadModel(self, request, context):
if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")


options = request.Options

# empty dict
self.options = {}

# The options are a list of strings in this form optname:optvalue
# We are storing all the options in a dict so we can use it later when
# generating the images
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":")
# if value is a number, convert it to the appropriate type
if is_float(value):
if value.is_integer():
value = int(value)
else:
value = float(value)
self.options[key] = value

self.AudioPath = None

if os.path.isabs(request.AudioPath):
Expand All @@ -56,10 +83,14 @@ def LoadModel(self, request, context):
modelFileBase = os.path.dirname(request.ModelFile)
# modify LoraAdapter to be relative to modelFileBase
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)

try:
print("Preparing models, please wait", file=sys.stderr)
self.model = ChatterboxTTS.from_pretrained(device=device)
if "multilingual" in self.options:
# remove key from options
del self.options["multilingual"]
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
else:
self.model = ChatterboxTTS.from_pretrained(device=device)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
Expand All @@ -68,12 +99,18 @@ def LoadModel(self, request, context):

def TTS(self, request, context):
try:
# Generate audio using ChatterboxTTS
kwargs = {}

if "language" in self.options:
kwargs["language_id"] = self.options["language"]
if self.AudioPath is not None:
wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath)
else:
wav = self.model.generate(request.text)

kwargs["audio_prompt_path"] = self.AudioPath

# add options to kwargs
kwargs.update(self.options)

# Generate audio using ChatterboxTTS
wav = self.model.generate(request.text, **kwargs)
# Save the generated audio
ta.save(request.dst, wav, self.model.sr)

Expand Down
Loading