diff --git a/docs/supported_metrics.md b/docs/supported_metrics.md index 3cbe700..fbad9d9 100644 --- a/docs/supported_metrics.md +++ b/docs/supported_metrics.md @@ -50,7 +50,9 @@ We include x mark if the metric is auto-installed in versa. | 43 | | Qwen2 Recording Environment - Background | qwen2_speech_background_environment_metric | qwen2_speech_background_environment_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | | 44 | | Qwen2 Recording Environment - Quality | qwen2_recording_quality_metric | qwen2_recording_quality_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | | 45 | | Qwen2 Recording Environment - Channel Type | qwen2_channel_type_metric | qwen2_channel_type_metric | [Qwen2 Audio](https://github.com/QwenLM/Qwen2-Audio) | [paper](https://arxiv.org/abs/2407.10759) | - +| 46 | | OpenBEATs - Embedding extraction | openbeats_embedding_extraction | openbeats_embedding_extraction | Released via VERSA | [Challenge report/OpenBEATs arxiv](todo) | +| 48 | | OpenBEATs - Similarity | openbeats_embedding_similarity | openbeats_embedding_similarity | Released via VERSA | [Challenge report/OpenBEATs arxiv](todo) | +| 49 | | OpenBEATs - Class prediction | openbeats_class_prediction | openbeats_class_prediction | Released via VERSA | [Challenge report/OpenBEATs arxiv](todo) | ### Dependent Metrics |Number| Auto-Install | Metric Name (Auto-Install) | Key in config | Key in report | Code Source | References | diff --git a/egs/separate_metrics/openbeats.yaml b/egs/separate_metrics/openbeats.yaml new file mode 100644 index 0000000..4371c82 --- /dev/null +++ b/egs/separate_metrics/openbeats.yaml @@ -0,0 +1,16 @@ +# Metrics with OpenBEATs +# Inference pipeline is released via VERSA! + +# 1. Class prediction +# TODO(shikhar): Add other checkpoints for fine-tuned models. +- name: openbeats_class_prediction + model_path: /work/nvme/bbjs/sbharadwaj/OpenBEATs/audioset20k/cls_earlarge3/ckpt_w_cfg.ckpt + +# 2. Embedding extraction +- name: openbeats_embedding_extraction + model_path: /work/nvme/bbjs/sbharadwaj/7Msounds/exp/beats_iter1_large1.tune_lr1.0e-4_warmup40000_bins1600000_totalsteps400000/epoch_latest.pt + embedding_output_file: test/test_samples/test2/embeddings/test_embeddings.npy + +# 3. Embedding similarity +- name: openbeats_embedding_similarity + model_path: /work/nvme/bbjs/sbharadwaj/7Msounds/exp/beats_iter1_large1.tune_lr1.0e-4_warmup40000_bins1600000_totalsteps400000/epoch_latest.pt diff --git a/test/test_pipeline/test_openbeats.py b/test/test_pipeline/test_openbeats.py new file mode 100644 index 0000000..d7e1970 --- /dev/null +++ b/test/test_pipeline/test_openbeats.py @@ -0,0 +1,94 @@ +import logging +import os + +import yaml +import numpy as np + +from versa.scorer_shared import ( + find_files, + list_scoring, + load_score_modules, +) + +TEST_INFO = { + "openbeats_embedding_extraction": np.array([-0.42187455, -0.6287595, 0.1792216]), + "openbeats_embedding_similarity": 1.0, +} + + +def test_openbeats_embedding_extraction(embedding_result): + """Test OpenBEATs embedding extraction.""" + # Read embedding + assert ( + "embedding_file" in embedding_result + ), "Embedding result does not contain 'embedding_file'" + with open(embedding_result["embedding_file"], "rb") as f: + embedding_result["embedding"] = np.load(f) + + assert embedding_result["embedding"].shape[:-1] == ( + 1, + 48, + ), f'The frame size is off. Expected (1,48) but got {embedding_result["embedding"].shape[:-1]}' + summary_value = embedding_result["embedding"][0, :3, 0] + if np.any( + np.abs(TEST_INFO["openbeats_embedding_extraction"] - summary_value) > 1e-3 + ): + raise ValueError( + "Value issue in the test case, might be some issue in scorer {}".format( + "openbeats_embedding_extraction" + ) + ) + + +def test_openbeats_embedding_similarity(embedding_result): + """Test OpenBEATs embedding similarity.""" + assert ( + "similarity_score" in embedding_result + ), "Embedding result does not contain 'similarity_score'" + similarity_score = embedding_result["similarity_score"] + assert ( + np.abs(TEST_INFO["openbeats_embedding_similarity"] - similarity_score) < 1e-3 + ), "Similarity score should be 1.0, got {}".format(similarity_score) + + +def test_openbeats_class_prediction(class_prediction_result): + """Test OpenBEATs class prediction.""" + assert ( + "class_probabilities" in class_prediction_result + ), "Class prediction result does not contain 'class_probabilities'" + class_probabilities = class_prediction_result["class_probabilities"] + print("Multi-class log probabilities: {}".format(class_probabilities), flush=True) + + +def info_update(): + + # find files + if os.path.isdir("test/test_samples/test2"): + gen_files = find_files("test/test_samples/test2") + + logging.info("The number of utterances = %d" % len(gen_files)) + + with open("egs/separate_metrics/openbeats.yaml", "r", encoding="utf-8") as f: + score_config = yaml.full_load(f) + + score_modules = load_score_modules( + score_config, + use_gt=True, + use_gpu=False, + ) + + assert len(score_config) > 0, "no scoring function is provided" + + score_info = list_scoring( + gen_files, score_modules, gt_files=gen_files, output_file=None, io="soundfile" + ) + + test_openbeats_embedding_extraction(score_info[0]) + test_openbeats_embedding_similarity(score_info[0]) + test_openbeats_class_prediction(score_info[0]) + + print("check successful", flush=True) + + +if __name__ == "__main__": + info_update() diff --git a/versa/__init__.py b/versa/__init__.py index 782218a..137da66 100644 --- a/versa/__init__.py +++ b/versa/__init__.py @@ -103,3 +103,10 @@ ) from versa.utterance_metrics.squim import squim_metric, squim_metric_no_ref from versa.utterance_metrics.srmr import srmr_metric +from versa.utterance_metrics.openbeats import ( + openbeats_setup, + openbeats_class_prediction, + openbeats_embedding_extraction, + openbeats_embedding_similarity, +) +from versa import models diff --git a/versa/metrics.py b/versa/metrics.py index ae4f91c..15db6c2 100644 --- a/versa/metrics.py +++ b/versa/metrics.py @@ -31,6 +31,8 @@ "espnet_hyp_text", "owsm_hyp_text", "whisper_hyp_text", + "openbeats_class_prediction", + "openbeats_embedding_extraction", # HACK: using STR_METRIC to bypass summarization ] NUM_METRIC = [ diff --git a/versa/models/__init__.py b/versa/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/versa/models/openbeats/__init__.py b/versa/models/openbeats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/versa/models/openbeats/decoder.py b/versa/models/openbeats/decoder.py new file mode 100644 index 0000000..afbc9a7 --- /dev/null +++ b/versa/models/openbeats/decoder.py @@ -0,0 +1,93 @@ +"""A simple linear layer decoder. + +This can be used for classification tasks from sequence input. +""" + +from typing import Tuple +import torch +from typeguard import typechecked +from versa.models.openbeats.utils import make_pad_mask + + +class LinearDecoder(torch.nn.Module): + + @typechecked + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + pooling: str = "mean", + dropout: float = 0.0, + pre_layer_norm: bool = False, + ): + """Initialize the module.""" + super().__init__() + + self.input_dim = encoder_output_size + self.output_dim = vocab_size # No special symbols + self.dropout = None + if dropout != 0.0: + self.dropout = torch.nn.Dropout(p=dropout) + self.linear_out = torch.nn.Linear(self.input_dim, self.output_dim) + assert pooling in [ + "mean", + "max", + "CLS", + ], f"Invalid pooling: {pooling}. Should be 'mean', 'max' or 'CLS'." + self.pooling = pooling + self.layer_norm = torch.nn.LayerNorm(self.input_dim) if pre_layer_norm else None + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + hs_pad: (B, Tmax, D) + hlens: (B,) + Returns: + output: (B, n_classes) + """ + + mask = make_pad_mask(lengths=hlens, xs=hs_pad, length_dim=1).to(hs_pad.device) + if self.layer_norm is not None: + hs_pad = self.layer_norm(hs_pad) + if self.dropout is not None: + hs_pad = self.dropout(hs_pad) + if self.pooling == "mean": + unmasked_entries = (~mask).to(dtype=hs_pad.dtype) + input_feature = (hs_pad * unmasked_entries).sum(dim=1) + input_feature = input_feature / unmasked_entries.sum(dim=1) + elif self.pooling == "max": + input_feature = hs_pad.masked_fill(mask, float("-inf")) + input_feature, _ = torch.max(input_feature, dim=1) + elif self.pooling == "CLS": + input_feature = hs_pad[:, 0, :] + + output = self.linear_out(input_feature) + return output + + def score(self, ys, state, x): + """Classify x. + Args: + ys: Not used + state: Not used + x: (T, D). this should be a single sample without + any padding ie batch size=1. + Returns: + ret1: logits over (n_classes,) + state: None + Assumes that x is a single unpadded sequence. + """ + assert len(x.shape) == 2, x.shape + hs_len = torch.tensor([x.shape[0]], dtype=torch.long).to(x.device) + logits = self.forward( + x.unsqueeze(0), + hs_len, + ) + return logits.squeeze(0), None + + def output_size(self) -> int: + """Get the output size.""" + return self.output_dim diff --git a/versa/models/openbeats/encoder.py b/versa/models/openbeats/encoder.py new file mode 100644 index 0000000..1ab2124 --- /dev/null +++ b/versa/models/openbeats/encoder.py @@ -0,0 +1,1814 @@ +"""OpenBEATs: Encoder implementation. + +OpenBEATs is based on the original BEATs implementation (https://arxiv.org/abs/2212.09058) +from Github source: https://github.com/microsoft/unilm/tree/master/beats +Copyright (c) 2022 Microsoft +Licensed under The MIT License [see LICENSE in BEATs repo for details] +""" + +import logging +import math +import warnings +from typing import Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.compliance.kaldi as ta_kaldi +from packaging.version import parse as V +from torch.nn import LayerNorm, Parameter +from contextlib import contextmanager + +try: + from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding + from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( + Wav2Vec2ConformerConfig, + Wav2Vec2ConformerEncoder, + ) + + transformers_import_error = None +except ImportError as e: + transformers_import_error = e + + +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.specaug.specaug import SpecAug +from versa.models.openbeats.utils import ( + beats_frontend, + forward_padding_mask_conv, + freeze_conv_module, + roll_tensor, + make_pad_mask, +) + +if V(torch.__version__) >= V("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +is_torch_v25_to_v26 = V(torch.__version__) >= V("2.5.0") and V(torch.__version__) <= V( + "2.6.0" +) + + +class BeatsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = 16 # patch size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = ( + 1.0 # ratio for layer-wise gradient decay + ) + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = ( + 0.0 # dropout probability after activation in FFN + ) + self.encoder_layerdrop: float = ( + 0.0 # probability of dropping a tarnsformer layer + ) + self.dropout_input: float = ( + 0.0 # dropout to apply to the input (after feat extr) + ) + + # positional embeddings + self.conv_pos: int = ( + 128 # number of filters for convolutional positional embeddings + ) + self.conv_pos_groups: int = ( + 16 # number of groups for convolutional positional embedding + ) + + # relative position embedding + self.relative_position_embedding: bool = ( + False # apply relative position embedding + ) + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = ( + 1280 # maximum distance for relative position embedding + ) + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + # Decoder parameters + self.decoder_embed_dim: int = ( + 768 # decoder embedding dimension, audiomae is 512 + ) + self.decoder_pos_trainable: bool = False + self.decoder_attention_heads: int = 12 + self.decoder_mlp_ratio: float = 4.0 # MLP to transformer dimension ratio + self.decoder_layers: int = 3 # number of decoder layers + self.codebook_vocab_size: int = 1024 # targets vectors in codebook + self.mask_ratio = 0.75 # masking ratio for pre-training + self.use_flash_attn = False # use flash attention in MultiHead attention + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BeatsEncoder(AbsEncoder): + """BEATs: Audio Pre-Training with Acoustic Tokenizers. + + (https://arxiv.org/abs/2212.09058) + Args: + beats_ckpt_path: Path to a pretrained Beats checkpoint. If + `beats_config` is provided and it does not match the + config in the checkpoint, code might throw an error. + max_layer: Propagate input through all layers for encoding + if None. Otherwise use upto `max_layer`. + downsampling_rate: Downsampling rate for the encoder. Applied if > 1. + adapter_config: Path to a config file for the wav2vec2 adapter. + use_weighted_representation: Use weighted representations + from max_layer if True. Weights are randomly initialized. + beats_config: `BeatsConfig` object. If provided, we will try + to override the config in the checkpoint. This can be used + to change dropouts etc for fine-tuning the model while + starting from a pretrained checkpoint. + specaug_config: Dictionary containing parameters for SpecAugment. + If provided, SpecAugment will be applied. + add_positional_information: Add learned positional embeddings. + max_positions: Maximum number of positions for positional embeddings. + Required if `add_positional_information` is True. + roll_augment: Apply roll augmentation to the input. + roll_interval: Interval for roll augmentation. All rolling is + quantized to this interval. + """ + + def __init__( + self, + input_size: int, + beats_ckpt_path: str = None, + max_layer: int = None, + downsampling_rate: int = 1, + adapter_config: str = "", + use_weighted_representation: bool = False, + beats_config: Optional[Dict] = None, + specaug_config: Optional[Dict] = None, + add_positional_information: bool = False, + max_positions: Optional[int] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + roll_augment: bool = False, + roll_interval: int = 1600, + is_pretraining: Optional[bool] = False, + ) -> None: + super().__init__() + + self.fbank_mean = fbank_mean + self.fbank_std = fbank_std + self.max_layer = max_layer + self.beats_ckpt_path = beats_ckpt_path + self.roll_augment = roll_augment + self.roll_interval = roll_interval + + # Four cases for loading Beats config: + # 1. No checkpoint and no config: Default config + # 2. Checkpoint and no user-provided config: Load config from + # checkpoint + # 3. Checkpoint and user-provided config: Merge the two, but + # override with user-provided config + # 4. No checkpoint and user-provided config: Use user-provided config + if adapter_config or add_positional_information: + # We need transformers library for adapter and positional embeddings + if transformers_import_error: + raise ImportError( + "`transformers` is not available. Please install it " + " via `pip install transformers` or" + " `cd /path/to/espnet/tools && " + ". ./activate_python.sh" + " && ./installers/install_transformers.sh`." + f"The original error was: {transformers_import_error}" + ) + config = BeatsConfig() # Default config + if beats_ckpt_path and beats_config: + logging.warning( + "Both pretrained checkpoint and config are provided." + " We will override ckpt config with user-provided config." + ) + self.loaded_state_dict_ = None + if beats_ckpt_path is not None: + self.loaded_state_dict_ = torch.load(beats_ckpt_path) + logging.info(f"Loaded Beats pretrained config from {beats_ckpt_path}.") + config = BeatsConfig(self.loaded_state_dict_["cfg"]) + if beats_config is not None: + config.update(beats_config) + logging.info("Overriding Beats config with user-provided config.") + + self.specaug = None + if specaug_config is not None: + self.specaug = SpecAug(**specaug_config) + + self._output_size = config.encoder_embed_dim + + self.embed = config.embed_dim + self.input_patch_size = config.input_patch_size + self.post_extract_proj = ( + nn.Linear(self.embed, config.encoder_embed_dim) + if self.embed != config.encoder_embed_dim + else None + ) + self.patch_embedding = nn.Conv2d( + 1, + self.embed, + kernel_size=self.input_patch_size, + stride=self.input_patch_size, + bias=config.conv_bias, + ) + self.patch_embedding_pad = nn.Conv2d( + 1, + 1, + kernel_size=self.input_patch_size, + stride=self.input_patch_size, + bias=False, + ) + self.raw2fbank_pad = nn.Conv1d( + 1, + 1, + kernel_size=400, + stride=160, + bias=False, + ) + self.dropout_input = nn.Dropout(config.dropout_input) + assert not config.deep_norm or not config.layer_norm_first + + self.encoder = TransformerEncoder(config) + self.layer_norm = LayerNorm(self.embed) + + self.use_weighted_representation = use_weighted_representation + if self.use_weighted_representation: + if self.max_layer is None: + logging.warning( + f"max_layer must be provided when using weighted" + f" representations. Set to {config.encoder_layers-1}." + ) + self.max_layer = config.encoder_layers - 1 # 0 based index + self.layer_weights = nn.Parameter( + torch.ones((self.max_layer + 1, 1)), requires_grad=True + ) + + # Downsampling modules + self.encoder_downsample_rate = downsampling_rate + self.downsample_conv = None + if self.encoder_downsample_rate > 1: + self.downsample_conv = nn.Conv1d( + in_channels=config.encoder_embed_dim, + out_channels=config.encoder_embed_dim, + kernel_size=int( + round(self.encoder_downsample_rate * 1.5) + ), # kernel multiplier from Shih-Lun's code + stride=self.encoder_downsample_rate, + ) + + # Adapter module + self.conformer_adapter = None + if adapter_config: + conformer_config = Wav2Vec2ConformerConfig.from_json_file(adapter_config) + self.conformer_adapter = Wav2Vec2ConformerEncoder(conformer_config) + + # Positional embeddings applied before cross-attention with decoder. + self.cross_embed_positions = None + if add_positional_information: + assert ( + max_positions is not None + ), "max_positions must be provided in the config." + learned_pos_dim = ( + config.encoder_embed_dim + if not self.conformer_adapter + else self.conformer_adapter.config.hidden_size + ) + self.cross_embed_positions = BartLearnedPositionalEmbedding( + max_positions, learned_pos_dim + ) + # FIXME(shikhar): This is a hack to make the model compatible with + # small audio inputs, without this the window sizes become larger + # than audio. We should add an option to use this via the config. + self.min_input_length_at_16khz = 3200 + + self.is_pretraining = is_pretraining + self.mask_ratio = config.mask_ratio + self.use_flash_attn = config.use_flash_attn + if self.use_flash_attn: + assert V(torch.__version__) >= V( + "2.0.0" + ), "Flash attention requires PyTorch >= 2.0" + if is_pretraining: + assert config.mask_ratio > 0.0, "mask_ratio must be > 0.0 for pretraining." + self.config = config + self.initialize() + + def initialize(self): + logging.info("Beats Initialization function called.") + if self.post_extract_proj: + torch.nn.init.xavier_normal_(self.post_extract_proj.weight) + if self.post_extract_proj.bias is not None: + torch.nn.init.constant_(self.post_extract_proj.bias, 0) + torch.nn.init.xavier_normal_(self.patch_embedding.weight) + if self.patch_embedding.bias is not None: + torch.nn.init.constant_(self.patch_embedding.bias, 0) + freeze_conv_module(self.patch_embedding_pad) + freeze_conv_module(self.raw2fbank_pad) + self.reload_pretrained_parameters() + + def reload_pretrained_parameters(self): + """Initialization function for Beats. + + This must be called last in the initialization procedure. + For pre-training the flow is, + 1. All modules have an initialize function to set the weights. + 2. We do not call espnet initialization once modules are built. + + For fine-tuning using one of the recipes, + The initialization occurs in three steps: + 1. BEATs modules initialize themselves as in 1 above. + 2. ESPnet initializes all modules. + (this step inits the convnet weights in clotho_v2/asr1 recipe) + 3. Optionally, if we have the pretrained checkpoint, we load the + weights from the checkpoint overriding 2 and 1 in this function. + """ + logging.info("Beats parameter loading function called.") + if self.loaded_state_dict_ is not None: + + load_info = self.load_state_dict( + self.loaded_state_dict_["model"], strict=False + ) + # strict=False to ignore Weights in the predictor + logging.info( + f"Loaded Beats pretrained model. Following keys were missing" + f" in your custom model: {load_info.missing_keys}. " + f"Follwing keys could not be loaded from the pretrained" + f"checkpoint: {load_info.unexpected_keys}." + "It is expected to have 'predictor' listed above if you are " + "fine-tuning with only the Beats backbone." + ) + freeze_conv_module(self.patch_embedding_pad) + freeze_conv_module(self.raw2fbank_pad) + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """Forward padding mask. Features: BTC, padding_mask: BT.""" + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.contiguous().view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.any(-1) # remove totally empty sequences + # NOTE(shikhar): This should be any? snip_edges=True in kaldi + # This is problem with raw input, not with fbank input + # Probably replace this with a 1dconv kernel or abandon raw input. + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + ) -> torch.Tensor: + """Preprocess raw audio.""" + fbanks = [] + fbank_lens = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2**15 # float32 to int16 + fbank = ta_kaldi.fbank( + waveform, + num_mel_bins=128, + sample_frequency=16000, + frame_length=25, + frame_shift=10, + ) + fbanks.append(fbank) + fbank_lens.append(fbank.shape[0]) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - self.fbank_mean) / (2 * self.fbank_std) + fbank_lens = torch.tensor(fbank_lens).to(fbank.device) + return fbank, fbank_lens + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + waveform_input: Optional[bool] = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Wrapper for compatibility with ESPnets' AbsEncoder Interface. + + Args: + xs_pad: (B, T) or (B,T,D). If sound then (B,T) or (B,T,1) both work. + If features, then only (B,T,D) will work. + ilens: (B,) + prev_states: None + waveform_input: If True, then input is waveform. If False, then input is + already features. + Returns: + audio_representation: (B, T, D) + output_lens: (B,) + masks: None + """ + if xs_pad.dim() == 2 and waveform_input: + xs_pad = xs_pad.unsqueeze(-1) # (B,T) -> (B,T,1) for sound + if self.roll_augment and self.training: + xs_pad = roll_tensor(xs_pad, ilens, fixed_intervals=self.roll_interval) + mask = make_pad_mask(lengths=ilens, traceable=False).to(xs_pad.device) + audio_representation, mask, restore_ids, kept_mask, patch_padding_mask = ( + self.extract_features( + xs_pad, + mask, + max_layer=self.max_layer, + skip_fbank_extraction=not waveform_input, + ) + ) + # patch_padding_mask is the padding mask before any patch masking, + # only valid in pretraining mode + output_lens = (~mask).sum(-1) + + if self.is_pretraining: + patch_lengths = (~patch_padding_mask).sum(-1) + # audio_representation only contains the unmasked portion + # Therefore patch_lengths > audio_representation.shape[1] + return audio_representation, patch_lengths, restore_ids, kept_mask + + return audio_representation, output_lens, None + + def mask_sequence(self, x, padding_mask): + """Masks the input embedding sequence x for MLM style training. + Needs self.mask_ratio to be set. + + Args: + x: [N, L, D], sequence of embeddings. + padding_mask: [N, L], padding mask for x seq. + True means padded. + Returns: + x_unmasked: [N, l, D], only unmasked portion of + the input sequence is returned. + padding_mask: [N, l], portion of padding mask + corresponding to x_unmasked. True means padded. + ids_restore: [N, L], restore ids for unshuffling. + ids_restore[b,j] = position of x_unmasked[b,j] in x[b]. + No guarantees for masked positions. + kept: [N, L], binary mask for the unmasked(kept) positions. + True if the position is kept. Useful for loss computation. + """ + N, L, D = x.shape # batch, length, dim + + seq_lengths = (~padding_mask).sum(-1) + len_keep = (seq_lengths * (1 - self.mask_ratio)).round().to(dtype=torch.long) + max_len_kept = len_keep.max() + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1]] + noise[padding_mask] = float("inf") + ids_shuffle = torch.argsort(noise, dim=1) + ids_keep = ids_shuffle[:, :max_len_kept] + + # make new masks + padding_mask = make_pad_mask(lengths=len_keep, traceable=False).to(x.device) + kept = torch.cat( + [ + ~padding_mask, + torch.zeros([N, L - max_len_kept], device=x.device, dtype=torch.bool), + ], + dim=1, + ) + + # sort only kept indices for maintaining same order x + ids_keep_sorted = ids_keep.clone() + ids_keep_sorted = torch.where( + padding_mask, + torch.tensor(L - 1, dtype=torch.long, device=x.device), + ids_keep_sorted, + ) # introduce L-1 for sorting only important elements + ids_keep_sorted = ids_keep_sorted.sort(dim=1)[0] + ids_keep_sorted = torch.where( + padding_mask, ids_keep, ids_keep_sorted + ) # handle L-1 + + ids_shuffle = torch.cat([ids_keep_sorted, ids_shuffle[:, max_len_kept:]], dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + x_unmasked = torch.gather( + x, dim=1, index=ids_keep_sorted.unsqueeze(-1).repeat(1, 1, D) + ) + + # unshuffle the loss mask + kept = torch.gather(kept, dim=1, index=ids_restore) + return x_unmasked, padding_mask, ids_restore, kept + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + max_layer: Optional[int] = None, + skip_fbank_extraction: bool = False, + ): + """Extract features from raw audio. + source: (B,T,D). for waveform input D=1, for features D=feature_dim + padding_mask: (B,T). If True then pad the element. + """ + if ( + not skip_fbank_extraction + and self.min_input_length_at_16khz + and source.size(1) < self.min_input_length_at_16khz + ): + # Only executed for raw waveform input + logging.warning( + f"Input shape: {source.shape}. This is less than" + f" the minimum size of {self.min_input_length_at_16khz}." + ) + # repeat the input to make it at least min_length + repeat_factor = self.min_input_length_at_16khz // source.size(1) + 1 + source = torch.cat([source] * repeat_factor, dim=1) + padding_mask = torch.cat([padding_mask] * repeat_factor, dim=1) + + with autocast(False): + fbank = ( + (source - self.fbank_mean) / (2 * self.fbank_std) + if skip_fbank_extraction + else beats_frontend( + source.squeeze(-1), + fbank_mean=self.fbank_mean, + fbank_std=self.fbank_std, + ) + ) + + if self.specaug is not None and self.training: + fbank = self.specaug(fbank)[0] + + if padding_mask is not None and not skip_fbank_extraction: + # padding_mask = self.forward_padding_mask(fbank, padding_mask) + padding_mask = forward_padding_mask_conv( + padding_mask=padding_mask, n_dim=0, conv_module=self.raw2fbank_pad + ) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + + if padding_mask is not None: + # features is BTC + # padding_mask = self.forward_padding_mask(features, padding_mask) + # NOTE(shikhar): ESC-50, BEANS and Clotho_v2 use the previous version + # which is wrong (the one implmented above). + padding_mask = forward_padding_mask_conv( + padding_mask=padding_mask, + n_dim=fbank.shape[-1], + conv_module=self.patch_embedding_pad, + ) + + patch_padding_mask = None + restore_ids = None + kept_mask = None + if self.is_pretraining: + assert ( + max_layer is None + ), "During pretraining max_layer should be set to None!" + patch_padding_mask = padding_mask.clone() + # kept_mask: 1 - kept, 0 - removed, corresponding to features + # features, padding_mask will be shortened to only keep the kept positions + features, padding_mask, restore_ids, kept_mask = self.mask_sequence( + features, padding_mask + ) + + features = self.layer_norm(features) + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + features, layer_results = self.encoder( + features, padding_mask=padding_mask, layer=max_layer + ) + + if max_layer is not None: + features = layer_results[max_layer][0].transpose( + 0, 1 + ) # use the output from the max_layer + + if self.use_weighted_representation: + repr_layer_weights = nn.functional.softmax(self.layer_weights, dim=-2) + assert ( + max_layer is not None + ), "max_layer must not be None when using weighted representations." + features = ( + torch.stack( + [ + layer_result_i.transpose(0, 1) + for layer_result_i, _ in layer_results[: max_layer + 1] + ], + dim=-2, + ) + * repr_layer_weights + ) + features = features.sum(dim=-2) # BTC + + if self.downsample_conv is not None: + features = self.downsample_conv(features.transpose(1, 2)).transpose( + 1, 2 + ) # BTC + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.conformer_adapter: + # to handle incompatibility btw torch & huggingface + conformer_attn_mask = ~padding_mask + # run through conformer + features = self.conformer_adapter( + features, + attention_mask=conformer_attn_mask, + ).last_hidden_state + + if self.cross_embed_positions is not None: + features = features + self.cross_embed_positions(features) + + return features, padding_mask, restore_ids, kept_mask, patch_padding_mask + + +class TransformerEncoder(nn.Module): + """Transformer encoder.""" + + def __init__(self, config): + super().__init__() + + self.dropout = config.dropout + self.embedding_dim = config.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=config.conv_pos, + padding=config.conv_pos // 2, + groups=config.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (config.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential( + self.pos_conv, SamePad(config.conv_pos), nn.GELU() + ) + + if hasattr(config, "relative_position_embedding"): + self.relative_position_embedding = config.relative_position_embedding + self.num_buckets = config.num_buckets + self.max_distance = config.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=config.encoder_ffn_embed_dim, + num_attention_heads=config.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=config.attention_dropout, + activation_dropout=config.activation_dropout, + activation_fn=config.activation_fn, + layer_norm_first=config.layer_norm_first, + deep_norm=config.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=config.gru_rel_pos, + encoder_layers=config.encoder_layers, + use_flash_attn=config.use_flash_attn, + ) + for i in range(config.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, config.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[ + 0 + ].self_attn.relative_attention_bias + + self.layer_norm_first = config.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = config.encoder_layerdrop + + self.apply(init_bert_params) + + if config.deep_norm: + logging.info("Deep Norm is applied.") + deep_norm_beta = math.pow(8 * config.encoder_layers, -1 / 4) + for i in range(config.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_( + self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta + ) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_( + self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta + ) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr( + config, "layer_wise_gradient_decay_ratio", 1 + ) + + def forward(self, x, padding_mask=None, layer=None): + """Forward pass.""" + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + """Extract features from the input sequence.""" + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply((x, self.layer_wise_gradient_decay_ratio)) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + pos_bias=pos_bias, + ) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """Transformer encoder layer.""" + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + use_flash_attn: bool = False, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + use_flash_attn=use_flash_attn, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None, + ): + """Forward pass.""" + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + use_flash_attn=False, + ): + super().__init__() + self.use_flash_attn = use_flash_attn + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.initialize() + + def initialize(self): + """Initiate parameters in the transformer model.""" + # logging.info("Initiate parameters in the MultiheadAttention module.") + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets = ( + relative_buckets + (relative_positions > 0).to(torch.long) * num_buckets + ) + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min( + relative_positions, torch.zeros_like(relative_positions) + ) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.to(torch.get_default_dtype()) / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1), + ) + + relative_buckets = relative_buckets + torch.where( + is_small, relative_positions, relative_postion_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute relative position bias.""" + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = memory_position - context_position # [i,j] = j-i + relative_position_bucket = self._relative_positions_bucket( + relative_position, bidirectional=True + ) # [qlen,klen] + # relative_position_bucket = relative_position_bucket.to( + # self.relative_attention_bias.weight.device + # ) + # [qlen,klen,head_dim] + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) # [head_dim,qlen,klen] + return values + + def forward( + self, + query, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[ + Dict[str, Dict[str, Optional[torch.Tensor]]] + ] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[torch.Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + # headdim,qlen,klen + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = ( + position_bias.unsqueeze(0) + .repeat(bsz, 1, 1, 1) + .contiguous() + .view(bsz * self.num_heads, tgt_len, src_len) + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + alpha = 32 + if not self.use_flash_attn: + q *= self.scaling + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape + # (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.contiguous().view( + bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.contiguous().view( + bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[torch.Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.contiguous().view( + bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_value"] = v.contiguous().view( + bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len = src_len + 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + # ----- Switch: use flash-attn attention if enabled ------------------------- + if self.use_flash_attn: + assert not before_softmax, "Flash attention does not support before_softmax" + assert ( + not need_weights + ), "Flash attention does not support returning attention weights" + # NOTE(shikhar): Pytorch < 2.5.0 has different shape requirements for qkv + if is_torch_v25_to_v26: + # NOTE(shikhar): The contiguous() call can be optimized. + q = q.contiguous().view(bsz, self.num_heads, tgt_len, self.q_head_dim) + k = k.contiguous().view(bsz, self.num_heads, src_len, self.k_head_dim) + v = v.contiguous().view(bsz, self.num_heads, src_len, self.head_dim) + # NOTE(shikhar): Do causal attention via attn_mask + if key_padding_mask is not None: + assert ( + attn_mask is None + ), "key_padding_mask not supported with attn_mask" + attn_mask = key_padding_mask == 0 # B x keylen(srclen) + attn_mask = attn_mask.unsqueeze(1).expand( + -1, tgt_len, -1 + ) # B x tgtlen x srclen + + if attn_mask is not None: + if is_torch_v25_to_v26: + attn_mask = attn_mask.unsqueeze(1) # B x 1 x tgtlen x srclen + else: + attn_mask = attn_mask.unsqueeze(1).expand( + -1, self.num_heads, -1, -1 + ) + attn_mask = attn_mask.contiguous().view( + bsz * self.num_heads, tgt_len, src_len + ) + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.contiguous().view( + bsz, self.num_heads, tgt_len, self.q_head_dim + ) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer) + .contiguous() + .view(_B, _H, _L, 2, 4) + .sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = ( + gate_a_1.contiguous().view(bsz * self.num_heads, tgt_len, 1) + * position_bias + ) + attn_mask_rel_pos = ( + attn_mask_rel_pos.contiguous() + .view(bsz * self.num_heads, tgt_len, src_len) + .contiguous() + ) + + if is_torch_v25_to_v26: + attn_mask_rel_pos = attn_mask_rel_pos.unsqueeze(1) + attn_mask_rel_pos = attn_mask_rel_pos.view( + bsz, self.num_heads, tgt_len, src_len + ).contiguous() + + if attn_mask is not None: + attn_mask_ = torch.zeros_like(attn_mask) + attn_mask_rel_pos + attn_mask_ = attn_mask_.masked_fill( + attn_mask.logical_not(), float("-inf") + ) + attn_mask = attn_mask_ + else: + attn_mask = attn_mask_rel_pos + + attn = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + dropout_p=self.dropout_module.p, + is_causal=False, + ) # B x H x T x D + if is_torch_v25_to_v26: + attn = attn.permute(2, 0, 1, 3).reshape(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + # attention @ value, attn_weights, position_bias + return attn, None, None + # ------------------------------------------------------------------------------ + + # Original BEATs implementation below: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = ( + attn_weights - attn_weights.max(dim=-1, keepdim=True)[0] + ) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights = attn_weights + attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.contiguous().view( + bsz, self.num_heads, tgt_len, src_len + ) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.contiguous().view( + bsz * self.num_heads, tgt_len, src_len + ) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = ( + q.contiguous().view(bsz, self.num_heads, tgt_len, self.q_head_dim) + * alpha + / self.scaling + ) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer) + .contiguous() + .view(_B, _H, _L, 2, 4) + .sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = ( + gate_a_1.contiguous().view(bsz * self.num_heads, tgt_len, 1) + * position_bias + ) + + attn_mask_rel_pos = attn_mask_rel_pos.contiguous().view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[torch.Tensor] = None + if need_weights: + attn_weights = ( + attn_weights_float.contiguous() + .view(bsz, self.num_heads, tgt_len, src_len) + .transpose(1, 0) + ) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[torch.Tensor], + prev_key_padding_mask: Optional[torch.Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[torch.Tensor]: + # NOTE(shikhar): Deepspeed trainer might not work with this + # code path due to the float calls + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] + ) -> Dict[str, Optional[torch.Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[torch.Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], + buffer: Dict[str, Optional[torch.Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + """No op""" + return attn_weights + + +def init_bert_params(module): + """Initialize the weights specific to the BERT Model. + + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + # logging.info("Intializing Linear Layer") + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + # logging.info("Intializing Embedding Layer") + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + # logging.info("Intializing Multihead Attention") + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +class GradMultiply(torch.autograd.Function): + """A gradient modification function that scales the gradient by a fixed scalar""" + + @staticmethod + def forward(ctx, i): + """Forward pass""" + x, scale = i + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + """Backward pass""" + return grad * ctx.scale + + +class SamePad(nn.Module): + """Change input tensor shape according to the kernel size and type of LM""" + + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + """Forward pass""" + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish activation function""" + + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + """Forward pass""" + return x * self.act(x) + + +class GLU_Linear(nn.Module): + """GLU Linear layer""" + + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + """Forward pass""" + # to be consistent with GLU_Linear, we assume the input always has the + # #channel (#dim) in the last dimension of the tensor, so need to + # switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = ( + x[:, :, 0 : self.output_dim] + * x[:, :, self.output_dim : self.output_dim * 2] + ) + else: + x = x[:, :, 0 : self.output_dim] * self.glu_act( + x[:, :, self.output_dim : self.output_dim * 2] + ) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return F.gelu(x).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate") + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError(f"--activation-fn {activation} not supported") + + +def quant_noise(module, p, block_size): + """Wraps modules and applies quantization noise to the weights for + + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.repeat_interleave(block_size, -1) + .contiguous() + .view(-1, in_features) + ) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = ( + mask.repeat_interleave(block_size, -1) + .contiguous() + .view(-1, in_channels) + ) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module diff --git a/versa/models/openbeats/utils.py b/versa/models/openbeats/utils.py new file mode 100644 index 0000000..5f8064e --- /dev/null +++ b/versa/models/openbeats/utils.py @@ -0,0 +1,287 @@ +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.compliance.kaldi as ta_kaldi +from einops import rearrange, repeat +from typing import Optional +from espnet.nets.pytorch_backend.nets_utils import ( + _make_pad_mask_traceable, + _make_pad_mask, +) + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + logging.info( + f"Running K-means with {num_clusters} clusters, {num_iters} iterations, and cosine similarity: {use_cosine_sim}" + ) + + means = sample_vectors(samples, num_clusters) + logging.info(f"Init means!") + + for _ in range(num_iters): + logging.info(f"Running iteration {_ + 1}...") + if use_cosine_sim: + # Assumes samples are normalized + dists = samples @ means.t() + else: + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +@torch.no_grad() +def beats_frontend( + source: torch.Tensor, + fbank_mean: float, + fbank_std: float, +) -> torch.Tensor: + """Preprocess raw audio.""" + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2**15 # float32 to int16 + fbank = ta_kaldi.fbank( + waveform, + num_mel_bins=128, + sample_frequency=16000, + frame_length=25, + frame_shift=10, + ) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + +@torch.no_grad() +def forward_padding_mask_conv( + padding_mask: torch.Tensor, + n_dim: int, + conv_module: nn.Module, +): + """Forward padding mask. + To be applied after features are passed through conv module or after converting to spectrogram, + for consistency. + padding_mask: BT + n_dim: number of dimensions before the transformation was applied to features. + When applying after fbank computation set this to a non-positive value. + conv_module: conv module applied to features, + the channel dimension must be 1. + """ + assert padding_mask.dim() == 2 + if n_dim >= 1: + padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, n_dim) # btn + padding_mask = padding_mask.unsqueeze(1) # b1tn or b1t, depending on n_dim + dtype_ = next(conv_module.parameters()).dtype + padding_mask = conv_module(padding_mask.to(dtype_)) + padding_mask = padding_mask != 0 + padding_mask = ( + padding_mask.view(padding_mask.shape[0], padding_mask.shape[1], -1) + .squeeze(-2) + .contiguous() + ) + return padding_mask + + +def freeze_conv_module(conv_module: nn.Module): + # Fix patch embedding for padding + conv_module.weight.data.fill_(1) + conv_module.weight.requires_grad = False + if conv_module.bias is not None: + conv_module.bias.data.fill_(0) + conv_module.bias.requires_grad = False + + +def roll_tensor( + x: torch.Tensor, + lengths: torch.Tensor, + roll_amounts: Optional[torch.Tensor] = None, + fixed_intervals: Optional[int] = None, +) -> torch.Tensor: + """Left-roll tensor x by roll_amounts, only within lengths and optionally quantized. + + Args: + x: input tensor (B, T, D) + lengths: lengths of each sequence (B,) + roll_amounts: random shift amounts (B,). If None, random shift + amounts are generated. + fixed_intervals: if not None, roll_amounts are quantized to + multiples of this. + Returns: + rolled_x: rolled tensor (B, T, D) + Useful to apply roll augmentation to the input, while considering + the input length for each sample. + """ + B, T, D = x.shape + + indices = torch.arange(T).unsqueeze(0).expand(B, T).to(x.device) # (B, T) + lengths = lengths.unsqueeze(1) # (B, 1) + + if roll_amounts is None: + roll_amounts = torch.randint(0, lengths.max(), (B,), device=x.device) + if fixed_intervals is not None: + roll_amounts = (roll_amounts // fixed_intervals) * fixed_intervals + roll_indices = (indices - roll_amounts.unsqueeze(1)) % lengths # (B, T) + roll_indices = roll_indices.unsqueeze(2).expand(-1, -1, D) # (B, T, D) + + mask = indices < lengths # (B, T), True if position is valid + rolled_x = torch.empty_like(x) + rolled_x[mask] = x.gather(1, roll_indices)[mask] + rolled_x[~mask] = x[~mask] + return rolled_x + + +def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None, traceable=True): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + traceable (bool, optional): If True, use a traceable implementation. + Traceable operations can be costly since they construct a + maxlen X maxlen triangular mask. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + # If the input dimension is 2 or 3, + # then we use ESPnet-ONNX based implementation for tracable modeling. + # otherwise we use the traditional implementation for research use. + if isinstance(lengths, list): + logging.warning( + "Using make_pad_mask with a list of lengths is not tracable. " + + "If you try to trace this function with type(lengths) == list, " + + "please change the type of lengths to torch.LongTensor." + ) + + if ( + (xs is None or xs.dim() in (2, 3)) + and length_dim <= 2 + and (not isinstance(lengths, list) and lengths.dim() == 1) + and traceable + ): + return _make_pad_mask_traceable(lengths, xs, length_dim, maxlen) + else: + return _make_pad_mask(lengths, xs, length_dim, maxlen) diff --git a/versa/scorer_shared.py b/versa/scorer_shared.py index 9cc632a..240f0ca 100644 --- a/versa/scorer_shared.py +++ b/versa/scorer_shared.py @@ -11,6 +11,7 @@ import yaml from tqdm import tqdm + from versa.metrics import STR_METRIC, NUM_METRIC from versa.utils_shared import ( check_all_same, @@ -823,6 +824,75 @@ def load_score_modules(score_config, use_gt=True, use_gt_text=False, use_gpu=Fal logging.info( "Initiate qwen2 audio metric: {} successfully".format(config["name"]) ) + + elif config["name"] == "openbeats_class_prediction": + from versa import openbeats_setup, openbeats_class_prediction + + assert ( + "model_path" in config + ), "model_path is required for openbeats_class_prediction" + logging.info( + f"Loading OpenBEATs class prediction using checkpoint {config.get('model_path')}" + ) + model = openbeats_setup( + model_path=config.get("model_path", None), + use_gpu=use_gpu, + ) + + score_modules["openbeats_class_prediction"] = { + "module": openbeats_class_prediction, + "model": model, + } + logging.info("Initialized OpenBEATs class prediction successfully.") + + elif config["name"] == "openbeats_embedding_extraction": + from versa import openbeats_setup, openbeats_embedding_extraction + + assert ( + "model_path" in config + ), "model_path is required for openbeats_embedding_extraction" + assert ( + "embedding_output_file" in config + ), "embedding_output_file is required for openbeats_embedding_extraction" + logging.info( + f"Loading OpenBEATs embedding extraction using checkpoint {config.get('model_path')}. " + f"Embedding will be saved to {config.get('embedding_output_file')}" + ) + model = openbeats_setup( + model_path=config.get("model_path", None), + use_gpu=use_gpu, + ) + + score_modules["openbeats_embedding_extraction"] = { + "module": openbeats_embedding_extraction, + "model": model, + "embedding_output_file": config.get("embedding_output_file", None), + } + logging.info("Initialized OpenBEATs embedding extraction successfully.") + elif config["name"] == "openbeats_embedding_similarity": + if not use_gt: + logging.warning( + "Cannot use openbeats_embedding_similarity because no gt audio is provided" + ) + continue + from versa import openbeats_setup, openbeats_embedding_similarity + + assert ( + "model_path" in config + ), "model_path is required for openbeats_embedding_similarity" + logging.info( + f"Loading OpenBEATs embedding similarity using checkpoint {config.get('model_path')}" + ) + model = openbeats_setup( + model_path=config.get("model_path", None), + use_gpu=use_gpu, + ) + + score_modules["openbeats_embedding_similarity"] = { + "module": openbeats_embedding_similarity, + "model": model, + } + logging.info("Initialized OpenBEATs embedding similarity successfully.") return score_modules @@ -1013,6 +1083,26 @@ def use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text=None): gen_sr, custom_prompt=score_modules[key]["prompt"], ) + elif key == "openbeats_class_prediction": + score = score_modules[key]["module"]( + score_modules[key]["model"], + gen_wav, + gen_sr, + ) + elif key == "openbeats_embedding_extraction": + score = score_modules[key]["module"]( + score_modules[key]["model"], + gen_wav, + gen_sr, + embedding_output_file=score_modules[key]["embedding_output_file"], + ) + elif key == "openbeats_embedding_similarity": + # NOTE(shikhar): using gt_wav as reference audio + score = score_modules[key]["module"]( + score_modules[key]["model"], + (gen_wav, gen_sr), + (gt_wav, gen_sr), + ) else: raise NotImplementedError(f"Not supported {key}") diff --git a/versa/utterance_metrics/openbeats.py b/versa/utterance_metrics/openbeats.py new file mode 100644 index 0000000..b5f9ab5 --- /dev/null +++ b/versa/utterance_metrics/openbeats.py @@ -0,0 +1,372 @@ +"""Script to run inference with OpenBEATs models. + +This script allows prediction of sound event/class, +and extraction of embeddings from +any layer using OpenBEATs models. + +Usage: +1. To predict sound events/classes: + python openbeats.py --model_path \ + --audio_path + +2. To extract embeddings: + python openbeats.py --model_path \ + --audio_path --extract_embeddings + +3. To compute similarity between two audio files: + python openbeats.py --model_path \ + --audio_path --compute_similarity \ + --reference_audio_path --output_dir + +Note: For predicting class, please use a fine-tuned checkpoint +with the decoder. + +Available checkpoints: +TODO(shikhar): Add list of checkpoints + +""" + +import torch +import os +import numpy as np +import librosa +from sklearn.metrics.pairwise import cosine_similarity + +# TODO(shikhar): after OpenBEATs is merged into ESPnet: +# 1. Switch the import to espnet. +# 2. Remove the model folder in versa +# PR: https://github.com/espnet/espnet/pull/6052 +from versa.models.openbeats.encoder import BeatsEncoder +from versa.models.openbeats.decoder import LinearDecoder + +# All checkpoints represent the best performing OpenBEATs +# checkpoint from TODO(shikhar): arxiv link1, link2. +OPENBEATS_CHECKPOINTS = { + # Pretrained models + "PT_MD_LARGE_iter1": "", # Pretrained on multi-domain audio (bioacoustics included) + "PT_MD_LARGE_iter2": "", # Pretrained on multi-domain audio (bioacoustics included) + "PT_MD_LARGE_iter3": "", # Pretrained on multi-domain audio (bioacoustics included) + "PT_AUDIO_LARGE_40K": "", # Pretrained on 40k hours of audio + "PT_SPEECH_AUDIO_LARGE_75K": "", # Pretrained on 75k hours of speech + audio + # Finetuned models (on top of PT_MD_LARGE) + "FT_MD_LARGE_AS20K": "", # Fine-tuned on AudioSet-20K from PT_MD_LARGE + "FT_MD_LARGE_AS2M": "", # Fine-tuned on AudioSet-2M from PT_MD_LARGE + "FT_MD_LARGE_FSD50K": "", # Fine-tuned on FSD50K from PT_MD_LARGE + "FT_MD_LARGE_ESC": "", # Fine-tuned on ESC-50 from PT_MD_LARGE + # Fine-tuned on BEANS bioacoustics datasets + "FT_MD_LARGE_WATKINS": "", + "FT_MD_LARGE_CBI": "", + "FT_MD_LARGE_HUMBUGDB": "", + "FT_MD_LARGE_DOGS": "", + "FT_MD_LARGE_BATS": "", + "FT_MD_LARGE_DCASE21": "", + "FT_MD_LARGE_RFCX": "", + "FT_MD_LARGE_GIBBONS": "", + "FT_MD_LARGE_HICEAS": "", + "FT_MD_LARGE_ENABIRDS": "", + # Fine-tuned on music-related datasets + "FT_MD_LARGE_GTZAN_GENRE": "", + "FT_MD_LARGE_NSYNTH_I": "", + "FT_MD_LARGE_NSYNTH_P": "", +} + + +def validate_input_arguments(args): + """Validate input arguments. + + Args: + args: Parsed command line arguments. + + Raises: + ValueError: If any required arguments are invalid or missing. + FileNotFoundError: If required files don't exist. + """ + if not os.path.exists(args.model_path): + raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}") + + if not os.path.exists(args.audio_path): + raise FileNotFoundError(f"Audio file not found: {args.audio_path}") + + if args.compute_similarity: + if args.reference_audio_path is None: + raise ValueError( + "Reference audio path must be specified when computing similarity." + ) + if not os.path.exists(args.reference_audio_path): + raise FileNotFoundError( + f"Reference audio file not found: {args.reference_audio_path}" + ) + + if args.extract_embeddings and args.output_dir is None: + raise ValueError( + "Output directory must be specified when extracting embeddings." + ) + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + +def _get_ckpt_components(model_path, checkpoint_type=None): + """Separate the components of the checkpoint. + Args: + model_path: Path or URL to the model checkpoint. + checkpoint_type: Type of checkpoint, either 'pretrained' or 'finetuned'. + Detect automatically if not specified. + Returns: + Tuple containing: + - encoder_state_dict: State dict for the encoder. + - decoder_state_dict: State dict for the decoder (if exists). + - cfg: ESPnet configuration of the BEATs model. + - token_list: List of tokens (if exists). + """ + model_url = None + if model_path.startswith("http"): + model_url = model_path + elif model_path in OPENBEATS_CHECKPOINTS: + model_url = OPENBEATS_CHECKPOINTS[model_path] + if model_url: + checkpoint = torch.hub.load_state_dict_from_url(model_url, map_location="cpu") + else: + checkpoint = torch.load(model_path, map_location="cpu") + if checkpoint_type: + is_pretrained = checkpoint_type == "pretrained" + else: + is_pretrained = ( + len(checkpoint.get("token_list", [])) == 0 + ) # because pretrained checkpoints do not have token_list + if is_pretrained: + # Pre-trained checkpoint BEATs style + encoder_state_dict = checkpoint["model"] + decoder_state_dict = None + cfg = checkpoint["cfg"] + token_list = None + else: + # Fine-tuned checkpoint ESPnet style + encoder_state_dict = { + k[len("encoder.") :]: v + for k, v in checkpoint["model"].items() + if k.startswith("encoder.") + } + decoder_state_dict = { + k[len("decoder.") :]: v + for k, v in checkpoint["model"].items() + if k.startswith("decoder.") + } + cfg = checkpoint["cfg"] + token_list = checkpoint["token_list"][:-2] # Exclude and tokens + return encoder_state_dict, decoder_state_dict, cfg, token_list + + +def openbeats_setup( + model_path, + use_gpu=False, +): + device = "cuda" if use_gpu else "cpu" + encoder_state_dict, decoder_state_dict, cfg, token_list = _get_ckpt_components( + model_path + ) + encoder = BeatsEncoder( + input_size=1, + beats_config=cfg, + ) + encoder.load_state_dict(encoder_state_dict, strict=False) + encoder.to(device) + encoder.eval() + if decoder_state_dict is None: + return encoder + + decoder = LinearDecoder( + vocab_size=len(token_list), + encoder_output_size=encoder.output_size(), + ) + decoder.load_state_dict(decoder_state_dict) + decoder.to(device) + decoder.eval() + return (encoder, decoder, token_list) + + +def _prepare_openbeats_input(audio, fs, model): + """Prepare input audio for model inference. + + Args: + audio: Input audio. + fs: Sampling rate of the input audio. + + Returns: + Tensor containing the audio data. + """ + if fs != 16000: + import librosa + + audio = librosa.resample(audio, orig_sr=fs, target_sr=16000) + if audio.ndim == 1: + audio = np.expand_dims(audio, axis=0) + device = next(model.parameters()).device + audio = torch.tensor(audio, dtype=torch.float32).to(device) + return audio + + +def openbeats_class_prediction(model, x, fs): + """Predict sound events/classes from audio. + + Args: + model: OpenBEATs encoder, decoder and token_list. + x: Input audio. + fs: Sampling rate of the input audio. + + Returns: + Dictionary with predicted classes and their probabilities. + """ + assert isinstance( + model, tuple + ), "Model should be a tuple of encoder, decoder and token_list." + encoder, decoder, token_list = model + audio = _prepare_openbeats_input(x, fs, encoder) + with torch.no_grad(): + ilens = torch.full( + (audio.size(0),), audio.size(1), dtype=torch.long, device=audio.device + ) + embedding, hlens, _ = encoder(xs_pad=audio, ilens=ilens, waveform_input=True) + probs = decoder(hs_pad=embedding, hlens=hlens) + class_probs = {k: v for k, v in zip(token_list, probs.to("cpu").numpy()[0])} + + return {"class_probabilities": class_probs} + + +def openbeats_embedding_extraction(model, x, fs, embedding_output_file=None): + """Extract embeddings from audio. + + Args: + model: OpenBEATs model. + x: Input audio. + fs: Sampling rate of the input audio. + embedding_output_file: Path to save the extracted embeddings. + + Returns: + Dictionary where value is the numpy file containing extracted embedding. + """ + assert embedding_output_file is not None, "Output file path must be specified." + audio = _prepare_openbeats_input(x, fs, model) + with torch.no_grad(): + ilens = torch.full( + (audio.size(0),), audio.size(1), dtype=torch.long, device=audio.device + ) + embedding, _, _ = model(xs_pad=audio, ilens=ilens, waveform_input=True) + + if not os.path.exists(embedding_output_file): + os.makedirs(os.path.dirname(embedding_output_file), exist_ok=True) + with open(embedding_output_file, "wb") as f: + np.save(f, embedding.to("cpu").numpy()) + + return {"embedding_file": embedding_output_file} + + +def openbeats_embedding_similarity(model, x, ref_x): + """Compute embedding similarity between input and reference audio. + + Args: + model: OpenBEATs model. + x: Input audio and sampling rate. + ref_x: Reference audio and sampling rate. + + Returns: + Dictionary where value is the similarity score between the embeddings. + """ + audio1, fs1 = x + audio2, fs2 = ref_x + + audio1_prepared = _prepare_openbeats_input(audio1, fs1, model) + audio2_prepared = _prepare_openbeats_input(audio2, fs2, model) + + with torch.no_grad(): + ilens1 = torch.full( + (audio1_prepared.size(0),), + audio1_prepared.size(1), + dtype=torch.long, + device=audio1_prepared.device, + ) + embedding1, _, _ = model( + xs_pad=audio1_prepared, ilens=ilens1, waveform_input=True + ) + ilens2 = torch.full( + (audio2_prepared.size(0),), + audio2_prepared.size(1), + dtype=torch.long, + device=audio2_prepared.device, + ) + embedding2, _, _ = model( + xs_pad=audio2_prepared, ilens=ilens2, waveform_input=True + ) + + embedding1_np = embedding1.to("cpu").numpy() + embedding2_np = embedding2.to("cpu").numpy() + + # (batch_size, features) + embedding1_flat = embedding1_np.reshape(embedding1_np.shape[0], -1) + embedding2_flat = embedding2_np.reshape(embedding2_np.shape[0], -1) + similarity_score = cosine_similarity(embedding1_flat, embedding2_flat)[0, 0] + return { + "similarity_score": similarity_score, + } + + +if __name__ == "__main__": + import argparse + import librosa + + parser = argparse.ArgumentParser(description="Run OpenBEATs inference.") + parser.add_argument("--model_path", required=True, help="Path to model checkpoint") + parser.add_argument("--audio_path", required=True, help="Path to input audio file") + parser.add_argument( + "--extract_embeddings", + action="store_true", + help="Extract embeddings instead of predicting classes", + ) + parser.add_argument( + "--compute_similarity", + action="store_true", + help="Compute similarity between input and reference audio embeddings", + ) + parser.add_argument( + "--reference_audio_path", + type=str, + default=None, + help="Path to reference audio file for embedding similarity", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory path to save output embeddings as npy files", + ) + parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available") + args = parser.parse_args() + validate_input_arguments(args) + + # Load model + model = openbeats_setup( + model_path=args.model_path, + use_gpu=args.use_gpu, + ) + + # Load audio + audio, fs = librosa.load(args.audio_path, sr=None, mono=False) + + # Infer + if args.compute_similarity: + ref_audio, ref_fs = librosa.load(args.reference_audio_path, sr=None, mono=False) + similarity_result = openbeats_embedding_similarity( + model, (audio, fs), (ref_audio, ref_fs) + ) + print("Embedding Similarity Result: ", similarity_result) + elif args.extract_embeddings: + embedding_output_file = os.path.join( + args.output_dir, os.path.basename(args.audio_path) + "_embedding.npy" + ) + embedding = openbeats_embedding_extraction( + model, audio, fs, embedding_output_file + ) + print("Extracted Embedding: ", embedding) + else: + prediction = openbeats_class_prediction(model, audio, fs) + print("Predicted Classes: ", prediction)