diff --git a/egs/ljspeech/voc1/conf/hifigan_hubert_large_km500_LJ_24khz_hopsize_480.yaml b/egs/ljspeech/voc1/conf/hifigan_hubert_large_km500_LJ_24khz_hopsize_480.yaml new file mode 100644 index 00000000..210fefc0 --- /dev/null +++ b/egs/ljspeech/voc1/conf/hifigan_hubert_large_km500_LJ_24khz_hopsize_480.yaml @@ -0,0 +1,177 @@ +# This is the configuration file for LJSpeech dataset for 24k sample rate. +# This configuration is based on HiFiGAN V1, derived +# from official repository (https://github.com/jik876/hifi-gan). + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sampling_rate: 24000 # Sampling rate. +fft_size: null # FFT size. +hop_size: 480 # Hop size. +win_length: null # Window length. + # If set to null, it will be the same as fft_size. +window: null # Window function. +num_mels: 2 # Number of mel basis. +fmin: null # Minimum freq in mel basis calculation. +fmax: null # Maximum frequency in mel basis calculation. +global_gain_scale: 1.0 # Will be multiplied to all of waveform. +trim_silence: false # Whether to trim the start and end of silence. +trim_threshold_in_db: 20 # Need to tune carefully if the recording is not good. +trim_frame_size: 1024 # Frame size in trimming. +trim_hop_size: 256 # Hop size in trimming. +format: "hdf5" # Feature file format. "npy" or "hdf5" is supported. + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_type: DiscreteSymbolHiFiGANGenerator +generator_params: + in_channels: 512 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + num_embs: 500 + num_spk_embs: 0 + spk_emb_dim: 512 + concat_spk_emb: false + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [12, 10, 2, 2] # Upsampling scales. + upsample_kernel_sizes: [24, 20, 4, 4] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: true # Whether to use additional conv layer in residual blocks. + bias: true # Whether to use bias parameter in conv. + nonlinear_activation: "LeakyReLU" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: true # Whether to apply weight normalization. + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_type: HiFiGANMultiScaleMultiPeriodDiscriminator +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1d" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernal sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: true + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: true # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernal sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: true # Whether to use bias parameter in conv layer." + nonlinear_activation: "LeakyReLU" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: true # Whether to apply weight normalization. + use_spectral_norm: false # Whether to apply spectral normalization. + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: false # Whether to use multi-resolution STFT loss. +use_mel_loss: true # Whether to use Mel-spectrogram loss. +mel_loss_params: # Mel-spectrogram loss parameters. + fs: 24000 + fft_size: 2048 + hop_size: 480 + win_length: null + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 8000 + log_base: null # Log base. If set to null, use natural logarithm. +generator_adv_loss_params: + average_by_discriminators: false # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: false # Whether to average loss by #discriminators. +use_feat_match_loss: true +feat_match_loss_params: + average_by_discriminators: false # Whether to average loss by #discriminators. + average_by_layers: false # Whether to average loss by #layers in each discriminator. + include_final_outputs: true # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 32 # Batch size. +batch_max_steps: 10240 # Length of each audio in batch. Make sure dividable by hop_size. +pin_memory: true # Whether to pin memory in Pytorch DataLoader. +num_workers: 2 # Number of workers in Pytorch DataLoader. +remove_short_samples: false # Whether to remove samples the length of which are less than batch_max_steps. +allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_type: Adam +generator_optimizer_params: + lr: 2.0e-4 + betas: [0.5, 0.9] + weight_decay: 0.0 +generator_scheduler_type: MultiStepLR +generator_scheduler_params: + gamma: 0.5 + milestones: + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 +discriminator_optimizer_type: Adam +discriminator_optimizer_params: + lr: 2.0e-4 + betas: [0.5, 0.9] + weight_decay: 0.0 +discriminator_scheduler_type: MultiStepLR +discriminator_scheduler_params: + gamma: 0.5 + milestones: + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 10000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. +log_interval_steps: 100 # Interval steps to record the training log. + +########################################################### +# OTHER SETTING # +########################################################### +num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. diff --git a/egs/ljspeech/voc1/local/preprocess_hubert.py b/egs/ljspeech/voc1/local/preprocess_hubert.py new file mode 100755 index 00000000..e0a830ee --- /dev/null +++ b/egs/ljspeech/voc1/local/preprocess_hubert.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Perform preprocessing and raw feature extraction.""" + +import argparse +import logging +import os + +import librosa +import numpy as np +import soundfile as sf +import resampy +import yaml + +from tqdm import tqdm + +from parallel_wavegan.datasets import AudioDataset +from parallel_wavegan.datasets import AudioSCPDataset +from parallel_wavegan.utils import write_hdf5 + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)." + ) + parser.add_argument( + "--wav-scp", + "--scp", + default=None, + type=str, + help="kaldi-style wav.scp file. you need to specify either scp or rootdir.", + ) + parser.add_argument( + "--segments", + default=None, + type=str, + help="kaldi-style segments file. if use, you must to specify both scp and segments.", + ) + parser.add_argument( + "--text", + default=None, + type=str, + help="kaldi-style text format hubert embedding index.", + ) + parser.add_argument( + "--utt2spk", + default=None, + type=str, + help="kaldi-style utt2spk file. If you want to add global conditionning with " + "speaker id, you need to specify this argument.", + ) + parser.add_argument( + "--spk2idx", + default=None, + type=str, + help="kaldi-style spk2idx file. If you want to add global conditionning with " + "speaker id, you need to specify this argument.", + ) + parser.add_argument( + "--rootdir", + default=None, + type=str, + help="directory including wav files. you need to specify either scp or rootdir.", + ) + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.", + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="yaml format configuration file.", + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # load config + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + config.update(vars(args)) + + # check arguments + if (args.wav_scp is not None and args.rootdir is not None) or ( + args.wav_scp is None and args.rootdir is None + ): + raise ValueError("Please specify either --rootdir or --wav-scp.") + + # get dataset + print("=====args.rootdir",args.rootdir) + if args.rootdir is not None: + dataset = AudioDataset( + args.rootdir, + "*.wav", + audio_load_fn=sf.read, + return_utt_id=True, + ) + else: + dataset = AudioSCPDataset( + args.wav_scp, + segments=args.segments, + return_utt_id=True, + return_sampling_rate=True, + ) + + # get text + print("text filr", args.text) + print("spk2idx", args.spk2idx) + + with open(args.text) as f: + lines = [line.strip() for line in f.readlines()] + text = {line.split(" ")[0]: line.split(" ")[1:-1] for line in lines} + + # load spk2utt file + if args.utt2spk is not None: + with open(args.utt2spk) as f: + lines = [l.replace("\n", "") for l in f.readlines()] + utt2spk = {l.split()[0]: l.split()[1] for l in lines} + if args.spk2idx is not None: + with open(args.spk2idx) as f: + lines = [l.replace("\n", "") for l in f.readlines()] + spk2idx = {l.split()[0]: int(l.split()[1]) for l in lines} + + # check directly existence + if not os.path.exists(args.dumpdir): + os.makedirs(args.dumpdir, exist_ok=True) + + # process each data + for utt_id, (audio, fs) in tqdm(dataset): + # check + if utt_id not in text.keys(): + continue + assert len(audio.shape) == 1, f"{utt_id} seems to be multi-channel signal." + assert ( + np.abs(audio).max() <= 1.0 + ), f"{utt_id} seems to be different from 16 bit PCM." + + # downsample + if fs != config["sampling_rate"]: + audio = resampy.resample(audio, fs, config["sampling_rate"], axis=0) + + # trim silence + if config["trim_silence"]: + audio, _ = librosa.effects.trim( + audio, + top_db=config["trim_threshold_in_db"], + frame_length=config["trim_frame_size"], + hop_length=config["trim_hop_size"], + ) + + # use hubert index instead of mel + mel = np.array(text[utt_id]).astype(np.int64).reshape(-1, 1) + + if args.spk2idx is not None: + spk = utt2spk[utt_id] + if spk in spk2idx: + idx = spk2idx[spk] + else: + logging.warn(f"{spk} is unknown speaker.") + max_idx = max(spk2idx.values()) + 1 + idx = max_idx + + # concatenate with mel + idx = np.repeat(np.array(idx).reshape(1, 1), len(mel), axis=0) + mel = np.concatenate([mel, idx], axis=1) + + # make sure the audio length and feature length are matched + logging.info(f"Mod: {len(audio) - len(mel) * config['hop_size']}") + assert len(mel) * config["hop_size"] <= len(audio), f"{utt_id}: {len(mel)}, {config['hop_size']}, {len(audio)}" + + audio = audio[: len(mel) * config["hop_size"]] + assert len(mel) * config["hop_size"] == len(audio) + + # apply global gain + if config["global_gain_scale"] > 0.0: + audio *= config["global_gain_scale"] + if np.abs(audio).max() >= 1.0: + logging.warn( + f"{utt_id} causes clipping. " + f"it is better to re-consider global gain scale." + ) + continue + + # save + if config["format"] == "hdf5": + write_hdf5( + os.path.join(args.dumpdir, f"{utt_id}.h5"), + "wave", + audio.astype(np.float32), + ) + write_hdf5( + os.path.join(args.dumpdir, f"{utt_id}.h5"), + "feats", + mel.astype(np.float32), + ) + elif config["format"] == "npy": + np.save( + os.path.join(args.dumpdir, f"{utt_id}-wave.npy"), + audio.astype(np.float32), + allow_pickle=False, + ) + np.save( + os.path.join(args.dumpdir, f"{utt_id}-feats.npy"), + mel.astype(np.float32), + allow_pickle=False, + ) + else: + raise ValueError("support only hdf5 or npy format.") + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/voc1/output.txt b/egs/ljspeech/voc1/output.txt new file mode 100644 index 00000000..772cadf4 --- /dev/null +++ b/egs/ljspeech/voc1/output.txt @@ -0,0 +1,3 @@ +Stage 2: Network training +ls: cannot access 'exp/tr_no_dev_ljspeech_24k_hifigan_hubert_large_km500_LJ_24khz_hopsize_480/*.pkl': No such file or directory +Training start. See the progress via exp/tr_no_dev_ljspeech_24k_hifigan_hubert_large_km500_LJ_24khz_hopsize_480/train.log. diff --git a/parallel_wavegan/models/hifigan_hubert_representation.py b/parallel_wavegan/models/hifigan_hubert_representation.py new file mode 100644 index 00000000..5e1dfdb7 --- /dev/null +++ b/parallel_wavegan/models/hifigan_hubert_representation.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- + +"""HiFi-GAN Modules. + +This code is based on https://github.com/jik876/hifi-gan. + +""" +from argparse import Namespace +import copy +import logging +import os + +import numpy as np +import torch +import torch.nn.functional as F + + +from parallel_wavegan.layers import HiFiGANResidualBlock as ResidualBlock +from parallel_wavegan.utils import read_hdf5 + + +def base_s3prl_setup(args): + args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None) + args.upstream_model_config = getattr(args, "upstream_model_config", None) + args.upstream_refresh = getattr(args, "upstream_refresh", False) + args.upstream_ckpt = getattr(args, "upstream_ckpt", None) + args.init_ckpt = getattr(args, "init_ckpt", None) + args.verbose = getattr(args, "verbose", False) + args.tile_factor = getattr(args, "tile_factor", 1) + return args + + +class HuBERTREPRHiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator with HuBERT representation module.""" + + def __init__( + self, + in_channels=1024, + out_channels=1, + channels=512, + num_spk_embs=128, + spk_emb_dim=128, + spk_emb_inventory=None, + concat_spk_emb=False, + kernel_size=7, + upsample_scales=(10, 8, 2, 2), + upsample_kernel_sizes=(20, 16, 4, 4), + resblock_kernel_sizes=(3, 7, 11), + resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], + use_additional_convs=True, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.1}, + use_weight_norm=True, + ckpt_path=None, + layer_idx=9, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (list): List of upsampling scales. + upsample_kernel_sizes (list): List of kernal sizes for upsampling layers. + resblock_kernal_sizes (list): List of kernal sizes for residual blocks. + resblock_dilations (list): List of dilation list for residual blocks. + use_additional_convs (bool): Whether to use additional conv layers in residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + self.num_spk_embs = num_spk_embs + + if self.num_spk_embs > 0: + # self.spk_emb = torch.nn.Embedding( + # num_embeddings=num_spk_embs, embedding_dim=spk_emb_dim + # ) + if spk_emb_inventory is None: + self.spk_emb = torch.nn.Embedding( + num_embeddings=num_spk_embs, embedding_dim=spk_emb_dim + ) + else: + spk_emb = torch.load(spk_emb_inventory) + self.spk_emb = torch.nn.Embedding.from_pretrained(spk_emb) + self.spk_emb.requires_grad = False + + self.concat_spk_emb = concat_spk_emb + if not concat_spk_emb: + assert in_channels == spk_emb_dim + else: + in_channels = in_channels + spk_emb_dim + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernal size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # HuBERT + assert ckpt_path is not None + self.layer = layer_idx + import fairseq + ( + model, + cfg, + task, + ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) + self.upstream = model[0].eval() # .cuda() + self.task = task + for k, p in self.upstream.named_parameters(): + logging.info(f"Setting {k}.requires_grad = False") + p.requires_grad = False + self.upstream_pretrained_params = copy.deepcopy(self.upstream.state_dict()) + + # reset parameters + self.reset_parameters() + self.upstream.load_state_dict(self.upstream_pretrained_params) + + def forward(self, c, l): + """Calculate forward propagation. + + Args: + c (Tensor): Input audio tensor (B, T, D). + l (Tensor): Input text Tensor (B, 2, T). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + + # convert idx to embedding + if self.num_spk_embs > 0: + assert l.size(1) == 2 + _, g_idx = l.long().split(1, dim=1) + g = self.spk_emb(g_idx[:, 0, 0]) + + # integrate global embedding + if not self.concat_spk_emb: + c = c + g.unsqueeze(2) + else: + g = g.unsqueeze(1).expand(-1, c.size(1), -1) # (B, T, D) + c = torch.cat([c, g], dim=-1) # (B, T, D1 + D2) + c = c.transpose(1, 2) # (B, D', T) + + c = self.input_conv(c) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def extract_features(self, c): + """Extract features from audio. + + Args: + c (Tensor): Input audio tensor (B, in_channels, T). + + Returns: + Tensor: Output tensor (B, L, D). + + """ + + f = c.squeeze(1) + with torch.no_grad(): + if self.task.cfg.normalize: + f = torch.nn.functional.layer_norm(f, f.shape) + + f, _ = self.upstream.extract_features( + source=f, + padding_mask=None, + mask=False, + output_layer=self.layer, + ) + + return f + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def register_stats(self, stats): + """Register stats for de-normalization as buffer. + + Args: + stats (str): Path of statistics file (".npy" or ".h5"). + + """ + assert stats.endswith(".h5") or stats.endswith(".npy") + if stats.endswith(".h5"): + mean = read_hdf5(stats, "mean").reshape(-1) + scale = read_hdf5(stats, "scale").reshape(-1) + else: + mean = np.load(stats)[0].reshape(-1) + scale = np.load(stats)[1].reshape(-1) + self.register_buffer("mean", torch.from_numpy(mean).float()) + self.register_buffer("scale", torch.from_numpy(scale).float()) + logging.info("Successfully registered stats as buffer.") + + def inference(self, c, l, normalize_before=False): + """Perform inference. + + Args: + c (Union[Tensor, ndarray]): Input tensor (T, D). + l (Tensor or Int): Input spkid Tensor (1) or int. + normalize_before (bool): Whether to perform normalization. + + Returns: + Tensor: Output tensor (T ** prod(upsample_scales), out_channels). + + """ + if not isinstance(c, torch.Tensor): + c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) + if not isinstance(l, torch.Tensor): + l = torch.tensor(l, dtype=torch.long).to(c.device) + + l = l[None, :].repeat(2, 1) + + if normalize_before: + c = (c - self.mean) / self.scale + c = self.forward(c.unsqueeze(0), l.unsqueeze(0)) + return c.squeeze(0).transpose(1, 0)