From 8ad90eea0d899fbef4ab76b6a154aa4799ee4e7d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 9 Sep 2025 14:38:34 +0200 Subject: [PATCH] feat(chatterbox): support multilingual Signed-off-by: Ettore Di Giacinto --- backend/python/chatterbox/backend.py | 53 +++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/backend/python/chatterbox/backend.py b/backend/python/chatterbox/backend.py index 0944202b9457..40dac9bd31d6 100644 --- a/backend/python/chatterbox/backend.py +++ b/backend/python/chatterbox/backend.py @@ -14,9 +14,15 @@ import torch import torchaudio as ta from chatterbox.tts import ChatterboxTTS - +from chatterbox.mtl_tts import ChatterboxMultilingualTTS import grpc +def is_float(s): + try: + float(s) + return True + except ValueError: + return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -47,6 +53,27 @@ def LoadModel(self, request, context): if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") + + options = request.Options + + # empty dict + self.options = {} + + # The options are a list of strings in this form optname:optvalue + # We are storing all the options in a dict so we can use it later when + # generating the images + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":") + # if value is a number, convert it to the appropriate type + if is_float(value): + if value.is_integer(): + value = int(value) + else: + value = float(value) + self.options[key] = value + self.AudioPath = None if os.path.isabs(request.AudioPath): @@ -56,10 +83,14 @@ def LoadModel(self, request, context): modelFileBase = os.path.dirname(request.ModelFile) # modify LoraAdapter to be relative to modelFileBase self.AudioPath = os.path.join(modelFileBase, request.AudioPath) - try: print("Preparing models, please wait", file=sys.stderr) - self.model = ChatterboxTTS.from_pretrained(device=device) + if "multilingual" in self.options: + # remove key from options + del self.options["multilingual"] + self.model = ChatterboxMultilingualTTS.from_pretrained(device=device) + else: + self.model = ChatterboxTTS.from_pretrained(device=device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service @@ -68,12 +99,18 @@ def LoadModel(self, request, context): def TTS(self, request, context): try: - # Generate audio using ChatterboxTTS + kwargs = {} + + if "language" in self.options: + kwargs["language_id"] = self.options["language"] if self.AudioPath is not None: - wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath) - else: - wav = self.model.generate(request.text) - + kwargs["audio_prompt_path"] = self.AudioPath + + # add options to kwargs + kwargs.update(self.options) + + # Generate audio using ChatterboxTTS + wav = self.model.generate(request.text, **kwargs) # Save the generated audio ta.save(request.dst, wav, self.model.sr)