Skip to content

Commit 8ad90ee

Browse files
committed
feat(chatterbox): support multilingual
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 59311d8 commit 8ad90ee

File tree

1 file changed

+45
-8
lines changed

1 file changed

+45
-8
lines changed

backend/python/chatterbox/backend.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
import torch
1515
import torchaudio as ta
1616
from chatterbox.tts import ChatterboxTTS
17-
17+
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
1818
import grpc
1919

20+
def is_float(s):
21+
try:
22+
float(s)
23+
return True
24+
except ValueError:
25+
return False
2026

2127
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
2228

@@ -47,6 +53,27 @@ def LoadModel(self, request, context):
4753
if not torch.cuda.is_available() and request.CUDA:
4854
return backend_pb2.Result(success=False, message="CUDA is not available")
4955

56+
57+
options = request.Options
58+
59+
# empty dict
60+
self.options = {}
61+
62+
# The options are a list of strings in this form optname:optvalue
63+
# We are storing all the options in a dict so we can use it later when
64+
# generating the images
65+
for opt in options:
66+
if ":" not in opt:
67+
continue
68+
key, value = opt.split(":")
69+
# if value is a number, convert it to the appropriate type
70+
if is_float(value):
71+
if value.is_integer():
72+
value = int(value)
73+
else:
74+
value = float(value)
75+
self.options[key] = value
76+
5077
self.AudioPath = None
5178

5279
if os.path.isabs(request.AudioPath):
@@ -56,10 +83,14 @@ def LoadModel(self, request, context):
5683
modelFileBase = os.path.dirname(request.ModelFile)
5784
# modify LoraAdapter to be relative to modelFileBase
5885
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
59-
6086
try:
6187
print("Preparing models, please wait", file=sys.stderr)
62-
self.model = ChatterboxTTS.from_pretrained(device=device)
88+
if "multilingual" in self.options:
89+
# remove key from options
90+
del self.options["multilingual"]
91+
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
92+
else:
93+
self.model = ChatterboxTTS.from_pretrained(device=device)
6394
except Exception as err:
6495
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
6596
# Implement your logic here for the LoadModel service
@@ -68,12 +99,18 @@ def LoadModel(self, request, context):
6899

69100
def TTS(self, request, context):
70101
try:
71-
# Generate audio using ChatterboxTTS
102+
kwargs = {}
103+
104+
if "language" in self.options:
105+
kwargs["language_id"] = self.options["language"]
72106
if self.AudioPath is not None:
73-
wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath)
74-
else:
75-
wav = self.model.generate(request.text)
76-
107+
kwargs["audio_prompt_path"] = self.AudioPath
108+
109+
# add options to kwargs
110+
kwargs.update(self.options)
111+
112+
# Generate audio using ChatterboxTTS
113+
wav = self.model.generate(request.text, **kwargs)
77114
# Save the generated audio
78115
ta.save(request.dst, wav, self.model.sr)
79116

0 commit comments

Comments
 (0)