-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training support? #39
Comments
You can find details for data creation and training in the examples: https://github.com/edwko/OuteTTS/tree/main/examples/v1. This should give you an idea of how to get started. To add a new language, you’ll need to verify if it processes text correctly in these two functions: Adding a completely new language might not be very straightforward and would require a significant amount of high-quality data to train effectively. |
I face this problem all the time:
I modify interface.py to add my language. |
You don't strictly need to use parquet, you can convert to anything you want. Also "HF or torchtune" should probably be used more for fine-tuning. The NaN value can happen if there's an issue with how you load the model or your data. Could also be a hardware issue, hard to say. Also in your code you passed Try this as a test and see what loss you get (loss from this example should be around 8), you can also pass one of your training samples just shift the tgt by 1 and test it. import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(
"OuteAI/OuteTTS-0.2-500M",
torch_dtype=torch.bfloat16
).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, betas=(0.9, 0.95))
# Create dummy batch
batch_size = 1
seq_length = 100
src = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
tgt = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
print("Input tensor ->", src.size())
# Training step
optimizer.zero_grad()
logits = model(src).logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
print("Loss ->", round(loss.item(), 6)) |
I would be very grateful if you could provide a trainer for the new languages. I mean the full Pipeline. I have 22hours of good audio, transcriptions for it, resources for training. but Im stuck on preparing data |
While I can't provide you with the internal training library, for quick start on pre-training, I'd suggest you could use this trainer: https://github.com/karpathy/nanoGPT/blob/master/train.py |
Have anybody replicated the training procedure? |
prepare data from dataset (dataset/wavs/*.wav & dataset/metadata.csv (file.wav|text)): import os import csv import polars as pl import torchaudio import torch from tqdm import tqdm from loguru import logger # Import necessary OuteTTS modules from outetts.wav_tokenizer.audio_codec import AudioCodec from outetts.version.v2.alignment import CTCForcedAlignment from outetts.version.v2.prompt_processor import PromptProcessor class DataPreparation: def __init__(self, wavs_dir, metadata_file, save_dir, save_len=5000, model_tokenizer_path="OuteAI/OuteTTS-0.3-1B"): self.wavs_dir = wavs_dir self.metadata_file = metadata_file self.save_dir = save_dir self.save_len = save_len self.data = [] self.save_id = 0 self.device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the audio codec (set load_decoder=False for data creation) self.audio_codec = AudioCodec(device=self.device, load_decoder=False) # The prompt processor will be used to generate full training prompts self.prompt_processor = PromptProcessor(model_tokenizer_path) # Set up CTC forced alignment self.ctc = CTCForcedAlignment(self.device) def create_speaker(self, wav_path, transcript): # Load the audio (if stereo, average channels) waveform, sr = torchaudio.load(wav_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) transcript_clean = transcript.strip() # The alignment method expects a file path (it will load the file itself) speaker_words = self.ctc.align(wav_path, transcript_clean) # Concatenate the audio segments for all words try: concatenated = torch.cat([w["audio"] for w in speaker_words], dim=1) except Exception as e: logger.error(f"Error concatenating audio for file {wav_path}: {e}") raise e # Convert and encode the concatenated audio converted_audio = self.audio_codec.convert_audio_tensor( audio=concatenated, sr=self.ctc.sample_rate ).to(self.audio_codec.device) full_codes = self.audio_codec.encode(converted_audio).tolist() data_words = [] start = 0 for word in speaker_words: # Map the word’s end position (x1) to token indices (assume 75 tokens per second) end = int(round((word["x1"] / self.ctc.sample_rate) * 75)) try: word_tokens = full_codes[0][0][start:end] except Exception: word_tokens = [1] start = end if not word_tokens: word_tokens = [1] data_words.append({ "word": word["word"], "duration": round(len(word_tokens) / 75, 2), "codes": word_tokens }) speaker_profile = { "text": transcript_clean, "words": data_words } return speaker_profile def process(self): # Read metadata CSV (each line: filename|text) with open(self.metadata_file, "r", encoding="utf-8") as f: reader = csv.reader(f, delimiter="|") rows = list(reader) total = len(rows) for row in tqdm(rows, total=total, desc="Processing audio files"): if len(row) < 2: continue filename, text = row[0].strip(), row[1].strip() wav_path = os.path.join(self.wavs_dir, filename) if not os.path.exists(wav_path): logger.error(f"WAV file {wav_path} does not exist.") continue try: speaker = self.create_speaker(wav_path, text) # Generate the full training prompt (text + encoded audio) prompt = self.prompt_processor.get_training_prompt(speaker) self.data.append({"prompt": prompt}) except Exception as e: logger.error(f"Error processing {wav_path}: {e}") if len(self.data) >= self.save_len: self.save_data() if self.data: self.save_data() # Free the CTC model resources self.ctc.free() def save_data(self): os.makedirs(self.save_dir, exist_ok=True) save_path = os.path.join(self.save_dir, f"{self.save_id:06d}.parquet") logger.info(f"Saving data to {save_path}") pl.DataFrame(self.data).write_parquet(save_path) self.data = [] self.save_id += 1 if __name__ == "__main__": wavs_dir = "dataset/wavs" metadata_file = "dataset/metadata.csv" save_dir = "prepared_data" # For demonstration you may use a smaller batch size (adjust save_len as needed) dp = DataPreparation(wavs_dir, metadata_file, save_dir, save_len=1000) dp.process() train: import glob import os import torch import polars as pl from loguru import logger import datasets from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling def load_dataset_from_parquet(folder): prompts = [] for file in glob.glob(os.path.join(folder, "*.parquet")): try: df = pl.read_parquet(file).to_pandas() prompts.extend(df["prompt"].tolist()) except Exception as e: logger.error(f"Error reading {file}: {e}") return datasets.Dataset.from_dict({"text": prompts}) def main(): # Specify the model name (use your fine-tuned model’s path if available) model_name = "OuteAI/OuteTTS-0.3-1B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # Load the dataset created in the previous step dataset = load_dataset_from_parquet("prepared_data") # Split the dataset: 90% training, 10% evaluation dataset = dataset.train_test_split(test_size=0.1) # Tokenize the text (prompts) with a maximum length – adjust as necessary def tokenize_function(examples): return tokenizer(examples["text"], truncation=True, max_length=4096) tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) # Data collator (no masking since this is causal LM training) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir="fine_tuned_model", overwrite_output_dir=True, num_train_epochs=1, per_device_train_batch_size=2, per_device_eval_batch_size=2, evaluation_strategy="epoch", save_strategy="epoch", logging_steps=10, learning_rate=5e-5, weight_decay=0.01, report_to="none", save_total_limit=2, fp16=torch.cuda.is_available() ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], data_collator=data_collator, ) trainer.train() # Save the fine-tuned model and tokenizer for later inference model.save_pretrained("fine_tuned_model") tokenizer.save_pretrained("fine_tuned_model") if __name__ == "__main__": main() test: import outetts from outetts.models.config import GenerationConfig import torch def main(): # Configure the model. # If you fine-tuned your model, use the "fine_tuned_model" folder; otherwise, use the original model name. model_config = outetts.HFModelConfig_v2( model_path="fine_tuned_model", # or "OuteAI/OuteTTS-0.3-1B" tokenizer_path="fine_tuned_model", # same value as above max_seq_length=4096 ) # Initialize the Hugging Face interface for OuteTTS v0.3 interface = outetts.InterfaceHF(model_version="0.3", cfg=model_config) # Prepare a generation configuration. # Here no speaker profile is provided – for synthesis, you can also load or create a speaker. gen_cfg = GenerationConfig( text="Speech synthesis is the artificial production of human speech.", temperature=0.1, repetition_penalty=1.1, max_length=4096, speaker=None ) # Generate speech and obtain the model’s output. output = interface.generate(config=gen_cfg) # Save the output audio to a WAV file. output.save("test_output.wav") print("Audio generated and saved as test_output.wav") # Optionally, to play the audio you could call: # output.play(backend="pygame") or output.play(backend="sounddevice") if __name__ == "__main__": main() somethin like that, i guess. @edwko, did I miss something? |
@paleicikas Nice! that should do the trick :) Also this part seems to be unused:
|
Hello! How I can train this model for another lang on my dataset?
The text was updated successfully, but these errors were encountered: