Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 68 additions & 53 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,97 +13,99 @@

@dataclass
class Segment:
"""Represents a segment of audio and its associated speaker and text."""

speaker: int
text: str
# (num_samples,), sample_rate = 24_000
audio: torch.Tensor


def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
def load_llama3_tokenizer() -> AutoTokenizer:
"""Load and configure the Llama-3 tokenizer with custom post-processing."""
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer '{tokenizer_name}': {e}")
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
special_tokens=[
(f"{bos}", tokenizer.bos_token_id),
(f"{eos}", tokenizer.eos_token_id),
],
)

return tokenizer


class Generator:
"""Main class for generating audio from text and context segments."""

def __init__(
self,
model: Model,
):
"""Initialize the generator with a model and required tokenizers."""
self._model = model
self._model.setup_caches(1)

self._text_tokenizer = load_llama3_tokenizer()

device = next(model.parameters()).device
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
try:
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
except Exception as e:
raise RuntimeError(f"Failed to load audio tokenizer: {e}")
mimi.set_num_codebooks(32)
self._audio_tokenizer = mimi

self._watermarker = load_watermarker(device=device)

self.sample_rate = mimi.sample_rate
self.device = device

def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
def _tokenize_text_segment(
self, text: str, speaker: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Tokenize a text segment for a given speaker."""
frame_tokens = []
frame_masks = []

text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
text_frame = torch.zeros(len(text_tokens), 33).long()
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
text_frame[:, -1] = torch.tensor(text_tokens)
text_frame_mask[:, -1] = True

frame_tokens.append(text_frame.to(self.device))
frame_masks.append(text_frame_mask.to(self.device))

return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)

def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Tokenize an audio tensor."""
assert audio.ndim == 1, "Audio must be single channel"

frame_tokens = []
frame_masks = []

# (K, T)
audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)

audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :-1] = True

frame_tokens.append(audio_frame)
frame_masks.append(audio_frame_mask)

return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)

def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
(seq_len, 33), (seq_len, 33)
"""
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
"""Tokenize a Segment object into text and audio tokens."""
text_tokens, text_masks = self._tokenize_text_segment(
segment.text, segment.speaker
)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)

return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat(
[text_masks, audio_masks], dim=0
)

@torch.inference_mode()
def generate(
Expand All @@ -115,62 +117,75 @@ def generate(
temperature: float = 0.9,
topk: int = 50,
) -> torch.Tensor:
"""Generate audio for the given text, speaker, and context."""
self._model.reset_caches()

max_generation_len = int(max_audio_length_ms / 80)
tokens, tokens_mask = [], []
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)

gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
text, speaker
)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)

prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)

samples = []
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)

curr_pos = (
torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
)
max_seq_len = 2048
max_context_len = max_seq_len - max_generation_len
if curr_tokens.size(1) >= max_context_len:
raise ValueError(
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
)

for _ in range(max_generation_len):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
sample = self._model.generate_frame(
curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
)
if torch.all(sample == 0):
break # eos

samples.append(sample)

curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens = torch.cat(
[sample, torch.zeros(1, 1).long().to(self.device)], dim=1
).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
[
torch.ones_like(sample).bool(),
torch.zeros(1, 1).bool().to(self.device),
],
dim=1,
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1

audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)

audio = (
self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0))
.squeeze(0)
.squeeze(0)
)
# This applies an imperceptible watermark to identify audio as AI-generated.
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
# Please be a responsible AI citizen and keep the watermarking in place.
# If using CSM 1B in another application, use your own private key and keep it secret.
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)

audio, wm_sample_rate = watermark(
self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK
)
audio = torchaudio.functional.resample(
audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate
)
return audio


def load_csm_1b(device: str = "cuda") -> Generator:
model = Model.from_pretrained("sesame/csm-1b")
model.to(device=device, dtype=torch.bfloat16)

"""Load the CSM 1B model and return a Generator instance."""
try:
model = Model.from_pretrained("sesame/csm-1b")
model.to(device=device, dtype=torch.bfloat16)
except Exception as e:
raise RuntimeError(f"Failed to load CSM 1B model: {e}")
generator = Generator(model)
return generator
return generator
61 changes: 46 additions & 15 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
return r


def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
def _multinomial_sample_one_no_sync(
probs,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)

Expand Down Expand Up @@ -107,15 +109,27 @@ def __init__(self, config: ModelArgs):
super().__init__()
self.config = config

self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
self.backbone, backbone_dim = _prepare_transformer(
FLAVORS[config.backbone_flavor]()
)
self.decoder, decoder_dim = _prepare_transformer(
FLAVORS[config.decoder_flavor]()
)

self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
self.audio_embeddings = nn.Embedding(
config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
)

self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
self.codebook0_head = nn.Linear(
backbone_dim, config.audio_vocab_size, bias=False
)
self.audio_head = nn.Parameter(
torch.empty(
config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
)
)

def setup_caches(self, max_batch_size: int) -> torch.Tensor:
"""Setup KV caches and return a causal mask."""
Expand All @@ -124,10 +138,20 @@ def setup_caches(self, max_batch_size: int) -> torch.Tensor:

with device:
self.backbone.setup_caches(max_batch_size, dtype)
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
self.decoder.setup_caches(
max_batch_size,
dtype,
decoder_max_seq_len=self.config.audio_num_codebooks,
)

self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
self.register_buffer(
"backbone_causal_mask",
_create_causal_mask(self.backbone.max_seq_len, device),
)
self.register_buffer(
"decoder_causal_mask",
_create_causal_mask(self.config.audio_num_codebooks, device),
)

def generate_frame(
self,
Expand Down Expand Up @@ -155,7 +179,9 @@ def generate_frame(
embeds = self._embed_tokens(tokens)
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(
dtype=dtype
)

last_h = h[:, -1, :]
c0_logits = self.codebook0_head(last_h)
Expand All @@ -164,15 +190,19 @@ def generate_frame(

curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
curr_sample = c0_sample.clone()
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
curr_pos = (
torch.arange(0, curr_h.size(1), device=curr_h.device)
.unsqueeze(0)
.repeat(curr_h.size(0), 1)
)

# Decoder caches must be reset every frame.
self.decoder.reset_caches()
for i in range(1, self.config.audio_num_codebooks):
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
dtype=dtype
)
decoder_h = self.decoder(
self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
).to(dtype=dtype)
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
ci_sample = sample_topk(ci_logits, topk, temperature)
ci_embed = self._embed_audio(i, ci_sample)
Expand All @@ -194,7 +224,8 @@ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)

audio_tokens = tokens[:, :, :-1] + (
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
self.config.audio_vocab_size
* torch.arange(self.config.audio_num_codebooks, device=tokens.device)
)
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
Expand Down
Loading