diff --git a/README.md b/README.md index a0ddce678..b587e744e 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ for result in model.generate("Hello from MLX-Audio!", voice="af_heart"): | **Voxtral TTS** | Mistral's 4B multilingual TTS (20 voices, 9 languages) | EN, FR, ES, DE, IT, PT, NL, AR, HI | [mlx-community/Voxtral-4B-TTS-2603-mlx-bf16](https://huggingface.co/mlx-community/Voxtral-4B-TTS-2603-mlx-bf16) | | **LongCat-AudioDiT** | SOTA diffusion TTS in waveform latent space with voice cloning | ZH, EN | [mlx-community/LongCat-AudioDiT-1B-bf16](https://huggingface.co/mlx-community/LongCat-AudioDiT-1B-bf16) | | **MeloTTS** | Lightweight VITS2-based TTS with streaming | EN (more coming) | [mlx-community/MeloTTS-English-MLX](https://huggingface.co/mlx-community/MeloTTS-English-MLX) | +| **Higgs Audio v2** | 3B Llama-backed TTS with real-time voice cloning | EN, ZH, KO, DE, ES | [bf16 (upstream)](https://huggingface.co/bosonai/higgs-audio-v2-generation-3B-base), [q8](https://huggingface.co/mlx-community/higgs-audio-v2-3B-mlx-q8), [q6](https://huggingface.co/mlx-community/higgs-audio-v2-3B-mlx-q6) | ### Speech-to-Text (STT) diff --git a/docs/models/tts/higgs_audio.md b/docs/models/tts/higgs_audio.md new file mode 100644 index 000000000..271a1b48f --- /dev/null +++ b/docs/models/tts/higgs_audio.md @@ -0,0 +1,183 @@ +# Higgs Audio v2 + +Higgs Audio v2 is a Llama-3.2-3B-backed TTS with multi-codebook acoustic tokens and delay-pattern streaming. The MLX port targets the 3B open-weights release from Boson AI and reuses the in-tree HiggsAudio acoustic tokenizer (originally added for OmniVoice). + +## Highlights + +- Real-time voice cloning on Apple Silicon (RTF ≈ 0.6× bf16 / 0.36× q8 / 0.33× q6 on M5 Max) +- Reference-audio voice cloning via ChatML prompt format +- Full `AUDIO_INIT` + delay-pattern ramp-in/out state machine +- Repetition-avoidance sampling (RAS) for stable long-form output +- MLX native 4/6/8-bit quantization with optional per-layer protection + +## Basic usage + +### Top-level CLI + +```bash +python -m mlx_audio.tts.generate \ + --model mlx-community/higgs-audio-v2-3B-mlx-q8 \ + --text "Hello from Higgs Audio on MLX." \ + --ref_audio path/to/reference.wav \ + --ref_text "Transcript of the reference clip." +``` + +The `Model` class conforms to the standard mlx-audio interface, so the +existing `mlx_audio.tts.generate` CLI and `mlx_audio.server` both work +unchanged against Higgs. + +### Python API (standard) + +```python +from mlx_audio.tts.utils import load +import soundfile as sf + +model = load("mlx-community/higgs-audio-v2-3B-mlx-q8") + +for result in model.generate( + text="Hello from Higgs Audio on MLX.", + ref_audio="path/to/reference.wav", # optional; strongly recommended + ref_text="Transcript of the reference clip.", + temperature=0.7, + top_p=0.95, + max_new_frames=1200, + fade_in_ms=30.0, +): + sf.write("output.wav", result.audio, result.sample_rate) +``` + +Without `ref_audio`, generation runs in "smart voice" mode (random voice +per sample). This works but is less reliable than voice cloning — the +sampling occasionally collapses to `stream_eos` early and produces silent +output. If that happens, rerun (each call draws fresh noise) or pass +`ref_audio`. For production use, a reference voice is strongly recommended. + +### Python API (Higgs-specific kwargs) + +For direct access to the full Higgs parameter surface (RAS windowing, +sampling warmup, pre-loaded codec override, etc.), use `HiggsAudioServer`: + +```python +from mlx_audio.tts.models.higgs_audio import HiggsAudioServer +import soundfile as sf + +server = HiggsAudioServer.from_pretrained( + model_path="bosonai/higgs-audio-v2-generation-3B-base", # bf16 base + codec_path="mlx-community/higgs-audio-v2-tokenizer", # acoustic tokenizer +) + +result = server.generate( + target_text="Hello from Higgs Audio on MLX.", + temperature=0.7, + top_p=0.95, + max_new_frames=1200, + fade_in_ms=30.0, +) +sf.write("output.wav", result.pcm, result.sampling_rate) +``` + +### Recommended parameters + +- `temperature=0.7`, `top_p=0.95` — proven stable across prompt lengths during the M5 benchmark +- `max_new_frames=1200` — generous cap; generation stops naturally at the EOS ramp +- `fade_in_ms=30.0`, `fade_out_ms=15.0` — suppresses the first-frame transient that the 5ms default occasionally lets through + +## Voice cloning + +Pass `ref_audio` (path or pre-loaded mx.array at 24 kHz mono) together with +`ref_text` (the transcript of that clip). Reference audio is encoded through +the in-tree `HiggsAudioTokenizer` and stitched into the assistant turn of a +ChatML prompt — the transcript is required for stable alignment between the +cloned voice and the target text. + +```python +for result in model.generate( + text="Hello, this is a cloned voice.", + ref_audio="reference.wav", + ref_text="Transcript of the reference clip.", + temperature=0.7, + top_p=0.95, + max_new_frames=1200, + fade_in_ms=30.0, +): + sf.write("output.wav", result.audio, result.sample_rate) +``` + +Best results come from 5–15 seconds of clean reference speech. + +### Bundled sample voices + +Three drop-in reference voices ship in `examples/voice_prompts/`, generated via Higgs smart-voice mode so they're license-clean: + +- `en_woman.wav` — English, feminine register +- `en_man.wav` — English, masculine register +- `en_man_deep.wav` — English, masculine register, lower pitch + +Each `.wav` is paired with a matching `.txt` transcript. See `examples/voice_prompts/README.md` for the usage snippet. + +## Streaming + +For chunked streaming output (e.g. Pipecat pipelines), use +`HiggsAudioServer.generate_stream`: + +```python +for pcm_chunk in server.generate_stream( + target_text="Generating in chunks for live playback.", + reference_audio_path="reference.wav", + reference_text="...", + chunk_ms=640.0, +): + # emit or resample pcm_chunk (float32 at 24 kHz) + ... +``` + +Current shape: full generate, then chunk the resulting PCM. Per-chunk quality matches non-streaming exactly. Mid-generation streaming (emit-as-you-go) is not yet supported because the neural-vocoder codec produces subtly different PCM at the same sample position when called with different accumulated lengths — boundary discontinuities become audible. Proper overlap-add streaming is follow-up work. + +## Quantization + +MLX native 4/6/8-bit quantization works on the Llama backbone. The audio head and audio codebook embeddings benefit from staying at bf16 — quantizing them introduces voice-character drift (pitch register shifts at q6, trajectory instability at q4). + +Already-quantized checkpoints load transparently via `load(...)` — config.json carries a `quantization` block that the framework applies before weight load. To quantize in place on a fresh bf16 load, use `model.model_quant_predicate`: + +```python +import mlx.core as mx +import mlx.nn as nn +from mlx_audio.tts.utils import load + +model = load("bosonai/higgs-audio-v2-generation-3B-base") +nn.quantize(model, group_size=64, bits=8, class_predicate=model.model_quant_predicate) +mx.eval(model.parameters()) +``` + +Benchmark on M5 Max (warm), long-prompt RTF: + +| variant | RTF | weights size | notes | +|---------|-------|--------------|---------------------------------------------| +| bf16 | 0.60× | 6.8 GB | `bosonai/higgs-audio-v2-generation-3B-base` (authoritative) | +| q8 | 0.36× | 6.18 GB | `mlx-community/higgs-audio-v2-3B-mlx-q8` | +| q6 | 0.33× | 4.75 GB | `mlx-community/higgs-audio-v2-3B-mlx-q6` | +| q4 | 0.26× | 3.32 GB | deferred — seed-sensitive, follow-up PR | + +bf16 is served directly from the authoritative `bosonai/*` upload — no need for a redundant mlx-community re-host. q8 and q6 are MLX-specific selectively-quantized variants. + +## Sampling controls + +- `temperature=0.7`, `top_p=0.95` are the Higgs defaults. +- `ras_win_len=7`, `ras_max_repeat=2` enables repetition-avoidance sampling (catches near-tie mispicks that compound into loops). Set `ras_win_len=None` to disable. +- `sampling_warmup_frames=N` uses greedy sampling for the first N frames, then switches to temperature. Exposed for experimentation; not helpful at default settings. +- `fade_in_ms=5.0`, `fade_out_ms=5.0` applies a short linear fade to the decoded PCM boundaries. Below onset perception threshold on bf16/q8; masks rounding-click transients on quantized variants. + +## Implementation notes + +The generation state machine is the non-obvious piece of this port. See source at `mlx_audio/tts/models/higgs_audio/higgs_audio.py:HiggsAudioModel._generate_raw_frames`. The first audio frame is **synthetic all `audio_stream_bos_id`** (AUDIO_INIT) — not sampled from audio_logits at the `<|audio_out_bos|>` text position, because those logits were never trained for direct audio prediction. Without this, the model emits the stream-EOS token on half the codebooks at step 1 and output collapses to a stuck pitch. + +Codebook `i` is emitted with `i`-frame delay, so the first K frames are a progressive ramp-in (cb₀ sampled at frame 1, cb₁ at frame 2, etc.; the rest forced to BOS). On any codebook emitting EOS, a K-frame ramp-out begins — trailing codebooks forced to EOS before termination. After `revert_delay_pattern`, the first and last aligned columns are dropped (BOS-seed and EOS-seal — they decode to arbitrary codec token 1023 and produce audible clicks otherwise). + +## References + +- Original repo: +- Paper / blog: +- HF model (reference): +- HF model (MLX q8): +- HF model (MLX q6): +- HF codec: diff --git a/examples/higgs_audio_clone_demo.py b/examples/higgs_audio_clone_demo.py new file mode 100644 index 000000000..1fd0b1adf --- /dev/null +++ b/examples/higgs_audio_clone_demo.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +"""Higgs Audio v2 voice cloning demo. + +Uses the Higgs-specific HiggsAudioServer API for full parameter surface. +For a drop-in example against the standard mlx_audio.tts.generate CLI, +see docs/models/tts/higgs_audio.md. + +Quick start with the bundled `en_woman` sample voice: + python examples/higgs_audio_clone_demo.py \\ + --text "Text to synthesize in the cloned voice." + +Supply your own reference: + python examples/higgs_audio_clone_demo.py \\ + --ref_audio reference.wav \\ + --ref_text "Reference transcript text." \\ + --text "Text to synthesize in the cloned voice." + +Reference audio is encoded through the in-tree HiggsAudioTokenizer and +stitched into the assistant turn of a ChatML prompt. ref_text is the +transcript of the reference clip — required for stable alignment +between the cloned voice and the target text. + +Best results come from 5-15 seconds of clean reference speech. +Three sample voices live in examples/voice_prompts/ (en_woman, +en_man, en_man_deep) for drop-in use. +""" + +import argparse +import sys +import time +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import soundfile as sf + +from mlx_audio.tts.models.higgs_audio import HiggsAudioServer + + +def _quantize_predicate(name: str, module: nn.Module) -> bool: + """Keep audio head + audio codebook embeddings at bf16 — they're most + sensitive to quantization noise. Everything else (Llama backbone + text + head) gets compressed.""" + if not isinstance(module, (nn.Linear, nn.Embedding)): + return False + protected = ("audio_codebook_embeddings", "audio_decoder_proj.audio_lm_head") + return not any(b in name for b in protected) + + +def _default_voice_prompt() -> tuple[str, str]: + """Return the bundled `en_woman` voice prompt (wav path + transcript).""" + here = Path(__file__).resolve().parent / "voice_prompts" + return str(here / "en_woman.wav"), (here / "en_woman.txt").read_text().strip() + + +def main() -> int: + p = argparse.ArgumentParser(description="Higgs Audio v2 voice cloning demo") + default_ref_audio, default_ref_text = _default_voice_prompt() + p.add_argument( + "--ref_audio", + default=default_ref_audio, + help="Reference audio WAV (defaults to bundled en_woman sample)", + ) + p.add_argument( + "--ref_text", default=default_ref_text, help="Transcript of the reference audio" + ) + p.add_argument("--text", required=True, help="Target text to synthesize") + p.add_argument("--output", default="higgs_clone_output.wav", help="Output WAV path") + p.add_argument( + "--model", + default="mlx-community/higgs-audio-v2-3B-mlx-bf16", + help="Higgs Audio v2 MLX model repo or path", + ) + p.add_argument( + "--codec", + default="mlx-community/higgs-audio-v2-tokenizer", + help="Higgs Audio v2 tokenizer repo or path", + ) + p.add_argument( + "--quantize_bits", + type=int, + default=None, + choices=[4, 6, 8], + help="Optionally quantize the loaded model in-place (4/6/8-bit)", + ) + p.add_argument("--temperature", type=float, default=0.7) + p.add_argument("--top_p", type=float, default=0.95) + p.add_argument("--max_new_frames", type=int, default=1200) + p.add_argument("--ras_win_len", type=int, default=7) + p.add_argument("--ras_max_repeat", type=int, default=2) + p.add_argument( + "--fade_in_ms", + type=float, + default=30.0, + help="Leading fade (ms) — 30ms suppresses the first-frame transient cleanly", + ) + p.add_argument("--fade_out_ms", type=float, default=15.0) + args = p.parse_args() + + print(f"[load] HiggsAudioServer from {args.model}") + t0 = time.monotonic() + server = HiggsAudioServer.from_pretrained( + model_path=args.model, + codec_path=args.codec, + ) + if args.quantize_bits is not None: + print( + f"[quantize] group_size=64 bits={args.quantize_bits} (audio head protected)" + ) + nn.quantize( + server.model, + group_size=64, + bits=args.quantize_bits, + class_predicate=_quantize_predicate, + ) + mx.eval(server.model.parameters()) + print(f" loaded in {time.monotonic() - t0:.2f}s") + + print(f"[generate] target: {args.text!r}") + t_gen = time.monotonic() + result = server.generate( + target_text=args.text, + reference_audio_path=args.ref_audio, + reference_text=args.ref_text, + max_new_frames=args.max_new_frames, + temperature=args.temperature, + top_p=args.top_p, + ras_win_len=args.ras_win_len, + ras_max_repeat=args.ras_max_repeat, + fade_in_ms=args.fade_in_ms, + fade_out_ms=args.fade_out_ms, + ) + wall = time.monotonic() - t_gen + audio_sec = len(result.pcm) / result.sampling_rate + rtf = wall / audio_sec if audio_sec > 0 else float("inf") + + sf.write(args.output, result.pcm, result.sampling_rate) + print( + f"[done] {audio_sec:.2f}s audio in {wall:.2f}s wall " + f"(RTF {rtf:.2f}×, {result.num_frames_raw} frames, " + f"stop={result.stop_reason}) → {args.output}" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/voice_prompts/README.md b/examples/voice_prompts/README.md new file mode 100644 index 000000000..3f63c17f2 --- /dev/null +++ b/examples/voice_prompts/README.md @@ -0,0 +1,39 @@ +# Higgs Audio v2 — Sample Voice Prompts + +Drop-in reference voices for `HiggsAudioServer.generate(..., reference_audio_path=...)`. Each `.wav` is paired with a `.txt` containing the transcript of that clip (required for stable alignment between the cloned voice and the target text). + +| File | Character | +| --- | --- | +| `en_woman.wav` | English, feminine register | +| `en_man.wav` | English, masculine register | +| `en_man_deep.wav` | English, masculine register, lower pitch | + +All three were generated via Higgs Audio v2 smart-voice mode (no human recordings), so they're license-clean and can be freely redistributed. + +## Usage + +```python +from mlx_audio.tts.models.higgs_audio import HiggsAudioServer +from pathlib import Path + +voice_dir = Path("examples/voice_prompts") +ref_wav = voice_dir / "en_woman.wav" +ref_txt = (voice_dir / "en_woman.txt").read_text().strip() + +server = HiggsAudioServer.from_pretrained( + model_path="mlx-community/higgs-audio-v2-3B-mlx-q8", + codec_path="mlx-community/higgs-audio-v2-tokenizer", +) + +result = server.generate( + target_text="Anything you want cloned in the chosen voice.", + reference_audio_path=str(ref_wav), + reference_text=ref_txt, + temperature=0.7, + top_p=0.95, + max_new_frames=1200, + fade_in_ms=30.0, +) +``` + +For the recommended parameter set, see [`docs/models/tts/higgs_audio.md`](../../docs/models/tts/higgs_audio.md). diff --git a/examples/voice_prompts/en_man.txt b/examples/voice_prompts/en_man.txt new file mode 100644 index 000000000..e30751609 --- /dev/null +++ b/examples/voice_prompts/en_man.txt @@ -0,0 +1 @@ +The radio quietly played a familiar song. Outside, rain tapped against the window in a steady rhythm. Coffee cooled slowly in a ceramic mug. Somewhere down the hall, a door clicked shut. diff --git a/examples/voice_prompts/en_man.wav b/examples/voice_prompts/en_man.wav new file mode 100644 index 000000000..09be87280 Binary files /dev/null and b/examples/voice_prompts/en_man.wav differ diff --git a/examples/voice_prompts/en_man_deep.txt b/examples/voice_prompts/en_man_deep.txt new file mode 100644 index 000000000..e30751609 --- /dev/null +++ b/examples/voice_prompts/en_man_deep.txt @@ -0,0 +1 @@ +The radio quietly played a familiar song. Outside, rain tapped against the window in a steady rhythm. Coffee cooled slowly in a ceramic mug. Somewhere down the hall, a door clicked shut. diff --git a/examples/voice_prompts/en_man_deep.wav b/examples/voice_prompts/en_man_deep.wav new file mode 100644 index 000000000..aac6f7e3e Binary files /dev/null and b/examples/voice_prompts/en_man_deep.wav differ diff --git a/examples/voice_prompts/en_woman.txt b/examples/voice_prompts/en_woman.txt new file mode 100644 index 000000000..e30751609 --- /dev/null +++ b/examples/voice_prompts/en_woman.txt @@ -0,0 +1 @@ +The radio quietly played a familiar song. Outside, rain tapped against the window in a steady rhythm. Coffee cooled slowly in a ceramic mug. Somewhere down the hall, a door clicked shut. diff --git a/examples/voice_prompts/en_woman.wav b/examples/voice_prompts/en_woman.wav new file mode 100644 index 000000000..c0aa7e794 Binary files /dev/null and b/examples/voice_prompts/en_woman.wav differ diff --git a/mlx_audio/tts/models/higgs_audio/README.md b/mlx_audio/tts/models/higgs_audio/README.md new file mode 100644 index 000000000..48d405a77 --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/README.md @@ -0,0 +1,114 @@ +# Higgs Audio v2 + +Llama-3.2-3B-backed TTS with multi-codebook acoustic tokens and delay-pattern streaming, with real-time voice cloning on Apple Silicon. The MLX port targets the 3B open-weights release from Boson AI and reuses the in-tree HiggsAudio acoustic tokenizer (originally added for OmniVoice). + +- **Original repo:** [boson-ai/higgs-audio](https://github.com/boson-ai/higgs-audio) +- **Paper / blog:** [Higgs Audio v2](https://boson.ai/blog/higgs-audio-v2) +- **Full MLX port docs:** [`docs/models/tts/higgs_audio.md`](../../../../docs/models/tts/higgs_audio.md) + +## Highlights + +- Reference-audio voice cloning via ChatML prompt format +- Full `AUDIO_INIT` + delay-pattern ramp-in/out state machine +- Repetition-avoidance sampling (RAS) for stable long-form output +- MLX native 4/6/8-bit quantization with optional per-layer protection +- Conforms to the standard mlx-audio interface (`mlx_audio.tts.generate` CLI works unchanged) + +## Usage + +CLI: + +```bash +python -m mlx_audio.tts.generate \ + --model mlx-community/higgs-audio-v2-3B-mlx-q8 \ + --text "Hello from Higgs Audio on MLX." \ + --ref_audio path/to/reference.wav \ + --ref_text "Transcript of the reference clip." +``` + +Python API (standard): + +```python +from mlx_audio.tts.utils import load +import soundfile as sf + +model = load("mlx-community/higgs-audio-v2-3B-mlx-q8") + +for result in model.generate( + text="Hello from Higgs Audio on MLX.", + ref_audio="path/to/reference.wav", + ref_text="Transcript of the reference clip.", +): + sf.write("output.wav", result.audio, result.sample_rate) +``` + +Python API (Higgs-specific surface): + +```python +from mlx_audio.tts.models.higgs_audio import HiggsAudioServer + +server = HiggsAudioServer( + model_path="mlx-community/higgs-audio-v2-3B-mlx-q8", + codec_path="mlx-community/higgs-audio-v2-tokenizer", +) +result = server.generate( + text="Hello from Higgs Audio on MLX.", + ref_audio_path="path/to/reference.wav", + ref_text="Transcript of the reference clip.", +) +``` + +## Voice Cloning + +Best results come from **5-15 seconds of clean reference speech**. Reference audio is encoded through the in-tree HiggsAudioTokenizer and stitched into the assistant turn of a ChatML prompt. `ref_text` is the transcript of the reference clip and is required for stable alignment between the cloned voice and the target text. + +Three sample voices ship in [`examples/voice_prompts/`](../../../../examples/voice_prompts/): `en_woman`, `en_man`, `en_man_deep`. + +See [`examples/higgs_audio_clone_demo.py`](../../../../examples/higgs_audio_clone_demo.py) for a complete cloning walkthrough. + +Without `ref_audio`, generation runs in "smart voice" mode (random voice per sample) — works but is less reliable than voice cloning. For production use, a reference voice is strongly recommended. + +## Generation Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `temperature` | 0.7 | Sampling temperature | +| `top_p` | 0.95 | Nucleus sampling cutoff | +| `top_k` | `None` | Optional top-k cap | +| `max_new_frames` | 1200 | Max acoustic frames to generate (≈ 48s @ 25 fps) | +| `fade_in_ms` | 30.0 | Fade-in on decoded audio | +| `fade_out_ms` | 30.0 | Fade-out on decoded audio | +| `ref_audio` | `None` | Path to reference audio (voice cloning) | +| `ref_text` | `None` | Transcript of the reference clip | + +## Available Models + +| Model | Parameters | Format | RTF (M5 Max) | Memory | +|-------|-----------|--------|--------------|--------| +| [`bosonai/higgs-audio-v2-generation-3B-base`](https://huggingface.co/bosonai/higgs-audio-v2-generation-3B-base) | 3B | bf16 (authoritative) | 0.60× | 6.8 GB | +| [`mlx-community/higgs-audio-v2-3B-mlx-q8`](https://huggingface.co/mlx-community/higgs-audio-v2-3B-mlx-q8) | 3B | 8-bit | 0.36× | 6.18 GB | +| [`mlx-community/higgs-audio-v2-3B-mlx-q6`](https://huggingface.co/mlx-community/higgs-audio-v2-3B-mlx-q6) | 3B | 6-bit | 0.33× | 4.75 GB | +| [`mlx-community/higgs-audio-v2-tokenizer`](https://huggingface.co/mlx-community/higgs-audio-v2-tokenizer) | — | — | — | acoustic tokenizer (required) | + +## Conversion + +To quantize or save a pre-converted format: + +```bash +python -m mlx_audio.convert \ + --hf-path bosonai/higgs-audio-v2-generation-3B-base \ + --mlx-path ./higgs-audio-v2-3B-mlx-q8 \ + --quantize --q-bits 8 +``` + +## Architecture + +- **Llama-3.2-3B backbone** for the text/language stream +- **Multi-codebook acoustic tokens** with `AUDIO_INIT` initialization and delay-pattern ramp-in/out +- **HiggsAudio acoustic tokenizer** (shared with the in-tree OmniVoice entry) at 24kHz +- **ChatML prompt format** for ref-audio conditioning +- **RAS (repetition-avoidance sampling)** for long-form stability + +## License + +Higgs Audio v2 is released under the [Apache 2.0 License](https://github.com/boson-ai/higgs-audio/blob/main/LICENSE). diff --git a/mlx_audio/tts/models/higgs_audio/__init__.py b/mlx_audio/tts/models/higgs_audio/__init__.py new file mode 100644 index 000000000..fec5463ff --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/__init__.py @@ -0,0 +1,23 @@ +"""Higgs Audio v2 — MLX port.""" + +from .config import HiggsAudioConfig, HiggsTextConfig +from .higgs_audio import HiggsAudioModel +from .model import Model, ModelConfig +from .serve import HiggsAudioGenerationResult, HiggsAudioServer, build_prompt + +__all__ = [ + # Framework-conforming entry points — used by mlx_audio.tts.utils.load + # and `python -m mlx_audio.tts.generate`. + "Model", + "ModelConfig", + # Lower-level building blocks. + "HiggsAudioConfig", + "HiggsTextConfig", + "HiggsAudioModel", + "build_prompt", + # Additional Python API — composes model + codec + tokenizer with a + # kwarg-rich generate(target_text=..., reference_audio_path=..., ...) + # signature tailored to the Higgs serve flow. + "HiggsAudioServer", + "HiggsAudioGenerationResult", +] diff --git a/mlx_audio/tts/models/higgs_audio/config.py b/mlx_audio/tts/models/higgs_audio/config.py new file mode 100644 index 000000000..d01c8f558 --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/config.py @@ -0,0 +1,113 @@ +"""Higgs Audio v2 (Boson AI) — MLX port config. + +Maps the PyTorch HiggsAudioConfig (transformers-style) into dataclasses +usable by this MLX port. + +Upstream reference: + bosonai/higgs-audio-v2-generation-3B-base/config.json + boson_multimodal/model/higgs_audio/configuration_higgs_audio.py +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class HiggsTextConfig: + """Llama-3.2-3B backbone configuration.""" + + hidden_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 24 + num_key_value_heads: int = 8 + intermediate_size: int = 8192 + vocab_size: int = 128256 + rope_theta: float = 500000.0 + rms_norm_eps: float = 1e-5 + tie_word_embeddings: bool = True + rope_scaling: Optional[dict] = field( + default_factory=lambda: { + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + } + ) + model_type: str = "llama" + + +@dataclass +class HiggsAudioConfig: + """Top-level Higgs v2 config: Llama backbone + Higgs-specific audio extensions.""" + + text_config: HiggsTextConfig = field(default_factory=HiggsTextConfig) + + # Audio codebook params (match bosonai/higgs-audio-v2-tokenizer) + audio_num_codebooks: int = 8 + audio_codebook_size: int = 1024 + audio_stream_bos_id: int = 1024 + audio_stream_eos_id: int = 1025 + + # Dual-FFN layer indices — which backbone layers run the audio MLP path. + # For v2 3B this is [0..27], i.e. all layers. + audio_dual_ffn_layers: List[int] = field(default_factory=lambda: list(range(28))) + + # If True, insert a separate audio-attention layer per dual-FFN block. + # For v2 3B this is 0 — attention is fully shared between text and audio. + use_audio_out_self_attention: bool = False + + # If > 0, add an extra transformer stack inside HiggsAudioDecoderProjector + # before the audio head. For v2 3B this is 0 — the decoder projector is + # literally just two nn.Linear heads (text_lm_head + audio_lm_head). + audio_decoder_proj_num_layers: int = 0 + + # If True, generation uses the delay-pattern trick (codebook i is emitted + # with i-frame delay so one forward pass predicts K codebooks at once). + # For v2 3B: True. Requires revert_delay_pattern at decode time. + use_delay_pattern: bool = True + + # Special-token ids in the text vocab that mark audio-in and audio-out + # positions. Used to build the audio_out_mask during generation. + audio_in_token_idx: Optional[int] = None + audio_out_token_idx: Optional[int] = None + audio_out_bos_token_id: Optional[int] = None + audio_eos_token_id: Optional[int] = None + pad_token_id: Optional[int] = None + + @classmethod + def from_dict(cls, d: dict) -> "HiggsAudioConfig": + """Construct from the HF config.json dict (permissive — ignores extras).""" + tc = d.get("text_config", {}) + text_config = HiggsTextConfig( + hidden_size=tc.get("hidden_size", 3072), + num_hidden_layers=tc.get("num_hidden_layers", 28), + num_attention_heads=tc.get("num_attention_heads", 24), + num_key_value_heads=tc.get("num_key_value_heads", 8), + intermediate_size=tc.get("intermediate_size", 8192), + vocab_size=tc.get("vocab_size", 128256), + rope_theta=tc.get("rope_theta", 500000.0), + rms_norm_eps=tc.get("rms_norm_eps", 1e-5), + tie_word_embeddings=tc.get("tie_word_embeddings", True), + rope_scaling=tc.get("rope_scaling"), + ) + return cls( + text_config=text_config, + audio_num_codebooks=d.get("audio_num_codebooks", 8), + audio_codebook_size=d.get("audio_codebook_size", 1024), + audio_stream_bos_id=d.get("audio_stream_bos_id", 1024), + audio_stream_eos_id=d.get("audio_stream_eos_id", 1025), + audio_dual_ffn_layers=d.get("audio_dual_ffn_layers", list(range(28))), + use_audio_out_self_attention=bool( + d.get("use_audio_out_self_attention", False) + ), + audio_decoder_proj_num_layers=d.get("audio_decoder_proj_num_layers", 0), + use_delay_pattern=d.get("use_delay_pattern", True), + audio_in_token_idx=d.get("audio_in_token_idx"), + audio_out_token_idx=d.get("audio_out_token_idx"), + audio_out_bos_token_id=d.get("audio_out_bos_token_id"), + audio_eos_token_id=d.get("audio_eos_token_id"), + pad_token_id=d.get("pad_token_id"), + ) diff --git a/mlx_audio/tts/models/higgs_audio/generation.py b/mlx_audio/tts/models/higgs_audio/generation.py new file mode 100644 index 000000000..aa46119c6 --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/generation.py @@ -0,0 +1,165 @@ +"""Higgs Audio v2 generation primitives — delay pattern, audio embedding +lookup, and a minimal generate_frames() loop. + +Not yet integrated with ChatML / reference audio / streaming — those +layers build on top of what's here. This module handles the per-frame +mechanics of autoregressive multi-codebook audio generation. +""" + +from __future__ import annotations + +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .config import HiggsAudioConfig + + +def revert_delay_pattern(data: mx.array) -> mx.array: + """Undo the delay pattern applied during generation. + + Input shape: (K, seq_len + K - 1) — codebook i is offset by i frames. + Output shape: (K, seq_len) — codebooks aligned at each timestep. + + Ported from boson_multimodal/model/higgs_audio/utils.py. + """ + assert data.ndim == 2, f"expected 2D, got {data.shape}" + K = data.shape[0] + L = data.shape[1] + rows = [data[i : i + 1, i : L - K + 1 + i] for i in range(K)] + return mx.concatenate(rows, axis=0) + + +def apply_delay_pattern(codebook_ids: mx.array, bos_id: int) -> mx.array: + """Apply the delay pattern to a sequence of aligned codebook frames. + + Inverse of revert_delay_pattern. Given aligned tokens [K, L], produce + delayed [K, L + K - 1] where codebook i starts emitting at frame i. + + Positions before codebook i's start are filled with bos_id. + Positions after codebook i's real content tail are untouched (we + typically don't need them, but they'd be EOS in a full generation). + """ + K = codebook_ids.shape[0] + L = codebook_ids.shape[1] + out = mx.full((K, L + K - 1), bos_id, dtype=codebook_ids.dtype) + for i in range(K): + out[i, i : i + L] = codebook_ids[i] + return out + + +def build_delay_pattern_mask( + input_ids: mx.array, bos_token_id: int, pad_token_id: int +) -> mx.array: + """Apply the delay pattern to a conditioning/prompt sequence of codebook tokens. + + Takes aligned [K, L] codebook ids and produces delayed [K, L + K - 1]: + - Lower triangle (i > j): bos_token_id + - Upper triangle (j >= L + i): pad_token_id + - Middle band: original input_ids aligned so codebook i is offset by i frames + + Ported from boson_multimodal/model/higgs_audio/utils.py build_delay_pattern_mask + (non-generation variant — no -1 placeholders since all positions are known). + """ + K, L = input_ids.shape + new_L = L + K - 1 + i_idx = mx.arange(K)[:, None] # [K, 1] + j_idx = mx.arange(new_L)[None, :] # [1, new_L] + bos_mask = j_idx < i_idx # below diag → BOS + eos_mask = j_idx >= (L + i_idx) # past content end → EOS + # Middle: input_ids[i, j - i] + src_j = mx.clip(j_idx - i_idx, 0, L - 1) + src_j_broad = mx.broadcast_to(src_j, (K, new_L)) + gathered = mx.take_along_axis(input_ids, src_j_broad, axis=1) + out = mx.where(bos_mask, mx.array(bos_token_id, dtype=input_ids.dtype), gathered) + out = mx.where(eos_mask, mx.array(pad_token_id, dtype=input_ids.dtype), out) + return out + + +def lookup_audio_embedding( + audio_codebook_embeddings: nn.Embedding, + codebook_ids: mx.array, + codebook_size_plus2: int, +) -> mx.array: + """Convert [K, T] codebook token ids to [T, hidden] summed embeddings. + + Each codebook's tokens are shifted by `k * (codebook_size + 2)` to index + into the shared embedding table. Per-codebook embeddings are summed + (per Higgs v2 config.audio_embed_avg = False). + + Args: + codebook_ids: shape [K, T] int32, values in [0, codebook_size + 1] + (includes stream BOS at codebook_size and EOS at codebook_size+1). + codebook_size_plus2: codebook_size + 2 (per-codebook stride in the + shared embedding table). + + Returns: + [T, hidden_size] float array. + """ + K, T = codebook_ids.shape + shift = mx.arange(K, dtype=mx.int32) * codebook_size_plus2 # [K] + shifted = codebook_ids + shift[:, None] # [K, T] + per_codebook = audio_codebook_embeddings(shifted) # [K, T, hidden] + return mx.sum(per_codebook, axis=0) # [T, hidden] + + +def greedy_sample_audio(audio_logits: mx.array) -> mx.array: + """Argmax per codebook on audio_logits. + + Args: + audio_logits: shape [B, T, K, C+2] (raw audio head output). + + Returns: + [B, T, K] int32 token ids. + """ + return mx.argmax(audio_logits, axis=-1).astype(mx.int32) + + +def sample_audio( + audio_logits: mx.array, + temperature: float = 0.7, + top_p: float = 0.95, + top_k: Optional[int] = None, +) -> mx.array: + """Temperature + (optional) top-p / top-k sampling per codebook. + + Args: + audio_logits: [B, T, K, C+2]. + temperature: 0 → greedy. + top_p: nucleus cutoff (applied per codebook independently). + top_k: top-k cutoff (applied per codebook independently, None → off). + + Returns: + [B, T, K] int32. + """ + if temperature <= 0.0: + return greedy_sample_audio(audio_logits) + + logits = audio_logits / temperature + + if top_k is not None and top_k > 0: + # Mask logits below top-k per codebook. + kth = mx.sort(logits, axis=-1)[..., -top_k : -top_k + 1] + logits = mx.where(logits < kth, -mx.inf, logits) + + if top_p is not None and 0.0 < top_p < 1.0: + # Nucleus sampling per codebook on the last axis. + sorted_idx = mx.argsort(-logits, axis=-1) # descending order + sorted_logits = mx.take_along_axis(logits, sorted_idx, axis=-1) + sorted_probs = mx.softmax(sorted_logits, axis=-1) + cumulative = mx.cumsum(sorted_probs, axis=-1) + # Keep positions where previous-cumsum < top_p (always keep idx 0). + shifted = mx.concatenate( + [mx.zeros_like(cumulative[..., :1]), cumulative[..., :-1]], axis=-1 + ) + keep_sorted = shifted < top_p + masked_sorted = mx.where(keep_sorted, sorted_logits, -mx.inf) + # Scatter back to original order. + inv_idx = mx.argsort(sorted_idx, axis=-1) + logits = mx.take_along_axis(masked_sorted, inv_idx, axis=-1) + + # Gumbel-max trick for categorical sampling. + u = mx.random.uniform(shape=logits.shape) + g = -mx.log(-mx.log(u + 1e-20) + 1e-20) + return mx.argmax(logits + g, axis=-1).astype(mx.int32) diff --git a/mlx_audio/tts/models/higgs_audio/higgs_audio.py b/mlx_audio/tts/models/higgs_audio/higgs_audio.py new file mode 100644 index 000000000..6f37e0b2d --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/higgs_audio.py @@ -0,0 +1,478 @@ +"""Higgs Audio v2 (Boson AI) — MLX port. + +Llama-3.2-3B backbone with a dual-FFN decoder layer (text path and audio path +share self-attention; LN + MLP are per-path, routed by `audio_out_mask`). Audio +tokens are emitted via a delay pattern (codebook i lags by i frames) — see +`generation.build_delay_pattern_mask` and `revert_delay_pattern`. + +The first generated audio frame must be all-`audio_stream_bos_id` (synthetic); +sampling from the bos-text-position audio_logits collapses to stream-EOS on +half the codebooks. See `Model.generate` for the full ramp-in / ramp-out state +machine. +""" + +from __future__ import annotations + +from typing import Iterator, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import create_causal_mask +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.models.llama import MLP as LlamaMLP +from mlx_lm.models.llama import Attention as LlamaAttention +from mlx_lm.models.llama import ModelArgs as LlamaModelArgs + +from .config import HiggsAudioConfig +from .generation import ( + greedy_sample_audio, + lookup_audio_embedding, + revert_delay_pattern, + sample_audio, +) + + +def _llama_args_from_text_config(text_cfg) -> LlamaModelArgs: + return LlamaModelArgs( + model_type="llama", + hidden_size=text_cfg.hidden_size, + num_hidden_layers=text_cfg.num_hidden_layers, + intermediate_size=text_cfg.intermediate_size, + num_attention_heads=text_cfg.num_attention_heads, + num_key_value_heads=text_cfg.num_key_value_heads, + rms_norm_eps=text_cfg.rms_norm_eps, + vocab_size=text_cfg.vocab_size, + rope_theta=text_cfg.rope_theta, + rope_scaling=text_cfg.rope_scaling, + tie_word_embeddings=text_cfg.tie_word_embeddings, + ) + + +class HiggsDualFFNDecoderLayer(nn.Module): + """One Llama-style decoder layer with dual-path norm+MLP for text vs audio tokens. + + Routing: positions where audio_out_mask==1 use the audio variants of the + input layernorm, post-attention layernorm, and MLP. All other positions use + the text variants. The self_attention module is shared. + + For v2 3B with use_audio_out_self_attention=0 and no fast_forward, every + layer (0..27) is this kind of block. + """ + + def __init__(self, llama_args: LlamaModelArgs, text_cfg): + super().__init__() + self.input_layernorm = nn.RMSNorm( + text_cfg.hidden_size, eps=text_cfg.rms_norm_eps + ) + self.audio_input_layernorm = nn.RMSNorm( + text_cfg.hidden_size, eps=text_cfg.rms_norm_eps + ) + self.self_attn = LlamaAttention(llama_args) + self.post_attention_layernorm = nn.RMSNorm( + text_cfg.hidden_size, eps=text_cfg.rms_norm_eps + ) + self.audio_post_attention_layernorm = nn.RMSNorm( + text_cfg.hidden_size, eps=text_cfg.rms_norm_eps + ) + self.mlp = LlamaMLP(llama_args) + self.audio_mlp = LlamaMLP(llama_args) + + def __call__( + self, + x: mx.array, + audio_out_mask: mx.array, # [B, T] bool + attn_mask: Optional[mx.array] = None, + cache=None, + ) -> mx.array: + # Pre-attention norm split: compute both paths, select by mask. + mask_expanded = audio_out_mask[..., None] # [B, T, 1] + h_text_norm = self.input_layernorm(x) + h_audio_norm = self.audio_input_layernorm(x) + h_norm = mx.where(mask_expanded, h_audio_norm, h_text_norm) + + # Shared attention + attn_out = self.self_attn(h_norm, attn_mask, cache) + h = x + attn_out + + # Post-attention: split norm + split MLP by mask. + post_text = self.post_attention_layernorm(h) + post_audio = self.audio_post_attention_layernorm(h) + mlp_text_out = self.mlp(post_text) + mlp_audio_out = self.audio_mlp(post_audio) + mlp_out = mx.where(mask_expanded, mlp_audio_out, mlp_text_out) + + return h + mlp_out + + +class HiggsAudioDecoderProjector(nn.Module): + """Final projection: hidden states → (text_logits, audio_logits). + + For v2 3B with audio_decoder_proj_num_layers=0 this is literally two linear + heads with no intermediate transformer layers. + """ + + def __init__(self, config: HiggsAudioConfig): + super().__init__() + hidden_size = config.text_config.hidden_size + vocab_size = config.text_config.vocab_size + audio_out_dim = config.audio_num_codebooks * (config.audio_codebook_size + 2) + + self.text_lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + self.audio_lm_head = nn.Linear(hidden_size, audio_out_dim, bias=False) + self.audio_num_codebooks = config.audio_num_codebooks + self.audio_codebook_plus2 = config.audio_codebook_size + 2 + + def __call__( + self, hidden_states: mx.array, audio_out_mask: Optional[mx.array] = None + ) -> Tuple[mx.array, Optional[mx.array]]: + """Project hidden states to text + audio logits. + + Returns (text_logits [B, T, vocab], audio_logits [B, T, K, C+2]). + audio_logits is None only when audio_out_mask is None and we want to + skip the audio head entirely (pure text forward). Otherwise we compute + audio logits for all positions — caller filters to audio positions by + mask. This is cheap (audio head is one Linear layer) and avoids MLX's + unsupported boolean-mask indexing. + """ + text_logits = self.text_lm_head(hidden_states) + + if audio_out_mask is None: + return text_logits, None + + # audio_flat: [B, T, K * (C+2)] + audio_flat = self.audio_lm_head(hidden_states) + B, T = hidden_states.shape[:2] + audio_logits = audio_flat.reshape( + B, T, self.audio_num_codebooks, self.audio_codebook_plus2 + ) + return text_logits, audio_logits + + +class HiggsAudioModel(nn.Module): + """End-to-end Higgs Audio v2 for MLX. + + NOT fully wired yet — see TODOs below. This scaffold defines the structure + and the forward for a single decoder-layer pass with dual-FFN routing. + + TODO (M3 continuation): + - Wire up the full forward: embed → stack of HiggsDualFFNDecoderLayer → norm + - Audio codebook embedding lookup: audio tokens (from the codec) index a + separate table; resulting embeddings replace embed_tokens[audio_pos] + - Position embeddings + RoPE + attention mask with cache compatibility + - sanitize(): weight-name conversion from bosonai safetensors to this + module's state dict (M2 already did this for the pure-backbone keys; + need to extend for dual-FFN audio_* keys + decoder_proj + codebook_embed) + + TODO (M4): + - generate_delta_stream async loop — ChatML prompt + reference audio → + tokenized context → autoregressive decode yielding [K] codebook frames + with delay pattern applied + """ + + def __init__(self, config: HiggsAudioConfig): + super().__init__() + self.config = config + llama_args = _llama_args_from_text_config(config.text_config) + + self.embed_tokens = nn.Embedding( + config.text_config.vocab_size, config.text_config.hidden_size + ) + + # Separate embedding table for audio codebook tokens. + # Per-codebook has (codebook_size + 2) entries (values 0..C-1 plus BOS/EOS). + self.audio_codebook_embeddings = nn.Embedding( + config.audio_num_codebooks * (config.audio_codebook_size + 2), + config.text_config.hidden_size, + ) + + self.layers = [ + HiggsDualFFNDecoderLayer(llama_args, config.text_config) + for _ in range(config.text_config.num_hidden_layers) + ] + self.norm = nn.RMSNorm( + config.text_config.hidden_size, eps=config.text_config.rms_norm_eps + ) + self.audio_decoder_proj = HiggsAudioDecoderProjector(config) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, + audio_out_mask: Optional[mx.array] = None, + attn_mask: Optional[mx.array] = None, + cache=None, + ) -> Tuple[mx.array, Optional[mx.array]]: + """Forward pass. + + Pass either `input_ids` (text tokens — looked up via `embed_tokens`) OR + `inputs_embeds` (pre-computed hidden states, e.g. for audio positions + where summed codebook embeddings replace a vocab lookup). + + Args: + input_ids: [B, T] int32 — text-vocab token ids. + inputs_embeds: [B, T, hidden] float — pre-computed embeddings. + Exactly one of input_ids / inputs_embeds must be provided. + audio_out_mask: [B, T] bool — 1 at positions that should be routed + through audio_mlp + audio_* layernorms. None → text-only (no + audio_logits computed). + attn_mask: [B, T, T] or broadcastable causal mask. Auto-created as + causal if None. + cache: optional kv-cache (list-per-layer) for incremental decoding. + + Returns: + (text_logits [B, T, vocab], audio_logits [B, T, K, C+2] or None). + """ + assert (input_ids is None) != ( + inputs_embeds is None + ), "pass exactly one of input_ids or inputs_embeds" + if input_ids is not None: + B, T = input_ids.shape + h = self.embed_tokens(input_ids) + else: + B, T = inputs_embeds.shape[:2] + h = inputs_embeds + + caller_wants_audio = audio_out_mask is not None + # Dual-FFN layers always need a mask; synthesize all-False when caller passes None. + layer_mask = ( + audio_out_mask if caller_wants_audio else mx.zeros((B, T), dtype=mx.bool_) + ) + + if attn_mask is None and T > 1: + attn_mask = create_causal_mask(T, offset=0) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, layer_mask, attn_mask, c) + + h = self.norm(h) + + # Only compute audio logits when caller explicitly passed a mask. + proj_mask = audio_out_mask if caller_wants_audio else None + text_logits, audio_logits = self.audio_decoder_proj(h, proj_mask) + return text_logits, audio_logits + + def sanitize(self, weights: dict) -> dict: + """Convert bosonai safetensors keys → our MLX state dict. + + Higgs's safetensors already use HuggingFace Llama naming with the + Higgs-specific audio_* additions, so this is essentially a pass-through. + Kept as a method so sanitize() can grow if we later add transposes + (e.g. if any conv-style weights need layout remapping). + """ + return dict(weights) + + # ------------------------------------------------------------------ + # Generation — AUDIO_INIT + delay-pattern ramp-in + EOS ramp-out + # ------------------------------------------------------------------ + + def _generate_raw_frames( + self, + inputs_embeds: mx.array, + audio_out_mask: mx.array, + *, + max_new_frames: int, + temperature: float, + top_p: Optional[float], + top_k: Optional[int], + ras_win_len: Optional[int], + ras_max_repeat: int, + sampling_warmup_frames: int, + ) -> Iterator[Tuple[mx.array, dict]]: + """Run prefill then yield ([K] int32 frame, info) per generation step. + + Frames are emitted in *delay-pattern* form (not aligned). Caller is + responsible for stacking, revert_delay_pattern, and boundary trim. + + Implements the Higgs v2 generation state machine documented at + project_higgs_audio_init_delay_pattern.md: + - frame 0: forced all audio_stream_bos_id (AUDIO_INIT — NOT sampled) + - frames 1..K-1: ramp-in (sample [0..num_delay], force tail to BOS) + - full sampling once num_delay >= K-1 + - any EOS in a frame triggers a K-frame EOS ramp-out, then stop + """ + cfg = self.config + K = cfg.audio_num_codebooks + BOS = cfg.audio_stream_bos_id + EOS = cfg.audio_stream_eos_id + stride = cfg.audio_codebook_size + 2 + + # Prefill. Discard logits (they come from a text position that was + # never trained for direct audio prediction — use AUDIO_INIT instead). + cache = make_prompt_cache(self) + _, _ = self( + inputs_embeds=inputs_embeds, audio_out_mask=audio_out_mask, cache=cache + ) + mx.eval(*[c.state for c in cache]) # force prefill graph materialization + + # Frame 0 = synthetic all-BOS (AUDIO_INIT). + frame0 = mx.full((K,), BOS, dtype=mx.int32) + yield frame0, {"step": 0, "source": "audio_init", "num_delay": 0} + + num_delay = 0 + num_remaining_delays: Optional[int] = None + step_mask = mx.ones((1, 1), dtype=mx.bool_) + prev_frame = frame0 + + # RAS window: per-codebook rolling history of recent sampled tokens. + # Seeded with frame 0 (all-BOS) so the rolling window is already filled + # enough to check from frame 1. We track python ints (small footprint). + ras_enabled = ras_win_len is not None and ras_win_len > 0 + # rows are codebooks, columns are recent frames (most recent last) + ras_window: list[list[int]] = [[BOS] for _ in range(K)] if ras_enabled else [] + + for step in range(max_new_frames): + last = prev_frame.reshape(K, 1) + embed = lookup_audio_embedding(self.audio_codebook_embeddings, last, stride) + _, audio_logits = self( + inputs_embeds=embed[None], audio_out_mask=step_mask, cache=cache + ) + # Greedy during warmup — pins trajectory through the low-context + # ramp-in region where quantization noise otherwise amplifies + # sampling variance into divergent outputs. After warmup, switch to + # temperature+top_p for natural prosody variance. + if step < sampling_warmup_frames: + sampled = greedy_sample_audio(audio_logits) + else: + sampled = sample_audio( + audio_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + mx.eval(sampled) + tok_list = sampled[0, 0].tolist() # list of K ints + + # --- RAS (repetition-avoidance sampling) BEFORE delay-pattern forcing. + # For each codebook, if the sampled token has appeared >= max_repeat + # times in the recent window, resample that codebook greedily (temp=0). + # This catches first-token-dominance loops that quantization noise + # triggers especially in the low-context prefix window. + if ras_enabled: + resample_mask_needed = False + per_cb_resample: list[bool] = [] + for cb_i in range(K): + window = ras_window[cb_i][-ras_win_len:] + count = sum(1 for v in window if v == tok_list[cb_i]) + need = count >= ras_max_repeat + per_cb_resample.append(need) + if need: + resample_mask_needed = True + if resample_mask_needed: + greedy = greedy_sample_audio(audio_logits) + mx.eval(greedy) + greedy_tok = greedy[0, 0].tolist() + for cb_i in range(K): + if per_cb_resample[cb_i]: + tok_list[cb_i] = greedy_tok[cb_i] + + if cfg.use_delay_pattern: + # Ramp-in: force tail codebooks to BOS. + if num_delay + 1 < K: + for k_i in range(num_delay + 1, K): + tok_list[k_i] = BOS + num_delay += 1 + + # Ramp-out in progress. + if num_remaining_delays is not None: + force_until = K - num_remaining_delays + for k_i in range(force_until): + tok_list[k_i] = EOS + num_remaining_delays -= 1 + else: + # Check whether any codebook emitted EOS → start ramp-out. + eos_positions = [i for i, v in enumerate(tok_list) if v == EOS] + if eos_positions: + last_eos = eos_positions[-1] + for k_i in range(last_eos): + tok_list[k_i] = EOS + num_remaining_delays = K - last_eos - 1 + + tok_arr = mx.array(tok_list, dtype=mx.int32) + if ras_enabled: + for cb_i in range(K): + ras_window[cb_i].append(tok_list[cb_i]) + if len(ras_window[cb_i]) > ras_win_len + 4: + # keep window bounded + ras_window[cb_i] = ras_window[cb_i][-ras_win_len:] + info = { + "step": step + 1, + "source": "sampled", + "num_delay": num_delay, + "num_remaining_delays": num_remaining_delays, + } + yield tok_arr, info + prev_frame = tok_arr + + if ( + cfg.use_delay_pattern + and num_remaining_delays is not None + and num_remaining_delays <= 0 + ): + return + + def generate( + self, + inputs_embeds: mx.array, + audio_out_mask: mx.array, + *, + max_new_frames: int = 900, + temperature: float = 0.7, + top_p: Optional[float] = 0.95, + top_k: Optional[int] = None, + ras_win_len: Optional[int] = 7, + ras_max_repeat: int = 2, + sampling_warmup_frames: int = 0, + trim_boundaries: bool = True, + ) -> Tuple[mx.array, dict]: + """Generate audio codebook tokens from a prefill of audio_out_mask=[...]. + + Args: + inputs_embeds: [1, T_prompt, hidden] prompt embeddings with any + text + reference-audio context already stitched. + audio_out_mask: [1, T_prompt] bool — True at positions routed + through the audio dual-FFN path. + max_new_frames: hard cap on generation length. + temperature / top_p / top_k: sampling controls. + trim_boundaries: drop the synthetic BOS-seed (col 0) and EOS-seal + (col -1) columns after revert_delay_pattern. Default True; + required for clean codec decode (otherwise those columns + clip to codec token 1023 → audible click at sample-zero and end). + + Returns: + (aligned_tokens, info): + aligned_tokens: [K, T_audio] int32, codec-ready + info: dict with num_frames, stop_reason, timing hooks, etc. + """ + frames = [] + stop_reason = "max-frames" + for tok, meta in self._generate_raw_frames( + inputs_embeds, + audio_out_mask, + max_new_frames=max_new_frames, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ras_win_len=ras_win_len, + ras_max_repeat=ras_max_repeat, + sampling_warmup_frames=sampling_warmup_frames, + ): + frames.append(tok) + if ( + meta.get("num_remaining_delays") is not None + and meta["num_remaining_delays"] <= 0 + ): + stop_reason = f"eos-ramp-complete-at-frame-{meta['step']}" + if len(frames) - 1 >= max_new_frames and stop_reason == "max-frames": + stop_reason = f"max-frames-{max_new_frames}" + + sequence = mx.stack(frames, axis=1).astype(mx.int32) # [K, N] + aligned = revert_delay_pattern(sequence) # [K, N-K+1] + if trim_boundaries and aligned.shape[1] >= 2: + aligned = aligned[:, 1:-1] + aligned = mx.clip(aligned, 0, self.config.audio_codebook_size - 1) + info = { + "num_frames_raw": sequence.shape[1], + "num_frames_aligned": aligned.shape[1], + "stop_reason": stop_reason, + } + return aligned, info diff --git a/mlx_audio/tts/models/higgs_audio/model.py b/mlx_audio/tts/models/higgs_audio/model.py new file mode 100644 index 000000000..282a75d6d --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/model.py @@ -0,0 +1,217 @@ +"""Higgs Audio v2 — framework-conforming entry point. + +Defines `Model` and `ModelConfig` in the shape expected by +`mlx_audio.tts.utils.load()` and the top-level +`python -m mlx_audio.tts.generate` CLI. + +`Model` subclasses `HiggsAudioModel` so the safetensors checkpoint loads +directly against it (no key remapping). Tokenizer + codec are attached +via `post_load_hook`; `generate(text, voice=..., ref_audio=..., ref_text=...)` +matches the framework's standard signature and yields a +`mlx_audio.tts.models.base.GenerationResult`. + +The richer, kwarg-style API (`HiggsAudioServer.from_pretrained(...).generate(target_text=...)`) +in `serve.py` is preserved as an additional Python entrypoint. +""" + +from __future__ import annotations + +import time +from typing import Iterator, Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from ..base import GenerationResult +from .config import HiggsAudioConfig +from .higgs_audio import HiggsAudioModel +from .serve import build_prompt + +# Framework loader reads `module.ModelConfig.from_dict(config_json)`; aliasing +# HiggsAudioConfig keeps a single source of truth. +ModelConfig = HiggsAudioConfig + + +_DEFAULT_CODEC_REPO = "mlx-community/higgs-audio-v2-tokenizer" + + +def _format_duration(seconds: float) -> str: + h = int(seconds // 3600) + m = int((seconds % 3600) // 60) + s = int(seconds % 60) + ms = int((seconds % 1) * 1000) + return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}" + + +class Model(HiggsAudioModel): + """Framework-conforming wrapper around `HiggsAudioModel`. + + Inherits the full model structure so the checkpoint loads without any + key remapping. Adds a tokenizer + codec (populated via `post_load_hook`) + and a `generate()` method matching the standard TTS interface. + """ + + def __init__(self, config: HiggsAudioConfig): + super().__init__(config) + self._config = config + self._tokenizer = None + self._codec = None + self._sample_rate = 24000 + + # --- framework hooks ------------------------------------------------ + + def model_quant_predicate(self, name: str, module: nn.Module) -> bool: + """Quantize everything except the audio head + codebook embeddings. + + These two components are most sensitive to quantization noise: q4/q6 + pushed the audio head to collapse to stream-EOS or drift a semitone. + Keeping them at bf16 preserves voice character; the Llama backbone + and text head compress cleanly. + """ + if not isinstance(module, (nn.Linear, nn.Embedding)): + return False + protected = ("audio_codebook_embeddings", "audio_decoder_proj.audio_lm_head") + return not any(p in name for p in protected) + + @classmethod + def post_load_hook(cls, model: "Model", model_path) -> "Model": + """Attach the HF text tokenizer (from model_path) and the Higgs codec + (from the default mlx-community repo, overridable post-hoc via + `model.codec = ...`).""" + from huggingface_hub import snapshot_download + from transformers import AutoTokenizer + + from ....codec.models.higgs_audio.higgs_audio import HiggsAudioTokenizer + + model._tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + codec_dir = snapshot_download(repo_id=_DEFAULT_CODEC_REPO) + model._codec = HiggsAudioTokenizer.from_pretrained(codec_dir) + return model + + # --- accessors ------------------------------------------------------ + + @property + def sample_rate(self) -> int: + return self._sample_rate + + @property + def tokenizer(self): + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value) -> None: + self._tokenizer = value + + @property + def codec(self): + return self._codec + + @codec.setter + def codec(self, value) -> None: + self._codec = value + + # --- generation ----------------------------------------------------- + + def generate( # type: ignore[override] + self, + text: str, + voice: Optional[str] = None, + ref_audio=None, + ref_text: Optional[str] = None, + max_new_frames: int = 1200, + temperature: float = 0.7, + top_p: Optional[float] = 0.95, + top_k: Optional[int] = None, + ras_win_len: Optional[int] = 7, + ras_max_repeat: int = 2, + sampling_warmup_frames: int = 0, + fade_in_ms: float = 30.0, + fade_out_ms: float = 15.0, + verbose: bool = False, + **kwargs, # absorb framework kwargs we don't use (speed, lang_code, ...) + ) -> Iterator[GenerationResult]: + """Synthesize `text` and yield a single `GenerationResult`. + + Voice cloning: pass `ref_audio` (mono 24 kHz mx.array or numpy array, + loaded by the top-level CLI from `--ref_audio`) together with + `ref_text` (its transcript). Without `ref_audio`, runs smart-voice + mode — works but less reliable; recommended to always supply a + reference for production use. + """ + if self._tokenizer is None or self._codec is None: + raise RuntimeError( + "Model not fully loaded — tokenizer/codec missing. Load via " + "mlx_audio.tts.utils.load(...) so post_load_hook runs." + ) + + start = time.perf_counter() + + ref_audio_24k = None + if ref_audio is not None: + ref_audio_24k = np.asarray(ref_audio, dtype=np.float32).reshape(-1) + + full_embeds, audio_out_mask, _info = build_prompt( + text, + ref_text=ref_text, + ref_audio_24k=ref_audio_24k, + config=self._config, + tokenizer=self._tokenizer, + codec=self._codec, + embed_tokens=self.embed_tokens, + audio_codebook_embeddings=self.audio_codebook_embeddings, + ) + + # Call the inherited HiggsAudioModel.generate (delay-pattern state machine). + aligned, gen_info = HiggsAudioModel.generate( + self, + inputs_embeds=full_embeds, + audio_out_mask=audio_out_mask, + max_new_frames=max_new_frames, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ras_win_len=ras_win_len, + ras_max_repeat=ras_max_repeat, + sampling_warmup_frames=sampling_warmup_frames, + ) + mx.eval(aligned) + + pcm = self._codec.decode(aligned.T) + mx.eval(pcm) + pcm_np = np.array(pcm).astype(np.float32).reshape(-1) + + sr = self._sample_rate + n_in = int(fade_in_ms * sr / 1000.0) + n_out = int(fade_out_ms * sr / 1000.0) + if n_in > 0 and pcm_np.size > n_in: + pcm_np[:n_in] *= np.linspace(0.0, 1.0, n_in, dtype=np.float32) + if n_out > 0 and pcm_np.size > n_out: + pcm_np[-n_out:] *= np.linspace(1.0, 0.0, n_out, dtype=np.float32) + + audio_mx = mx.array(pcm_np) + samples = audio_mx.shape[0] + elapsed = time.perf_counter() - start + duration_s = samples / sr + rtf = elapsed / duration_s if duration_s > 0 else 0.0 + tok = gen_info["num_frames_aligned"] + + yield GenerationResult( + audio=audio_mx, + samples=samples, + sample_rate=sr, + segment_idx=0, + token_count=tok, + audio_duration=_format_duration(duration_s), + real_time_factor=round(rtf, 3), + prompt={ + "tokens": tok, + "tokens-per-sec": round(tok / elapsed, 2) if elapsed > 0 else 0.0, + }, + audio_samples={ + "samples": samples, + "samples-per-sec": round(samples / elapsed, 2) if elapsed > 0 else 0.0, + }, + processing_time_seconds=elapsed, + peak_memory_usage=mx.get_peak_memory() / 1e9, + ) diff --git a/mlx_audio/tts/models/higgs_audio/serve.py b/mlx_audio/tts/models/higgs_audio/serve.py new file mode 100644 index 000000000..2abe26590 --- /dev/null +++ b/mlx_audio/tts/models/higgs_audio/serve.py @@ -0,0 +1,405 @@ +"""Higgs Audio v2 — MLX serve engine. + +High-level entry point that takes ChatML-style messages (with optional +reference audio for voice cloning) and produces PCM output. Mirrors the +PyTorch `boson_multimodal.serve.serve_engine.HiggsAudioServeEngine` shape +but uses the MLX port end-to-end. + +Usage: + + server = HiggsAudioServer.from_pretrained( + model_path="/path/to/higgs-v2-3b", + codec_path="/path/to/higgs-mlx-codec", + # tokenizer_path defaults to model_path + ) + pcm_24k, sr = server.generate( + target_text="Hello.", + reference_audio_path="/path/to/voice.wav", # optional + reference_text="Reference transcript.", # optional, paired with ref audio + temperature=0.7, top_p=0.95, + ) +""" + +from __future__ import annotations + +import json +import wave +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator, Optional, Tuple + +import mlx.core as mx +import numpy as np + +from ....codec.models.higgs_audio.higgs_audio import HiggsAudioTokenizer +from .config import HiggsAudioConfig +from .generation import build_delay_pattern_mask, lookup_audio_embedding +from .higgs_audio import HiggsAudioModel + + +@dataclass +class HiggsAudioGenerationResult: + pcm: np.ndarray # float32 in [-1, 1], mono + sampling_rate: int # 24000 for Higgs v2 + num_frames_raw: int + num_frames_aligned: int + stop_reason: str + + +def _load_wav_as_24k_mono(path: str) -> np.ndarray: + """Minimal wav loader → float32 mono at 24 kHz. Pure-python fallback. + + Higgs codec expects 24 kHz mono input. We resample via scipy if available + and fall back to a noisy message if the source isn't already 24 kHz. + """ + with wave.open(path, "rb") as r: + sr = r.getframerate() + nchan = r.getnchannels() + raw = r.readframes(r.getnframes()) + samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 + if nchan == 2: + samples = samples.reshape(-1, 2).mean(axis=1) + if sr != 24000: + try: + from scipy.signal import resample_poly + except ImportError as e: + raise RuntimeError( + f"reference audio is {sr} Hz; install scipy to resample, or pre-convert to 24 kHz mono wav" + ) from e + samples = resample_poly(samples, up=24000, down=sr).astype(np.float32) + return samples + + +def build_prompt( + target_text: str, + *, + ref_text: Optional[str], + ref_audio_24k, # np.ndarray (float32, mono, 24kHz) or None + config: HiggsAudioConfig, + tokenizer, + codec, + embed_tokens, # nn.Embedding, text vocab + audio_codebook_embeddings, # nn.Embedding, audio codebook table +) -> Tuple[mx.array, mx.array, dict]: + """Build (inputs_embeds [1,T,H], audio_out_mask [1,T], info) for a Higgs + generation step. + + Voice-clone mode (ref_audio_24k provided): ChatML layout + user: ref_text / assistant: / user: target / assistant: <|audio_out_bos|> + Smart-voice mode (no ref_audio_24k): target only, assistant kicks off at <|audio_out_bos|>. + """ + K = config.audio_num_codebooks + stride = config.audio_codebook_size + 2 + + def _encode(s: str) -> list[int]: + return tokenizer.encode(s, add_special_tokens=False) + + if ref_audio_24k is not None: + ref_codes = codec.encode(mx.array(ref_audio_24k).reshape(1, -1, 1)) + mx.eval(ref_codes) + ref_codes = ref_codes[0].T.astype(mx.int32) # [K, T_ref] + T_ref = ref_codes.shape[1] + + bos_col = mx.full((K, 1), config.audio_stream_bos_id, dtype=mx.int32) + eos_col = mx.full((K, 1), config.audio_stream_eos_id, dtype=mx.int32) + ref_wrapped = mx.concatenate([bos_col, ref_codes, eos_col], axis=1) + ref_delayed = build_delay_pattern_mask( + ref_wrapped, + bos_token_id=config.audio_stream_bos_id, + pad_token_id=config.audio_stream_eos_id, + ) + T_ref_d = ref_delayed.shape[1] + + prefix = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + f"{ref_text or ''}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|audio_out_bos|>" + ) + middle = ( + "<|audio_eos|><|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n" + f"{target_text}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|audio_out_bos|>" + ) + prefix_ids = _encode(prefix) + middle_ids = _encode(middle) + + prefix_emb = embed_tokens(mx.array([prefix_ids], dtype=mx.int32))[0] + middle_emb = embed_tokens(mx.array([middle_ids], dtype=mx.int32))[0] + audio_emb = lookup_audio_embedding( + audio_codebook_embeddings, ref_delayed, stride + ) + + full_embeds = mx.concatenate([prefix_emb, audio_emb, middle_emb], axis=0)[None] + audio_out_mask = mx.concatenate( + [ + mx.zeros((len(prefix_ids),), dtype=mx.bool_), + mx.ones((T_ref_d,), dtype=mx.bool_), + mx.zeros((len(middle_ids),), dtype=mx.bool_), + ], + axis=0, + )[None] + info = { + "mode": "voice_clone", + "T_ref": T_ref, + "T_ref_delayed": T_ref_d, + "text_len": len(prefix_ids) + len(middle_ids), + } + else: + # Smart-voice mode: no reference audio context. Less reliable. + prompt = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + f"{target_text}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|audio_out_bos|>" + ) + prompt_ids = _encode(prompt) + full_embeds = embed_tokens(mx.array([prompt_ids], dtype=mx.int32)) + audio_out_mask = mx.zeros((1, len(prompt_ids)), dtype=mx.bool_) + info = {"mode": "smart_voice", "text_len": len(prompt_ids)} + + return full_embeds, audio_out_mask, info + + +class HiggsAudioServer: + """Serve engine — composes model + codec + tokenizer for end-to-end generation.""" + + def __init__( + self, + model: HiggsAudioModel, + codec: HiggsAudioTokenizer, + tokenizer, # HF AutoTokenizer, kept loose to avoid transformers import here + config: HiggsAudioConfig, + ): + self.model = model + self.codec = codec + self.tokenizer = tokenizer + self.config = config + + @classmethod + def from_pretrained( + cls, + model_path: str, + codec_path: str, + tokenizer_path: Optional[str] = None, + ) -> "HiggsAudioServer": + import mlx.nn as nn + from transformers import AutoTokenizer + + # Accept either a local directory or a Hugging Face repo id. Repo ids + # look like "owner/name" and do not resolve to an existing local path. + def _resolve(path: str) -> Path: + p = Path(path) + if p.exists(): + return p + if "/" in path and not path.startswith(("/", "./", "../", "~")): + from huggingface_hub import snapshot_download + + return Path(snapshot_download(repo_id=path)) + return p # will error later with a clear "not found" + + model_dir = _resolve(model_path) + codec_dir = _resolve(codec_path) + tokenizer_dir = _resolve(tokenizer_path) if tokenizer_path else model_dir + raw_config = json.loads((model_dir / "config.json").read_text()) + config = HiggsAudioConfig.from_dict(raw_config) + + model = HiggsAudioModel(config) + + # If config declares a quantization block, quantize the model skeleton + # BEFORE loading weights so the QuantizedLinear parameter shapes match + # the saved checkpoint. The skip list mirrors what the checkpoint was + # built with — protected layers stay at bf16. + q = raw_config.get("quantization") + if q is not None: + skip = set(q.get("class_predicate_skip", [])) + + def _predicate(name: str, module: nn.Module) -> bool: + if not isinstance(module, (nn.Linear, nn.Embedding)): + return False + return not any(s in name for s in skip) + + nn.quantize( + model, + group_size=q["group_size"], + bits=q["bits"], + class_predicate=_predicate, + ) + + # Accept either a sharded bosonai-style layout (index + shards) or a + # single-file weights.safetensors / model.safetensors dump. + index_path = model_dir / "model.safetensors.index.json" + if index_path.exists(): + idx = json.loads(index_path.read_text()) + shards = sorted(set(idx["weight_map"].values())) + weights: dict = {} + for shard in shards: + weights.update(mx.load(str(model_dir / shard))) + else: + single = None + for candidate in ("model.safetensors", "weights.safetensors"): + if (model_dir / candidate).exists(): + single = model_dir / candidate + break + if single is None: + raise FileNotFoundError( + f"No weights found in {model_dir}: expected " + "model.safetensors.index.json, model.safetensors, " + "or weights.safetensors" + ) + weights = mx.load(str(single)) + + model.load_weights(list(weights.items()), strict=True) + mx.eval(model.parameters()) + + codec = HiggsAudioTokenizer.from_pretrained(str(codec_dir)) + tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir)) + + return cls(model=model, codec=codec, tokenizer=tokenizer, config=config) + + # ------------------------------------------------------------------ + # Prompt assembly + # ------------------------------------------------------------------ + + def _build_prompt( + self, + target_text: str, + reference_text: Optional[str], + reference_audio_path: Optional[str], + ) -> Tuple[mx.array, mx.array, dict]: + """Backward-compatible thin wrapper around build_prompt().""" + ref_audio_24k = ( + _load_wav_as_24k_mono(reference_audio_path) + if reference_audio_path is not None + else None + ) + return build_prompt( + target_text, + ref_text=reference_text, + ref_audio_24k=ref_audio_24k, + config=self.config, + tokenizer=self.tokenizer, + codec=self.codec, + embed_tokens=self.model.embed_tokens, + audio_codebook_embeddings=self.model.audio_codebook_embeddings, + ) + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + + def generate( + self, + target_text: str, + *, + reference_audio_path: Optional[str] = None, + reference_text: Optional[str] = None, + max_new_frames: int = 900, + temperature: float = 0.7, + top_p: Optional[float] = 0.95, + top_k: Optional[int] = None, + ras_win_len: Optional[int] = 7, + ras_max_repeat: int = 2, + sampling_warmup_frames: int = 0, + fade_in_ms: float = 5.0, + fade_out_ms: float = 5.0, + ) -> HiggsAudioGenerationResult: + """Synthesize target_text through the MLX port, optionally voice-cloned. + + Returns PCM at 24 kHz mono as float32 in [-1, 1]. + """ + full_embeds, audio_out_mask, _info = self._build_prompt( + target_text, reference_text, reference_audio_path + ) + + aligned, gen_info = self.model.generate( + inputs_embeds=full_embeds, + audio_out_mask=audio_out_mask, + max_new_frames=max_new_frames, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ras_win_len=ras_win_len, + ras_max_repeat=ras_max_repeat, + sampling_warmup_frames=sampling_warmup_frames, + ) + mx.eval(aligned) + + # Codec decode: expects [T, K] layout (per codec_roundtrip reference path). + pcm = self.codec.decode(aligned.T) + mx.eval(pcm) + pcm_np = np.array(pcm).astype(np.float32).reshape(-1) + + # Short linear fade-in/out masks sample-zero/sample-last rounding clicks + # from quantized variants. At 24 kHz a 5 ms fade is 120 samples — below + # the onset perception threshold, inaudible on clean material. + sr = 24000 + n_in = int(fade_in_ms * sr / 1000.0) + n_out = int(fade_out_ms * sr / 1000.0) + if n_in > 0 and pcm_np.size > n_in: + pcm_np[:n_in] *= np.linspace(0.0, 1.0, n_in, dtype=np.float32) + if n_out > 0 and pcm_np.size > n_out: + pcm_np[-n_out:] *= np.linspace(1.0, 0.0, n_out, dtype=np.float32) + + return HiggsAudioGenerationResult( + pcm=pcm_np, + sampling_rate=24000, + num_frames_raw=gen_info["num_frames_raw"], + num_frames_aligned=gen_info["num_frames_aligned"], + stop_reason=gen_info["stop_reason"], + ) + + # ------------------------------------------------------------------ + # Streaming + # ------------------------------------------------------------------ + + def generate_stream( + self, + target_text: str, + *, + reference_audio_path: Optional[str] = None, + reference_text: Optional[str] = None, + max_new_frames: int = 900, + temperature: float = 0.7, + top_p: Optional[float] = 0.95, + top_k: Optional[int] = None, + ras_win_len: Optional[int] = 7, + ras_max_repeat: int = 2, + sampling_warmup_frames: int = 0, + chunk_ms: float = 640.0, + fade_in_ms: float = 5.0, + fade_out_ms: float = 5.0, + ) -> Iterator[np.ndarray]: + """Yield PCM chunks (float32, 24 kHz mono) suitable for Pipecat. + + Current shape: full-generate then chunk. TTFB = full generation wall + time; per-chunk quality is identical to non-streaming generate(). + + Mid-generation emission was tried (re-decode every N frames on the + accumulated sequence) but produced audible discontinuities at chunk + boundaries — the codec is a neural vocoder whose output at a given + sample depends on full-sequence context, so re-decode at chunk N + vs N+1 produces slightly different PCM for the same positions. + Proper mid-generation streaming requires overlap-add or an + incremental-decoder API on the codec; deferred to follow-up. + """ + result = self.generate( + target_text, + reference_audio_path=reference_audio_path, + reference_text=reference_text, + max_new_frames=max_new_frames, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ras_win_len=ras_win_len, + ras_max_repeat=ras_max_repeat, + sampling_warmup_frames=sampling_warmup_frames, + fade_in_ms=fade_in_ms, + fade_out_ms=fade_out_ms, + ) + samples_per_chunk = int(chunk_ms * result.sampling_rate / 1000.0) + pcm = result.pcm + for i in range(0, pcm.size, samples_per_chunk): + yield pcm[i : i + samples_per_chunk] diff --git a/mlx_audio/tts/tests/test_higgs_audio.py b/mlx_audio/tts/tests/test_higgs_audio.py new file mode 100644 index 000000000..ae8c6d04e --- /dev/null +++ b/mlx_audio/tts/tests/test_higgs_audio.py @@ -0,0 +1,300 @@ +# Copyright (c) 2025, Prince Canuma and contributors (https://github.com/Blaizzy/mlx-audio) + +import unittest + +import mlx.core as mx +import numpy as np + +from mlx_audio.tts.models.higgs_audio.config import HiggsAudioConfig +from mlx_audio.tts.models.higgs_audio.generation import ( + apply_delay_pattern, + build_delay_pattern_mask, + greedy_sample_audio, + lookup_audio_embedding, + revert_delay_pattern, + sample_audio, +) +from mlx_audio.tts.models.higgs_audio.higgs_audio import HiggsAudioModel + + +def _tiny_config(): + """Small config just enough to shape-check forward — not real weights.""" + from mlx_audio.tts.models.higgs_audio.config import HiggsTextConfig + + return HiggsAudioConfig( + text_config=HiggsTextConfig( + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, + vocab_size=256, + rope_theta=10000.0, + rms_norm_eps=1e-5, + tie_word_embeddings=True, + rope_scaling=None, + ), + audio_num_codebooks=4, + audio_codebook_size=16, + audio_stream_bos_id=16, + audio_stream_eos_id=17, + audio_dual_ffn_layers=[0, 1], + use_audio_out_self_attention=False, + audio_decoder_proj_num_layers=0, + use_delay_pattern=True, + ) + + +class TestDelayPattern(unittest.TestCase): + """Delay-pattern round-trip and structural invariants.""" + + def test_build_delay_pattern_mask_shape(self): + K, L = 4, 5 + x = mx.arange(K * L, dtype=mx.int32).reshape(K, L) + out = build_delay_pattern_mask(x, bos_token_id=99, pad_token_id=88) + self.assertEqual(out.shape, (K, L + K - 1)) + + def test_build_delay_pattern_mask_triangles(self): + """Lower triangle is BOS, upper triangle is EOS, middle is shifted content.""" + K, L = 3, 4 + x = mx.arange(1, K * L + 1, dtype=mx.int32).reshape(K, L) # nonzero content + out = build_delay_pattern_mask(x, bos_token_id=-1, pad_token_id=-2) + out_np = np.array(out) + + # Row 0: full content + EOS tail + np.testing.assert_array_equal(out_np[0], [1, 2, 3, 4, -2, -2]) + # Row 1: one BOS + content + one EOS + np.testing.assert_array_equal(out_np[1], [-1, 5, 6, 7, 8, -2]) + # Row 2: two BOS + content + np.testing.assert_array_equal(out_np[2], [-1, -1, 9, 10, 11, 12]) + + def test_revert_reverses_apply(self): + """revert_delay_pattern should be the inverse of apply_delay_pattern on the content band.""" + K, L = 4, 6 + content = mx.arange(1, K * L + 1, dtype=mx.int32).reshape(K, L) + delayed = apply_delay_pattern(content, bos_id=0) # [K, L+K-1] + reverted = revert_delay_pattern(delayed) + # After revert, shape is [K, L+K-1-K+1] = [K, L] + self.assertEqual(reverted.shape, (K, L)) + np.testing.assert_array_equal(np.array(reverted), np.array(content)) + + +class TestAudioEmbedding(unittest.TestCase): + """lookup_audio_embedding correctly offsets per codebook and sums.""" + + def test_lookup_shape_and_sum(self): + import mlx.nn as nn + + K = 4 + C_plus2 = 10 + hidden = 8 + T = 3 + + emb = nn.Embedding(K * C_plus2, hidden) + ids = mx.zeros((K, T), dtype=mx.int32) # all codebooks emit token 0 + out = lookup_audio_embedding(emb, ids, C_plus2) + self.assertEqual(out.shape, (T, hidden)) + + # Per-codebook row 0 uses embedding[0], row 1 uses embedding[C_plus2], etc. + # Sum across K codebooks at one timestep equals sum of those K embeddings. + expected_sum = mx.sum( + emb(mx.arange(K, dtype=mx.int32) * C_plus2), axis=0 + ) # [hidden] + # out[0] should equal expected_sum + np.testing.assert_allclose(np.array(out[0]), np.array(expected_sum), rtol=1e-5) + + +class TestSampling(unittest.TestCase): + """Sampling utilities produce correct shapes and respect temperature=0 = greedy.""" + + def test_greedy_picks_argmax(self): + # B=1, T=1, K=2, V=5 — craft logits with known argmax per codebook + logits = mx.array( + [[[[0.0, 1.0, 2.0, 0.5, 0.0], [3.0, 0.5, 0.2, 0.1, 0.0]]]], dtype=mx.float32 + ) + out = greedy_sample_audio(logits) + self.assertEqual(out.shape, (1, 1, 2)) + np.testing.assert_array_equal(np.array(out)[0, 0], [2, 0]) + + def test_sample_audio_zero_temp_equals_greedy(self): + logits = mx.random.normal(shape=(1, 1, 3, 8)) + g = greedy_sample_audio(logits) + s = sample_audio(logits, temperature=0.0) + np.testing.assert_array_equal(np.array(g), np.array(s)) + + def test_sample_audio_temperature_returns_valid_token(self): + mx.random.seed(42) + logits = mx.random.normal(shape=(1, 1, 3, 8)) + out = sample_audio(logits, temperature=0.7, top_p=0.95) + self.assertEqual(out.shape, (1, 1, 3)) + out_np = np.array(out) + self.assertTrue((out_np >= 0).all() and (out_np < 8).all()) + + +class TestHiggsAudioModel(unittest.TestCase): + """Model instantiates and forwards on a tiny synthetic config.""" + + def _tiny_config(self): + return _tiny_config() + + def test_forward_shapes(self): + cfg = self._tiny_config() + model = HiggsAudioModel(cfg) + mx.eval(model.parameters()) + + B, T = 1, 5 + input_ids = mx.zeros((B, T), dtype=mx.int32) + audio_out_mask = mx.zeros((B, T), dtype=mx.bool_) + text_logits, audio_logits = model( + input_ids=input_ids, audio_out_mask=audio_out_mask + ) + self.assertEqual(text_logits.shape, (B, T, cfg.text_config.vocab_size)) + self.assertEqual( + audio_logits.shape, + (B, T, cfg.audio_num_codebooks, cfg.audio_codebook_size + 2), + ) + + def test_forward_text_only_no_audio_logits(self): + cfg = self._tiny_config() + model = HiggsAudioModel(cfg) + mx.eval(model.parameters()) + + B, T = 1, 4 + input_ids = mx.zeros((B, T), dtype=mx.int32) + text_logits, audio_logits = model(input_ids=input_ids, audio_out_mask=None) + self.assertEqual(text_logits.shape, (B, T, cfg.text_config.vocab_size)) + self.assertIsNone(audio_logits) + + +class TestQuantizedLoad(unittest.TestCase): + """Verify that a config.json quantization block triggers nn.quantize on the + model skeleton before weight-loading in HiggsAudioServer.from_pretrained.""" + + def test_quantize_predicate_skips_protected_layers(self): + """The skip-list applied by from_pretrained must leave protected + layers (audio_codebook_embeddings, audio_decoder_proj.audio_lm_head) + as plain nn.Linear / nn.Embedding — only the Llama backbone quantizes.""" + import mlx.nn as nn + + from mlx_audio.tts.models.higgs_audio.config import ( + HiggsAudioConfig, + HiggsTextConfig, + ) + from mlx_audio.tts.models.higgs_audio.higgs_audio import HiggsAudioModel + + cfg = HiggsAudioConfig( + text_config=HiggsTextConfig( + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, + vocab_size=256, + rope_theta=10000.0, + rms_norm_eps=1e-5, + tie_word_embeddings=True, + rope_scaling=None, + ), + audio_num_codebooks=4, + audio_codebook_size=16, + audio_stream_bos_id=16, + audio_stream_eos_id=17, + audio_dual_ffn_layers=[0, 1], + use_audio_out_self_attention=False, + audio_decoder_proj_num_layers=0, + use_delay_pattern=True, + ) + model = HiggsAudioModel(cfg) + mx.eval(model.parameters()) + + skip = {"audio_codebook_embeddings", "audio_decoder_proj.audio_lm_head"} + + def predicate(name, module): + if not isinstance(module, (nn.Linear, nn.Embedding)): + return False + return not any(s in name for s in skip) + + nn.quantize(model, group_size=64, bits=8, class_predicate=predicate) + + # Protected: audio_codebook_embeddings remains a plain Embedding. + self.assertIsInstance(model.audio_codebook_embeddings, nn.Embedding) + self.assertNotIsInstance(model.audio_codebook_embeddings, nn.QuantizedEmbedding) + + # Protected: audio_lm_head remains a plain Linear. + self.assertIsInstance(model.audio_decoder_proj.audio_lm_head, nn.Linear) + self.assertNotIsInstance( + model.audio_decoder_proj.audio_lm_head, nn.QuantizedLinear + ) + + # Quantized: text_lm_head becomes QuantizedLinear. + self.assertIsInstance(model.audio_decoder_proj.text_lm_head, nn.QuantizedLinear) + + +class TestFrameworkInterface(unittest.TestCase): + """Model + ModelConfig conform to the mlx_audio.tts.utils.load convention.""" + + def test_modelconfig_aliases_higgs_audio_config(self): + from mlx_audio.tts.models.higgs_audio import ModelConfig + + self.assertIs(ModelConfig, HiggsAudioConfig) + + def test_modelconfig_from_dict(self): + from mlx_audio.tts.models.higgs_audio import ModelConfig + + cfg = ModelConfig.from_dict({}) # permissive defaults + self.assertIsInstance(cfg, HiggsAudioConfig) + self.assertEqual(cfg.audio_num_codebooks, 8) + + def test_model_subclasses_higgs_audio_model(self): + """Subclassing (not wrapping) keeps safetensors key paths unchanged.""" + from mlx_audio.tts.models.higgs_audio import Model + + self.assertTrue(issubclass(Model, HiggsAudioModel)) + + def test_model_exposes_sample_rate(self): + from mlx_audio.tts.models.higgs_audio import Model + + cfg = _tiny_config() + model = Model(cfg) + self.assertEqual(model.sample_rate, 24000) + + def test_model_quant_predicate_protects_audio_head(self): + """Model.model_quant_predicate must skip audio_codebook_embeddings and + audio_decoder_proj.audio_lm_head (voice-character preservation).""" + import mlx.nn as nn + + from mlx_audio.tts.models.higgs_audio import Model + + cfg = _tiny_config() + model = Model(cfg) + + dummy_lin = nn.Linear(4, 4) + dummy_emb = nn.Embedding(4, 4) + # Protected names → False (skip quantization) + self.assertFalse( + model.model_quant_predicate("audio_codebook_embeddings", dummy_emb) + ) + self.assertFalse( + model.model_quant_predicate("audio_decoder_proj.audio_lm_head", dummy_lin) + ) + # Unprotected names → True (quantize) + self.assertTrue( + model.model_quant_predicate("layers.0.mlp.gate_proj", dummy_lin) + ) + self.assertTrue( + model.model_quant_predicate("audio_decoder_proj.text_lm_head", dummy_lin) + ) + + def test_model_generate_requires_loaded_codec(self): + """Calling generate() before post_load_hook ran must raise a clear error.""" + from mlx_audio.tts.models.higgs_audio import Model + + cfg = _tiny_config() + model = Model(cfg) + # Realize the iterator so the guard actually runs. + with self.assertRaises(RuntimeError): + next(iter(model.generate("hello"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/uv.lock b/uv.lock index bfc4e85ea..1f809979e 100644 --- a/uv.lock +++ b/uv.lock @@ -1334,6 +1334,7 @@ all = [ { name = "mistral-common", extra = ["audio"] }, { name = "num2words" }, { name = "phonemizer-fork" }, + { name = "pydub" }, { name = "sentencepiece" }, { name = "setuptools" }, { name = "spacy" }, @@ -1379,6 +1380,7 @@ tts = [ { name = "misaki" }, { name = "num2words" }, { name = "phonemizer-fork" }, + { name = "pydub" }, { name = "sentencepiece" }, { name = "spacy" }, { name = "tiktoken" }, @@ -1416,6 +1418,8 @@ requires-dist = [ { name = "phonemizer-fork", marker = "extra == 'tts'", specifier = ">=3.3.2" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0" }, { name = "protobuf", specifier = ">=6.33.5" }, + { name = "pydub", marker = "extra == 'all'", specifier = ">=0.25.1" }, + { name = "pydub", marker = "extra == 'tts'", specifier = ">=0.25.1" }, { name = "pyloudnorm", specifier = ">=0.2.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.0.0" }, @@ -2243,6 +2247,15 @@ pycountry = [ { name = "pycountry" }, ] +[[package]] +name = "pydub" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326, upload-time = "2021-03-10T02:09:54.659Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327, upload-time = "2021-03-10T02:09:53.503Z" }, +] + [[package]] name = "pygments" version = "2.19.2"