diff --git a/.gitignore b/.gitignore index a83935fa..35a30644 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,8 @@ pip-delete-this-directory.txt _site/ .sass-cache/ .jekyll-cache/ -.jekyll-metadata \ No newline at end of file +.jekyll-metadata +# FT: +data/ +checkpoints/ +*.onnx \ No newline at end of file diff --git a/convert-to-jit.py b/convert-to-jit.py new file mode 100644 index 00000000..e93f4495 --- /dev/null +++ b/convert-to-jit.py @@ -0,0 +1,50 @@ +import os +import argparse +import torch +from models.forward_tacotron import ForwardTacotron + + +""" +Torchscript exporter for ⏩ ForwardTacotron +""" + + +# Declaring the convertor: +def run_convertor(model_path, save_path): + if not os.path.exists(model_path): + raise FileNotFoundError("Please give me an existing model!") + tts_model = ForwardTacotron.from_checkpoint(model_path) + tts_model.eval() + # Initialize a defined TTS model for torchscript in models/ForwardTacotron: + model_script = torch.jit.script(tts_model) + # Generate input for testing: + x = torch.ones((1, 5)).long() + # Try generating this input: + y = model_script.generate_jit(x) + if save_path is None: + save_path = model_path[:-3]+".ts" + # Finally, we export it: + torch.jit.save( + model_script, + save_path + ) + print("Model successfully converted to torchscript.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="TorchScript convertor for ⏩ForwardTacotron") + parser.add_argument( + '--checkpoint_path', + '-c', + required=True, + type=str, + help='The full checkpoint (*.pt) file to convert.' + ) + parser.add_argument( + '--output_path', + '-o', + default=None, + type=str, + help='Output path to save the converted TorchScript model.' + ) + args = parser.parse_args() + run_convertor(args.checkpoint_path, args.output_path) \ No newline at end of file diff --git a/convert-to-onnx.py b/convert-to-onnx.py new file mode 100644 index 00000000..4e9bbb6d --- /dev/null +++ b/convert-to-onnx.py @@ -0,0 +1,91 @@ +import os +import argparse +import torch +from models.forward_tacotron import ForwardTacotron +from utils.text.symbols import phonemes + +""" +ONNX convertor for ⏩ ForwardTacotron +Lately, ONNX stuff for TTS models is popular, because these models provides a faster inference than the full PyTorch models. Faster inference is good to use these models, for example, in a screen reader. Also, ONNX models can be used in a multi-platform/system way such as IOS, Android Phone devices, etc. +The onnx compatibility has been fixed by Matthew C. (rmcpantoja). +""" + +# ======================global vars====================== +OPSET = 17 +SEED = 1234 +# ======================end global vars====================== + +# Declaring the convertor: +def run_convertor(model_path, save_path): + if not os.path.exists(model_path): + raise FileNotFoundError("Please give me an existing model!") + tts_model = ForwardTacotron.from_checkpoint(model_path) + tts_model.eval() + # Configure seed: + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + # We create the custom generate() function to return the mel post only and acomodate synthesizer options into a list array: + def custom_generate(text, synth_options): + alpha = synth_options[0] + pitch = synth_options[1] + energy = synth_options[2] + # Todo: try inferencing this pitch/energy with an ONNX model: + pitch_function = lambda x: x * pitch + energy_function = lambda x: x * energy + infer = tts_model.generate( + text, + alpha=alpha, + pitch_function=pitch_function, + energy_function=energy_function, + onnx=True + ) + mel = infer['mel_post'] + return mel + # We replace the forward function to the created one: + tts_model.forward = custom_generate + # We set the inputs and outputs for the ONNX model: + dummy_input_length = 50 + rand = torch.randint(low=0, high=len(phonemes), size=(1, dummy_input_length), dtype=torch.long) + synth_inputs = torch.FloatTensor( + [1.0, 1.0, 1.0] # Alpha, pitch, energy + ) + model_inputs = (rand, synth_inputs) + input_names = [ + "input", + "synth_options" + ] + if save_path is None: + save_path = model_path[:-3]+".onnx" + # Finally, we export it: + torch.onnx.export( + model = tts_model, + args = model_inputs, + f = save_path, + opset_version=OPSET, + input_names=input_names, + output_names=['output'], + dynamic_axes = { + "input": {0: "batch_size", 1: "text"}, + "output": {0: "batch_size", 1: "time"} + } + ) + print("Checkpoint successfully converted to ONNX.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Onnx conversor for ⏩ForwardTacotron") + parser.add_argument( + '--checkpoint_path', + '-c', + required=True, + type=str, + help='The full checkpoint (*.pt) file to convert.' + ) + parser.add_argument( + '--output_path', + '-o', + default=None, + type=str, + help='Output path to save the converted ONNX model.' + ) + args = parser.parse_args() + run_convertor(args.checkpoint_path, args.output_path) \ No newline at end of file diff --git a/gen_forward_onnx.py b/gen_forward_onnx.py new file mode 100644 index 00000000..5ded2619 --- /dev/null +++ b/gen_forward_onnx.py @@ -0,0 +1,90 @@ +import argparse +from pathlib import Path +import numpy as np +import onnxruntime +from utils.display import simple_table +from utils.dsp import DSP +from utils.files import read_config +from utils.paths import Paths +from utils.text.cleaners import Cleaner +from utils.text.tokenizer import Tokenizer + + +if __name__ == '__main__': + + # Parse Arguments + parser = argparse.ArgumentParser(description='TTS Generator') + parser.add_argument('--input_text', '-i', default=None, type=str, help='[string] Type in something here and TTS will generate it!') + parser.add_argument('--checkpoint', type=str, default=None, help='[string/path] path to .pt model file.') + parser.add_argument('--config', metavar='FILE', default='default.yaml', help='The config containing all hyperparams.') + parser.add_argument('--speaker', type=str, default=None, help='Speaker to generate audio for (only multispeaker).') + + parser.add_argument('--alpha', type=float, default=1., help='Parameter for controlling length regulator for speedup ' + 'or slow-down of generated speech, e.g. alpha=2.0 is double-time') + parser.add_argument('--amp', type=float, default=1., help='Parameter for controlling pitch amplification') + + # name of subcommand goes to args.vocoder + subparsers = parser.add_subparsers(dest='vocoder') + gl_parser = subparsers.add_parser('griffinlim') + mg_parser = subparsers.add_parser('melgan') + hg_parser = subparsers.add_parser('hifigan') + + args = parser.parse_args() + + assert args.vocoder in {'griffinlim', 'melgan', 'hifigan'}, \ + 'Please provide a valid vocoder! Choices: [griffinlim, melgan, hifigan]' + + checkpoint_path = args.checkpoint + if checkpoint_path is None: + config = read_config(args.config) + paths = Paths(config['data_path'], config['tts_model_id']) + checkpoint_path = paths.forward_checkpoints / 'latest_model.onnx' + sess_options = onnxruntime.SessionOptions() + checkpoint = onnxruntime.InferenceSession(str(checkpoint_path), sess_options=sess_options) + config = read_config(args.config) + dsp = DSP.from_config(config) + + voc_model, voc_dsp = None, None + out_path = Path('model_outputs') + out_path.mkdir(parents=True, exist_ok=True) + cleaner = Cleaner.from_config(config) + tokenizer = Tokenizer() + + if args.input_text: + texts = [args.input_text] + else: + with open('sentences.txt', 'r', encoding='utf-8') as f: + texts = f.readlines() + + pitch_function = lambda x: x * args.amp + energy_function = lambda x: x + + for i, x in enumerate(texts, 1): + print(f'\n| Generating {i}/{len(texts)}') + text = x + x = cleaner(x) + x = tokenizer(x) + text = np.expand_dims(np.array(x, dtype=np.int64), 0) + synth_options = np.array( + [args.alpha, 1.0, 1.0], + dtype=np.float32, + ) + speaker_name = args.speaker if args.speaker is not None else 'default_speaker' + wav_name = f"test{i}" + m = checkpoint.run( + None, + { + "input": text, + "synth_options": synth_options, + }, + )[0] + #m = (m * 32767).astype(np.int16) + if args.vocoder == 'melgan': + torch.save(m, out_path / f'{wav_name}.mel') + if args.vocoder == 'hifigan': + np.save(str(out_path / f'{wav_name}.npy'), m, allow_pickle=False) + elif args.vocoder == 'griffinlim': + wav = dsp.griffinlim(m) + dsp.save_wav(wav, out_path / f'{wav_name}.wav') + + print('\n\nDone.\n') \ No newline at end of file diff --git a/gen_forward_onnx_benchmark.py b/gen_forward_onnx_benchmark.py new file mode 100644 index 00000000..ba0d9060 --- /dev/null +++ b/gen_forward_onnx_benchmark.py @@ -0,0 +1,63 @@ +import argparse +from pathlib import Path +import numpy as np +import onnxruntime +from utils.files import read_config +from utils.dsp import DSP +from utils.paths import Paths +from utils.text.cleaners import Cleaner +from utils.text.tokenizer import Tokenizer +import time + +if __name__ == '__main__': + + # Parse Arguments + parser = argparse.ArgumentParser(description='TTS Generator') + parser.add_argument('--checkpoint', type=str, default=None, help='[string/path] path to .onnx model file.') + parser.add_argument('--config', metavar='FILE', default='default.yaml', help='The config containing all hyperparams.') + parser.add_argument('--speaker', type=str, default=None, help='Speaker to generate audio for (only multispeaker).') + + args = parser.parse_args() + + checkpoint_path = args.checkpoint + if checkpoint_path is None: + config = read_config(args.config) + paths = Paths(config['data_path'], config['tts_model_id']) + checkpoint_path = paths.forward_checkpoints / 'latest_model.onnx' + sess_options = onnxruntime.SessionOptions() + checkpoint = onnxruntime.InferenceSession(str(checkpoint_path), sess_options=sess_options) + config = read_config(args.config) + dsp = DSP.from_config(config) + cleaner = Cleaner.from_config(config) + tokenizer = Tokenizer() + + with open('sentences.txt', 'r', encoding='utf-8') as f: + texts = f.readlines() + for i, x in enumerate(texts, 1): + print(f'\n| Generating {i}/{len(texts)}') + text = x + x = cleaner(x) + x = tokenizer(x) + text = np.expand_dims(np.array(x, dtype=np.int64), 0) + synth_options = np.array( + [1.0, 1.0, 1.0], + dtype=np.float32, + ) + speaker_name = args.speaker if args.speaker is not None else 'default_speaker' + start_time = time.perf_counter() + m = checkpoint.run( + None, + { + "input": text, + "synth_options": synth_options, + }, + )[0] + end_time = time.perf_counter() + mel_length = m.shape[-1] + spec_length = mel_length * dsp.hop_length + spec_sec = spec_length / dsp.sample_rate + infer_sec = (end_time - start_time) + rtf = infer_sec / spec_sec*1000 + print(f"Sentence {i} generation time: {infer_sec} MS, RTF: {rtf} MS.") + + print('\n\nDone.\n') \ No newline at end of file diff --git a/models/common_layers.py b/models/common_layers.py index ac56b3fd..13185ed8 100644 --- a/models/common_layers.py +++ b/models/common_layers.py @@ -6,8 +6,24 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import LayerNorm, MultiheadAttention -from torch.nn.utils.rnn import pad_sequence +#from torch.nn.utils.rnn import pad_sequence +class CustomPadSequence(nn.Module): + def __init__(self, padding_value=0.): + super().__init__() + self.padding_value = padding_value + + def forward(self, sequences): + # Find the maximum length in the sequences + max_length = max(len(seq) for seq in sequences) + + # Pad sequences with the specified padding value + padded_sequences = [F.pad(seq, (0, max_length - len(seq)), value=self.padding_value) for seq in sequences] + + # Stack the padded sequences + padded_sequences = torch.stack(padded_sequences, dim=0) + + return padded_sequences class LengthRegulator(nn.Module): @@ -23,6 +39,22 @@ def forward(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor: x_expanded = pad_sequence(x_expanded, padding_value=0., batch_first=True) return x_expanded +class LengthRegulator_onnx(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor: + dur[dur < 0] = 0. + x_expanded = [] + for i in range(x.size(0)): + x_exp = torch.repeat_interleave(x[i], (dur[i] + 0.5).long(), dim=0) + x_expanded.append(x_exp) + customPadSequence = CustomPadSequence(padding_value=0.) + x_expanded = customPadSequence(x_expanded) + return x_expanded + + class HighwayNetwork(nn.Module): diff --git a/models/forward_tacotron.py b/models/forward_tacotron.py index 2069370e..7d962cbf 100644 --- a/models/forward_tacotron.py +++ b/models/forward_tacotron.py @@ -7,7 +7,7 @@ from torch.nn import Embedding from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -from models.common_layers import CBHG, LengthRegulator, BatchNormConv +from models.common_layers import CBHG, LengthRegulator, LengthRegulator_onnx, BatchNormConv from utils.text.symbols import phonemes @@ -71,7 +71,8 @@ def __init__(self, self.rnn_dims = rnn_dims self.padding_value = padding_value self.embedding = nn.Embedding(num_chars, embed_dims) - self.lr = LengthRegulator() + #self.lr = LengthRegulator() + self.lr_onnx = LengthRegulator_onnx() self.dur_pred = SeriesPredictor(num_chars=num_chars, emb_dim=series_embed_dims, conv_dims=durpred_conv_dims, @@ -168,7 +169,8 @@ def generate(self, x: torch.Tensor, alpha=1.0, pitch_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, - energy_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x) -> Dict[str, torch.Tensor]: + energy_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + onnx: bool = False) -> Dict[str, torch.Tensor]: self.eval() with torch.no_grad(): dur_hat = self.dur_pred(x, alpha=alpha) @@ -181,7 +183,8 @@ def generate(self, energy_hat = energy_function(energy_hat) return self._generate_mel(x=x, dur_hat=dur_hat, pitch_hat=pitch_hat, - energy_hat=energy_hat) + energy_hat=energy_hat, + onnx=onnx) @torch.jit.export def generate_jit(self, @@ -206,7 +209,8 @@ def _generate_mel(self, x: torch.Tensor, dur_hat: torch.Tensor, pitch_hat: torch.Tensor, - energy_hat: torch.Tensor) -> Dict[str, torch.Tensor]: + energy_hat: torch.Tensor, + onnx: bool = False) -> Dict[str, torch.Tensor]: x = self.embedding(x) x = x.transpose(1, 2) x = self.prenet(x) @@ -218,9 +222,10 @@ def _generate_mel(self, energy_proj = self.energy_proj(energy_hat) energy_proj = energy_proj.transpose(1, 2) x = x + energy_proj * self.energy_strength - - x = self.lr(x, dur_hat) - + if not onnx: + x = self.lr(x, dur_hat) + else: + x = self.lr_onnx(x, dur_hat) x, _ = self.lstm(x) x = self.lin(x)