diff --git a/examples/huggingface/reverb_config.py b/examples/huggingface/reverb_config.py new file mode 100644 index 0000000..f7e1553 --- /dev/null +++ b/examples/huggingface/reverb_config.py @@ -0,0 +1,90 @@ +# Following https://huggingface.co/docs/transformers/en/custom_models +import math +from typing import Dict, List, Optional +from transformers import PretrainedConfig +import numpy as np +import yaml +from pyctcdecode import build_ctcdecoder + + +def cmvn(means: List[float], variance: List[float], count: int): + """ Calculate cmvn from stats + + Returns: + a numpy array of [means, vars] + """ + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + return [means, variance] + + +class ReverbConfig(PretrainedConfig): + model_type = "reverb_asr" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Set default special tokens if not provided + if not hasattr(self, 'special_tokens'): + self.special_tokens = { + "": 0, + "": 2, + "": 2, + "": 1, + } + + # Calculate CMVN if the required stats are provided + if hasattr(self, 'cmvn_mean_stat') and hasattr(self, 'cmvn_var_stat') and hasattr(self, 'cmvn_frame_num'): + self.cmvn_mean, self.cmvn_istd = cmvn( + self.cmvn_mean_stat, + self.cmvn_var_stat, + self.cmvn_frame_num + ) + + # Set default ratio if not provided + if not hasattr(self, 'inputs_to_logits_ratio'): + self.inputs_to_logits_ratio = 1 + + # Tokenizer configuration + if not hasattr(self, 'tokenizer_path'): + self.tokenizer_path = "path/to/tokenizer.model" + if not hasattr(self, 'units_path'): + self.units_path = "path/to/units.txt" + if not hasattr(self, 'decoder_beam_width'): + self.decoder_beam_width = 8 + if not hasattr(self, 'decoder_token_min_logp'): + self.decoder_token_min_logp = -10 + + # Load units and build decoder + self._load_units_and_build_decoder() + + def _load_units_and_build_decoder(self): + """Load units from file and build the CTC decoder.""" + decoder_ids = [] + with open(self.units_path, 'r') as units_file: + for line in units_file: + token = line.split()[0] + if len(token) == 0: + continue + if token == '': + token = '' + decoder_ids.append(token) + self.decoder = build_ctcdecoder(decoder_ids) + + @classmethod + def from_yaml_file(cls, yaml_file: str) -> "ReverbConfig": + """Load a ReverbConfig from a YAML file. + + Args: + yaml_file: Path to the YAML file containing the configuration + + Returns: + A ReverbConfig instance loaded from the file + """ + with open(yaml_file, 'r') as f: + config_dict = yaml.safe_load(f) + return cls(**config_dict) \ No newline at end of file diff --git a/examples/huggingface/reverb_config.yaml b/examples/huggingface/reverb_config.yaml new file mode 100644 index 0000000..6003377 --- /dev/null +++ b/examples/huggingface/reverb_config.yaml @@ -0,0 +1,48 @@ +model_type: reverb_asr +input_dim: 80 +output_dim: 10001 +cmvn_mean_stat: [33596438528.0, 35418329088.0, 39182106624.0, 41983324160.0, 44419112960.0, 46015381504.0, 46934564864.0, 47058870272.0, 47288012800.0, 47522979840.0, 48491438080.0, 49308729344.0, 50230493184.0, 50796900352.0, 51020386304.0, 51297456128.0, 51333586944.0, 51126181888.0, 51455569920.0, 50636410880.0, 49947033600.0, 50365546496.0, 49383075840.0, 49540546560.0, 49066065920.0, 49236889600.0, 48820707328.0, 49071112192.0, 48968024064.0, 49024458752.0, 49202397184.0, 49374433280.0, 49620660224.0, 49947111424.0, 50326310912.0, 50717818880.0, 51046891520.0, 51345678336.0, 51655733248.0, 51505459200.0, 51813666816.0, 51577262080.0, 51776524288.0, 51754237952.0, 51918598144.0, 52158758912.0, 52405276672.0, 52596776960.0, 52639731712.0, 52631220224.0, 52443103232.0, 52315619328.0, 52219695104.0, 52178399232.0, 52083040256.0, 52064792576.0, 51980918784.0, 51824164864.0, 51550973952.0, 51002216448.0, 50422747136.0, 49847754752.0, 49474338816.0, 48997863424.0, 48617009152.0, 48309174272.0, 48084140032.0, 48095608832.0, 47965765632.0, 47909335040.0, 47780065280.0, 47762370560.0, 47757099008.0, 47731314688.0, 47574110208.0, 47336361984.0, 47009054720.0, 46283513856.0, 44821860352.0, 42771775488.0] +cmvn_var_stat: [360475131904.0, 401487724544.0, 484368646144.0, 548414357504.0, 608912080896.0, 651613241344.0, 678013698048.0, 683624693760.0, 689524047872.0, 695375822848.0, 722376851456.0, 746773872640.0, 774244204544.0, 791678353408.0, 798920015872.0, 807307444224.0, 808713453568.0, 802957754368.0, 812319899648.0, 788076953600.0, 767619497984.0, 777970712576.0, 748566544384.0, 751065628672.0, 736340869120.0, 739872473088.0, 727466704896.0, 734006083584.0, 731017904128.0, 732582576128.0, 737590444032.0, 742469861376.0, 749455671296.0, 758746972160.0, 769666121728.0, 781107331072.0, 790730506240.0, 799342002176.0, 808164917248.0, 803454713856.0, 812040585216.0, 804632395776.0, 809866821632.0, 808861499392.0, 813548044288.0, 820701954048.0, 828343779328.0, 834335604736.0, 835754590208.0, 835251011584.0, 829192929280.0, 824705744896.0, 821224734720.0, 819399753728.0, 816182853632.0, 815243788288.0, 812578177024.0, 807846281216.0, 799796035584.0, 784661544960.0, 770915631104.0, 756696285184.0, 746462183424.0, 734193254400.0, 724980072448.0, 717529612288.0, 711156563968.0, 710358204416.0, 706386919424.0, 704228884480.0, 700537110528.0, 699519008768.0, 699025129472.0, 698035535872.0, 693109391360.0, 686047887360.0, 676213948416.0, 655917645824.0, 616676458496.0, 563932168192.0] +cmvn_frame_num: 3519342927 +encoder: conformer +encoder_activation_type: swish +encoder_attention_dropout_rate: 0.1 +encoder_attention_heads: 8 +encoder_causal: true +encoder_cnn_module_kernel: 31 +encoder_cnn_module_norm: layer_norm +encoder_dropout_rate: 0.1 +encoder_input_layer: conv2d +encoder_linear_units: 2048 +encoder_normalize_before: true +encoder_num_blocks: 18 +encoder_num_langs: 2 +encoder_output_size: 640 +encoder_pos_enc_layer_type: rel_pos +encoder_positional_dropout_rate: 0.1 +encoder_selfattention_layer_type: rel_selfattn +encoder_use_cnn_module: true +encoder_use_dynamic_chunk: true +decoder: lslbitransformer +decoder_attention_heads: 8 +decoder_dropout_rate: 0.1 +decoder_linear_units: 2048 +decoder_num_blocks: 6 +decoder_num_langs: 2 +decoder_positional_dropout_rate: 0.1 +decoder_r_num_blocks: 6 +decoder_self_attention_dropout_rate: 0.1 +decoder_src_attention_dropout_rate: 0.1 +ctc_blank_id: 0 +ctc_weight: 0.3 +lsm_weight: 0.1 +reverse_weight: 0.3 +special_tokens: + "": 0 + "": 2 + "": 2 + "": 1 +tokenizer_path: "hf-reverb/tk.model" +units_path: "hf-reverb/tk.units.txt" +decoder_beam_width: 8 +decoder_token_min_logp: -10 \ No newline at end of file diff --git a/examples/huggingface/reverb_hf.py b/examples/huggingface/reverb_hf.py new file mode 100644 index 0000000..d1908c0 --- /dev/null +++ b/examples/huggingface/reverb_hf.py @@ -0,0 +1,98 @@ +# Following https://huggingface.co/docs/transformers/en/custom_models + +from typing import List, Optional, Tuple, Union +import torch +from transformers import PreTrainedModel +from transformers.modeling_outputs import Seq2SeqLMOutput +from wenet.transformer.asr_model import ASRModel +from wenet.transformer.cmvn import GlobalCMVN +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import LanguageSpecificBiTransformerDecoder +from wenet.transformer.encoder import ConformerEncoder +from reverb_config import ReverbConfig + +class ReverbModel(PreTrainedModel): + config_class = ReverbConfig + main_input_name = "input_features" + + def __init__(self, config): + super().__init__(config) + self.config = config + global_cmvn = GlobalCMVN( + torch.Tensor(config.cmvn_mean), + torch.Tensor(config.cmvn_istd), + ) + encoder = ConformerEncoder( + config.input_dim, + global_cmvn=global_cmvn, + activation_type=config.encoder_activation_type, + attention_dropout_rate=config.encoder_attention_dropout_rate, + attention_heads=config.encoder_attention_heads, + causal=config.encoder_causal, + cnn_module_kernel=config.encoder_cnn_module_kernel, + cnn_module_norm=config.encoder_cnn_module_norm, + dropout_rate=config.encoder_dropout_rate, + input_layer=config.encoder_input_layer, + linear_units=config.encoder_linear_units, + normalize_before=config.encoder_normalize_before, + num_blocks=config.encoder_num_blocks, + num_langs=config.encoder_num_langs, + output_size=config.encoder_output_size, + pos_enc_layer_type=config.encoder_pos_enc_layer_type, + positional_dropout_rate=config.encoder_positional_dropout_rate, + selfattention_layer_type=config.encoder_selfattention_layer_type, + use_cnn_module=config.encoder_use_cnn_module, + use_dynamic_chunk=config.encoder_use_dynamic_chunk, + ) + + decoder = LanguageSpecificBiTransformerDecoder( + config.output_dim, + config.encoder_output_size, + attention_heads=config.decoder_attention_heads, + dropout_rate=config.decoder_dropout_rate, + linear_units=config.decoder_linear_units, + num_blocks=config.decoder_num_blocks, + num_langs=config.decoder_num_langs, + positional_dropout_rate=config.decoder_positional_dropout_rate, + r_num_blocks=config.decoder_r_num_blocks, + self_attention_dropout_rate=config.decoder_self_attention_dropout_rate, + src_attention_dropout_rate=config.decoder_src_attention_dropout_rate, + ) + + ctc = CTC( + config.output_dim, + config.encoder_output_size, + config.ctc_blank_id, + ) + + self.model = ASRModel( + vocab_size=config.output_dim, + encoder=encoder, + decoder=decoder, + ctc=ctc, + special_tokens=config.special_tokens, + ctc_weight=config.ctc_weight, + lsm_weight=config.lsm_weight, + reverse_weight=config.reverse_weight, + ) + self.model.lsl_enc = True + self.model.lsl_dec = True + + def forward( + self, + input_features=None, + feats_lengths=None, + labels=None, + labels_lengths=None, + **kwargs, + ): + output = self.model.hf_forward( + input_features, + feats_lengths=feats_lengths, + labels=labels, + labels_lengths=labels_lengths, + ) + return Seq2SeqLMOutput( + logits=output['ctc_probs'], + loss=output['loss'], + ) diff --git a/examples/huggingface/reverb_processor.py b/examples/huggingface/reverb_processor.py new file mode 100644 index 0000000..c3aff5f --- /dev/null +++ b/examples/huggingface/reverb_processor.py @@ -0,0 +1,132 @@ +import json +from typing import List, Optional, Union +import numpy as np +import sentencepiece as spm +import torch +import torchaudio +from torchaudio.compliance import kaldi +from tqdm import tqdm +from transformers import BatchFeature, PreTrainedTokenizer, ProcessorMixin, SequenceFeatureExtractor +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class ReverbFeatureExtractor(SequenceFeatureExtractor): + model_input_names = ["input_features"] + def __init__( + self, + feature_size=80, + sampling_rate=16000, + frame_length=25, + frame_shift=10, + chunk_length=15, + padding_value=0.0, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=False, + **kwargs, + ) + self.frame_length = frame_length + self.frame_shift = frame_shift + self.chunk_length = chunk_length + self.max_chunk_size = 2051 + self._processor_class = "CTCWithLM" + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + device: Optional[str] = "cpu", + sampling_rate: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + " Attempting a conversion." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray([speech], dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + if not is_batched: + raw_speech = [np.asarray([raw_speech])] + + fbank_speech, feats_lengths = [], [] + for waveform in raw_speech: + fbank_speech.append( + kaldi.fbank( + torch.tensor(waveform), + num_mel_bins=self.feature_size, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=0.0, + energy_floor=0.0, + sample_frequency=self.sampling_rate, + ) + ) + feats_lengths.append(fbank_speech[-1].shape[0]) + fbank_speech = BatchFeature({ + "input_features": fbank_speech, + "feats_lengths": feats_lengths, + }) + padded = self.pad( + fbank_speech, + padding="max_length", + max_length=self.max_chunk_size, + ) + return padded + + +class ReverbTokenizer(PreTrainedTokenizer): + def __init__( + self, + model: str, + #units: str, + **kwargs, + ): + self.tokenizer = spm.SentencePieceProcessor(model) + """self.units = dict() + with open(units, 'r') as units_file: + for line in tqdm(units_file.readlines()): + token, id = line.split() + self.units[int(id)] = token.replace('▁', ' ')""" + + + def encode( + self, + text, + **kwargs + ): + return self.tokenizer.encode(text) + + def decode( + self, + token_ids, + **kwargs, + ): + return self.tokenizer.decode(token_ids[token_ids.nonzero()[0]].tolist()) diff --git a/examples/huggingface/transcribe.py b/examples/huggingface/transcribe.py new file mode 100644 index 0000000..d99077b --- /dev/null +++ b/examples/huggingface/transcribe.py @@ -0,0 +1,55 @@ +import sys +import os + +# Add the project root directory to Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../asr/")) +sys.path.append(project_root) + +import numpy as np +import torch +import torchaudio +from transformers import pipeline +from transformers import AutoConfig, AutoModelForSpeechSeq2Seq +from reverb_hf import ReverbModel +from reverb_config import ReverbConfig +from reverb_processor import ReverbFeatureExtractor, ReverbTokenizer + + +# Register the custom model and config +AutoConfig.register("reverb_asr", ReverbConfig) +AutoModelForSpeechSeq2Seq.register(ReverbConfig, ReverbModel) + +# Load configuration +config = ReverbConfig.from_yaml_file("reverb_config.yaml") + +# Initialize feature extractor and tokenizer using config +feature_extractor = ReverbFeatureExtractor(return_tensors='pt') +tokenizer = ReverbTokenizer(config.tokenizer_path) + +# Initialize model +model = ReverbModel(config) + +# Initialize transcription pipeline +transcribe = pipeline( + "automatic-speech-recognition", + model=model, + feature_extractor=feature_extractor, + tokenizer=tokenizer, + framework='pt', + device='cpu', #crucial + decoder=config.decoder, + decoder_kwargs={ + "beam_width": config.decoder_beam_width, + "token_min_logp": config.decoder_token_min_logp + } +) + +# Process audio +AUDIO_PATH = "" +waveform, sample_rate = torchaudio.load(AUDIO_PATH, normalize=False) +#print(waveform) +waveform = np.array(waveform.to(torch.float).reshape(-1)) + +chunk_size_samples = feature_extractor.chunk_length * sample_rate +for idx in range(0,len(waveform),chunk_size_samples): + print(transcribe(waveform[idx: idx+chunk_size_samples])['text'])