Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions mlx_audio/tts/models/irodori_tts/README.md
Original file line number Diff line number Diff line change
@@ -1,39 +1,70 @@
# Irodori TTS

Japanese text-to-speech model based on Echo TTS architecture, ported to MLX.
Uses Rectified Flow diffusion with a DiT (Diffusion Transformer) and DACVAE codec (48kHz).
Flow Matching-based Japanese TTS model, ported to MLX.
Uses a Rectified Flow DiT over continuous DACVAE latents (48kHz).
Architecture and training follow [Echo-TTS](https://jordandarefsky.com/blog/2025/echo/).

## Model
Original: [Aratako/Irodori-TTS](https://github.com/Aratako/Irodori-TTS)

Original: [Aratako/Irodori-TTS-500M](https://huggingface.co/Aratako/Irodori-TTS-500M) (500M parameters)
## Models

### v2 (recommended)

| Model | HuggingFace | Conditioning |
|---|---|---|
| `mlx-community/Irodori-TTS-500M-v2-fp16` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-v2-fp16) | Voice cloning (reference audio) |
| `mlx-community/Irodori-TTS-500M-v2-8bit` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-v2-8bit) | Voice cloning (reference audio) |
| `mlx-community/Irodori-TTS-500M-v2-4bit` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-v2-4bit) | Voice cloning (reference audio) |
| `mlx-community/Irodori-TTS-500M-v2-VoiceDesign-fp16` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-v2-VoiceDesign-fp16) | Voice design (text description) |
| `mlx-community/Irodori-TTS-500M-v2-VoiceDesign-8bit` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-v2-VoiceDesign-8bit) | Voice design (text description) |
| `mlx-community/Irodori-TTS-500M-v2-VoiceDesign-4bit` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-v2-VoiceDesign-4bit) | Voice design (text description) |

### v1

| Model | HuggingFace |
|---|---|
| `mlx-community/Irodori-TTS-500M-fp16` | [link](https://huggingface.co/mlx-community/Irodori-TTS-500M-fp16) |

## Usage

Python API:
### Voice cloning

```python
from mlx_audio.tts import load
from mlx_audio.tts.generate import generate_audio

model = load("mlx-community/Irodori-TTS-500M-fp16")
result = next(model.generate("こんにちは、音声合成のテストです。"))
audio = result.audio
generate_audio(
model="mlx-community/Irodori-TTS-500M-v2-fp16",
text="今日はいい天気ですね。",
ref_audio="speaker.wav",
file_prefix="output",
)
```

With reference audio for voice cloning:
```bash
python -m mlx_audio.tts.generate \
--model mlx-community/Irodori-TTS-500M-v2-fp16 \
--text "今日はいい天気ですね。" \
--ref_audio speaker.wav
```

### VoiceDesign

Describe the desired voice in text instead of providing reference audio:

```python
result = next(model.generate(
"こんにちは、音声合成のテストです。",
ref_audio="speaker.wav",
))
generate_audio(
model="mlx-community/Irodori-TTS-500M-v2-VoiceDesign-fp16",
text="今日はいい天気ですね。",
instruct="落ち着いた、近い距離感の女性話者",
file_prefix="output",
)
```

CLI:

```bash
python -m mlx_audio.tts.generate \
--model mlx-community/Irodori-TTS-500M-fp16 \
--text "こんにちは、音声合成のテストです。"
--model mlx-community/Irodori-TTS-500M-v2-VoiceDesign-fp16 \
--text "今日はいい天気ですね。" \
--instruct "落ち着いた、近い距離感の女性話者"
```

## Memory requirements
Expand All @@ -42,11 +73,13 @@ The default `sequence_length=750` requires approximately 24GB of unified memory.
On 16GB machines, use reduced settings:

```python
result = next(model.generate(
"こんにちは。",
sequence_length=300, # ~9GB
cfg_guidance_mode="alternating", # ~1/3 of independent mode memory
))
generate_audio(
model="mlx-community/Irodori-TTS-500M-v2-fp16",
text="こんにちは。",
sequence_length=300,
cfg_guidance_mode="alternating",
file_prefix="output",
)
```

Approximate memory usage with `cfg_guidance_mode="alternating"`:
Expand All @@ -61,12 +94,10 @@ With `cfg_guidance_mode="independent"` (default), multiply memory by ~3.

## Notes

- Input language: Japanese. Latin characters may not be pronounced correctly;
convert them to katakana beforehand (e.g. "MLX" → "エムエルエックス").
- The DACVAE codec weights (`facebook/dacvae-watermarked`) are automatically
downloaded on first use.
- v2 uses [Semantic-DACVAE-Japanese-32dim](https://huggingface.co/Aratako/Semantic-DACVAE-Japanese-32dim)
and is bundled in the converted model weights.
- v1 uses `facebook/dacvae-watermarked`, downloaded automatically on first use.

## License

Irodori-TTS weights are released under the [MIT License](https://opensource.org/licenses/MIT).
See [Aratako/Irodori-TTS-500M](https://huggingface.co/Aratako/Irodori-TTS-500M) for details.
MIT License. See [Aratako/Irodori-TTS-500M-v2](https://huggingface.co/Aratako/Irodori-TTS-500M-v2) for details.
73 changes: 68 additions & 5 deletions mlx_audio/tts/models/irodori_tts/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@dataclass
class IrodoriDiTConfig(BaseModelArgs):
# Audio latent dimensions (DACVAE: 128-dim, 48kHz)
latent_dim: int = 128
# Audio latent dimensions (v2: 32-dim Semantic-DACVAE, v1: 128-dim DACVAE)
latent_dim: int = 32
Comment thread
lucasnewman marked this conversation as resolved.
latent_patch_size: int = 1

# DiT backbone
Expand Down Expand Up @@ -39,6 +39,64 @@ class IrodoriDiTConfig(BaseModelArgs):
adaln_rank: int = 192
norm_eps: float = 1e-5

# Caption (Voice Design) conditioning — mutually exclusive with speaker
use_caption_condition: bool = False
caption_vocab_size: Optional[int] = None
caption_tokenizer_repo: Optional[str] = None
caption_add_bos: Optional[bool] = None
caption_dim: Optional[int] = None
caption_layers: Optional[int] = None
caption_heads: Optional[int] = None
caption_mlp_ratio: Optional[float] = None

@property
def use_speaker_condition(self) -> bool:
return not self.use_caption_condition

@property
def caption_vocab_size_resolved(self) -> int:
return (
self.caption_vocab_size
if self.caption_vocab_size is not None
else self.text_vocab_size
)

@property
def caption_tokenizer_repo_resolved(self) -> str:
return (
self.caption_tokenizer_repo
if self.caption_tokenizer_repo is not None
else self.text_tokenizer_repo
)

@property
def caption_add_bos_resolved(self) -> bool:
return (
self.caption_add_bos
if self.caption_add_bos is not None
else self.text_add_bos
)

@property
def caption_dim_resolved(self) -> int:
return self.caption_dim if self.caption_dim is not None else self.text_dim

@property
def caption_layers_resolved(self) -> int:
return (
self.caption_layers if self.caption_layers is not None else self.text_layers
)

@property
def caption_heads_resolved(self) -> int:
return self.caption_heads if self.caption_heads is not None else self.text_heads

@property
def caption_mlp_ratio_resolved(self) -> float:
if self.caption_mlp_ratio is not None:
return float(self.caption_mlp_ratio)
return self.text_mlp_ratio_resolved

@property
def patched_latent_dim(self) -> int:
return self.latent_dim * self.latent_patch_size
Expand Down Expand Up @@ -69,6 +127,7 @@ class SamplerConfig(BaseModelArgs):
num_steps: int = 40
cfg_scale_text: float = 3.0
cfg_scale_speaker: float = 5.0
cfg_scale_caption: float = 3.0
cfg_guidance_mode: str = "independent"
cfg_min_t: float = 0.5
cfg_max_t: float = 1.0
Expand All @@ -88,11 +147,12 @@ class ModelConfig(BaseModelArgs):
sample_rate: int = 48000

max_text_length: int = 256
max_caption_length: int = 512
max_speaker_latent_length: int = 6400
# DACVAE hop_length = 2*8*10*12 = 1920
# DACVAE hop_length = 2*8*10*12 = 1920 (48kHz)
audio_downsample_factor: int = 1920

dacvae_repo: str = "Aratako/Irodori-TTS-500M"
dacvae_repo: str = "Aratako/Semantic-DACVAE-Japanese-32dim"
model_path: Optional[str] = None

dit: IrodoriDiTConfig = field(default_factory=IrodoriDiTConfig)
Expand All @@ -104,9 +164,12 @@ def from_dict(cls, config: dict) -> "ModelConfig":
model_type=config.get("model_type", "irodori_tts"),
sample_rate=config.get("sample_rate", 48000),
max_text_length=config.get("max_text_length", 256),
max_caption_length=config.get("max_caption_length", 512),
max_speaker_latent_length=config.get("max_speaker_latent_length", 6400),
audio_downsample_factor=config.get("audio_downsample_factor", 1920),
dacvae_repo=config.get("dacvae_repo", "Aratako/Irodori-TTS-500M"),
dacvae_repo=config.get(
"dacvae_repo", "Aratako/Semantic-DACVAE-Japanese-32dim"
),
model_path=config.get("model_path"),
dit=IrodoriDiTConfig.from_dict(config.get("dit", {})),
sampler=SamplerConfig.from_dict(config.get("sampler", {})),
Expand Down
57 changes: 50 additions & 7 deletions mlx_audio/tts/models/irodori_tts/irodori_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, config: ModelConfig):
self.model = IrodoriDiT(config.dit)
self.dacvae: DACVAE | None = None
self._tokenizer = None
self._caption_tokenizer = None

# ------------------------------------------------------------------
# Properties
Expand Down Expand Up @@ -127,6 +128,18 @@ def _get_tokenizer(self):
)
return self._tokenizer

def _get_caption_tokenizer(self):
if self._caption_tokenizer is None:
from transformers import AutoTokenizer

repo = self.config.dit.caption_tokenizer_repo_resolved
# Reuse text tokenizer if repo is the same
if repo == self.config.dit.text_tokenizer_repo:
self._caption_tokenizer = self._get_tokenizer()
else:
self._caption_tokenizer = AutoTokenizer.from_pretrained(repo)
return self._caption_tokenizer

def _prepare_text(
self, text: str, max_length: Optional[int] = None
) -> tuple[mx.array, mx.array]:
Expand All @@ -146,6 +159,18 @@ def _prepare_text(
add_bos=self.config.dit.text_add_bos,
)

def _prepare_caption(
self, caption: str, max_length: Optional[int] = None
) -> tuple[mx.array, mx.array]:
if max_length is None:
max_length = self.config.max_caption_length
return encode_text(
caption,
tokenizer=self._get_caption_tokenizer(),
max_length=max_length,
add_bos=self.config.dit.caption_add_bos_resolved,
)

# ------------------------------------------------------------------
# Reference audio encoding
# ------------------------------------------------------------------
Expand Down Expand Up @@ -191,15 +216,23 @@ def generate_latents(
text: str,
ref_latent: Optional[mx.array] = None,
ref_mask: Optional[mx.array] = None,
caption: Optional[str] = None,
rng_seed: int = 0,
**sampling_kwargs,
) -> mx.array:
text_input_ids, text_mask = self._prepare_text(text)

if ref_latent is None:
ref_latent = mx.zeros((1, 1, self.config.dit.latent_dim))
if ref_mask is None:
ref_mask = mx.zeros((1, ref_latent.shape[1]), dtype=mx.bool_)
caption_input_ids: Optional[mx.array] = None
caption_mask: Optional[mx.array] = None

if self.config.dit.use_caption_condition:
cap = caption or ""
caption_input_ids, caption_mask = self._prepare_caption(cap)
else:
if ref_latent is None:
ref_latent = mx.zeros((1, 1, self.config.dit.latent_dim))
if ref_mask is None:
ref_mask = mx.zeros((1, ref_latent.shape[1]), dtype=mx.bool_)

sampler_cfg = dict(self.config.sampler.__dict__)
for k, v in sampling_kwargs.items():
Expand All @@ -212,6 +245,8 @@ def generate_latents(
text_mask=text_mask,
ref_latent=ref_latent,
ref_mask=ref_mask,
caption_input_ids=caption_input_ids,
caption_mask=caption_mask,
rng_seed=rng_seed,
latent_dim=self.config.dit.patched_latent_dim,
**sampler_cfg,
Expand All @@ -226,9 +261,12 @@ def generate(
text: str,
voice: str | None = None,
ref_audio: str | mx.array | None = None,
caption: str | None = None,
stream: bool = False,
**kwargs,
) -> Generator[GenerationResult, None, None]:
# instruct is an alias for caption (mlx-audio convention)
caption = caption or kwargs.pop("instruct", None)
if stream:
raise NotImplementedError("Irodori-TTS streaming is not yet implemented.")

Expand Down Expand Up @@ -262,15 +300,20 @@ def generate(
text=text,
ref_latent=ref_latent,
ref_mask=ref_mask,
caption=caption,
rng_seed=int(kwargs.get("rng_seed", 0)),
**{k: v for k, v in kwargs.items() if k != "rng_seed"},
)

# Decode latent → waveform
# latent_out: (1, T, 128)
latent_for_decode = mx.transpose(latent_out, (0, 2, 1)) # (1, 128, T)
audio_out = self.dacvae.decode(latent_for_decode) # (1, L, 1)
# latent_out: (1, T, latent_dim)
latent_for_decode = mx.transpose(latent_out, (0, 2, 1)) # (1, latent_dim, T)
# Use chunked decoding to avoid large ConvTranspose1d intermediates on
# 16 GB unified-memory systems (stride-8 block creates a ~17 GB tensor
# at T=750 in a single pass; 50-frame chunks keep it under ~1.2 GB).
audio_out = self.dacvae.decode(latent_for_decode, chunk_size=50) # (1, L, 1)
audio_out = audio_out[:, :, 0] # (1, L)
mx.eval(audio_out)

# Trim trailing silence
silence_t = _find_silence_point(latent_out[0])
Expand Down
Loading