diff --git a/src/chatterbox/mtl_tts.py b/src/chatterbox/mtl_tts.py index fccfbf3d..f6a4ee7d 100644 --- a/src/chatterbox/mtl_tts.py +++ b/src/chatterbox/mtl_tts.py @@ -163,7 +163,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS': ve = VoiceEncoder() ve.load_state_dict( - torch.load(ckpt_dir / "ve.pt", weights_only=True) + torch.load(ckpt_dir / "ve.pt", map_location=device, weights_only=True) ) ve.to(device).eval() @@ -176,7 +176,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS': s3gen = S3Gen() s3gen.load_state_dict( - torch.load(ckpt_dir / "s3gen.pt", weights_only=True) + torch.load(ckpt_dir / "s3gen.pt", map_location=device, weights_only=True) ) s3gen.to(device).eval()