diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index 691551ad9393..5ee43deb0620 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -41,7 +41,7 @@ jobs: - script: L0_Unit_Tests_CPU_ASR runner: azure-gpu-vm-runner1-cpu cpu-only: true - timeout: 20 + timeout: 30 - script: L0_Unit_Tests_GPU_TTS runner: self-hosted-azure-gpus-1 - script: L0_Unit_Tests_CPU_TTS @@ -129,6 +129,8 @@ jobs: script: L2_Speech_Transcription_Speech_to_Text_Streaming_Infer - runner: self-hosted-azure script: L2_Speech_Transcription_Speech_to_Text_Cache_Aware_Infer + - runner: self-hosted-azure + script: L2_Speech_Transcription_Streaming_Inference - runner: self-hosted-azure script: L2_Speech_Transcription_Canary_Transcribe_Full_Manifest - runner: self-hosted-azure diff --git a/examples/asr/asr_chunked_inference/README.md b/examples/asr/asr_chunked_inference/README.md index fec2e2901c18..f65a0b793a63 100644 --- a/examples/asr/asr_chunked_inference/README.md +++ b/examples/asr/asr_chunked_inference/README.md @@ -13,3 +13,4 @@ On the other hand, if you increase your chunk size, then the delay between spoke ## Chunked Inference For MultitaskAED models, we provide a script to perform chunked inference. This script will split the input audio into non-overlapping chunks and perform inference on each chunk. The script will then concatenate the results to provide the final transcript. + diff --git a/examples/asr/asr_streaming_inference/README.md b/examples/asr/asr_streaming_inference/README.md new file mode 100644 index 000000000000..c60bb042fbc4 --- /dev/null +++ b/examples/asr/asr_streaming_inference/README.md @@ -0,0 +1,11 @@ +# Universal Streaming Inference + +The `asr_streaming_infer.py` script enables streaming inference for both buffered (CTC/RNNT/TDT) and cache-aware (CTC/RNNT) ASR models. It supports processing a single audio file, a directory of audio files, or a manifest file. + +Beyond streaming ASR, the script also supports: + +* **Inverse Text Normalization (ITN)** +* **End-of-Utterance (EoU) Detection** +* **Word-level and Segment-level Output** + +All related configurations can be found in the `../conf/asr_streaming_inference/` directory. \ No newline at end of file diff --git a/examples/asr/asr_streaming_inference/asr_streaming_infer.py b/examples/asr/asr_streaming_inference/asr_streaming_infer.py new file mode 100644 index 000000000000..7318cf5a3bb5 --- /dev/null +++ b/examples/asr/asr_streaming_inference/asr_streaming_infer.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script serves as the entry point for local ASR inference, supporting buffered CTC/RNNT/TDT and cache-aware CTC/RNNT inference. + +The script performs the following steps: + (1) Accepts as input a single audio file, a directory of audio files, or a manifest file. + - Note: Input audio files must be 16 kHz, mono-channel WAV files. + (2) Creates a pipeline object to perform inference. + (3) Runs inference on the input audio files. + (4) Writes the transcriptions to an output json/jsonl file. Word/Segment level output is written to a separate JSON file. + +Example usage: +python asr_streaming_infer.py \ + --config-path=../conf/asr_streaming_inference/ \ + --config-name=config.yaml \ + audio_file= \ + output_filename= \ + lang=en \ + enable_pnc=False \ + enable_itn=True \ + asr_output_granularity=segment \ + ... + # See ../conf/asr_streaming_inference/*.yaml for all available options + +Note: + The output file is a json file with the following structure: + {"audio_filepath": "path/to/audio/file", "text": "transcription of the audio file", "json_filepath": "path/to/json/file"} +""" + + +from time import time + +import hydra + + +from nemo.collections.asr.inference.factory.pipeline_builder import PipelineBuilder +from nemo.collections.asr.inference.utils.manifest_io import calculate_duration, dump_output, get_audio_filepaths +from nemo.collections.asr.inference.utils.progressbar import TQDMProgressBar +from nemo.utils import logging + +# disable nemo_text_processing logging +try: + from nemo_text_processing.utils import logger as nemo_text_logger + + nemo_text_logger.propagate = False +except ImportError: + # NB: nemo_text_processing requires pynini, which is tricky to install on MacOS + # since nemo_text_processing is not necessary for ASR, wrap the import + logging.warning("NeMo text processing library is unavailable.") + + +@hydra.main(version_base=None) +def main(cfg): + + # Set the logging level + logging.setLevel(cfg.log_level) + + # Reading audio filepaths + audio_filepaths = get_audio_filepaths(cfg.audio_file, sort_by_duration=True) + logging.info(f"Found {len(audio_filepaths)} audio files") + + # Build the pipeline + pipeline = PipelineBuilder.build_pipeline(cfg) + progress_bar = TQDMProgressBar() + + # Run the pipeline + start = time() + output = pipeline.run(audio_filepaths, progress_bar=progress_bar) + exec_dur = time() - start + + # Calculate RTFX + data_dur = calculate_duration(audio_filepaths) + rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf') + logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)") + + # Dump the transcriptions to a output file + dump_output(output, cfg.output_filename, cfg.output_dir) + logging.info(f"Transcriptions written to {cfg.output_filename}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/examples/asr/conf/asr_streaming_inference/buffered_ctc.yaml b/examples/asr/conf/asr_streaming_inference/buffered_ctc.yaml new file mode 100644 index 000000000000..98ab4bebbd05 --- /dev/null +++ b/examples/asr/conf/asr_streaming_inference/buffered_ctc.yaml @@ -0,0 +1,80 @@ +# ================================ +# ASR Configuration +# ================================ +asr: + model_name: nvidia/parakeet-ctc-1.1b # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path + device: cuda # Device for inference: 'cuda' or 'cpu' + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' + use_amp: false # Enable Automatic Mixed Precision + + +# ========================================== +# Inverse Text Normalization Configuration +# ========================================== +itn: + input_case: lower_cased # Input text case handling: 'lower_cased', 'cased' + whitelist: null # Custom whitelist for ITN processing + overwrite_cache: false # Whether to overwrite existing cache files + max_number_of_permutations_per_split: 729 # Maximum permutations allowed per text split during ITN processing + left_padding_size: 4 # Padding size (#spans) for ITN context + batch_size: 32 # Batch size for ITN inference + n_jobs: 16 # Number of parallel jobs for ITN processing + + +# ======================== +# Confidence estimation +# ======================== +confidence: + exclude_blank: true # Exclude blank tokens when calculating confidence + aggregation: mean # Aggregation method for confidence across time steps + method_cfg: + name: entropy # Confidence estimation method: 'max_prob' or 'entropy' + entropy_type: tsallis + alpha: 0.5 + entropy_norm: exp + + +# ======================== +# Endpointing settings +# ======================== +endpointing: + stop_history_eou: 800 # Time window (ms) for evaluating EoU + residue_tokens_at_end: 2 # Number of residual tokens used for EoU + + +# ======================== +# Streaming configuration +# ======================== +streaming: + sample_rate: 16000 # Audio sample rate in Hz + batch_size: 256 # Number of audio frames per batch + left_padding_size: 1.6 # Left padding duration in seconds + right_padding_size: 1.6 # Right padding duration in seconds + chunk_size: 4.8 # Audio chunk size in seconds + word_boundary_tolerance: 4 # Tolerance for word boundaries + request_type: feature_buffer # Type of request: frame or feature_buffer + padding_mode: right # Padding mode: left or right. How to pad frames to match the required buffer length + + +# ======================== +# Pipeline settings +# ======================== +matmul_precision: high # Matrix multiplication precision: highest, high, medium +log_level: 20 # Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) +pipeline_type: buffered # Pipeline type: buffered, cache_aware +asr_decoding_type: ctc # Decoding method: ctc or rnnt + + +# ======================== +# Runtime arguments defined at runtime via command line +# ======================== +audio_file: null # Path to audio file, directory, or manifest JSON +output_filename: null # Path to output transcription JSON file +output_dir: null # Directory to save time-aligned output +enable_pnc: false # Whether to apply punctuation & capitalization +enable_itn: false # Whether to apply inverse text normalization +asr_output_granularity: segment # Output granularity: word or segment +cache_dir: null # Directory to store cache (e.g., .far files) +lang: null # Language code for ASR model +return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer diff --git a/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml new file mode 100644 index 000000000000..2a53542cfdaa --- /dev/null +++ b/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml @@ -0,0 +1,83 @@ +# ================================ +# ASR Configuration +# ================================ +asr: + model_name: nvidia/parakeet-rnnt-1.1b # Pre-trained RNNT/hybrid model from NGC/HuggingFace or local .nemo file path + device: cuda # Device for inference: 'cuda' or 'cpu' + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' + use_amp: false # Enable Automatic Mixed Precision + ngram_lm_model: "" # Path to ngram language model + ngram_lm_alpha: 0.0 # Alpha for language model + + +# ========================================== +# Inverse Text Normalization Configuration +# ========================================== +itn: + input_case: lower_cased # Input text case handling: 'lower_cased', 'cased' + whitelist: null # Custom whitelist for ITN processing + overwrite_cache: false # Whether to overwrite existing cache files + max_number_of_permutations_per_split: 729 # Maximum permutations allowed per text split during ITN processing + left_padding_size: 4 # Padding size (#spans) for ITN context + batch_size: 32 # Batch size for ITN inference + n_jobs: 16 # Number of parallel jobs for ITN processing + + +# ======================== +# Confidence estimation +# ======================== +confidence: + exclude_blank: true # Exclude blank tokens when calculating confidence + aggregation: mean # Aggregation method for confidence across time steps + method_cfg: + name: entropy # Confidence estimation method: 'max_prob' or 'entropy' + entropy_type: tsallis + alpha: 0.5 + entropy_norm: exp + + +# ======================== +# Endpointing settings +# ======================== +endpointing: + stop_history_eou: 800 # Time window (ms) for evaluating EoU + residue_tokens_at_end: 2 # Number of residual tokens used for EoU + + +# ======================== +# Streaming configuration +# ======================== +streaming: + sample_rate: 16000 # Audio sample rate in Hz + batch_size: 256 # Number of audio frames per batch + left_padding_size: 1.6 # Left padding duration in seconds + right_padding_size: 1.6 # Right padding duration in seconds + chunk_size: 4.8 # Audio chunk size in seconds + word_boundary_tolerance: 4 # Tolerance for word boundaries + request_type: feature_buffer # Type of request: frame or feature_buffer + stateful: true # Whether to use stateful processing + padding_mode: right # Padding mode: left or right. How to pad frames to match the required buffer length + + +# ======================== +# Pipeline settings +# ======================== +matmul_precision: high # Matrix multiplication precision: highest, high, medium +log_level: 20 # Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) +pipeline_type: buffered # Pipeline type: buffered, cache_aware +asr_decoding_type: rnnt # Decoding method: ctc or rnnt + + +# ======================== +# Runtime arguments defined at runtime via command line +# ======================== +audio_file: null # Path to audio file, directory, or manifest JSON +output_filename: null # Path to output transcription JSON file +output_dir: null # Directory to save time-aligned output +enable_pnc: false # Whether to apply punctuation & capitalization +enable_itn: false # Whether to apply inverse text normalization +asr_output_granularity: segment # Output granularity: word or segment +cache_dir: null # Directory to store cache (e.g., .far files) +lang: null # Language code for ASR model +return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml new file mode 100644 index 000000000000..534ec322e645 --- /dev/null +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml @@ -0,0 +1,80 @@ +# ================================ +# ASR Configuration +# ================================ +asr: + model_name: stt_en_fastconformer_hybrid_large_streaming_multi # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path + device: cuda # Device for inference: 'cuda' or 'cpu' + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' + use_amp: true # Enable Automatic Mixed Precision + + +# ========================================== +# Inverse Text Normalization Configuration +# ========================================== +itn: + input_case: lower_cased # Input text case handling: 'lower_cased', 'cased' + whitelist: null # Custom whitelist for ITN processing + overwrite_cache: false # Whether to overwrite existing cache files + max_number_of_permutations_per_split: 729 # Maximum permutations allowed per text split during ITN processing + left_padding_size: 4 # Padding size (#spans) for ITN context + batch_size: 32 # Batch size for ITN inference + n_jobs: 16 # Number of parallel jobs for ITN processing + + +# ======================== +# Confidence estimation +# ======================== +confidence: + exclude_blank: true # Exclude blank tokens when calculating confidence + aggregation: mean # Aggregation method for confidence across time steps + method_cfg: + name: entropy # Confidence estimation method: 'max_prob' or 'entropy' + entropy_type: tsallis + alpha: 0.5 + entropy_norm: exp + + +# ======================== +# Endpointing settings +# ======================== +endpointing: + stop_history_eou: 800 # Time window (ms) for evaluating EoU + residue_tokens_at_end: 2 # Number of residual tokens used for EoU + + +# ======================== +# Streaming configuration +# ======================== +streaming: + sample_rate: 16000 # Audio sample rate in Hz + batch_size: 256 # Number of audio frames per batch + word_boundary_tolerance: 4 # Tolerance for word boundaries + att_context_size: [70,13] # Attention context size: [70,13],[70,6],[70,1],[70,0] + use_cache: true # Whether to use cache for streaming + use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer + chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames. + request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming + num_slots: 1024 # Number of slots in the context manager: must be >= batch_size + + +# ======================== +# Pipeline settings +# ======================== +matmul_precision: high # Matrix multiplication precision: highest, high, medium +log_level: 20 # Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) +pipeline_type: cache_aware # Pipeline type: buffered, cache_aware +asr_decoding_type: ctc # Decoding method: ctc or rnnt + +# ======================== +# Runtime arguments defined at runtime via command line +# ======================== +audio_file: null # Path to audio file, directory, or manifest JSON +output_filename: null # Path to output transcription JSON file +output_dir: null # Directory to save time-aligned output +enable_pnc: false # Whether to apply punctuation & capitalization +enable_itn: false # Whether to apply inverse text normalization +asr_output_granularity: segment # Output granularity: word or segment +cache_dir: null # Directory to store cache (e.g., .far files) +lang: null # Language code for ASR model +return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml new file mode 100644 index 000000000000..285c8e7533f3 --- /dev/null +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -0,0 +1,81 @@ +# ================================ +# ASR Configuration +# ================================ +asr: + model_name: stt_en_fastconformer_hybrid_large_streaming_multi # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path + device: cuda # Device for inference: 'cuda' or 'cpu' + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' + use_amp: true # Enable Automatic Mixed Precision + + +# ========================================== +# Inverse Text Normalization Configuration +# ========================================== +itn: + input_case: lower_cased # Input text case handling: 'lower_cased', 'cased' + whitelist: null # Custom whitelist for ITN processing + overwrite_cache: false # Whether to overwrite existing cache files + max_number_of_permutations_per_split: 729 # Maximum permutations allowed per text split during ITN processing + left_padding_size: 4 # Padding size (#spans) for ITN context + batch_size: 32 # Batch size for ITN inference + n_jobs: 16 # Number of parallel jobs for ITN processing + + +# ======================== +# Confidence estimation +# ======================== +confidence: + exclude_blank: true # Exclude blank tokens when calculating confidence + aggregation: mean # Aggregation method for confidence across time steps + method_cfg: + name: entropy # Confidence estimation method: 'max_prob' or 'entropy' + entropy_type: tsallis + alpha: 0.5 + entropy_norm: exp + + +# ======================== +# Endpointing settings +# ======================== +endpointing: + stop_history_eou: 800 # Time window (ms) for evaluating EoU + residue_tokens_at_end: 2 # Number of residual tokens used for EoU + + +# ======================== +# Streaming configuration +# ======================== +streaming: + sample_rate: 16000 # Audio sample rate in Hz + batch_size: 256 # Number of audio frames per batch + word_boundary_tolerance: 4 # Tolerance for word boundaries + att_context_size: [70,13] # Attention context size: [70,13],[70,6],[70,1],[70,0] + use_cache: true # Whether to use cache for streaming + use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer + chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames. + request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming + num_slots: 1024 # Number of slots in the context manager: must be >= batch_size + + +# ======================== +# Pipeline settings +# ======================== +matmul_precision: high # Matrix multiplication precision: highest, high, medium +log_level: 20 # Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) +pipeline_type: cache_aware # Pipeline type: buffered, cache_aware +asr_decoding_type: rnnt # Decoding method: ctc or rnnt + + +# ======================== +# Runtime arguments defined at runtime via command line +# ======================== +audio_file: null # Path to audio file, directory, or manifest JSON +output_filename: null # Path to output transcription JSON file +output_dir: null # Directory to save time-aligned output +enable_pnc: false # Whether to apply punctuation & capitalization +enable_itn: false # Whether to apply inverse text normalization +asr_output_granularity: segment # Output granularity: word or segment +cache_dir: null # Directory to store cache (e.g., .far files) +lang: null # Language code for ASR model +return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer diff --git a/nemo/collections/asr/inference/__init__.py b/nemo/collections/asr/inference/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/factory/__init__.py b/nemo/collections/asr/inference/factory/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/factory/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/factory/base_builder.py b/nemo/collections/asr/inference/factory/base_builder.py new file mode 100644 index 000000000000..737556f18e6e --- /dev/null +++ b/nemo/collections/asr/inference/factory/base_builder.py @@ -0,0 +1,140 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from omegaconf import open_dict +from omegaconf.dictconfig import DictConfig + +from nemo.collections.asr.inference.model_wrappers.asr_inference_wrapper import ASRInferenceWrapper +from nemo.collections.asr.inference.model_wrappers.cache_aware_ctc_inference_wrapper import ( + CacheAwareCTCInferenceWrapper, +) +from nemo.collections.asr.inference.model_wrappers.cache_aware_rnnt_inference_wrapper import ( + CacheAwareRNNTInferenceWrapper, +) +from nemo.collections.asr.inference.model_wrappers.ctc_inference_wrapper import CTCInferenceWrapper +from nemo.collections.asr.inference.model_wrappers.rnnt_inference_wrapper import RNNTInferenceWrapper +from nemo.collections.asr.inference.utils.enums import ASRDecodingType, PipelineType +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.utils import logging + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +class BaseBuilder: + """ + Base Builder class. + Builds the ASR/ITN components. + Derived classes should implement the `build` method which should include the logic of creating concrete pipeline. + """ + + @classmethod + def _build_asr(cls, cfg: DictConfig, decoding_cfg: CTCDecodingConfig | RNNTDecodingConfig) -> ASRInferenceWrapper: + """ + Build the ASR model based on the config. + Args: + cfg: (DictConfig) Config + decoding_cfg: (CTCDecodingConfig | RNNTDecodingConfig) Decoding config + Returns: + (ASRInferenceWrapper) ASR inference model + """ + + asr_decoding_type = ASRDecodingType.from_str(cfg.asr_decoding_type) + pipeline_type = PipelineType.from_str(cfg.pipeline_type) + match (asr_decoding_type, pipeline_type): + case (ASRDecodingType.CTC, PipelineType.BUFFERED): + asr_class = CTCInferenceWrapper + case (ASRDecodingType.RNNT, PipelineType.BUFFERED): + asr_class = RNNTInferenceWrapper + case (ASRDecodingType.CTC, PipelineType.CACHE_AWARE): + asr_class = CacheAwareCTCInferenceWrapper + case (ASRDecodingType.RNNT, PipelineType.CACHE_AWARE): + asr_class = CacheAwareRNNTInferenceWrapper + case _: + raise ValueError( + f"Wrong combination of ASR decoding type and pipeline type: {asr_decoding_type, pipeline_type}" + ) + + asr_model = asr_class( + model_name=cfg.asr.model_name, + decoding_cfg=decoding_cfg, + device=cfg.asr.device, + device_id=cfg.asr.device_id, + compute_dtype=cfg.asr.compute_dtype, + use_amp=cfg.asr.use_amp, + ) + + logging.info(f"ASR model `{cfg.asr.model_name}` loaded") + return asr_model + + @classmethod + def _build_itn(cls, cfg: DictConfig, input_is_lower_cased: bool) -> AlignmentPreservingInverseNormalizer | None: + """ + Build the ITN model based on the config. + Args: + cfg: (DictConfig) Config + input_is_lower_cased: (bool) Whether the input is lower cased + Returns: + (AlignmentPreservingInverseNormalizer | None) ITN model + """ + itn_model = None + if cfg.enable_itn: + # Do not remove this import. It is used to avoid nemo_text_processing import when verbatim transcripts is enabled. + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + input_case = ( + AlignmentPreservingInverseNormalizer.LOWER_CASED + if input_is_lower_cased + else AlignmentPreservingInverseNormalizer.UPPER_CASED + ) + + target_lang = getattr(cfg, "lang", getattr(cfg, "target_lang", None)) + if target_lang is None: + raise ValueError("Language is not specified. Cannot load PnC model.") + + itn_cfg = cfg.itn + with open_dict(itn_cfg): + itn_cfg.lang = target_lang + itn_cfg.input_case = input_case + itn_cfg.cache_dir = cfg.cache_dir + + itn_model = AlignmentPreservingInverseNormalizer( + lang=itn_cfg.lang, + input_case=itn_cfg.input_case, + whitelist=itn_cfg.whitelist, + cache_dir=itn_cfg.cache_dir, + overwrite_cache=itn_cfg.overwrite_cache, + max_number_of_permutations_per_split=itn_cfg.max_number_of_permutations_per_split, + ) + logging.info(f"Built inverse text normalizer with the input case: `{input_case}`.") + + if itn_model is not None: + logging.info("ITN model loaded") + return itn_model + + @classmethod + def build(cls, cfg: DictConfig) -> Any: + """ + Build the pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns object responsible for the inference + """ + raise NotImplementedError("This method should be implemented in subclasses.") diff --git a/nemo/collections/asr/inference/factory/buffered_pipeline_builder.py b/nemo/collections/asr/inference/factory/buffered_pipeline_builder.py new file mode 100644 index 000000000000..d875d5833206 --- /dev/null +++ b/nemo/collections/asr/inference/factory/buffered_pipeline_builder.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.dictconfig import DictConfig + +from nemo.collections.asr.inference.factory.base_builder import BaseBuilder +from nemo.collections.asr.inference.pipelines.buffered_ctc_pipeline import BufferedCTCPipeline +from nemo.collections.asr.inference.pipelines.buffered_rnnt_pipeline import BufferedRNNTPipeline +from nemo.collections.asr.inference.utils.enums import ASRDecodingType +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.utils import logging + + +class BufferedPipelineBuilder(BaseBuilder): + """ + Buffered Pipeline Builder class. + Builds the buffered CTC/RNNT/TDT pipelines. + """ + + @classmethod + def build(cls, cfg: DictConfig) -> BufferedRNNTPipeline | BufferedCTCPipeline: + """ + Build the buffered streaming pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns BufferedRNNTPipeline or BufferedCTCPipeline object + """ + asr_decoding_type = ASRDecodingType.from_str(cfg.asr_decoding_type) + + if asr_decoding_type is ASRDecodingType.RNNT: + return cls.build_buffered_rnnt_pipeline(cfg) + elif asr_decoding_type is ASRDecodingType.CTC: + return cls.build_buffered_ctc_pipeline(cfg) + + raise ValueError("Invalid asr decoding type for buffered streaming. Need to be one of ['CTC', 'RNNT']") + + @classmethod + def get_rnnt_decoding_cfg(cls, cfg: DictConfig) -> RNNTDecodingConfig: + """ + Get the decoding config for the RNNT pipeline. + Returns: + (RNNTDecodingConfig) Decoding config + """ + decoding_cfg = RNNTDecodingConfig() + + # greedy_batch decoding strategy required for stateless streaming + decoding_cfg.strategy = "greedy_batch" + + # required to compute the middle token for transducers. + decoding_cfg.preserve_alignments = False + + # temporarily stop fused batch during inference. + decoding_cfg.fused_batch_size = -1 + + # return and write the best hypothesis only + decoding_cfg.beam.return_best_hypothesis = True + + # setup ngram language model + if hasattr(cfg.asr, "ngram_lm_model") and cfg.asr.ngram_lm_model != "": + decoding_cfg.greedy.ngram_lm_model = cfg.asr.ngram_lm_model + decoding_cfg.greedy.ngram_lm_alpha = cfg.asr.ngram_lm_alpha + + return decoding_cfg + + @classmethod + def get_ctc_decoding_cfg(cls) -> CTCDecodingConfig: + """ + Get the decoding config for the CTC pipeline. + Returns: + (CTCDecodingConfig) Decoding config + """ + decoding_cfg = CTCDecodingConfig() + decoding_cfg.strategy = "greedy" + return decoding_cfg + + @classmethod + def build_buffered_rnnt_pipeline(cls, cfg: DictConfig) -> BufferedRNNTPipeline: + """ + Build the RNNT streaming pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns BufferedRNNTPipeline object + """ + # building ASR model + decoding_cfg = cls.get_rnnt_decoding_cfg(cfg) + asr_model = cls._build_asr(cfg, decoding_cfg) + + # building ITN model + itn_model = cls._build_itn(cfg, input_is_lower_cased=True) + + # building RNNT pipeline + rnnt_pipeline = BufferedRNNTPipeline(cfg, asr_model, itn_model) + logging.info(f"`{type(rnnt_pipeline).__name__}` pipeline loaded") + return rnnt_pipeline + + @classmethod + def build_buffered_ctc_pipeline(cls, cfg: DictConfig) -> BufferedCTCPipeline: + """ + Build the CTC buffered streaming pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns BufferedCTCPipeline object + """ + # building ASR model + decoding_cfg = cls.get_ctc_decoding_cfg() + asr_model = cls._build_asr(cfg, decoding_cfg) + + # building ITN model + itn_model = cls._build_itn(cfg, input_is_lower_cased=True) + + # building CTC pipeline + ctc_pipeline = BufferedCTCPipeline(cfg, asr_model, itn_model) + logging.info(f"`{type(ctc_pipeline).__name__}` pipeline loaded") + return ctc_pipeline diff --git a/nemo/collections/asr/inference/factory/cache_aware_pipeline_builder.py b/nemo/collections/asr/inference/factory/cache_aware_pipeline_builder.py new file mode 100644 index 000000000000..c91c964939ce --- /dev/null +++ b/nemo/collections/asr/inference/factory/cache_aware_pipeline_builder.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.dictconfig import DictConfig + +from nemo.collections.asr.inference.factory.base_builder import BaseBuilder +from nemo.collections.asr.inference.pipelines.cache_aware_ctc_pipeline import CacheAwareCTCPipeline +from nemo.collections.asr.inference.pipelines.cache_aware_rnnt_pipeline import CacheAwareRNNTPipeline +from nemo.collections.asr.inference.utils.enums import ASRDecodingType +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.utils import logging + + +class CacheAwarePipelineBuilder(BaseBuilder): + """ + Cache Aware Pipeline Builder class. + Builds the cache aware CTC/RNNT pipelines. + """ + + @classmethod + def build(cls, cfg: DictConfig) -> CacheAwareCTCPipeline | CacheAwareRNNTPipeline: + """ + Build the cache aware streaming pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns CacheAwareCTCPipeline or CacheAwareRNNTPipeline object + """ + asr_decoding_type = ASRDecodingType.from_str(cfg.asr_decoding_type) + + if asr_decoding_type is ASRDecodingType.RNNT: + return cls.build_cache_aware_rnnt_pipeline(cfg) + elif asr_decoding_type is ASRDecodingType.CTC: + return cls.build_cache_aware_ctc_pipeline(cfg) + + raise ValueError("Invalid asr decoding type for cache aware streaming. Need to be one of ['CTC', 'RNNT']") + + @classmethod + def get_rnnt_decoding_cfg(cls) -> RNNTDecodingConfig: + """ + Get the decoding config for the RNNT pipeline. + Returns: + (RNNTDecodingConfig) Decoding config + """ + decoding_cfg = RNNTDecodingConfig() + decoding_cfg.strategy = "greedy_batch" + decoding_cfg.preserve_alignments = False + decoding_cfg.greedy.use_cuda_graph_decoder = False + decoding_cfg.greedy.max_symbols = 10 + decoding_cfg.fused_batch_size = -1 + return decoding_cfg + + @classmethod + def get_ctc_decoding_cfg(cls) -> CTCDecodingConfig: + """ + Get the decoding config for the CTC pipeline. + Returns: + (CTCDecodingConfig) Decoding config + """ + decoding_cfg = CTCDecodingConfig() + decoding_cfg.strategy = "greedy" + decoding_cfg.preserve_alignments = False + return decoding_cfg + + @classmethod + def build_cache_aware_rnnt_pipeline(cls, cfg: DictConfig) -> CacheAwareRNNTPipeline: + """ + Build the cache aware RNNT streaming pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns CacheAwareRNNTPipeline object + """ + # building ASR model + decoding_cfg = cls.get_rnnt_decoding_cfg() + asr_model = cls._build_asr(cfg, decoding_cfg) + + # building ITN model + itn_model = cls._build_itn(cfg, input_is_lower_cased=True) + + # building cache aware RNNT pipeline + ca_rnnt_pipeline = CacheAwareRNNTPipeline(cfg, asr_model, itn_model=itn_model) + logging.info(f"`{type(ca_rnnt_pipeline).__name__}` pipeline loaded") + return ca_rnnt_pipeline + + @classmethod + def build_cache_aware_ctc_pipeline(cls, cfg: DictConfig) -> CacheAwareCTCPipeline: + """ + Build the cache aware CTC streaming pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns CacheAwareCTCPipeline object + """ + # building ASR model + decoding_cfg = cls.get_ctc_decoding_cfg() + asr_model = cls._build_asr(cfg, decoding_cfg) + + # building ITN model + itn_model = cls._build_itn(cfg, input_is_lower_cased=True) + + # building cache aware CTC pipeline + ca_ctc_pipeline = CacheAwareCTCPipeline(cfg, asr_model, itn_model=itn_model) + logging.info(f"`{type(ca_ctc_pipeline).__name__}` pipeline loaded") + return ca_ctc_pipeline diff --git a/nemo/collections/asr/inference/factory/pipeline_builder.py b/nemo/collections/asr/inference/factory/pipeline_builder.py new file mode 100644 index 000000000000..8b4e658ee01d --- /dev/null +++ b/nemo/collections/asr/inference/factory/pipeline_builder.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +import torch +from omegaconf.dictconfig import DictConfig + +from nemo.collections.asr.inference.factory.buffered_pipeline_builder import BufferedPipelineBuilder +from nemo.collections.asr.inference.factory.cache_aware_pipeline_builder import CacheAwarePipelineBuilder +from nemo.collections.asr.inference.utils.enums import PipelineType +from nemo.utils import logging + + +class PipelineBuilder: + """Router for building the pipeline based on the pipeline type.""" + + @staticmethod + def set_matmul_precision(matmul_precision: str) -> None: + """ + Set the matmul precision. + Args: + matmul_precision: (str) Matmul precision: highest, high, medium + """ + choices = ["highest", "high", "medium"] + matmul_precision = matmul_precision.lower() + if matmul_precision not in choices: + raise ValueError(f"Invalid matmul precision: {matmul_precision}. Need to be one of {choices}") + torch.set_float32_matmul_precision(matmul_precision) + logging.info(f"Using matmul precision: {matmul_precision}") + + @staticmethod + def set_log_level(log_level: int) -> None: + """ + Set the logging level. + Args: + log_level: (int) Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) + """ + choices = [0, 10, 20, 30, 40, 50] + if log_level not in choices: + raise ValueError(f"Invalid log level: {log_level}. Need to be one of {choices}") + logging.setLevel(log_level) + + @staticmethod + def build_pipeline(cfg: DictConfig) -> Any: + """ + Build the pipeline based on the config. + Args: + cfg: (DictConfig) Config + Returns: + Returns Pipeline object + """ + PipelineBuilder.set_log_level(cfg.log_level) + PipelineBuilder.set_matmul_precision(cfg.matmul_precision) + pipeline_type = PipelineType.from_str(cfg.pipeline_type) + if pipeline_type is PipelineType.BUFFERED: + builder = BufferedPipelineBuilder + elif pipeline_type is PipelineType.CACHE_AWARE: + builder = CacheAwarePipelineBuilder + else: + raise ValueError(f"Invalid pipeline type: {cfg.pipeline_type}") + + return builder.build(cfg) diff --git a/nemo/collections/asr/inference/itn/__init__.py b/nemo/collections/asr/inference/itn/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/itn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/itn/batch_inverse_normalizer.py b/nemo/collections/asr/inference/itn/batch_inverse_normalizer.py new file mode 100644 index 000000000000..99f52dd55882 --- /dev/null +++ b/nemo/collections/asr/inference/itn/batch_inverse_normalizer.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +from typing import Callable + +from joblib import Parallel, delayed + +from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer +from nemo.collections.asr.inference.utils.text_segment import Word + + +def merge_punctuation_and_itn_tags( + input_words: list[str], + output_words: list[str], + word_alignment: list[tuple], + pnc_words: list[Word], + punct_marks: set, + sep: str, + conf_aggregate_fn: Callable, +) -> list[Word]: + """ + Merge the punctuation marks and ITN tags to the final text. + It will also preserve first letter capitalization, start and end time of the span. + Args: + input_words: (list[str]) List of input words + output_words: (list[str]) List of output words + word_alignment: (list[tuple[list[int], list[int]]]) Word alignment between the input and output words + pnc_words: (list[Word]) List of words with punctuation marks + punct_marks: (set) Punctuation marks + sep: (str) Separator + conf_aggregate_fn: (Callable) Confidence aggregation function + Returns: + (list[Word]) Final words after merging the punctuation marks and ITN tags + """ + assert len(input_words) == len(pnc_words) + spans = [] + for s_idx, t_idx, semiotic_class in word_alignment: + if len(t_idx) == 1 and len(s_idx) == 1 and input_words[s_idx[0]] == output_words[t_idx[0]]: + span = pnc_words[s_idx[0]] + span.semiotic_class = semiotic_class + else: + span_text = sep.join([output_words[i] for i in t_idx]) + last_char = pnc_words[s_idx[-1]].text[-1] + first_char = pnc_words[s_idx[0]].text[0] + + # preserve the first char capitalization + first_word = pnc_words[s_idx[0]].copy() + first_char_is_upper = first_word.text[0].isupper() + first_word.normalize_text_inplace(punct_marks, sep) + if span_text.startswith(first_word.text): + if first_char_is_upper: + span_text = span_text[0].upper() + span_text[1:] + + # preserve the last punctuation mark + if last_char in punct_marks: + span_text += last_char + + # preserve the first punctuation mark + if first_char in punct_marks: + span_text = first_char + span_text + + scores = [pnc_words[i].conf for i in s_idx] + conf = conf_aggregate_fn(scores) if len(scores) > 0 else 0.0 + span = Word( + text=span_text, + start=pnc_words[s_idx[0]].start, + end=pnc_words[s_idx[-1]].end, + semiotic_class=semiotic_class, + conf=conf, + ) + spans.append(span) + return spans + + +class BatchAlignmentPreservingInverseNormalizer: + """ + Batch Alignment Preserving Inverse Text Normalizer. It is used to apply ITN to a batch of texts. + joblib.Parallel is used to parallelize the processing. + """ + + def __init__( + self, + itn_model: AlignmentPreservingInverseNormalizer, + sep: str, + asr_supported_puncts: set[str], + post_word_punctuation: set[str], + conf_aggregate_fn: Callable, + ): + """ + Batch Alignment Preserving Inverse Text Normalizer. It is used to apply ITN to a batch of texts. + Args: + itn_model: (AlignmentPreservingInverseNormalizer) Alignment Preserving Inverse Text Normalizer + sep: (str) Separator + asr_supported_puncts: (Set[str]) Punctuation marks supported by ASR model + post_word_punctuation: (Set[str]) Punctuation marks which usually appear after a word + conf_aggregate_fn: (Callable) Confidence aggregation function + """ + self.itn_model = itn_model + self.sep = sep + self.asr_supported_puncts = asr_supported_puncts + self.conf_aggregate_fn = conf_aggregate_fn + self.punct_marks = self.asr_supported_puncts | post_word_punctuation + + def apply_itn( + self, asr_words: list[Word], pnc_words: list[Word], return_alignment: bool = False + ) -> list[Word] | tuple[list[Word], list]: + """ + Apply Alignment Preserving Inverse Text Normalization. + Args: + asr_words: (list[Word]) List of ASR words + pnc_words: (list[Word]) List of words with punctuation/capitalization + return_alignment: (bool) Flag to return the word alignment + Returns: + (list[Word]) List of words after applying ITN + """ + input_words = [] + for word in asr_words: + word.normalize_text_inplace(self.asr_supported_puncts, self.sep) + input_words.append(word.text) + + input_words, output_words, word_alignment = self.itn_model.get_word_alignment(input_words, sep=self.sep) + spans = merge_punctuation_and_itn_tags( + input_words, output_words, word_alignment, pnc_words, self.punct_marks, self.sep, self.conf_aggregate_fn + ) + + if return_alignment: + # word alignment is needed for streaming inference + return spans, word_alignment + return spans + + def __call__( + self, + asr_words_list: list[list[Word]], + pnc_words_list: list[list[Word]], + itn_params: dict, + return_alignment: bool = False, + ) -> list[list[Word]] | list[tuple]: + """ + Batch Alignment Preserving Inverse Text Normalization. + Args: + asr_words_list: (list[list[Word]]) List of ASR words + pnc_words_list: (list[list[Word]]) List of words with punctuation/capitalization + itn_params: (dict) Parameters for the ITN model + return_alignment: (bool) Flag to return the word alignment + Returns: + (list[list[Word]]) List of words after applying ITN + """ + if len(asr_words_list) == 0: + return [] + + batch_size = itn_params.get("batch_size", 1) + n_texts = len(asr_words_list) + batch_size = min(n_texts, batch_size) + + def process_batch(batch_words, batch_words_with_pnc): + return [ + self.apply_itn(words, words_with_pnc, return_alignment) + for words, words_with_pnc in zip(batch_words, batch_words_with_pnc) + ] + + if n_texts <= 3 * batch_size or n_texts == 1: + # If the number of texts is less than 3 * batch_size, process the batch sequentially + # For small batch size, it is faster to process the batch sequentially + return process_batch(asr_words_list, pnc_words_list) + + n_jobs = itn_params.get("n_jobs", 1) + itn_words_list = Parallel(n_jobs=n_jobs)( + delayed(process_batch)(asr_words_list[i : i + batch_size], pnc_words_list[i : i + batch_size]) + for i in range(0, n_texts, batch_size) + ) + itn_words_list = list(itertools.chain(*itn_words_list)) + return itn_words_list diff --git a/nemo/collections/asr/inference/itn/inverse_normalizer.py b/nemo/collections/asr/inference/itn/inverse_normalizer.py new file mode 100644 index 000000000000..f1041410b70a --- /dev/null +++ b/nemo/collections/asr/inference/itn/inverse_normalizer.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re +from multiprocessing import Manager + +from nemo.collections.asr.inference.utils.itn_utils import ( + DEFAULT_SEMIOTIC_CLASS, + fallback_to_trivial_alignment, + find_tokens, + get_semiotic_class, + split_text, +) +from nemo.utils import logging + +IN_MEM_CACHE = Manager().dict(lock=False) + +try: + import pynini + from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer, Normalizer + from nemo_text_processing.text_normalization.en.graph_utils import INPUT_CASED, INPUT_LOWER_CASED +except ImportError as e: + raise ImportError("Failed to import pynini or nemo_text_processing.") from e + +try: + import diskcache + + CACHING_FROM_DISK = True +except ImportError: + logging.warning("diskcache package is not installed, caching from disk is disabled") + CACHING_FROM_DISK = False + + +class AlignmentPreservingInverseNormalizer: + """ + Inverse Text Normalizer that preserves the word alignment. + It is used to convert the spoken text to written text and preserve the alignment between the input and output words. + """ + + LOWER_CASED = INPUT_LOWER_CASED + UPPER_CASED = INPUT_CASED + GRAMMAR = "itn" + + def __init__( + self, + input_case: str = LOWER_CASED, + lang: str = "en", + whitelist: str = None, + cache_dir: str = None, + overwrite_cache: bool = False, + max_number_of_permutations_per_split: int = 729, + ): + """ + Inverse normalizer that converts text from spoken to written form. + Args: + input_case: Input text capitalization, set to 'cased' if text contains capital letters. + This flag affects normalization rules applied to the text. Note, `lower_cased` won't lower case input. + lang: language specifying the ITN + whitelist: path to a file with whitelist replacements. (each line of the file: written_form\tspoken_form\n), + e.g. nemo_text_processing/inverse_text_normalization/en/data/whitelist.tsv + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + max_number_of_permutations_per_split: a maximum number + of permutations which can be generated from input sequence of tokens. + """ + self.itn_model = InverseNormalizer( + lang=lang, + input_case=input_case, + whitelist=whitelist, + cache_dir=cache_dir, + overwrite_cache=overwrite_cache, + max_number_of_permutations_per_split=max_number_of_permutations_per_split, + ) + if cache_dir and CACHING_FROM_DISK: + self.DISK_TAG_CACHE = diskcache.Cache(os.path.join(cache_dir, "itn_tag_cache")) + self.DISK_VERB_CACHE = diskcache.Cache(os.path.join(cache_dir, "itn_verb_cache")) + self.caching_from_disk_enabled = True + else: + self.DISK_TAG_CACHE = None + self.DISK_VERB_CACHE = None + self.caching_from_disk_enabled = False + + def inverse_normalize_list(self, texts: list[str], params: dict) -> list[str]: + """ + Applies Inverse Text Normalization to the list of texts. + Args: + texts: (list[str]) list of input strings. + params: (dict) dictionary of runtime parameters. + Returns: + (list[str]) Returns converted list of input strings. + """ + normalized_texts = self.itn_model.normalize_list( + texts, + verbose=params.get('verbose', False), + punct_pre_process=params.get("punct_pre_process", False), + punct_post_process=params.get("punct_post_process", False), + batch_size=params.get("batch_size", 1), + n_jobs=params.get("n_jobs", 1), + ) + return normalized_texts + + def verbalize(self, tokens: list, sep: str) -> str | None: + """ + Appplies verbalization to the list of tokens. + Args: + tokens: (list) list of tokens + sep: (str) word separator + Returns: + (str | None) Returns verbalized text. If verbalization fails, returns None. + """ + split_tokens = self.itn_model._split_tokens_to_reduce_number_of_permutations(tokens) + output_str = "" + for s in split_tokens: + try: + tags_reordered = self.itn_model.generate_permutations(s) + verbalizer_lattice = None + for tagged_text_r in tags_reordered: + tagged_text_r = pynini.escape(tagged_text_r) + + verbalizer_lattice = self.itn_model.find_verbalizer(tagged_text_r) + if verbalizer_lattice.num_states() != 0: + break + + if verbalizer_lattice is None: + return None + + verbalized_text = Normalizer.select_verbalizer(verbalizer_lattice) + output_str += sep + verbalized_text + except Exception as e: + logging.warning("Failed to verbalize tagged text: " + str(e)) + return None + + output_str = output_str.strip(sep) + return re.sub(r"({sep})+".format(sep=sep), sep, output_str) + + def tag(self, text: str, no_cache: bool = False) -> str: + """ + Tags the input text. + Args: + text: (str) input text + no_cache: (bool) whether to use cache + Returns: + (str) tagged text + """ + if not no_cache: + # In-memory cache check + if text in IN_MEM_CACHE: + return IN_MEM_CACHE[text] + + # Disk cache check + if self.caching_from_disk_enabled and text in self.DISK_TAG_CACHE: + x = self.DISK_TAG_CACHE[text] + IN_MEM_CACHE[text] = x + return x + + text = text.strip() + if not text: + return text + + text = pynini.escape(text) + tagged_lattice = self.itn_model.find_tags(text) + tagged_text = Normalizer.select_tag(tagged_lattice) + IN_MEM_CACHE[text] = tagged_text + if self.caching_from_disk_enabled: + self.DISK_TAG_CACHE[text] = tagged_text + return tagged_text + + def parse_and_verbalize(self, tagged_text: str, sep: str) -> tuple[str, str]: + """ + Tags and verbalizes the input text. + Args: + tagged_text: (str) tagged input text + sep: (str) word separator + Returns: + (str, str) Returns the verbalized text, and the semiotic class. + """ + + # In-memory cache check + if tagged_text in IN_MEM_CACHE: + return IN_MEM_CACHE[tagged_text] + + # Disk cache check + if self.caching_from_disk_enabled and tagged_text in self.DISK_VERB_CACHE: + x = self.DISK_VERB_CACHE[tagged_text] + IN_MEM_CACHE[tagged_text] = x + return x + + self.itn_model.parser(tagged_text) + tokens = self.itn_model.parser.parse() + span_text = self.verbalize(tokens, sep) + semiotic_class = DEFAULT_SEMIOTIC_CLASS if span_text is None else get_semiotic_class(tokens) + + IN_MEM_CACHE[tagged_text] = (span_text, semiotic_class) + if self.caching_from_disk_enabled: + self.DISK_VERB_CACHE[tagged_text] = (span_text, semiotic_class) + return span_text, semiotic_class + + def find_token_words( + self, token: str, start_idx: int, input_words: list[str], sep: str + ) -> tuple[list[int], bool, int]: + """ + Finds the words that make up the token. + Args: + token: (str) token + start_idx: (int) start index + input_words: (list[str]) list of input words + sep: (str) word separator + Returns: + (tuple) Returns a tuple of indices, success, and the new start index + """ + indices, tmp_text, success = [], "", False + length = len(input_words) + for i in range(start_idx, length): + tmp_text = tmp_text + sep + input_words[i] if tmp_text else input_words[i] + tmp_tagged_text = self.tag(tmp_text) + + if tmp_tagged_text == token: + indices.append(i) + + # Try to extend the token by one word + if i + 1 < length: + extended_tmp_text = tmp_text + sep + input_words[i + 1] + extended_tmp_tagged_text = self.tag(extended_tmp_text) + if extended_tmp_tagged_text == token: + continue + + success = True + break + else: + indices.append(i) + + return indices, success, i + + def find_alignment( + self, + tokens: list[str], + input_words: list[str], + sep: str, + iwords: list[str], + owords: list[str], + word_alignment: list[tuple], + ) -> bool: + """ + Finds the word alignment for the input text. + Args: + tokens: (list[str]) list of tokens + input_words: (list[str]) list of input words + sep: (str) word separator + iwords: (list[str]) list of input words to be updated + owords: (list[str]) list of output words to be updated + word_alignment: (list[tuple]) list of word alignments to be updated + Returns: + (bool) True if the word alignment is found, False otherwise + """ + success = True + token_start_idx = word_start_idx = 0 + iwords_len = owords_len = 0 + + while token_start_idx < len(tokens): + token = tokens[token_start_idx] + current_word = input_words[word_start_idx] + if token == f"tokens {{ name: \"{current_word}\" }}": + word_alignment.append(([iwords_len], [owords_len], DEFAULT_SEMIOTIC_CLASS)) + iwords.append(current_word) + owords.append(current_word) + iwords_len += 1 + owords_len += 1 + else: + indices, success, word_start_idx = self.find_token_words(token, word_start_idx, input_words, sep) + if success: + span_text, semiotic_class = self.parse_and_verbalize(token, sep) + if span_text is None: + logging.warning(f"Failed to verbalize the token: {token}") + return False + + span_words, n_span_words = split_text(span_text, sep) + word_alignment.append( + ( + [iwords_len + i for i in range(len(indices))], + [owords_len + i for i in range(n_span_words)], + semiotic_class, + ) + ) + owords.extend(span_words) + iwords.extend([input_words[i] for i in indices]) + iwords_len += len(indices) + owords_len += n_span_words + else: + success = False + break + + token_start_idx += 1 + word_start_idx += 1 + + return success + + def get_word_alignment(self, input: str | list[str], sep: str) -> tuple[list[str], list[str], list[tuple]]: + """ + Returns a word alignment for the input text. + Args: + input: (str | list[str]) input text or list of input words + sep: (str) word separator + Returns: + (tuple) Returns a tuple of input words, output words, and a word alignment between input and output words + """ + + if isinstance(input, str): + input_text = input + input_words, n_words = split_text(input_text, sep) + else: + input_words, n_words = input, len(input) + input_text = sep.join(input_words) + + # If input_text is empty, return empty lists + if n_words == 0: + return [], [], [] + + # Tag the input text + tagged_text = self.tag(input_text, no_cache=False) + + # Find the tokens in the tagged text + tokens = find_tokens(tagged_text) + + # Find the word alignment + iwords, owords, word_alignment = [], [], [] + success = self.find_alignment( + tokens, input_words, sep, iwords=iwords, owords=owords, word_alignment=word_alignment + ) + + # If the word alignment is not found, fallback to the trivial alignment + if not success: + return fallback_to_trivial_alignment(input_words, i_shift=0, o_shift=0) + + return iwords, owords, word_alignment diff --git a/nemo/collections/asr/inference/model_wrappers/__init__.py b/nemo/collections/asr/inference/model_wrappers/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py new file mode 100644 index 000000000000..323a96e45e5a --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +from functools import cached_property +from typing import Callable + +import torch +from omegaconf import DictConfig, open_dict + +from nemo.collections.asr.inference.utils.constants import SENTENCEPIECE_UNDERSCORE +from nemo.collections.asr.inference.utils.device_utils import setup_device +from nemo.collections.asr.inference.utils.pipeline_utils import make_preprocessor_deterministic +from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.utils.asr_confidence_utils import get_confidence_aggregation_bank + +SUPPORTED_CONFIDENCE_AGGREGATORS = get_confidence_aggregation_bank() + + +class ASRInferenceWrapper: + """ + Base class for ASR inference wrappers. + It provides a common interface for ASR inference wrappers. + Derived classes MUST implement the following methods: + - __post_init__: Additional post initialization steps that must be implemented in the derived classes. + - get_blank_id: Returns the blank id for the model. + - get_vocabulary: Returns the vocabulary for the model. + - get_subsampling_factor: Returns the subsampling factor for the model. + """ + + def __init__( + self, + model_name: str, + decoding_cfg: CTCDecodingConfig | RNNTDecodingConfig, + device: str = 'cuda', + device_id: int = 0, + compute_dtype: str = 'bfloat16', + use_amp: bool = True, + ): + """ + Initialize the ASR inference wrapper. + Args: + model_name: (str) path to the model checkpoint or a model name from the NGC cloud. + decoding_cfg: (CTCDecodingConfig | RNNTDecodingConfig) decoding configuration. + device: (str) device to run the model on. + device_id: (int) device ID to run the model on. + compute_dtype: (str) compute dtype to run the model on. + use_amp: (bool) Use Automatic Mixed Precision + """ + + self.decoding_cfg = decoding_cfg + self.device_str, self.device_id, self.compute_dtype = setup_device(device.strip(), device_id, compute_dtype) + self.device = torch.device(self.device_str) + self.use_amp = use_amp + self.asr_model = self.load_model(model_name, self.device) + self.asr_model_cfg = self.asr_model._cfg + self.set_dither_to_zero() + self.tokenizer = self.asr_model.tokenizer + + # post initialization steps that must be implemented in the derived classes + self.__post_init__() + + @staticmethod + def load_model(model_name: str, map_location: torch.device) -> ASRModel: + """ + Load the ASR model. + Args: + model_name: (str) path to the model checkpoint or a model name from the NGC cloud. + map_location: (torch.device) device to load the model on. + Returns: + (ASRModel) loaded ASR model. + """ + try: + if model_name.endswith('.nemo'): + asr_model = ASRModel.restore_from(model_name, map_location=map_location) + else: + asr_model = ASRModel.from_pretrained(model_name, map_location=map_location) + asr_model.eval() + return asr_model + except Exception as e: + raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") + + @property + def word_separator(self) -> str: + """ + Returns word separator. + Returns: + (str) word separator. + """ + return self.decoding_cfg.word_seperator + + @property + def confidence_aggregator(self) -> Callable: + """ + Returns confidence aggregator function. + Returns: + (Callable) confidence aggregator function. + """ + return SUPPORTED_CONFIDENCE_AGGREGATORS[self.decoding_cfg.confidence_cfg.aggregation] + + def copy_asr_config(self) -> DictConfig: + """ + Copies the ASR model config. + Returns: + (DictConfig) copy of the ASR model configuration. + """ + return copy.deepcopy(self.asr_model_cfg) + + def create_preprocessor(self) -> tuple[Callable, DictConfig]: + """ + Creates a deterministic preprocessor from the ASR model configuration. + Disables normalization, dither and padding. + Returns: + (Callable, DictConfig) deterministic preprocessor and its configuration. + """ + new_asr_config = self.copy_asr_config() + new_asr_config = make_preprocessor_deterministic(new_asr_config) + preprocessor_config = copy.deepcopy(new_asr_config.preprocessor) + preprocessor = ASRModel.from_config_dict(preprocessor_config) + preprocessor.to(self.device) + return preprocessor, preprocessor_config + + def supports_capitalization(self) -> bool: + """ + Checks if the ASR model supports capitalization. + Returns: + (bool) True if the ASR model supports capitalization, False otherwise. + """ + if not hasattr(self, "asr_model") or self.asr_model is None: + raise ValueError("ASR model is not initialized.") + return self.tokenizer.supports_capitalization + + def supports_punctuation(self) -> bool: + """ + Checks if the ASR model supports punctuation. + Returns: + (bool) True if the ASR model supports punctuation, False otherwise. + """ + if not hasattr(self, "asr_model") or self.asr_model is None: + raise ValueError("ASR model is not initialized.") + return self.supported_punctuation() != set() + + def supported_punctuation(self) -> set: + """ + Returns supported punctuation symbol set without single quote. + Returns: + (set) Set of supported punctuation symbols. + """ + return self.tokenizer.supported_punctuation - set("'") + + @cached_property + def punctuation_ids(self) -> set: + """ + Returns ids of supported punctuation symbols. + Returns: + (set) Set of punctuation ids. + """ + punctuation_ids = set() + if self.supports_punctuation(): + for punctuation in self.supported_punctuation(): + punctuation_ids.add(self.tokenizer.tokens_to_ids(punctuation)[0]) + return punctuation_ids + + @cached_property + def underscore_id(self) -> int: + """ + Returns id of the underscore token. + Returns: + (int) underscore id for the model. + """ + if getattr(self.asr_model.tokenizer, "spm_separator_id", None) is not None: + return self.asr_model.tokenizer.spm_separator_id + else: + return self.asr_model.tokenizer.tokens_to_ids(SENTENCEPIECE_UNDERSCORE) + + @cached_property + def language_token_ids(self) -> set: + """ + This property is used for some Riva models that have language tokens included in the vocabulary. + Returns: + (set) Set of language token ids. + """ + vocab = self.get_vocabulary() + language_token_ids = set() + for token in vocab: + if token.startswith("<") and token.endswith(">") and token != "": + language_token_ids.add(self.asr_model.tokenizer.tokens_to_ids(token)[0]) + return language_token_ids + + def reset_decoding_strategy(self, decoder_type: str) -> None: + """ + Reset the decoding strategy for the model. + Args: + decoder_type: (str) decoding type either 'ctc', 'rnnt'. + """ + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + self.asr_model.change_decoding_strategy(decoding_cfg=None, decoder_type=decoder_type) + else: + self.asr_model.change_decoding_strategy(None) + + def set_decoding_strategy(self, decoder_type: str) -> None: + """ + Set the decoding strategy for the model. + Args: + decoder_type: (str) decoding type either 'ctc', 'rnnt'. + """ + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + self.asr_model.change_decoding_strategy(decoding_cfg=self.decoding_cfg, decoder_type=decoder_type) + else: + self.asr_model.change_decoding_strategy(self.decoding_cfg) + + def set_dither_to_zero(self) -> None: + """ + To remove randomness from preprocessor set the dither value to zero. + """ + self.asr_model.preprocessor.featurizer.dither = 0.0 + with open_dict(self.asr_model_cfg): + self.asr_model_cfg.preprocessor.dither = 0.0 + + def get_window_stride(self) -> float: + """ + Get the window stride for the model. + Returns: + (float) window stride for the model. + """ + return self.asr_model_cfg.preprocessor.window_stride + + def get_model_stride(self, in_secs: bool = False, in_milliseconds: bool = False) -> float: + """ + Get the model stride in seconds for the model. + Args: + in_secs: (bool) Whether to return the model stride in seconds. + in_milliseconds: (bool) Whether to return the model stride in milliseconds. + Returns: + (float) model stride in seconds or milliseconds. + """ + if in_secs and in_milliseconds: + raise ValueError("Cannot return both seconds and milliseconds at the same time.") + if in_secs: + return self.get_window_stride() * self.get_subsampling_factor() + if in_milliseconds: + return self.get_window_stride() * self.get_subsampling_factor() * 1000 + + return self.get_window_stride() * self.get_subsampling_factor() + + # Methods that must be implemented in the derived classes. + def __post_init__(self): + """ + Additional post initialization steps that must be implemented in the derived classes. + """ + raise NotImplementedError() + + def get_blank_id(self) -> int: + """ + Returns id of the blank token. + Returns: + (int) blank id for the model. + """ + raise NotImplementedError() + + def get_vocabulary(self) -> list[str]: + """ + Returns the list of vocabulary tokens. + Returns: + (list[str]) list of vocabulary tokens. + """ + raise NotImplementedError() + + def get_subsampling_factor(self) -> int: + """ + Returns the subsampling factor for the model. + Returns: + (int) subsampling factor for the model. + """ + raise NotImplementedError() diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_asr_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_asr_inference_wrapper.py new file mode 100644 index 000000000000..e05c0d029e0b --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_asr_inference_wrapper.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.asr_inference_wrapper import ASRInferenceWrapper + + +class CacheAwareASRInferenceWrapper(ASRInferenceWrapper): + """ + Base class for Cache-Aware inference wrappers. + It provides a common interface for Cache-Aware models. + Derived classes MUST implement the following methods: + - stream_step: Executes a single streaming step. + """ + + def get_input_features(self) -> int: + """ + Returns the number of channels in the input features. + Returns: + (int) number of channels in the input features. + """ + return self.asr_model.encoder._feat_in + + def get_sampling_frames(self) -> list[int] | int | None: + """ + It is used for checking to make sure the audio chunk has enough frames to produce at least one output after downsampling. + Returns: + (list[int] | int | None) sampling frames for the encoder. + """ + self.sampling_frames = None + if hasattr(self.asr_model.encoder, "pre_encode") and hasattr( + self.asr_model.encoder.pre_encode, "get_sampling_frames" + ): + self.sampling_frames = self.asr_model.encoder.pre_encode.get_sampling_frames() + return self.sampling_frames + + def get_initial_cache_state(self, batch_size: int) -> tuple[Tensor, Tensor, Tensor]: + """ + Returns the initial cache state for the encoder. + Returns: + (tuple[Tensor, Tensor, Tensor]) the initial cache state of the encoder. + """ + return self.asr_model.encoder.get_initial_cache_state(batch_size=batch_size) + + def get_drop_extra_pre_encoded(self) -> int: + """ + Returns the number of extra pre-encoded frames to drop. + Returns: + (int) drop_extra_pre_encoded. + """ + return self.asr_model.encoder.streaming_cfg.drop_extra_pre_encoded + + def get_chunk_size(self) -> list[int] | int: + """ + Returns the chunk size for the encoder. + Returns: + (list[int] | int) the chunk size. + """ + return self.asr_model.encoder.streaming_cfg.chunk_size + + def get_shift_size(self) -> list[int] | int: + """ + Returns the shift size for the encoder. + Returns: + (list[int] | int) the shift size. + """ + return self.asr_model.encoder.streaming_cfg.shift_size + + def get_pre_encode_cache_size(self) -> list[int] | int: + """ + Returns the pre-encode cache size for the encoder. + Returns: + (list[int] | int) the pre_encode cache size. + """ + return self.asr_model.encoder.streaming_cfg.pre_encode_cache_size + + def get_subsampling_factor(self) -> int: + """ + Returns the subsampling factor for the ASR encoder. + Returns: + (int) subsampling factor for the ASR encoder model. + """ + return self.asr_model.encoder.subsampling_factor + + def get_att_context_size(self) -> list: + """ + Returns the attention context size for the encoder. + Returns: + (list) copy of the attention context size. + """ + return self.asr_model.encoder.att_context_size.copy() + + def set_default_att_context_size(self, att_context_size: list) -> None: + """ + Set the default attention context size for the encoder. + The list of the supported look-ahead: [[70, 13], [70, 6], [70, 1], [70, 0]] + Args: + att_context_size: (list) the attention context size. + """ + if hasattr(self.asr_model.encoder, "set_default_att_context_size"): + self.asr_model.encoder.set_default_att_context_size(att_context_size=att_context_size) + else: + raise ValueError("Model does not support multiple lookaheads.") + + def setup_streaming_params(self, chunk_size: int, shift_size: int) -> None: + """ + Setup the streaming parameters (chunk_size, shift_size) for the encoder. + Args: + chunk_size: (int) the chunk size. + shift_size: (int) the shift size. + """ + self.asr_model.encoder.setup_streaming_params(chunk_size=chunk_size, shift_size=shift_size) + + def stream_step(self, *args, **kwargs) -> Any: + """ + Executes a single streaming step. + Each derived class must implement this method, with arguments and return types specific to that class. + """ + raise NotImplementedError( + "`stream_step` method is not implemented. It is required for cache-aware transcribers." + ) diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_ctc_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_ctc_inference_wrapper.py new file mode 100644 index 000000000000..d644e2eda14c --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_ctc_inference_wrapper.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import ( + CacheAwareASRInferenceWrapper, +) +from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext +from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder + + +class CacheAwareCTCInferenceWrapper(CacheAwareASRInferenceWrapper): + """ + Provides a unified interface to work with Cache-Aware CTC models. + """ + + def __post_init__(self) -> None: + """ + Additional post initialization step + Checks if the model is a ctc model and sets the decoding strategy to ctc. + """ + + if not isinstance(self.asr_model, (EncDecCTCModel, EncDecHybridRNNTCTCModel)): + raise ValueError( + "Provided model is not a CTC type. You are trying to use a CTC Inference with a non-CTC model." + ) + + if not isinstance(self.asr_model.encoder, StreamingEncoder): + raise NotImplementedError("Encoder of this model does not support streaming!") + + decoder_type = 'ctc' + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + self.asr_model.cur_decoder = decoder_type + + # reset the decoding strategy + self.reset_decoding_strategy(decoder_type) + self.set_decoding_strategy(decoder_type) + + # setup streaming parameters + if self.asr_model.encoder.streaming_cfg is None: + self.asr_model.encoder.setup_streaming_params() + + self.drop_extra_pre_encoded = self.get_drop_extra_pre_encoded() + + def get_blank_id(self) -> int: + """ + Returns id of the blank token. + Returns: + (int) blank id for the model. + """ + if isinstance(self.asr_model, EncDecCTCModel): + blank_id = len(self.asr_model.decoder.vocabulary) + else: + blank_id = len(self.asr_model.ctc_decoder.vocabulary) + return blank_id + + def get_vocabulary(self) -> list[str]: + """ + Returns the list of vocabulary tokens. + Returns: + (list[str]) list of vocabulary tokens. + """ + if isinstance(self.asr_model, EncDecCTCModel): + return self.asr_model.decoder.vocabulary + else: + return self.asr_model.ctc_decoder.vocabulary + + def execute_step( + self, + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext, + drop_extra_pre_encoded: int | None, + keep_all_outputs: bool, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + return_tail_result: bool = False, + ) -> tuple[Tensor, Tensor | None, CacheAwareContext]: + """ + Executes a single streaming step. + Args: + processed_signal: (Tensor) input signal tensor. + processed_signal_length: (Tensor) input signal length tensor. + context: (CacheAwareContext) context object. + drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. + keep_all_outputs: (bool) whether to keep all outputs or not. + drop_left_context: (int | None) number of left context frames to drop. + valid_out_len: (int | None) number of valid output frames. + return_tail_result: (bool) whether to return tail result or not. + Returns: + (tuple[Tensor, Tensor | None, CacheAwareContext]) log probabilities, tail log probabilities and new context. + """ + + ( + encoded, + encoded_len, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + ) = self.asr_model.encoder.cache_aware_stream_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + cache_last_channel=context.cache_last_channel, + cache_last_time=context.cache_last_time, + cache_last_channel_len=context.cache_last_channel_len, + keep_all_outputs=keep_all_outputs, + drop_extra_pre_encoded=drop_extra_pre_encoded, + ) + + if drop_left_context: + # drop left context + encoded = encoded[:, :, drop_left_context:] + + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + all_log_probs = self.asr_model.ctc_decoder(encoder_output=encoded) + else: + all_log_probs = self.asr_model.decoder(encoder_output=encoded) + + tail_log_probs = None + if valid_out_len and not keep_all_outputs: + # drop right context if any + log_probs = all_log_probs[:, :valid_out_len, :] + if return_tail_result: + tail_log_probs = all_log_probs[:, valid_out_len:, :] + else: + log_probs = all_log_probs + + # create a new context + new_context = CacheAwareContext( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + return log_probs, tail_log_probs, new_context + + def stream_step( + self, + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext = None, + drop_extra_pre_encoded: int | None = None, + keep_all_outputs: bool = False, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + return_tail_result: bool = False, + ) -> tuple[Tensor, Tensor | None, CacheAwareContext]: + """ + Executes a single streaming step. + Args: + processed_signal: (Tensor) input signal tensor. + processed_signal_length: (Tensor) input signal length tensor. + context: (CacheAwareContext) context object. + drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. + keep_all_outputs: (bool) whether to keep all outputs or not. + drop_left_context: (int | None) number of left context frames to drop. + valid_out_len: (int | None) number of valid output frames. + return_tail_result: (bool) whether to return tail result or not. + Returns: + (tuple[Tensor, Tensor | None, CacheAwareContext]) log probabilities, tail log probabilities and new context. + """ + + if processed_signal.device != self.device: + processed_signal = processed_signal.to(self.device) + + if processed_signal_length.device != self.device: + processed_signal_length = processed_signal_length.to(self.device) + + if context is None: + # create a dummy context + context = CacheAwareContext() + + with ( + torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), + torch.inference_mode(), + torch.no_grad(), + ): + + log_probs, tail_log_probs, new_context = self.execute_step( + processed_signal, + processed_signal_length, + context, + drop_extra_pre_encoded, + keep_all_outputs, + drop_left_context, + valid_out_len, + return_tail_result, + ) + return log_probs, tail_log_probs, new_context diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py new file mode 100644 index 000000000000..1bc7181d5cb5 --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import ( + CacheAwareASRInferenceWrapper, +) +from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext +from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis + + +class CacheAwareRNNTInferenceWrapper(CacheAwareASRInferenceWrapper): + """ + Provides a unified interface to work with Cache-Aware RNNT models. + """ + + def __post_init__(self) -> None: + """ + Additional post initialization step + Checks if the model is a rnnt model and sets the decoding strategy to rnnt. + """ + if not isinstance(self.asr_model, (EncDecRNNTModel, EncDecHybridRNNTCTCModel)): + raise ValueError( + "Provided model is not a RNNT type. You are trying to use a RNNT Inference with a non-RNNT model." + ) + + if not isinstance(self.asr_model.encoder, StreamingEncoder): + raise NotImplementedError("Encoder of this model does not support streaming!") + + decoder_type = 'rnnt' + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + self.asr_model.cur_decoder = decoder_type + + # reset the decoding strategy + self.reset_decoding_strategy(decoder_type) + self.set_decoding_strategy(decoder_type) + + # setup streaming parameters + if self.asr_model.encoder.streaming_cfg is None: + self.asr_model.encoder.setup_streaming_params() + + self.drop_extra_pre_encoded = self.get_drop_extra_pre_encoded() + + def get_blank_id(self) -> int: + """ + Returns id of the blank token. + Returns: + (int) blank id for the model. + """ + blank_id = len(self.asr_model.joint.vocabulary) + return blank_id + + def get_vocabulary(self) -> list[str]: + """ + Returns the list of vocabulary tokens. + Returns: + (list[str]) list of vocabulary tokens. + """ + return self.asr_model.joint.vocabulary + + def execute_step( + self, + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext, + previous_hypotheses: list[Hypothesis] | None, + drop_extra_pre_encoded: int | None, + keep_all_outputs: bool, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + ) -> tuple[list[Hypothesis], CacheAwareContext]: + """ + Executes a single streaming step. + Args: + processed_signal: (Tensor) input signal tensor. + processed_signal_length: (Tensor) input signal length tensor. + context: (CacheAwareContext) context object. + previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. + drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. + keep_all_outputs: (bool) whether to keep all outputs or not. + drop_left_context: (int | None) number of left context frames to drop. + valid_out_len: (int | None) number of valid output frames. + Returns: + (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. + """ + ( + encoded, + encoded_len, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + ) = self.asr_model.encoder.cache_aware_stream_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + cache_last_channel=context.cache_last_channel, + cache_last_time=context.cache_last_time, + cache_last_channel_len=context.cache_last_channel_len, + keep_all_outputs=keep_all_outputs, + drop_extra_pre_encoded=drop_extra_pre_encoded, + ) + new_context = CacheAwareContext( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + if drop_left_context: + # drop left context + encoded = encoded[:, :, drop_left_context:] + encoded_len = encoded_len - drop_left_context + + if valid_out_len and not keep_all_outputs: + # drop right context if any + encoded = encoded[:, :, :valid_out_len] + encoded_len = torch.ones_like(encoded_len) * valid_out_len + + best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor( + encoded, encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses + ) + return best_hyp, new_context + + def stream_step( + self, + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext = None, + previous_hypotheses: list[Hypothesis] | None = None, + drop_extra_pre_encoded: int | None = None, + keep_all_outputs: bool = False, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + ) -> tuple[list[Hypothesis], CacheAwareContext]: + """ + Executes a single streaming step. + Args: + processed_signal: (Tensor) input signal tensor. + processed_signal_length: (Tensor) input signal length tensor. + context: (CacheAwareContext) context object. + previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. + drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. + keep_all_outputs: (bool) whether to keep all outputs or not. + drop_left_context: (int | None) number of left context frames to drop. + valid_out_len: (int | None) number of valid output frames. + Returns: + (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. + """ + + if processed_signal.device != self.device: + processed_signal = processed_signal.to(self.device) + + if processed_signal_length.device != self.device: + processed_signal_length = processed_signal_length.to(self.device) + + if context is None: + # create a dummy context + context = CacheAwareContext() + + with ( + torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), + torch.inference_mode(), + torch.no_grad(), + ): + + best_hyp, new_context = self.execute_step( + processed_signal, + processed_signal_length, + context, + previous_hypotheses, + drop_extra_pre_encoded, + keep_all_outputs, + drop_left_context, + valid_out_len, + ) + + return best_hyp, new_context diff --git a/nemo/collections/asr/inference/model_wrappers/ctc_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/ctc_inference_wrapper.py new file mode 100644 index 000000000000..64304f6611be --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/ctc_inference_wrapper.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.asr_inference_wrapper import ASRInferenceWrapper +from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel + + +class CTCInferenceWrapper(ASRInferenceWrapper): + """ + Provides a unified interface to work with CTC/Hybrid-CTC models. + """ + + def __post_init__(self) -> None: + """ + Additional post initialization step + Checks if the model is a ctc model and sets the decoding strategy to ctc. + """ + if not isinstance(self.asr_model, (EncDecCTCModel, EncDecHybridRNNTCTCModel)): + raise ValueError( + "Provided model is not a CTC type. You are trying to use a CTC transcriber with a non-CTC model." + ) + + decoder_type = 'ctc' + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + self.asr_model.cur_decoder = decoder_type + + # reset the decoding strategy + self.reset_decoding_strategy(decoder_type) + self.set_decoding_strategy(decoder_type) + + self.cast_dtype = torch.float32 if self.use_amp else self.compute_dtype + self.asr_model.to(self.cast_dtype) + + def get_blank_id(self) -> int: + """ + Returns id of the blank token. + Returns: + (int) blank id for the model. + """ + if isinstance(self.asr_model, EncDecCTCModel): + blank_id = len(self.asr_model.decoder.vocabulary) + else: + blank_id = len(self.asr_model.ctc_decoder.vocabulary) + return blank_id + + def get_vocabulary(self) -> list[str]: + """ + Returns the list of vocabulary tokens. + Returns: + (list[str]) list of vocabulary tokens. + """ + if isinstance(self.asr_model, EncDecCTCModel): + return self.asr_model.decoder.vocabulary + else: + return self.asr_model.ctc_decoder.vocabulary + + def get_subsampling_factor(self) -> int: + """ + Returns the subsampling factor for the ASR encoder. + Returns: + (int) subsampling factor for the ASR encoder model. + """ + return self.asr_model.encoder.subsampling_factor + + def get_logprobs(self, processed_signal: Tensor, processed_signal_length: Tensor) -> Tensor: + """ + Get log probabilities from the model. It is used for streaming inference. + Args: + processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). + processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). + Returns: + (Tensor) log probabilities. Shape is torch.Size([B, T, V+1]). + """ + if processed_signal.device != self.device: + processed_signal = processed_signal.to(self.device) + + if processed_signal_length.device != self.device: + processed_signal_length = processed_signal_length.to(self.device) + + with ( + torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), + torch.inference_mode(), + torch.no_grad(), + ): + + forward_outs = self.asr_model( + processed_signal=processed_signal.to(self.cast_dtype), processed_signal_length=processed_signal_length + ) + + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + encoded, encoded_len = forward_outs + log_probs = self.asr_model.ctc_decoder(encoder_output=encoded.clone()) + else: + log_probs, encoded_len, predictions = forward_outs + return log_probs diff --git a/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py new file mode 100644 index 000000000000..f15b5ff07ef7 --- /dev/null +++ b/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.asr_inference_wrapper import ASRInferenceWrapper +from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel + + +class RNNTInferenceWrapper(ASRInferenceWrapper): + """ + Provides a unified interface to work with RNNT/TDT/Hybrid models. + """ + + def __post_init__(self) -> None: + """ + Additional post initialization step + Checks if the model is a rnnt model and sets the decoding strategy to rnnt. + """ + if not isinstance(self.asr_model, (EncDecRNNTModel, EncDecHybridRNNTCTCModel)): + raise ValueError( + "Provided model is not a RNNT type. You are trying to use a RNNT transcriber with a non-RNNT model." + ) + + decoder_type = 'rnnt' + if isinstance(self.asr_model, EncDecHybridRNNTCTCModel): + self.asr_model.cur_decoder = decoder_type + + # reset the decoding strategy + self.reset_decoding_strategy(decoder_type) + self.set_decoding_strategy(decoder_type) + + self.cast_dtype = torch.float32 if self.use_amp else self.compute_dtype + self.asr_model.to(self.cast_dtype) + + def get_blank_id(self) -> int: + """ + Returns id of the blank token. + Returns: + (int) blank id for the model. + """ + blank_id = len(self.asr_model.joint.vocabulary) + return blank_id + + def get_vocabulary(self) -> list[str]: + """ + Returns the list of vocabulary tokens. + Returns: + (list[str]) list of vocabulary tokens. + """ + return self.asr_model.joint.vocabulary + + def get_subsampling_factor(self) -> int: + """ + Returns the subsampling factor for the ASR encoder. + Returns: + (int) subsampling factor for the ASR encoder model. + """ + return self.asr_model.encoder.subsampling_factor + + def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> tuple[Tensor, Tensor]: + """ + Get encoder output from the model. It is used for streaming inference. + Args: + processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). + processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). + Returns: + (tuple[Tensor, Tensor]) encoder output and encoder output length of shape torch.Size([B, T, D]), torch.Size([B]). + """ + if processed_signal.device != self.device: + processed_signal = processed_signal.to(self.device) + + if processed_signal_length.device != self.device: + processed_signal_length = processed_signal_length.to(self.device) + + with ( + torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), + torch.inference_mode(), + torch.no_grad(), + ): + + forward_outs = self.asr_model( + processed_signal=processed_signal.to(self.cast_dtype), processed_signal_length=processed_signal_length + ) + + encoded, encoded_len = forward_outs + return encoded, encoded_len + + def decode(self, encoded: Tensor, encoded_len: Tensor, partial_hypotheses: list) -> list: + """ + Decode the encoder output using the RNNT decoder. + Args: + encoded: (Tensor) encoder output. + encoded_len: (Tensor) encoder output length. + partial_hypotheses: (list) list of partial hypotheses for stateful decoding. + Returns: + (list) list of best hypotheses. + """ + best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor( + encoded.to(self.cast_dtype), encoded_len, return_hypotheses=True, partial_hypotheses=partial_hypotheses + ) + return best_hyp diff --git a/nemo/collections/asr/inference/pipelines/__init__.py b/nemo/collections/asr/inference/pipelines/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/pipelines/base_pipeline.py b/nemo/collections/asr/inference/pipelines/base_pipeline.py new file mode 100644 index 000000000000..a52a3d22d0be --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/base_pipeline.py @@ -0,0 +1,423 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterable + +from omegaconf import DictConfig + +from nemo.collections.asr.inference.model_wrappers.asr_inference_wrapper import ASRInferenceWrapper +from nemo.collections.asr.inference.pipelines.pipeline_interface import PipelineInterface +from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import BatchedAudioBufferer +from nemo.collections.asr.inference.streaming.buffering.cache_feature_bufferer import BatchedCacheFeatureBufferer +from nemo.collections.asr.inference.streaming.buffering.feature_bufferer import BatchedFeatureBufferer +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request +from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions +from nemo.collections.asr.inference.streaming.state.state import StreamingState +from nemo.collections.asr.inference.streaming.text.text_processing import StreamingTextProcessor +from nemo.collections.asr.inference.utils.bpe_decoder import BPEDecoder +from nemo.collections.asr.inference.utils.context_manager import CacheAwareContextManager +from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.utils.pipeline_utils import ( + check_existance_of_required_attributes, + get_leading_punctuation_regex_pattern, + ids_to_text_without_stripping, +) +from nemo.collections.asr.inference.utils.progressbar import ProgressBar +from nemo.collections.asr.inference.utils.text_segment import TextSegment +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +@dataclass +class TranscribeStepOutput: + """ + Stores the output of a single transcribe step. + """ + + stream_id: int + # Final transcript is the transcript generated started from the previous EoU to the current EoU + # It is finilized transcript, optionally punctuated and ITN-normalized. It's not subject to further modifications. + # Final segments contains metadata for each word/segment in the final transcript. + final_transcript: str | None = None + final_segments: list[TextSegment] | None = None + # Partial transcript is the transcript generated started from the previous EoU up to the current frame + # It is not finilized transcript, it may be subject to further modifications. + # It can also contain transcript from future frames. + partial_transcript: str | None = None + # Current step transcript is the transcript generated from the current frame + current_step_transcript: str | None = None + + @classmethod + def from_state(cls, state: StreamingState, request: Request, sep: str = ' ') -> 'TranscribeStepOutput': + """ + Create a TranscribeStepOutput from a StreamingState + Args: + state (StreamingState): The state to create the output from. + request (Request): The request to create the output from. + sep (str): The separator for the text postprocessor. + Returns: + TranscribeStepOutput: The output for the step. + """ + final_transcript = state.final_transcript.strip() + final_segments = [seg.copy() for seg in state.final_segments] + if final_transcript: + separator = '' + if not request.is_first and state.concat_with_space: + separator = sep + final_transcript = separator + final_transcript + if len(final_segments) > 0: + final_segments[0].text = separator + final_segments[0].text + return cls( + stream_id=request.stream_id, + final_transcript=final_transcript, + final_segments=final_segments, + partial_transcript=state.partial_transcript, + current_step_transcript=state.current_step_transcript, + ) + + +class BasePipeline(PipelineInterface): + """ + Base class for all pipelines. + """ + + def __init__(self): + """Initialize state pool to store the state for each stream""" + self._state_pool: dict[int, StreamingState] = {} + + def get_state(self, stream_id: int) -> StreamingState: + """Retrieve state for a given stream ID.""" + return self._state_pool.get(stream_id, None) + + def get_states(self, stream_ids: Iterable[int]) -> list[StreamingState]: + """Retrieve states for a list of stream IDs.""" + return [self.get_state(stream_id) for stream_id in stream_ids] + + def delete_state(self, stream_id: int) -> None: + """Delete the state from the state pool.""" + if stream_id in self._state_pool: + del self._state_pool[stream_id] + + def delete_states(self, stream_ids: Iterable[int]) -> None: + """Delete states for a list of stream IDs.""" + for stream_id in stream_ids: + self.delete_state(stream_id) + + def init_state(self, stream_id: int, options: ASRRequestOptions) -> StreamingState: + """Initialize the state of the stream""" + if stream_id not in self._state_pool: + state = self.create_state(options) + self._state_pool[stream_id] = state + return self._state_pool[stream_id] + + def reset_session(self) -> None: + """Reset the frame buffer and internal state pool""" + self._state_pool.clear() + + def open_session(self) -> None: + """Start a new session by resetting the internal state pool""" + self.reset_session() + + def close_session(self) -> None: + """Close the session by resetting the internal state pool""" + self.reset_session() + + @abstractmethod + def transcribe_step_for_frames(self, frames: list[Frame]) -> None: + """Transcribe a step for frames""" + pass + + @abstractmethod + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: + """Transcribe a step for feature buffers""" + pass + + @abstractmethod + def get_request_generator(self) -> ContinuousBatchedRequestStreamer: + """Return the request generator.""" + pass + + @abstractmethod + def get_sep(self) -> str: + """Return the separator for the text postprocessor.""" + pass + + def transcribe_step(self, requests: list[Request]) -> list[TranscribeStepOutput]: + """ + Transcribe a step + Args: + requests (list[Request]): List of Request objects. + Returns: + list[TranscribeStepOutput]: List of TranscribeStepOutput objects. + """ + + # Initialize the state if it is the first request for the stream + for request in requests: + if request.is_first: + self.init_state(request.stream_id, request.options) + + # Perform the transcribe step for the frames or feature buffers + if isinstance(requests[0], Frame): + self.transcribe_step_for_frames(frames=requests) + elif isinstance(requests[0], FeatureBuffer): + self.transcribe_step_for_feature_buffers(fbuffers=requests) + else: + raise ValueError(f"Invalid request type: {type(requests[0])}") + + # Create current step output for each request + outputs = [] + for request in requests: + + # Extract current step output from the state + state = self.get_state(request.stream_id) + step_output = TranscribeStepOutput.from_state(state=state, request=request, sep=self.get_sep()) + outputs.append(step_output) + + # Cleanup the state after the response is sent + state.cleanup_after_response() + + # If last request, delete state from the state pool to free memory + if request.is_last: + self.delete_state(request.stream_id) + return outputs + + def copy_asr_model_attributes(self, asr_model: ASRInferenceWrapper) -> None: + """ + Copy the attributes from the ASR model + Args: + asr_model (ASRInferenceWrapper): ASR model to copy the attributes from. + """ + self.asr_model = asr_model + self.tokenizer = asr_model.tokenizer + self.device = asr_model.device + self.supports_punctuation = asr_model.supports_punctuation() + self.asr_supported_puncts = asr_model.supported_punctuation() + self.leading_regex_pattern = get_leading_punctuation_regex_pattern(self.asr_supported_puncts) + self.blank_id = asr_model.get_blank_id() + self.vocabulary = asr_model.get_vocabulary() + self.sep = asr_model.word_separator + self.underscore_id = asr_model.underscore_id + self.punctuation_ids = asr_model.punctuation_ids + self.language_token_ids = asr_model.language_token_ids + self.preprocessor, self.preprocessor_config = asr_model.create_preprocessor() + self.subsampling_factor = asr_model.get_subsampling_factor() + self.window_stride = asr_model.get_window_stride() + self.model_stride_in_secs = asr_model.get_model_stride(in_secs=True) + self.model_stride_in_milliseconds = asr_model.get_model_stride(in_milliseconds=True) + + def update_partial_transcript( + self, requests: list[Request], tokenizer: TokenizerSpec, leading_regex_pattern: str + ) -> None: + """ + Update partial and current step transcripts from the state. + Args: + requests (list[Request]): List of Request objects. + tokenizer (TokenizerSpec): Used to convert tokens into text + leading_regex_pattern (str): Regex pattern for the punctuation marks. + """ + word_separator = self.get_sep() + for request in requests: + state = self.get_state(request.stream_id) + # state tokens represent all tokens accumulated since the EOU + # incomplete segment tokens are the remaining tokens on the right side of the buffer after EOU + all_tokens = state.tokens + state.incomplete_segment_tokens + if len(all_tokens) > 0: + pt_string = ids_to_text_without_stripping(all_tokens, tokenizer, word_separator) + if leading_regex_pattern: + pt_string = re.sub(leading_regex_pattern, r'\1', pt_string) + state.partial_transcript = pt_string + else: + state.partial_transcript = "" + + current_step_tokens = state.current_step_tokens + if len(current_step_tokens) > 0: + step_transcript = ids_to_text_without_stripping(current_step_tokens, tokenizer, word_separator) + state.current_step_transcript = step_transcript + else: + state.current_step_transcript = "" + + def init_bpe_decoder(self) -> None: + """Initialize the BPE decoder""" + check_existance_of_required_attributes( + self, + [ + 'vocabulary', + 'tokenizer', + 'confidence_aggregator', + 'asr_supported_puncts', + 'word_boundary_tolerance', + 'model_stride_in_secs', + ], + ) + + self.bpe_decoder = BPEDecoder( + vocabulary=self.vocabulary, + tokenizer=self.tokenizer, + confidence_aggregator=self.confidence_aggregator, + asr_supported_puncts=self.asr_supported_puncts, + word_boundary_tolerance=self.word_boundary_tolerance, + token_duration_in_secs=self.model_stride_in_secs, + ) + + def init_text_processor( + self, + cfg: DictConfig, + itn_model: AlignmentPreservingInverseNormalizer | None, + ) -> None: + """ + Initialize the text processor. + Args: + cfg: (DictConfig) Configuration parameters. + itn_model: (AlignmentPreservingInverseNormalizer | None) Inverse Text Normalization model. + """ + check_existance_of_required_attributes( + self, + [ + 'asr_supported_puncts', + 'supports_punctuation', + 'confidence_aggregator', + 'sep', + ], + ) + + self.text_processor = StreamingTextProcessor( + itn_cfg=cfg.itn, + itn_model=itn_model, + asr_supported_puncts=self.asr_supported_puncts, + asr_supports_punctuation=self.supports_punctuation, + confidence_aggregator=self.confidence_aggregator, + sep=self.sep, + enable_pnc=cfg.enable_pnc, + enable_itn=cfg.enable_itn, + ) + + def init_bufferer_for_buffered_streaming(self) -> None: + """Initialize the bufferer.""" + check_existance_of_required_attributes( + self, + [ + 'request_type', + 'sample_rate', + 'buffer_size_in_secs', + 'preprocessor_config', + 'device', + ], + ) + + if self.request_type is RequestType.FEATURE_BUFFER: + # Feature buffering: It will be used when the input is feature buffers + self.bufferer = BatchedFeatureBufferer( + sample_rate=self.sample_rate, + buffer_size_in_secs=self.buffer_size_in_secs, + preprocessor_cfg=self.preprocessor_config, + device=self.device, + ) + elif self.request_type is RequestType.FRAME: + # Audio buffering: It will be used when the input is audio frames + self.bufferer = BatchedAudioBufferer( + sample_rate=self.sample_rate, buffer_size_in_secs=self.buffer_size_in_secs + ) + else: + raise ValueError(f"Unknown request type: {self.request_type}") + + def init_bufferer_for_cache_aware_streaming(self) -> None: + """Initialize the bufferer for cache-aware streaming.""" + check_existance_of_required_attributes( + self, + [ + 'use_feat_cache', + 'chunk_size_in_secs', + 'buffer_size_in_secs', + 'sample_rate', + 'preprocessor_config', + 'device', + ], + ) + + if self.use_feat_cache: + # Only calculate mel-spec features for last chunk + chunk_size_for_feature_buffer = self.chunk_size_in_secs + else: + # Calculate mel-spec features for the whole buffer + chunk_size_for_feature_buffer = self.buffer_size_in_secs + + self.bufferer = BatchedCacheFeatureBufferer( + sample_rate=self.sample_rate, + buffer_size_in_secs=self.buffer_size_in_secs, + chunk_size_in_secs=chunk_size_for_feature_buffer, + preprocessor_cfg=self.preprocessor_config, + device=self.device, + ) + + def init_context_manager(self) -> None: + """Initialize the context manager.""" + check_existance_of_required_attributes(self, ['asr_model', 'num_slots', 'use_cache']) + self.context_manager = CacheAwareContextManager( + cache_aware_model=self.asr_model, num_slots=self.num_slots, use_cache=self.use_cache + ) + + def run( + self, + audio_filepaths: list[str], + options: list[ASRRequestOptions] | None = None, + progress_bar: ProgressBar | None = None, + ) -> dict: + """ + Orchestrates reading from audio_filepaths in a streaming manner, + transcribes them, and packs the results into a PipelineOutput. + Args: + audio_filepaths (list[str]): List of audio filepaths to transcribe. + options (list[ASRRequestOptions] | None): List of RequestOptions for each stream. + progress_bar (ProgressBar | None): Progress bar to show the progress. Default is None. + Returns: + dict: A dictionary containing transcriptions and segments for each stream. + """ + if progress_bar is not None and not isinstance(progress_bar, ProgressBar): + raise ValueError("progress_bar must be an instance of ProgressBar.") + + if options is None: + # Use default options if not provided + options = [ASRRequestOptions() for _ in audio_filepaths] + + if len(options) != len(audio_filepaths): + raise ValueError("options must be the same length as audio_filepaths") + + request_generator = self.get_request_generator() + request_generator.set_audio_filepaths(audio_filepaths, options) + request_generator.set_progress_bar(progress_bar) + + pipeline_output = {} + self.open_session() + for requests in request_generator: + step_outputs = self.transcribe_step(requests) + for step_output in step_outputs: + stream_id = step_output.stream_id + if stream_id not in pipeline_output: + pipeline_output[stream_id] = { + "text": "", + "segments": [], + "audio_filepath": request_generator.get_audio_filepath(stream_id), + } + pipeline_output[stream_id]["text"] += step_output.final_transcript + pipeline_output[stream_id]["segments"].extend(step_output.final_segments) + self.close_session() + return pipeline_output diff --git a/nemo/collections/asr/inference/pipelines/buffered_ctc_pipeline.py b/nemo/collections/asr/inference/pipelines/buffered_ctc_pipeline.py new file mode 100644 index 000000000000..1bc8d1f98344 --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/buffered_ctc_pipeline.py @@ -0,0 +1,440 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import torch +from omegaconf import DictConfig +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.ctc_inference_wrapper import CTCInferenceWrapper +from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_ctc_decoder import ClippedCTCGreedyDecoder +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_ctc_endpointing import CTCGreedyEndpointing +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request +from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions +from nemo.collections.asr.inference.streaming.state.ctc_state import CTCStreamingState +from nemo.collections.asr.inference.utils.enums import FeatureBufferPaddingMode, RequestType +from nemo.collections.asr.inference.utils.pipeline_utils import ( + check_existance_of_required_attributes, + drop_trailing_features, + get_confidence_utils, + normalize_features, + normalize_log_probs, +) + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +class BufferedCTCPipeline(BasePipeline): + """Buffered CTC pipeline.""" + + def __init__( + self, + cfg: DictConfig, + asr_model: CTCInferenceWrapper, + itn_model: AlignmentPreservingInverseNormalizer | None = None, + ): + """ + Initialize the BufferedCTCPipeline. + Args: + cfg: (DictConfig) Configuration parameters. + asr_model: (CTCInferenceWrapper) ASR model. + itn_model: (AlignmentPreservingInverseNormalizer | None) Inverse Text Normalization model. + """ + self.copy_asr_model_attributes(asr_model) + self.init_parameters(cfg) + self.init_bufferer_for_buffered_streaming() + self.conf_func, self.confidence_aggregator = get_confidence_utils(cfg.confidence) + self.init_endpointer() + self.init_bpe_decoder() + self.init_greedy_ctc_decoder() + self.init_text_processor(cfg, itn_model) + super().__init__() + + def init_parameters(self, cfg: DictConfig) -> None: + """ + Initialize the configuration parameters. + Args: + cfg: (DictConfig) Configuration parameters. + """ + self.sample_rate = cfg.streaming.sample_rate + self.asr_output_granularity = cfg.asr_output_granularity + self.batch_size = cfg.streaming.batch_size + + self.chunk_size = cfg.streaming.chunk_size + self.left_padding_size = cfg.streaming.left_padding_size + self.right_padding_size = cfg.streaming.right_padding_size + self.buffer_size_in_secs = self.chunk_size + self.left_padding_size + self.right_padding_size + self.expected_feature_buffer_len = int(self.buffer_size_in_secs / self.window_stride) + self.tokens_per_frame_float = self.chunk_size / self.model_stride_in_secs + self.tokens_per_frame = math.ceil(self.tokens_per_frame_float) + self.initial_delay = (self.left_padding_size + self.right_padding_size) / self.model_stride_in_secs + self.mid_delay = math.ceil((self.chunk_size + self.right_padding_size) / self.model_stride_in_secs) + + self.stop_history_eou_in_milliseconds = cfg.endpointing.stop_history_eou + self.residue_tokens_at_end = cfg.endpointing.residue_tokens_at_end + self.request_type = RequestType.from_str(cfg.streaming.request_type) + self.word_boundary_tolerance = cfg.streaming.word_boundary_tolerance + self.padding_mode = FeatureBufferPaddingMode.from_str(cfg.streaming.padding_mode) + self.right_padding = self.padding_mode is FeatureBufferPaddingMode.RIGHT + self.return_tail_result = cfg.return_tail_result + + # Keep small amount of extra padding + self.tail_padding_in_samples = max(int(self.chunk_size * self.sample_rate * 0.45), 6400) + self.zero_log_probs = self.init_zero_log_probs() if self.right_padding else None + + def init_endpointer(self) -> None: + """Initialize the endpointing.""" + check_existance_of_required_attributes( + self, + [ + 'vocabulary', + 'model_stride_in_milliseconds', + 'stop_history_eou_in_milliseconds', + 'residue_tokens_at_end', + ], + ) + + self.endpointer = CTCGreedyEndpointing( + vocabulary=self.vocabulary, + ms_per_timestep=self.model_stride_in_milliseconds, + stop_history_eou=self.stop_history_eou_in_milliseconds, + residue_tokens_at_end=self.residue_tokens_at_end, + ) + + def init_greedy_ctc_decoder(self) -> None: + """Initialize the CTC decoder.""" + check_existance_of_required_attributes(self, ['vocabulary', 'conf_func', 'endpointer', 'tokens_per_frame']) + self.greedy_ctc_decoder = ClippedCTCGreedyDecoder( + vocabulary=self.vocabulary, + conf_func=self.conf_func, + endpointer=self.endpointer, + tokens_per_frame=self.tokens_per_frame, + ) + + def init_zero_log_probs(self) -> Tensor: + """ + Initialize the log probabilities for the zero buffer. + Returns: + (Tensor) Log probabilities for the zero buffer. + """ + check_existance_of_required_attributes( + self, ['asr_model', 'buffer_size_in_secs', 'sample_rate', 'device', 'expected_feature_buffer_len'] + ) + buffer_size_in_samples = int(self.buffer_size_in_secs * self.sample_rate) + zero_buffer = torch.zeros(1, buffer_size_in_samples, device=self.device) + zero_features, zero_features_len = self.preprocess( + buffers=zero_buffer, + buffer_lens=torch.tensor([zero_buffer.shape[1]], device=self.device), + expected_feature_buffer_len=self.expected_feature_buffer_len, + ) + return self.asr_model.get_logprobs(processed_signal=zero_features, processed_signal_length=zero_features_len)[ + 0 + ] + + def create_state(self, options: ASRRequestOptions) -> CTCStreamingState: + """ + Create new empty state. + Args: + options: (ASRRequestOptions) Request options for particular stream. + Returns: + (CTCStreamingState) New empty state. + """ + state = CTCStreamingState() + state.set_global_offset(-self.initial_delay) + new_options = options.augment_with_defaults( + default_enable_itn=self.text_processor.is_itn_enabled(), + default_enable_pnc=self.text_processor.is_pnc_enabled(), + default_stop_history_eou=self.stop_history_eou_in_milliseconds, + default_asr_output_granularity=self.asr_output_granularity, + ) + state.set_options(new_options) + return state + + def get_sep(self) -> str: + """Return the separator for the text processor.""" + return self.sep + + def get_cut_off_range(self, T: int, is_last: bool) -> tuple[int, int]: + """ + Compute the start and end indices to clip the log probs. + Args: + T: (int) Time dimension of the log probabilities. + is_last: (bool) Whether the last frame is reached. + Returns: + (tuple[int, int]) Start and end indices to clip the log probs. + """ + start = max(T - 1 - self.mid_delay, 0) + end = T if is_last else min(start + self.tokens_per_frame, T) + return start, end + + def preprocess( + self, buffers: Tensor, buffer_lens: Tensor, expected_feature_buffer_len: int + ) -> tuple[Tensor, Tensor]: + """ + Preprocess the buffered frames and extract features. + Args: + buffers: (Tensor) Audio buffers. + buffer_lens: (Tensor) Lengths of the audio buffers. + expected_feature_buffer_len: (int) Expected length of the feature buffers. + Returns: + (tuple[Tensor, Tensor]) Processed feature buffers and their lengths. + """ + feature_buffers, feature_buffer_lens = self.preprocessor(input_signal=buffers, length=buffer_lens) + feature_buffers = drop_trailing_features(feature_buffers, expected_feature_buffer_len) + feature_buffers = normalize_features(feature_buffers, feature_buffer_lens) + feature_buffer_lens = feature_buffer_lens.clamp(max=feature_buffers.shape[2]) + return feature_buffers, feature_buffer_lens + + def get_logprobs_given_raw_signals( + self, frames: list[Frame], raw_signals: list[Tensor], left_paddings: list[int] + ) -> Tensor: + """ + Get log probs from the CTC model. + Args: + frames: (list[Frame]) Frames to transcribe. + raw_signals: (list[Tensor]) Audio buffers. + left_paddings: (list[int]) Left paddings for audio buffers. + Returns: + (Tensor) Log probabilities. + """ + + if self.right_padding: + left_paddings = torch.tensor(left_paddings, dtype=torch.int64, device=self.device) + + buffers = [] + for i in range(len(raw_signals)): + buffer = raw_signals[i] + # Roll the buffered frames to the left by the left padding + # This is done to avoid the padding at the beginning of the buffered frames + # which can cause the performance degradation + if self.right_padding: + lpad = left_paddings[i].item() + if lpad > 0: + buffer = buffer.roll(shifts=-lpad) + buffers.append(buffer.unsqueeze_(0)) + + # Only final frames have right padding + # Keep some amount of extra padding to avoid the performance degradation + right_paddings = torch.tensor( + [frame.size - frame.valid_size - self.tail_padding_in_samples for frame in frames], device=self.device + ).clamp(min=0) + + # Create and adjust the buffer lens + buffer_lens = torch.tensor([buffers[0].size(1)] * len(buffers), device=self.device) + buffer_lens = buffer_lens - right_paddings + if self.right_padding: + buffer_lens = buffer_lens - left_paddings + + # Preprocess the buffers with corresponding buffer lens + feature_buffers, feature_buffer_lens = self.preprocess( + buffers=torch.cat(buffers).to(self.device), + buffer_lens=buffer_lens, + expected_feature_buffer_len=self.expected_feature_buffer_len, + ) + + # Get the log probabilities from the ASR model + log_probs = self.asr_model.get_logprobs( + processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens + ).clone() + + # Roll back the log probabilities to the right + if self.right_padding: + for i in range(len(log_probs)): + lpad = left_paddings[i] + if lpad > 0: + lpad = int(lpad / self.sample_rate / self.model_stride_in_secs) + log_probs[i] = log_probs[i].roll(lpad, dims=0) + log_probs[i][:lpad, :] = self.zero_log_probs[:lpad, :] + return log_probs + + def get_logprobs_given_processed_signals( + self, fbuffers: list[FeatureBuffer], processed_signals: list[Tensor] + ) -> Tensor: + """ + Get log probs from the ASR model. + Args: + fbuffers: (list[FeatureBuffer]) Feature buffers. + processed_signals: (list[Tensor]) Processed buffers. + Returns: + (Tensor) Log probabilities. + """ + processed_signals = torch.cat([sig.unsqueeze_(0) for sig in processed_signals]).to(self.device) + processed_signals = drop_trailing_features(processed_signals, self.expected_feature_buffer_len) + processed_signal_lengths = torch.tensor([f.valid_size for f in fbuffers], device=self.device) + processed_signals = normalize_features(processed_signals, processed_signal_lengths) + processed_signal_lengths = processed_signal_lengths.clamp(max=processed_signals.shape[2]) + + log_probs = self.asr_model.get_logprobs( + processed_signal=processed_signals, processed_signal_length=processed_signal_lengths + ).clone() + + if self.right_padding: + for i in range(len(log_probs)): + lpad = int(fbuffers[i].roll_size / self.subsampling_factor) + if lpad > 0: + log_probs[i] = log_probs[i].roll(lpad, dims=0) + log_probs[i][:lpad, :] = self.zero_log_probs[:lpad, :] + return log_probs + + def compute_logprobs_from_frames(self, frames: list[Frame]) -> Tensor: + """ + Buffer the frames and get the log probabilities. + Args: + frames: (list[Frame]) List of frames to transcribe. + Returns: + (Tensor) Log probabilities. + """ + raw_signals, left_paddings = self.bufferer.update(frames) + log_probs = None + if len(raw_signals) > 0: + log_probs = self.get_logprobs_given_raw_signals(frames, raw_signals, left_paddings) + return log_probs + + def compute_logprobs_from_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> Tensor: + """ + Buffer the feature buffers and get the log probabilities. + Args: + fbuffers: (list[FeatureBuffer]) List of feature buffers to transcribe. + Returns: + (Tensor) Log probabilities. + """ + processed_signals = self.bufferer.update(fbuffers) + log_probs = None + if len(processed_signals) > 0: + log_probs = self.get_logprobs_given_processed_signals(fbuffers, processed_signals) + return log_probs + + def run_greedy_decoder( + self, state: CTCStreamingState, request: Request, buffer_log_probs: Tensor, start: int, end: int + ) -> bool: + """ + Run Greedy decoder, update state and trigger EOU detection. + Args: + state: (CTCStreamingState) Current state for the particular stream. + request: (Request) Current request for the particular stream. + buffer_log_probs: (Tensor) Log probabilities. + start: (int) Start index of the log probabilities. + end: (int) End index of the log probabilities. + Returns: + (bool) Whether EOU is detected. + """ + clipped_output, tail_output, eou_detected, start_idx, end_idx = self.greedy_ctc_decoder( + buffer_log_probs, + start, + end, + request.is_last, + is_start=request.is_first, + return_partial_result=self.return_tail_result, + state_start_idx=state.decoder_start_idx, + state_end_idx=state.decoder_end_idx, + stop_history_eou=state.options.stop_history_eou, + compute_confidence=True, + ) + + state.update_state(clipped_output, eou_detected) + state.set_last_token(clipped_output["last_token"], clipped_output["last_token_idx"]) + state.update_from_decoder_results(start_idx, end_idx) + state.increment_global_offset(self.tokens_per_frame_float) + state.set_incomplete_segment_tokens(tail_output["tokens"]) + return eou_detected + + def shared_transcribe_step(self, requests: list[Request], log_probs: Tensor) -> None: + """ + Shared transcribe step for frames and feature buffers. + Args: + requests: (list[Request]) List of frames or feature buffers to transcribe. + log_probs: (Tensor) Log probabilities. + """ + postponed_requests = [(ridx, request.stream_id) for ridx, request in enumerate(requests)] + next_postponed_requests = [] + + while len(postponed_requests) > 0: + + ready_state_ids = set() + for ridx, stream_id in postponed_requests: + + if stream_id in ready_state_ids: + # Skip if the state is already ready + next_postponed_requests.append((ridx, stream_id)) + continue + + request = requests[ridx] + state = self.get_state(stream_id) + lp = log_probs[ridx].cpu() + start, end = self.get_cut_off_range(lp.shape[0], request.is_last) + eou_detected = self.run_greedy_decoder(state, request, lp, start, end) + + if eou_detected: + self.bpe_decoder.decode_bpe_tokens(state) + state.cleanup_after_eou() + ready_state_ids.add(stream_id) + + if len(ready_state_ids) > 0: + self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + ready_state_ids.clear() + + postponed_requests = next_postponed_requests.copy() + next_postponed_requests.clear() + + self.update_partial_transcript(requests, self.tokenizer, self.leading_regex_pattern) + + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: + """ + Transcribe a step for feature buffers. + Args: + fbuffers: (list[FeatureBuffer]) List of feature buffers to transcribe. + """ + log_probs = self.compute_logprobs_from_feature_buffers(fbuffers) + if log_probs is not None: + log_probs = normalize_log_probs(log_probs) + self.shared_transcribe_step(requests=fbuffers, log_probs=log_probs) + + def transcribe_step_for_frames(self, frames: list[Frame]) -> None: + """ + Transcribe step for frames. + Args: + frames: (list[Frame]) List of frames to transcribe. + """ + log_probs = self.compute_logprobs_from_frames(frames) + if log_probs is not None: + log_probs = normalize_log_probs(log_probs) + self.shared_transcribe_step(requests=frames, log_probs=log_probs) + + def get_request_generator(self) -> ContinuousBatchedRequestStreamer: + """ + Initialize the request generator. + Returns: + (ContinuousBatchedRequestStreamer) Request generator. + """ + request_generator = ContinuousBatchedRequestStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size, + sample_rate=self.sample_rate, + batch_size=self.batch_size, + request_type=self.request_type, + preprocessor=self.preprocessor, + buffer_size_in_secs=self.buffer_size_in_secs, + device=self.device, + pad_last_frame=True, + right_pad_features=self.right_padding, + tail_padding_in_samples=self.tail_padding_in_samples, + ) + return request_generator diff --git a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py new file mode 100644 index 000000000000..915063d4d22b --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py @@ -0,0 +1,698 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import torch +from omegaconf import DictConfig +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.rnnt_inference_wrapper import RNNTInferenceWrapper +from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_rnnt_decoder import ClippedRNNTGreedyDecoder +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_rnnt_endpointing import RNNTGreedyEndpointing +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request +from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions +from nemo.collections.asr.inference.streaming.state.rnnt_state import RNNTStreamingState +from nemo.collections.asr.inference.utils.enums import FeatureBufferPaddingMode, RequestType +from nemo.collections.asr.inference.utils.pipeline_utils import ( + adjust_vad_segments, + check_existance_of_required_attributes, + drop_trailing_features, + get_confidence_utils, + normalize_features, + update_punctuation_and_language_tokens_timestamps, +) +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis as NemoHypothesis + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +class BufferedRNNTPipeline(BasePipeline): + """Buffered RNN-T/TDT pipeline.""" + + def __init__( + self, + cfg: DictConfig, + asr_model: RNNTInferenceWrapper, + itn_model: AlignmentPreservingInverseNormalizer | None = None, + ): + """ + Initialize the BufferedRNNTPipeline. + Args: + cfg: (DictConfig) Configuration parameters. + asr_model: (RNNTInferenceWrapper) ASR model. + itn_model: (AlignmentPreservingInverseNormalizer | None) Inverse Text Normalization model. + """ + + self.copy_asr_model_attributes(asr_model) + self.init_parameters(cfg) + self.init_bufferer_for_buffered_streaming() + self.conf_func, self.confidence_aggregator = get_confidence_utils(cfg.confidence) + self.init_endpointer() + self.init_greedy_rnnt_decoder() + self.init_bpe_decoder() + self.init_decoding_computer() + self.init_text_processor(cfg, itn_model) + super().__init__() + + def init_parameters(self, cfg: DictConfig) -> None: + """ + Initialize the configuration parameters. + Args: + cfg: (DictConfig) Configuration parameters. + """ + self.asr_output_granularity = cfg.asr_output_granularity + self.sample_rate = cfg.streaming.sample_rate + self.stateful = cfg.streaming.stateful + self.stateless = not self.stateful + self.batch_size = cfg.streaming.batch_size + + self.chunk_size = cfg.streaming.chunk_size + self.left_padding_size = cfg.streaming.left_padding_size + self.right_padding_size = cfg.streaming.right_padding_size + self.buffer_size_in_secs = self.chunk_size + self.left_padding_size + self.right_padding_size + self.expected_feature_buffer_len = int(self.buffer_size_in_secs / self.window_stride) + + self.mid_delay = math.ceil((self.chunk_size + self.right_padding_size) / self.model_stride_in_secs) + self.tokens_per_frame_float = self.chunk_size / self.model_stride_in_secs + self.tokens_per_left_padding_float = self.left_padding_size / self.model_stride_in_secs + self.tokens_per_right_padding_float = self.right_padding_size / self.model_stride_in_secs + self.tokens_per_frame = math.ceil(self.tokens_per_frame_float) + self.tokens_per_left_padding = math.ceil(self.tokens_per_left_padding_float) + self.tokens_per_right_padding = math.ceil(self.tokens_per_right_padding_float) + + if self.stateful: + self.initial_delay = self.right_padding_size / self.model_stride_in_secs + else: + self.initial_delay = (self.left_padding_size + self.right_padding_size) / self.model_stride_in_secs + + if self.stateful and ( + abs(self.tokens_per_frame_float - self.tokens_per_frame) > 1e-5 + or abs(self.tokens_per_left_padding_float - self.tokens_per_left_padding) > 1e-5 + or abs(self.tokens_per_right_padding_float - self.tokens_per_right_padding) > 1e-5 + ): + self.tokens_per_frame_float = self.tokens_per_frame + self.tokens_per_left_padding_float = self.tokens_per_left_padding + self.left_padding_size = self.tokens_per_left_padding * self.model_stride_in_secs + self.chunk_size = self.tokens_per_frame * self.model_stride_in_secs + self.right_padding_size = self.tokens_per_right_padding * self.model_stride_in_secs + self.buffer_size_in_secs = self.chunk_size + self.left_padding_size + self.right_padding_size + + self.request_type = RequestType.from_str(cfg.streaming.request_type) + self.padding_mode = FeatureBufferPaddingMode.from_str(cfg.streaming.padding_mode) + self.right_padding = self.padding_mode is FeatureBufferPaddingMode.RIGHT + self.stop_history_eou_in_milliseconds = cfg.endpointing.stop_history_eou + self.residue_tokens_at_end = cfg.endpointing.residue_tokens_at_end + self.word_boundary_tolerance = cfg.streaming.word_boundary_tolerance + self.return_tail_result = cfg.return_tail_result + self.tokens_to_move = self.punctuation_ids.union(self.language_token_ids) + + # Keep small amount of extra padding + self.tail_padding_in_samples = max(int(self.chunk_size * self.sample_rate * 0.45), 6400) + self.zero_encoded = self.init_zero_enc() if self.right_padding else None + + def init_endpointer(self) -> None: + """Initialize the endpointer.""" + check_existance_of_required_attributes( + self, + [ + 'stateful', + 'chunk_size', + 'right_padding_size', + 'buffer_size_in_secs', + 'vocabulary', + 'model_stride_in_milliseconds', + 'stop_history_eou_in_milliseconds', + 'residue_tokens_at_end', + ], + ) + + if self.stateful: + effective_buffer_size_in_secs = self.chunk_size + self.right_padding_size + else: + effective_buffer_size_in_secs = self.buffer_size_in_secs + + self.endpointer = RNNTGreedyEndpointing( + vocabulary=self.vocabulary, + ms_per_timestep=self.model_stride_in_milliseconds, + effective_buffer_size_in_secs=effective_buffer_size_in_secs, + stop_history_eou=self.stop_history_eou_in_milliseconds, + residue_tokens_at_end=self.residue_tokens_at_end, + ) + + def init_greedy_rnnt_decoder(self) -> None: + """Initialize the greedy RNNT decoder.""" + check_existance_of_required_attributes(self, ['vocabulary', 'conf_func', 'endpointer', 'tokens_per_frame']) + self.greedy_rnnt_decoder = ClippedRNNTGreedyDecoder( + vocabulary=self.vocabulary, + conf_func=self.conf_func, + endpointer=self.endpointer, + tokens_per_frame=self.tokens_per_frame, + ) + + def init_decoding_computer(self) -> None: + """Initialize the decoding computer.""" + check_existance_of_required_attributes(self, ['stateful', 'asr_model']) + self.decoding_computer = None + if self.stateful: + self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer + + def init_zero_enc(self) -> Tensor: + """ + Initialize the encoder output for the zero buffer. + Returns: + (Tensor) Encoder output for the zero buffer. + """ + check_existance_of_required_attributes( + self, ['buffer_size_in_secs', 'sample_rate', 'device', 'expected_feature_buffer_len'] + ) + buffer_size_in_samples = int(self.buffer_size_in_secs * self.sample_rate) + zero_buffer = torch.zeros(1, buffer_size_in_samples, device=self.device) + zero_features, zero_features_len = self.preprocess( + buffers=zero_buffer, + buffer_lens=torch.tensor([zero_buffer.shape[1]], device=self.device), + expected_feature_buffer_len=self.expected_feature_buffer_len, + ) + zero_encoded, _ = self.asr_model.encode( + processed_signal=zero_features, processed_signal_length=zero_features_len + ) + return zero_encoded[0] + + def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState: + """ + Create new empty state. + Args: + options: (ASRRequestOptions) Request options for particular stream. + Returns: + (RNNTStreamingState) New empty state. + """ + state = RNNTStreamingState() + state.set_global_offset(-self.initial_delay) + new_options = options.augment_with_defaults( + default_enable_itn=self.text_processor.is_itn_enabled(), + default_enable_pnc=self.text_processor.is_pnc_enabled(), + default_stop_history_eou=self.stop_history_eou_in_milliseconds, + default_asr_output_granularity=self.asr_output_granularity, + ) + state.set_options(new_options) + return state + + def get_sep(self) -> str: + """Return the separator for the text processor.""" + return self.sep + + def preprocess( + self, buffers: Tensor, buffer_lens: Tensor, expected_feature_buffer_len: int + ) -> tuple[Tensor, Tensor]: + """ + Preprocess the buffered frames and extract features. + Args: + buffers: (Tensor) Audio buffers. + buffer_lens: (Tensor) Lengths of the audio buffers. + expected_feature_buffer_len: (int) Expected length of the feature buffers. + Returns: + (tuple[Tensor, Tensor]) Processed feature buffers and their lengths. + """ + feature_buffers, feature_buffer_lens = self.preprocessor(input_signal=buffers, length=buffer_lens) + feature_buffers = drop_trailing_features(feature_buffers, expected_feature_buffer_len) + feature_buffers = normalize_features(feature_buffers, feature_buffer_lens) + feature_buffer_lens = feature_buffer_lens.clamp(max=feature_buffers.shape[2]) + return feature_buffers, feature_buffer_lens + + def get_cut_off_range(self, T: int, is_last: bool) -> tuple[int, int]: + """ + Compute the start and end indices to clip. + Args: + T: (int) Time dimension of the alignment. + is_last: (bool) Whether the last frame is reached. + Returns: + (tuple[int, int]) Start and end indices to clip. + """ + start = max(T - 1 - self.mid_delay, 0) + end = T if is_last else min(start + self.tokens_per_frame, T) + return start, end + + def encode_raw_signals( + self, frames: list[Frame], raw_signals: list[Tensor], left_paddings: list[int] + ) -> tuple[Tensor, Tensor]: + """ + Run Encoder part on the audio buffers. + Args: + frames: (list[Frame]) Frames to transcribe. + raw_signals: (list[Tensor]) Audio buffers. + left_paddings: (list[int]) Left paddings for audio buffers. + Returns: + (tuple[Tensor, Tensor]) Encoded signals and their lengths. + """ + + if self.right_padding: + left_paddings = torch.tensor(left_paddings, dtype=torch.int64, device=self.device) + + buffers = [] + for i in range(len(raw_signals)): + buffer = raw_signals[i] + if self.right_padding: + # Roll the buffered frames to the left by the left padding + # This is done to avoid the padding at the beginning of the buffered frames + # which can cause the performance degradation + lpad = left_paddings[i].item() + if lpad > 0: + buffer = buffer.roll(shifts=-lpad) + buffers.append(buffer.unsqueeze_(0)) + + # Only final frames have right padding + # Keep some amount of extra padding to avoid the performance degradation + right_paddings = torch.tensor( + [frame.size - frame.valid_size - self.extra_padding_in_samples for frame in frames], device=self.device + ).clamp(min=0) + + # Create and adjust the buffer lens + buffer_lens = torch.tensor([buffers[0].size(1)] * len(buffers), device=self.device) + buffer_lens = buffer_lens - right_paddings + if self.right_padding: + buffer_lens = buffer_lens - left_paddings + + feature_buffers, feature_buffer_lens = self.preprocess( + buffers=torch.cat(buffers).to(self.device), + buffer_lens=buffer_lens, + expected_feature_buffer_len=self.expected_feature_buffer_len, + ) + + encoded, encoded_len = self.asr_model.encode( + processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens + ) + encoded = encoded.clone() + encoded_len = encoded_len.clone() + + # Roll back the encoded signals to the right + if self.right_padding: + for i in range(encoded.shape[0]): + lpad = left_paddings[i] + if lpad > 0: + lpad = int(lpad / self.sample_rate / self.model_stride_in_secs) + encoded[i] = encoded[i].roll(lpad, dims=1) + encoded[i][:, :lpad] = self.zero_encoded[:, :lpad] + encoded_len[i] = encoded_len[i] + lpad + + return encoded, encoded_len + + def encode_processed_signals( + self, fbuffers: list[FeatureBuffer], processed_signals: list[Tensor] + ) -> tuple[Tensor, Tensor]: + """ + Run Encoder part on the feature buffers. + Args: + fbuffers: (list[FeatureBuffer]) Feature buffers. + processed_signals: (list[Tensor]) Processed buffers. + Returns: + (tuple[Tensor, Tensor]) Encoder output and their lengths. + """ + + processed_signals = torch.cat([sig.unsqueeze_(0) for sig in processed_signals]).to(self.device) + processed_signals = drop_trailing_features(processed_signals, self.expected_feature_buffer_len) + processed_signal_lengths = torch.tensor([f.valid_size for f in fbuffers], device=self.device) + processed_signals = normalize_features(processed_signals, processed_signal_lengths) + processed_signal_lengths = processed_signal_lengths.clamp(max=processed_signals.shape[2]) + + encoded, encoded_len = self.asr_model.encode( + processed_signal=processed_signals, processed_signal_length=processed_signal_lengths + ) + encoded = encoded.clone() + encoded_len = encoded_len.clone() + + if self.right_padding: + for i in range(encoded.shape[0]): + lpad = int(fbuffers[i].roll_size / self.subsampling_factor) + if lpad > 0: + encoded[i] = encoded[i].roll(lpad, dims=1) + encoded[i][:, :lpad] = self.zero_encoded[:, :lpad] + encoded_len[i] = encoded_len[i] + lpad + return encoded, encoded_len + + def encode_frames(self, frames: list[Frame]) -> tuple[Tensor, Tensor]: + """ + Encode the frames using the Encoder part of the ASR model. + Args: + frames: (list[Frame]) Frames to transcribe. + Returns: + (tuple[Tensor, Tensor]) Encoder output and their lengths. + """ + raw_signals, left_paddings = self.bufferer.update(frames) + encs, enc_lens = None, None + if len(raw_signals) > 0: + encs, enc_lens = self.encode_raw_signals(frames, raw_signals, left_paddings) + return encs, enc_lens + + def encode_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> tuple[Tensor, Tensor]: + """ + Encode the feature buffers using the Encoder part of the ASR model. + Args: + fbuffers: (list[FeatureBuffer]) Feature buffers to transcribe. + Returns: + (tuple[Tensor, Tensor]) Encoder output and their lengths. + """ + processed_signals = self.bufferer.update(fbuffers) + encs, enc_lens = None, None + if len(processed_signals) > 0: + encs, enc_lens = self.encode_processed_signals(fbuffers, processed_signals) + return encs, enc_lens + + def run_greedy_decoder( + self, + state: RNNTStreamingState, + request: Request, + timesteps: torch.Tensor, + tokens: torch.Tensor, + start: int, + end: int, + alignment_length: int, + timestamp_offset: int = 0, + vad_segments: torch.Tensor = None, + ) -> bool: + """ + Greedy RNN-T decoder. + Args: + state: (RNNTStreamingState) Current state for the particular stream. + request: (Request) Current request for the particular stream. + timesteps: (Tensor) Timesteps. + tokens: (Tensor) Tokens. + start: (int) Start index. + end: (int) End index. + alignment_length: (int) Length of the alignment. + timestamp_offset: (int) Timestamp offset. + vad_segments: (Tensor) VAD segments. + Returns: + (bool) Whether EOU is detected. + """ + if self.stateful and vad_segments is not None: + vad_segments = adjust_vad_segments(vad_segments, self.left_padding_size) + + clipped_output, tail_output, eou_detected, start_idx, end_idx = self.greedy_rnnt_decoder( + global_timesteps=timesteps, + tokens=tokens, + alignment_length=alignment_length, + clip_start=start, + clip_end=end, + is_last=request.is_last, + is_start=request.is_first, + return_tail_result=self.return_tail_result, + state_start_idx=state.decoder_start_idx, + state_end_idx=state.decoder_end_idx, + timestamp_offset=timestamp_offset, + vad_segments=vad_segments, + stop_history_eou=state.options.stop_history_eou, + ) + state.update_state(clipped_output, eou_detected) + state.update_from_decoder_results(start_idx, end_idx) + if self.stateless: + # For stateless mode, we need to set the last token, it will be used for filtering duplicate token + state.set_last_token(clipped_output["last_token"], clipped_output["last_token_idx"]) + # For stateless mode, we need to increment the global offset + state.increment_global_offset(self.tokens_per_frame_float) + state.set_incomplete_segment_tokens(tail_output["tokens"]) + return eou_detected + + def stateless_transcribe_step( + self, requests: list[Request], encs: Tensor, enc_lens: Tensor, ready_state_ids: set + ) -> None: + """ + Stateless transcribe step. + Stateless assumes that we don't keep track of partial hypotheses (partial_hypotheses=None). + Args: + requests: (list[Request]) List of requests to transcribe. + encs: (Tensor) Encoder output. + enc_lens: (Tensor) Encoder output lengths. + ready_state_ids: (set) Set of ready state IDs. + """ + states = [self.get_state(request.stream_id) for request in requests] + best_hyp = self.asr_model.decode(encs, enc_lens, partial_hypotheses=None) + # For stateless mode, use zero timestamp offsets since we don't track timestamps + ready_states = self.decode_step(best_hyp, requests, states) + ready_state_ids.update(ready_states) + + def stateful_transcribe_step( + self, requests: list[Request], encs: Tensor, enc_lens_chunk: Tensor, enc_lens: Tensor, ready_state_ids: set + ) -> None: + """ + Stateful transcribe step. + Stateful assumes that we keep track of partial hypotheses. + Args: + requests: (list[Request]) List of requests to transcribe. + encs: (Tensor) Encoder output. + enc_lens_chunk: (Tensor) Encoder output lengths for the chunk. + enc_lens: (Tensor) Encoder output lengths. + ready_state_ids: (set) Set of ready state IDs. + """ + states = [self.get_state(request.stream_id) for request in requests] + partial_hypotheses, rnnt_states = [], [] + all_rnnt_states_are_none = True + for state in states: + hyp_state = state.hyp_decoding_state + if hyp_state is not None: + partial_hypotheses.append( + NemoHypothesis(score=0.0, y_sequence=torch.zeros([0], dtype=torch.long), dec_state=hyp_state) + ) + rnnt_states.append(hyp_state) + all_rnnt_states_are_none = False + else: + partial_hypotheses.append(None) + rnnt_states.append(None) + + batched_rnnt_states = None + if not all_rnnt_states_are_none: + batched_rnnt_states = self.decoding_computer.merge_to_batched_state(rnnt_states) + + batched_state = None + if self.tokens_per_right_padding > 0: + with torch.inference_mode(), torch.no_grad(): + best_hyp_chunk, alignments, batched_state = self.decoding_computer( + encs.transpose(1, 2), enc_lens_chunk, batched_rnnt_states + ) + + best_hyp = self.asr_model.decode(encs, enc_lens, partial_hypotheses=partial_hypotheses) + if self.tokens_per_right_padding > 0 and batched_state is not None: + for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)): + state.hyp_decoding_state = rnnt_state + else: + for state, hyp in zip(states, best_hyp): + state.hyp_decoding_state = hyp.dec_state + + ready_states = self.decode_step(best_hyp, requests, states) + for curr_state in states: + curr_state.timestamp_offset += self.tokens_per_frame_float + ready_state_ids.update(ready_states) + + def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNTStreamingState]) -> set: + """ + Perform greedy RNNT decoding to get the best hypothesis and update the state. + If EOU is detected, push the words to the state and cleanup the state. + Args: + best_hyp: (list) Best hypothesis. + requests: (list[Request]) List of requests to transcribe. + states: (list[RNNTStreamingState]) List of states. + Returns: + (set) Set of ready state IDs. + """ + ready_state_ids = set() + for idx, hyp in enumerate(best_hyp): + state = states[idx] + request = requests[idx] + # Perform timestamp based decoding for the hypothesis + if self.stateful: + alignment_length = self.tokens_per_right_padding + self.tokens_per_frame + else: + if self.request_type is RequestType.FEATURE_BUFFER: + alignment_length = math.ceil(request.size / self.subsampling_factor) + else: # RequestType.FRAME + alignment_length = math.ceil(self.expected_feature_buffer_len / self.subsampling_factor) + + if self.stateful: + start, end = 0, self.tokens_per_frame + else: + # For stateless mode + if request.is_first and request.is_last: + start, end = 0, alignment_length + else: + start, end = self.get_cut_off_range(alignment_length, request.is_last) + + timestamp = hyp.timestamp + tokens = hyp.y_sequence + timestamp = torch.tensor(timestamp) if isinstance(timestamp, list) else timestamp + tokens = torch.tensor(tokens) if isinstance(tokens, list) else tokens + timestamp = update_punctuation_and_language_tokens_timestamps( + tokens, timestamp, self.tokens_to_move, self.underscore_id + ) + vad_segments = request.vad_segments + eou_detected = self.run_greedy_decoder( + state=state, + request=request, + timesteps=timestamp, + tokens=tokens, + start=start, + end=end, + alignment_length=alignment_length, + timestamp_offset=state.timestamp_offset, + vad_segments=vad_segments, + ) + + if eou_detected: + self.bpe_decoder.decode_bpe_tokens(state) + state.cleanup_after_eou() + ready_state_ids.add(request.stream_id) + return ready_state_ids + + def shared_transcribe_step_stateful(self, requests: list[Request], encs: Tensor, enc_lens: Tensor) -> None: + """ + Stateful transcribe step. + After detecting EOU, it updates the state and run text processor. + If there are multiple streams, it waits until all states are ready to run text processor. + Args: + requests: (list[Request]) List of requests to transcribe. + encs: (Tensor) Encoder output. + enc_lens: (Tensor) Encoder output lengths. + """ + tokens_per_left_padding_tensor = torch.tensor(self.tokens_per_left_padding, device=self.device) + tokens_per_frame_tensor = torch.tensor(self.tokens_per_frame, device=self.device) + postponed_requests = [(ridx, request.stream_id) for ridx, request in enumerate(requests)] + next_postponed_requests = [] + ready_state_ids = set() + while len(postponed_requests) > 0: + request_ids_to_process = [] + for ridx, stream_id in postponed_requests: + if stream_id in ready_state_ids: + next_postponed_requests.append((ridx, stream_id)) + continue + request_ids_to_process.append(ridx) + if len(request_ids_to_process) > 0: + requests_to_process = [requests[jdx] for jdx in request_ids_to_process] + request_is_last = torch.tensor( + [request.is_last for request in requests_to_process], dtype=torch.bool, device=self.device + ) + enc_lens_dec = enc_lens - tokens_per_left_padding_tensor + enc_lens_dec_trimmed = torch.where( + request_is_last, + enc_lens_dec, + torch.minimum(enc_lens_dec, tokens_per_frame_tensor.expand_as(enc_lens_dec)), + ) + self.stateful_transcribe_step( + requests_to_process, + encs[request_ids_to_process][:, :, self.tokens_per_left_padding :], + enc_lens_dec_trimmed, + enc_lens_dec, + ready_state_ids, + ) + if len(ready_state_ids) > 0: + self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + ready_state_ids.clear() + postponed_requests = next_postponed_requests.copy() + next_postponed_requests.clear() + + self.update_partial_transcript(requests, self.tokenizer, self.leading_regex_pattern) + + def shared_transcribe_step(self, requests: list[Request], encs: Tensor, enc_lens: Tensor) -> None: + """ + Stateless transcribe step. + After detecting EOU, it updates the state and run text processor. + If there are multiple streams, it waits until all stated are ready to run text processor. + Args: + requests: (list[Request]) List of requests to transcribe. + encs: (Tensor) Encoder output. + enc_lens: (Tensor) Encoder output lengths. + """ + postponed_requests = [(ridx, request.stream_id) for ridx, request in enumerate(requests)] + next_postponed_requests = [] + ready_state_ids = set() + + while len(postponed_requests) > 0: + + request_ids_to_process = [] + for ridx, stream_id in postponed_requests: + + if stream_id in ready_state_ids: + # Skip if the state is already ready + next_postponed_requests.append((ridx, stream_id)) + continue + + request_ids_to_process.append(ridx) + + if len(request_ids_to_process) > 0: + requests_to_process = [requests[jdx] for jdx in request_ids_to_process] + self.stateless_transcribe_step( + requests_to_process, + encs=encs[request_ids_to_process], + enc_lens=enc_lens[request_ids_to_process], + ready_state_ids=ready_state_ids, + ) + + if len(ready_state_ids) > 0: + self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + ready_state_ids.clear() + + postponed_requests = next_postponed_requests.copy() + next_postponed_requests.clear() + + self.update_partial_transcript(requests, self.tokenizer, self.leading_regex_pattern) + + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: + """ + Transcribe a step for feature buffers. + Args: + fbuffers: (list[FeatureBuffer]) List of feature buffers to transcribe. + """ + encs, enc_lens = self.encode_feature_buffers(fbuffers) + if encs is not None: + if self.stateful: + self.shared_transcribe_step_stateful(requests=fbuffers, encs=encs, enc_lens=enc_lens) + else: + self.shared_transcribe_step(requests=fbuffers, encs=encs, enc_lens=enc_lens) + + def transcribe_step_for_frames(self, frames: list[Frame]) -> None: + """ + Transcribe a step for frames. + Args: + frames: (list[Frame]) List of frames to transcribe. + """ + encs, enc_lens = self.encode_frames(frames) + if encs is not None: + if self.stateful: + self.shared_transcribe_step_stateful(requests=frames, encs=encs, enc_lens=enc_lens) + else: + self.shared_transcribe_step(requests=frames, encs=encs, enc_lens=enc_lens) + + def get_request_generator(self) -> ContinuousBatchedRequestStreamer: + """ + Initialize the request generator. + Returns: + (ContinuousBatchedRequestStreamer) Request generator. + """ + request_generator = ContinuousBatchedRequestStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size, + sample_rate=self.sample_rate, + batch_size=self.batch_size, + request_type=self.request_type, + preprocessor=self.preprocessor, + buffer_size_in_secs=self.buffer_size_in_secs, + device=self.device, + pad_last_frame=True, + right_pad_features=self.right_padding, + tail_padding_in_samples=self.tail_padding_in_samples, + ) + return request_generator diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_ctc_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_ctc_pipeline.py new file mode 100644 index 000000000000..acedc265057d --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/cache_aware_ctc_pipeline.py @@ -0,0 +1,384 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.cache_aware_ctc_inference_wrapper import ( + CacheAwareCTCInferenceWrapper, +) +from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_ctc_decoder import CTCGreedyDecoder +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_ctc_endpointing import CTCGreedyEndpointing +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame +from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions +from nemo.collections.asr.inference.streaming.state.cache_aware_ctc_state import CacheAwareCTCStreamingState +from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames +from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.utils.pipeline_utils import ( + check_existance_of_required_attributes, + get_confidence_utils, + normalize_log_probs, +) + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +class CacheAwareCTCPipeline(BasePipeline): + """Cache Aware CTC pipeline.""" + + def __init__( + self, + cfg: DictConfig, + asr_model: CacheAwareCTCInferenceWrapper, + itn_model: AlignmentPreservingInverseNormalizer | None = None, + ): + """ + Initialize the CacheAwareCTCPipeline. + Args: + cfg: (DictConfig) Configuration parameters. + asr_model: (CacheAwareCTCInferenceWrapper) ASR model. + itn_model: (AlignmentPreservingInverseNormalizer | None) Inverse Text Normalization model. + """ + self.copy_asr_model_attributes(asr_model) + self.init_parameters(cfg) + self.init_context_manager() + self.init_bufferer_for_cache_aware_streaming() + self.conf_func, self.confidence_aggregator = get_confidence_utils(cfg.confidence) + self.init_bpe_decoder() + self.init_greedy_ctc_decoder() + self.init_endpointer() + self.init_text_processor(cfg, itn_model) + super().__init__() + + def init_parameters(self, cfg: DictConfig) -> None: + """ + Initialize the configuration parameters. + Args: + cfg: (DictConfig) Configuration parameters. + """ + if cfg.streaming.att_context_size is not None: + self.asr_model.set_default_att_context_size(att_context_size=cfg.streaming.att_context_size) + self.sample_rate = cfg.streaming.sample_rate + self.asr_output_granularity = cfg.asr_output_granularity + + self.use_cache = cfg.streaming.use_cache + self.use_feat_cache = cfg.streaming.use_feat_cache + self.batch_size = cfg.streaming.batch_size + self.num_slots = cfg.streaming.num_slots + if self.num_slots < self.batch_size: + raise ValueError( + f"Number of slots in the context manager must be >= batch_size: {self.num_slots} < {self.batch_size}" + ) + self.request_type = RequestType.from_str(cfg.streaming.request_type) + if self.request_type is not RequestType.FRAME: + raise ValueError(f"Request type {self.request_type} is not supported for cache-aware streaming.") + + self.word_boundary_tolerance = cfg.streaming.word_boundary_tolerance + self.stop_history_eou_in_milliseconds = cfg.endpointing.stop_history_eou + self.residue_tokens_at_end = cfg.endpointing.residue_tokens_at_end + self.return_tail_result = cfg.return_tail_result + + self.pre_encode_cache_size = self.asr_model.get_pre_encode_cache_size() + self.model_chunk_size = self.asr_model.get_chunk_size() + if isinstance(self.model_chunk_size, list): + self.model_chunk_size = self.model_chunk_size[1] + + if cfg.streaming.get("chunk_size_in_secs", None) is not None: + self.chunk_size_in_secs = cfg.streaming.chunk_size_in_secs + self.tokens_per_frame = math.ceil( + np.trunc(self.chunk_size_in_secs / self.window_stride) / self.subsampling_factor + ) + # overwrite the encoder streaming params with proper shift size for cache aware streaming + self.asr_model.setup_streaming_params( + chunk_size=self.model_chunk_size // self.subsampling_factor, shift_size=self.tokens_per_frame + ) + else: + self.chunk_size_in_secs = self.model_chunk_size * self.window_stride + self.tokens_per_frame = math.ceil(self.model_chunk_size / self.subsampling_factor) + + if isinstance(self.pre_encode_cache_size, list): + self.pre_encode_cache_size = self.pre_encode_cache_size[1] + self.pre_encode_cache_size_in_secs = self.pre_encode_cache_size * self.window_stride + + model_chunk_size_in_secs = self.model_chunk_size * self.window_stride + + if self.use_cache: + # if using cache, we need to pad some samples for pre_encode + self.buffer_size_in_secs = self.pre_encode_cache_size_in_secs + model_chunk_size_in_secs + self.drop_left_context = None + self.valid_out_len = None + else: + # if not using cache, we need to keep left context in buffer, but no extra padding in pre_encode + left_context_size = self.asr_model.get_att_context_size()[0] + if left_context_size < 0: + raise ValueError(f"Left context size should not be a negative value: {left_context_size}") + self.buffer_size_in_secs = ( + model_chunk_size_in_secs + left_context_size * self.subsampling_factor * self.window_stride + ) + self.drop_left_context = left_context_size + self.valid_out_len = self.tokens_per_frame + + def init_greedy_ctc_decoder(self) -> None: + """Initialize the CTC decoder.""" + check_existance_of_required_attributes(self, ['vocabulary', 'conf_func']) + self.greedy_ctc_decoder = CTCGreedyDecoder(vocabulary=self.vocabulary, conf_func=self.conf_func) + + def init_endpointer(self) -> None: + """Initialize the endpointer.""" + check_existance_of_required_attributes( + self, + [ + 'vocabulary', + 'model_stride_in_milliseconds', + 'stop_history_eou_in_milliseconds', + 'residue_tokens_at_end', + ], + ) + + self.endpointer = CTCGreedyEndpointing( + vocabulary=self.vocabulary, + ms_per_timestep=self.model_stride_in_milliseconds, + stop_history_eou=self.stop_history_eou_in_milliseconds, + residue_tokens_at_end=self.residue_tokens_at_end, + ) + + def reset_session(self) -> None: + """Reset the context manager.""" + self.context_manager.reset() + super().reset_session() + + def create_state(self, options: ASRRequestOptions) -> CacheAwareCTCStreamingState: + """ + Create new empty state. + Args: + options: (ASRRequestOptions) Request options for particular stream. + Returns: + (CacheAwareCTCStreamingState) New empty state. + """ + state = CacheAwareCTCStreamingState() + state.set_global_offset(0) + new_options = options.augment_with_defaults( + default_enable_itn=self.text_processor.is_itn_enabled(), + default_enable_pnc=self.text_processor.is_pnc_enabled(), + default_stop_history_eou=self.stop_history_eou_in_milliseconds, + default_asr_output_granularity=self.asr_output_granularity, + ) + + eou_label_buffer_size = 0 + if new_options.stop_history_eou > 0: + eou_label_buffer_size = millisecond_to_frames( + new_options.stop_history_eou, math.ceil(self.model_stride_in_milliseconds) + ) + eou_label_buffer_size += self.residue_tokens_at_end + state.setup_label_buffer(eou_label_buffer_size, self.blank_id) + state.set_options(new_options) + return state + + def get_sep(self) -> str: + """Return the separator for the text processor.""" + return self.sep + + def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = None) -> tuple[Tensor, Tensor]: + """ + Preprocess the feature buffers by stacking them and computing the lengths + Args: + buffers: (list[Tensor]) List of feature buffers. + right_paddings: (list[int] | None) List of right paddings. + Returns: + (tuple[Tensor, Tensor]) Processed feature buffers and their lengths. + """ + feature_buffers = [f_buffer.unsqueeze_(0) for f_buffer in buffers] + feature_buffer_lens = torch.tensor([f_buffer.shape[2] for f_buffer in feature_buffers], device=self.device) + if right_paddings is not None: + right_paddings = torch.tensor(right_paddings, device=feature_buffer_lens.device) + feature_buffer_lens = feature_buffer_lens - right_paddings + feature_buffers = torch.cat(feature_buffers).to(self.device) + return feature_buffers, feature_buffer_lens + + def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, frame: Frame, log_probs: Tensor): + """ + Run the greedy CTC decoder on the log_probs and update the state + Args: + state: (CacheAwareCTCStreamingState) The state of the stream + frame: (Frame) The current frame + log_probs: (Tensor) The log probabilities of the current frame + Returns: + (bool) Whether EOU is detected. + """ + eou_detected = frame.is_last + last_token = state.label_buffer[-1] if len(state.label_buffer) > 0 else self.blank_id + cur_output = self.greedy_ctc_decoder(log_probs, compute_confidence=True, previous=last_token) + state.update_label_buffer(cur_output["labels"]) + + if not eou_detected: + emissions = state.get_label_buffer() + pivot_point = len(emissions) - 1 + eou_detected, _ = self.endpointer.detect_eou_near_pivot( + emissions, pivot_point, stop_history_eou=state.options.stop_history_eou + ) + + state.update_state(cur_output, eou_detected=eou_detected) + state.increment_global_offset(self.tokens_per_frame) + return eou_detected + + def decode_log_probs( + self, frames: list[Frame], log_probs: Tensor, tail_log_probs: Tensor | None, ready_state_ids: set + ) -> None: + """ + Decode the log probabilities and update the state + Args: + frames: (list[Frame]) List of frames to transcribe. + log_probs: (Tensor) Log probabilities. + tail_log_probs: (Tensor | None) Tail log probabilities. + ready_state_ids: (set) Set of ready state IDs. + """ + + for idx, frame in enumerate(frames): + state = self.get_state(frame.stream_id) + eou_detected = self.run_greedy_decoder(state, frame, log_probs[idx]) + + if eou_detected: + self.bpe_decoder.decode_bpe_tokens(state) + state.cleanup_after_eou() + ready_state_ids.add(frame.stream_id) + + if tail_log_probs is not None: + last_token = state.label_buffer[-1] if len(state.label_buffer) > 0 else self.blank_id + tail_output = self.greedy_ctc_decoder( + tail_log_probs[idx], compute_confidence=False, previous=last_token + ) + state.set_incomplete_segment_tokens(tail_output["tokens"]) + + def cache_aware_transcribe_step( + self, + frames: list[Frame], + buffered_features: list[Tensor], + right_paddings: list[int] | None, + ready_state_ids: set, + keep_all_outputs: bool = False, + ) -> None: + """ + Cache Aware Transcribe Step + It receives a list of frames and features and do the following: + + 1. Preprocess the features by stacking them and computing the lengths + 2. Get the context and mapping from the context manager for cache aware streaming + 3. Perform a streaming step with the ASR model + 4. Update the cache and reset the cache slots for the streams that has ended + 5. Decode the log probabilities and update the state + + Args: + frames: (list[Frame]) List of frames to transcribe. + buffered_features: (list[Tensor]) List of buffered features. + right_paddings: (list[int] | None) List of right paddings. + ready_state_ids: (set) Set of ready state IDs. + keep_all_outputs: (bool) Whether to keep all outputs or not. + """ + feature_buffers, feature_buffer_lens = self.preprocess(buffered_features, right_paddings) + + stream_ids = [frame.stream_id for frame in frames] + eos_flags = [frame.is_last for frame in frames] + context, mapping = self.context_manager.get_context(stream_ids) + + drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded + log_probs, tail_log_probs, new_context = self.asr_model.stream_step( + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, + return_tail_result=self.return_tail_result, + ) + + if log_probs is not None: + log_probs = normalize_log_probs(log_probs) + self.context_manager.update_cache(stream_ids, new_context, mapping) + self.context_manager.reset_slots(stream_ids, eos_flags) + self.decode_log_probs(frames, log_probs, tail_log_probs, ready_state_ids) + + def transcribe_step_for_frames(self, frames: list[Frame]) -> None: + """ + Transcribes the frames in a streaming manner. + After detecting EOU, it updates the state and run text processor. + If there are multiple streams, it waits until all states are ready to run text processor. + Args: + frames: (list[Frame]) List of frames to transcribe. + """ + all_fbuffers, right_paddings = self.bufferer.update(frames) + + ready_state_ids = set() + if len(all_fbuffers) > 0: + nonfinal_frames, nonfinal_fbuffers = [], [] + final_frames, final_fbuffers = [], [] + final_right_paddings = [] + for jdx, bfeature in enumerate(all_fbuffers): + frame = frames[jdx] + if frame.is_last: + final_frames.append(frame) + final_fbuffers.append(bfeature) + final_right_paddings.append(right_paddings[jdx]) + else: + nonfinal_frames.append(frame) + nonfinal_fbuffers.append(bfeature) + + if len(nonfinal_frames) > 0: + self.cache_aware_transcribe_step( + nonfinal_frames, nonfinal_fbuffers, None, ready_state_ids, keep_all_outputs=False + ) + if len(final_frames) > 0: + self.cache_aware_transcribe_step( + final_frames, final_fbuffers, final_right_paddings, ready_state_ids, keep_all_outputs=True + ) + + # Postprocess the ready states + if len(ready_state_ids) > 0: + self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + ready_state_ids.clear() + + self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern) + + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: + """Transcribe a step for feature buffers""" + raise NotImplementedError("Feature buffer type is not supported for cache aware streaming.") + + def get_request_generator(self) -> ContinuousBatchedRequestStreamer: + """ + Initialize the request generator. + Returns: + (ContinuousBatchedRequestStreamer) Request generator. + """ + request_generator = ContinuousBatchedRequestStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size_in_secs, + sample_rate=self.sample_rate, + batch_size=self.batch_size, + request_type=self.request_type, + preprocessor=None, + buffer_size_in_secs=None, + device=None, + pad_last_frame=True, + ) + return request_generator diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py new file mode 100644 index 000000000000..e23b8ca34e0c --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -0,0 +1,390 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import Tensor + +from nemo.collections.asr.inference.model_wrappers.cache_aware_rnnt_inference_wrapper import ( + CacheAwareRNNTInferenceWrapper, +) +from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_rnnt_decoder import RNNTGreedyDecoder +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_rnnt_endpointing import RNNTGreedyEndpointing +from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame +from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions +from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTStreamingState +from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames +from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.utils.pipeline_utils import ( + check_existance_of_required_attributes, + get_confidence_utils, +) +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +class CacheAwareRNNTPipeline(BasePipeline): + """Cache Aware RNNT pipeline.""" + + def __init__( + self, + cfg: DictConfig, + asr_model: CacheAwareRNNTInferenceWrapper, + itn_model: AlignmentPreservingInverseNormalizer | None = None, + ): + """ + Initialize the CacheAwareRNNTPipeline. + Args: + cfg: (DictConfig) Configuration parameters. + asr_model: (CacheAwareRNNTInferenceWrapper) ASR model. + itn_model: (AlignmentPreservingInverseNormalizer | None) Inverse Text Normalization model. + """ + self.copy_asr_model_attributes(asr_model) + self.init_parameters(cfg) + self.init_context_manager() + self.init_bufferer_for_cache_aware_streaming() + self.conf_func, self.confidence_aggregator = get_confidence_utils(cfg.confidence) + self.init_bpe_decoder() + self.init_greedy_rnnt_decoder() + self.init_endpointer() + self.init_text_processor(cfg, itn_model) + super().__init__() + + def init_parameters(self, cfg: DictConfig) -> None: + """ + Initialize the parameters. + Args: + cfg: (DictConfig) Configuration parameters. + """ + if cfg.streaming.att_context_size is not None: + self.asr_model.set_default_att_context_size(att_context_size=cfg.streaming.att_context_size) + + self.sample_rate = cfg.streaming.sample_rate + self.asr_output_granularity = cfg.asr_output_granularity + self.pre_encode_cache_size = self.asr_model.get_pre_encode_cache_size() + self.model_chunk_size = self.asr_model.get_chunk_size() + if isinstance(self.model_chunk_size, list): + self.model_chunk_size = self.model_chunk_size[1] + + self.use_cache = cfg.streaming.use_cache + self.use_feat_cache = cfg.streaming.use_feat_cache + + if cfg.streaming.get("chunk_size_in_secs", None) is not None: + self.chunk_size_in_secs = cfg.streaming.chunk_size_in_secs + self.tokens_per_frame = math.ceil( + np.trunc(self.chunk_size_in_secs / self.window_stride) / self.subsampling_factor + ) + # overwrite the encoder streaming params with proper shift size for cache aware streaming + self.asr_model.setup_streaming_params( + chunk_size=self.model_chunk_size // self.subsampling_factor, shift_size=self.tokens_per_frame + ) + else: + self.chunk_size_in_secs = self.model_chunk_size * self.window_stride + self.tokens_per_frame = math.ceil(self.model_chunk_size / self.subsampling_factor) + + if isinstance(self.pre_encode_cache_size, list): + self.pre_encode_cache_size = self.pre_encode_cache_size[1] + self.pre_encode_cache_size_in_secs = self.pre_encode_cache_size * self.window_stride + + # Context Manager + self.batch_size = cfg.streaming.batch_size + self.num_slots = cfg.streaming.num_slots + if self.num_slots < self.batch_size: + raise ValueError( + f"Number of slots in the context manager must be >= batch_size: {self.num_slots} < {self.batch_size}" + ) + model_chunk_size_in_secs = self.model_chunk_size * self.window_stride + + if self.use_cache: + # if using cache, we need to pad some samples for pre_encode + self.buffer_size_in_secs = self.pre_encode_cache_size_in_secs + model_chunk_size_in_secs + self.drop_left_context = None + self.valid_out_len = None + else: + # if not using cache, we need to keep left context in buffer, but no extra padding in pre_encode + left_context_size = self.asr_model.get_att_context_size()[0] + if left_context_size < 0: + raise ValueError(f"Left context size should not be a negative value: {left_context_size}") + self.buffer_size_in_secs = ( + model_chunk_size_in_secs + left_context_size * self.subsampling_factor * self.window_stride + ) + self.drop_left_context = left_context_size + self.valid_out_len = self.tokens_per_frame + + self.stop_history_eou_in_milliseconds = cfg.endpointing.stop_history_eou + self.residue_tokens_at_end = cfg.endpointing.residue_tokens_at_end + self.word_boundary_tolerance = cfg.streaming.word_boundary_tolerance + self.return_tail_result = cfg.return_tail_result + + self.request_type = RequestType.from_str(cfg.streaming.request_type) + if self.request_type is not RequestType.FRAME: + raise ValueError(f"Request type {self.request_type} is not supported for cache-aware streaming.") + + def init_greedy_rnnt_decoder(self) -> None: + """Initialize the RNNT decoder.""" + check_existance_of_required_attributes(self, ['vocabulary', 'conf_func']) + self.greedy_rnnt_decoder = RNNTGreedyDecoder(vocabulary=self.vocabulary, conf_func=self.conf_func) + + def init_endpointer(self) -> None: + """Initialize the endpointer.""" + check_existance_of_required_attributes( + self, + [ + 'vocabulary', + 'model_stride_in_milliseconds', + 'stop_history_eou_in_milliseconds', + 'residue_tokens_at_end', + ], + ) + + self.endpointer = RNNTGreedyEndpointing( + vocabulary=self.vocabulary, + ms_per_timestep=self.model_stride_in_milliseconds, + stop_history_eou=self.stop_history_eou_in_milliseconds, + residue_tokens_at_end=self.residue_tokens_at_end, + ) + + def reset_session(self) -> None: + """Reset the context manager.""" + self.context_manager.reset() + super().reset_session() + + def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingState: + """ + Create new empty state. + Args: + options: (ASRRequestOptions) Request options for particular stream. + Returns: + (CacheAwareRNNTStreamingState) New empty state. + """ + state = CacheAwareRNNTStreamingState() + state.set_global_offset(0) + new_options = options.augment_with_defaults( + default_enable_itn=self.text_processor.is_itn_enabled(), + default_enable_pnc=self.text_processor.is_pnc_enabled(), + default_stop_history_eou=self.stop_history_eou_in_milliseconds, + default_asr_output_granularity=self.asr_output_granularity, + ) + + eou_label_buffer_size = 0 + if new_options.stop_history_eou > 0: + eou_label_buffer_size = millisecond_to_frames( + new_options.stop_history_eou, math.ceil(self.model_stride_in_milliseconds) + ) + eou_label_buffer_size += self.residue_tokens_at_end + state.setup_label_buffer(eou_label_buffer_size, self.blank_id) + state.set_previous_hypothesis(None) + state.set_options(new_options) + return state + + def get_sep(self) -> str: + """Return the separator for the text processor.""" + return self.sep + + def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = None) -> tuple[Tensor, Tensor]: + """ + Preprocess the feature buffers by stacking them and computing the lengths + Args: + buffers: (list[Tensor]) List of feature buffers. + right_paddings: (list[int] | None) List of right paddings. + Returns: + (tuple[Tensor, Tensor]) Processed feature buffers and their lengths. + """ + feature_buffers = [f_buffer.unsqueeze_(0) for f_buffer in buffers] + feature_buffer_lens = torch.tensor([f_buffer.shape[2] for f_buffer in feature_buffers], device=self.device) + if right_paddings is not None: + right_paddings = torch.tensor(right_paddings, device=feature_buffer_lens.device) + feature_buffer_lens = feature_buffer_lens - right_paddings + feature_buffers = torch.cat(feature_buffers).to(self.device) + return feature_buffers, feature_buffer_lens + + def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, frame: Frame, hyp: Hypothesis) -> bool: + """ + Run the greedy RNNT decoder on the hypothesis and update the state + Args: + state: (CacheAwareRNNTStreamingState) The state of the stream + frame: (Frame) The current frame + hyp: (Hypothesis) The hypothesis of the current frame + Returns: + (bool) Whether EOU is detected. + """ + eou_detected = frame.is_last + cur_output, cur_labels, new_offset = self.greedy_rnnt_decoder( + global_timestamps=hyp.timestamp, + tokens=hyp.y_sequence, + length=self.tokens_per_frame, + offset=state.offset, + ) + state.set_offset(new_offset) + + # cur labels contains blank tokens as well, it is needed for EOU detection + state.update_label_buffer(cur_labels) + + if not eou_detected: + emissions = state.get_label_buffer() + pivot_point = len(emissions) - 1 + eou_detected, _ = self.endpointer.detect_eou_near_pivot( + emissions, pivot_point, stop_history_eou=state.options.stop_history_eou + ) + + state.update_state(cur_output, eou_detected=eou_detected) + return eou_detected + + def cache_aware_transcribe_step( + self, + frames: list[Frame], + features: list[Tensor], + right_paddings: list[int], + ready_state_ids: set, + keep_all_outputs: bool = False, + ) -> None: + """ + Cache Aware Transcribe Step + It receives a list of frames and features and do the following: + + 1. Preprocess the features by stacking them and computing the lengths + 2. Collecting previous hypotheses for stateful decoding + 3. Get the context and mapping from the context manager for cache aware streaming + 4. Perform a streaming step with the ASR model + 5. Update the cache and reset the cache slots for the streams that has ended + 6. Update the previous hypothesis and reset the previous hypothesis for the streams that has ended + 7. Perform greedy RNNT decoding to get the best hypothesis and update the states + 8. Update the ready states to indicate that the state is ready for text post-processing + Args: + frames: (list[Frame]) List of frames to transcribe. + features: (list[Tensor]) List of feature buffers. + right_paddings: (list[int] | None) List of right paddings. + ready_state_ids: (set) Set of ready state IDs. + keep_all_outputs: (bool) Whether to keep all outputs or not. + """ + + feature_buffers, feature_buffer_lens = self.preprocess(features, right_paddings) + states, stream_ids, eos_flags = [], [], [] + for frame in frames: + states.append(self.get_state(frame.stream_id)) + stream_ids.append(frame.stream_id) + eos_flags.append(frame.is_last) + + previous_hypotheses = [state.get_previous_hypothesis() for state in states] + context, mapping = self.context_manager.get_context(stream_ids) + + drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded + best_hyp, new_context = self.asr_model.stream_step( + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + context=context, + previous_hypotheses=previous_hypotheses, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, + ) + + # update the cache and reset the cache slots for the streams that has ended + self.context_manager.update_cache(stream_ids, new_context, mapping) + self.context_manager.reset_slots(stream_ids, eos_flags) + + # update the previous hypothesis and reset the previous hypothesis for the streams that has ended + for state, hyp, eos in zip(states, best_hyp, eos_flags): + if eos: + state.reset_previous_hypothesis() + else: + state.set_previous_hypothesis(hyp) + + # run greedy decoder for each frame-state-hypothesis tuple + for frame, state, hyp in zip(frames, states, best_hyp): + eou_detected = self.run_greedy_decoder(state, frame, hyp) + if eou_detected: + self.bpe_decoder.decode_bpe_tokens(state) + state.cleanup_after_eou() + ready_state_ids.add(frame.stream_id) + + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: + """Transcribe a step for feature buffers""" + raise NotImplementedError("Feature buffer type is not supported for cache aware streaming.") + + def transcribe_step_for_frames(self, frames: list[Frame]) -> None: + """ + Transcribes the frames in a streaming manner. + After detecting EOU, it updates the state and run text processor. + If there are multiple streams, it waits until all states are ready to run text processor. + Args: + frames: (list[Frame]) List of frames to transcribe. + """ + + all_fbuffers, right_paddings = self.bufferer.update(frames) + ready_state_ids = set() + + # streams that contains multiple frames + if len(all_fbuffers) > 0: + final_frames, final_fbuffers = [], [] + nonfinal_frames, nonfinal_fbuffers = [], [] + final_right_paddings = [] + for jdx, bfeature in enumerate(all_fbuffers): + bframe = frames[jdx] + + if bframe.is_last: + final_frames.append(bframe) + final_fbuffers.append(bfeature) + final_right_paddings.append(right_paddings[jdx]) + else: + nonfinal_frames.append(bframe) + nonfinal_fbuffers.append(bfeature) + + if len(nonfinal_frames) > 0: + self.cache_aware_transcribe_step( + nonfinal_frames, nonfinal_fbuffers, None, ready_state_ids, keep_all_outputs=False + ) + + if len(final_frames) > 0: + self.cache_aware_transcribe_step( + final_frames, final_fbuffers, final_right_paddings, ready_state_ids, keep_all_outputs=True + ) + + # post-process the ready states + if len(ready_state_ids) > 0: + self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + ready_state_ids.clear() + + self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern) + + def get_request_generator(self) -> ContinuousBatchedRequestStreamer: + """ + Initialize the request generator. + Returns: + (ContinuousBatchedRequestStreamer) Request generator. + """ + # for cache aware streaming we need to process one frame at a time -> n_frames_per_stream=1 + request_generator = ContinuousBatchedRequestStreamer( + n_frames_per_stream=1, + frame_size_in_secs=self.chunk_size_in_secs, + sample_rate=self.sample_rate, + batch_size=self.batch_size, + request_type=self.request_type, + preprocessor=None, + buffer_size_in_secs=None, + device=None, + pad_last_frame=True, + ) + return request_generator diff --git a/nemo/collections/asr/inference/pipelines/pipeline_interface.py b/nemo/collections/asr/inference/pipelines/pipeline_interface.py new file mode 100644 index 000000000000..a66f879f21fc --- /dev/null +++ b/nemo/collections/asr/inference/pipelines/pipeline_interface.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +from nemo.collections.asr.inference.streaming.framing.request import Request +from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions + + +class PipelineInterface(ABC): + """ + The base interface for streaming speech pipelines + Base usage for all pipelines: + pipeline.start_session() + for requests in request_generator: + pipeline.transcribe_step(requests) + pipeline.close_session() + """ + + @abstractmethod + def open_session(self): + """ + Open a new session + """ + raise NotImplementedError + + @abstractmethod + def close_session(self): + """ + End the current session + """ + raise NotImplementedError + + @abstractmethod + def get_state(self, stream_id: int): + """ + Get the state of the stream + """ + raise NotImplementedError + + @abstractmethod + def delete_state(self, stream_id: int): + """ + Delete the state of the stream + """ + raise NotImplementedError + + @abstractmethod + def create_state(self, options: ASRRequestOptions): + """ + Create a new empty state + """ + raise NotImplementedError + + @abstractmethod + def init_state(self, stream_id: int, options: ASRRequestOptions): + """ + Initialize the state of the stream + """ + raise NotImplementedError + + @abstractmethod + def transcribe_step(self, requests: list[Request]): + """ + Transcribe a step + """ + raise NotImplementedError diff --git a/nemo/collections/asr/inference/streaming/__init__.py b/nemo/collections/asr/inference/streaming/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/buffering/__init__.py b/nemo/collections/asr/inference/streaming/buffering/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/buffering/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/buffering/audio_bufferer.py b/nemo/collections/asr/inference/streaming/buffering/audio_bufferer.py new file mode 100644 index 000000000000..c1bc862f3652 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/buffering/audio_bufferer.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch import Tensor +from nemo.collections.asr.inference.streaming.framing.request import Frame + + +class AudioBufferer: + """ + Audio bufferer class + It buffers the audio chunks and maintains the buffer. + """ + + def __init__(self, sample_rate: int, buffer_size_in_secs: float): + """ + Args: + sample_rate (int): sample rate + buffer_size_in_secs (float): buffer size in seconds + """ + self.buffer_size = int(buffer_size_in_secs * sample_rate) + self.sample_buffer = torch.zeros(self.buffer_size, dtype=torch.float32) + self.left_padding = self.buffer_size + + def reset(self) -> None: + """ + Reset the buffer to zero + """ + self.sample_buffer.zero_() + self.left_padding = self.buffer_size + + def update(self, frame: Frame) -> None: + """ + Update the buffer with the new frame + Args: + frame (Frame): frame to update the buffer with + """ + if frame.size > self.buffer_size: + raise RuntimeError(f"Frame size ({frame.size}) exceeds buffer size ({self.buffer_size})") + + shift = frame.size + self.sample_buffer = torch.roll(self.sample_buffer, -shift) + self.sample_buffer[-shift:].copy_(frame.samples) + self.left_padding = max(0, self.left_padding - shift) + + def get_buffer(self) -> Tensor: + """ + Get the current buffer + Returns: + Tensor: current state of the buffer + """ + return self.sample_buffer.clone() + + def get_left_padding(self) -> int: + """ + Get the left padding + Returns: + int: left padding + """ + return self.left_padding + + +class BatchedAudioBufferer: + """ + Batched audio bufferer class + It buffers the audio chunks from multiple streams and returns the buffers. + """ + + def __init__(self, sample_rate: int, buffer_size_in_secs: float): + """ + Args: + sample_rate (int): sample rate + buffer_size_in_secs (float): buffer size in seconds + """ + self.sample_rate = sample_rate + self.buffer_size_in_secs = buffer_size_in_secs + self.bufferers = {} + + def reset(self) -> None: + """ + Reset bufferers + """ + self.bufferers = {} + + def rm_bufferer(self, stream_id: int) -> None: + """ + Remove bufferer for the given stream id + Args: + stream_id (int): stream id + """ + self.bufferers.pop(stream_id, None) + + def update(self, frames: list[Frame]) -> tuple[list[Tensor], list[int]]: + """ + Update the bufferers with the new frames. + Frames can come from different streams (audios), so we need to maintain a bufferer for each stream + Args: + frames (list[Frame]): list of frames + Returns: + tuple[list[Tensor], list[int]]: + buffers: list of buffered audio tensors, one per input frame + left_paddings: list of left paddings, one per input frame + """ + buffers, left_paddings = [], [] + for frame in frames: + bufferer = self.bufferers.get(frame.stream_id, None) + + if bufferer is None: + bufferer = AudioBufferer(self.sample_rate, self.buffer_size_in_secs) + self.bufferers[frame.stream_id] = bufferer + + bufferer.update(frame) + buffers.append(bufferer.get_buffer()) + left_paddings.append(bufferer.get_left_padding()) + + if frame.is_last: + self.rm_bufferer(frame.stream_id) + + return buffers, left_paddings diff --git a/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py b/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py new file mode 100644 index 000000000000..3fbbd4ec37b6 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py @@ -0,0 +1,265 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import AudioBufferer +from nemo.collections.asr.inference.streaming.framing.request import Frame +from nemo.collections.asr.inference.utils.constants import LOG_MEL_ZERO +from nemo.collections.asr.models import ASRModel + + +class CacheFeatureBufferer: + """ + Cache feature bufferer class + It buffers the feature chunks and maintains the buffer. + """ + + def __init__( + self, + sample_rate: int, + buffer_size_in_secs: float, + chunk_size_in_secs: float, + preprocessor_cfg: DictConfig, + device: torch.device, + fill_value: float = LOG_MEL_ZERO, + right_padding_ratio: float = 0.8, + ): + """ + Args: + sample_rate (int): sample rate + buffer_size_in_secs (float): buffer size in seconds + chunk_size_in_secs (float): chunk size in seconds + preprocessor_cfg (DictConfig): preprocessor config + device (torch.device): device + fill_value (float): value to fill the feature buffer with + right_padding_ratio (float): what fraction of actual right padding of the last frame to use for padding mask, + some models perform better with extra padding at the end of the audio + """ + if buffer_size_in_secs < chunk_size_in_secs: + raise ValueError( + f"Buffer size ({buffer_size_in_secs}s) should be no less than chunk size ({chunk_size_in_secs}s)" + ) + + self.sample_rate = sample_rate + self.buffer_size_in_secs = buffer_size_in_secs + self.chunk_size_in_secs = chunk_size_in_secs + self.device = device + + if hasattr(preprocessor_cfg, 'log') and preprocessor_cfg.log: + self.ZERO_LEVEL_SPEC_DB_VAL = LOG_MEL_ZERO # Log-Mel spectrogram value for zero signals + else: + self.ZERO_LEVEL_SPEC_DB_VAL = fill_value # Custom fill value for the feature buffer + + self.n_feat = preprocessor_cfg.features + self.timestep_duration = preprocessor_cfg.window_stride + self.n_chunk_look_back = int(self.timestep_duration * self.sample_rate) + self.chunk_size = int(self.chunk_size_in_secs * self.sample_rate) + self.sample_buffer = AudioBufferer(sample_rate, buffer_size_in_secs) + + self.feature_buffer_len = int(buffer_size_in_secs / self.timestep_duration) + self.feature_chunk_len = int(chunk_size_in_secs / self.timestep_duration) + self.feature_buffer = torch.full( + [self.n_feat, self.feature_buffer_len], + self.ZERO_LEVEL_SPEC_DB_VAL, + dtype=torch.float32, + device=self.device, + ) + + self.preprocessor = ASRModel.from_config_dict(preprocessor_cfg) + self.preprocessor.to(self.device) + + self.right_padding_ratio = right_padding_ratio + self.right_padding = 0 + + def reset(self) -> None: + """ + Reset the buffer to zero + """ + self.sample_buffer.reset() + self.feature_buffer.fill_(self.ZERO_LEVEL_SPEC_DB_VAL) + self.right_padding = 0 + + def _update_feature_buffer(self, feat_chunk: torch.Tensor) -> None: + """ + Add an extracted feature to `feature_buffer` + Args: + feat_chunk (torch.Tensor): feature chunk + """ + self.feature_buffer[:, : -self.feature_chunk_len] = self.feature_buffer[:, self.feature_chunk_len :].clone() + self.feature_buffer[:, -self.feature_chunk_len :] = feat_chunk.clone() + + def preprocess( + self, audio_signal: torch.Tensor, right_padding: int = 0, expected_feat_len: int = None + ) -> tuple[torch.Tensor, int]: + """ + Preprocess the audio signal using the preprocessor + Args: + audio_signal (torch.Tensor): audio signal + right_padding (int): right padding + expected_feat_len (int): expected feature length + Returns: + torch.Tensor: preprocessed features + int: right padding + """ + sig_len = len(audio_signal) + if right_padding > 0: + right_padding = int(right_padding * self.right_padding_ratio) + + sig_len -= right_padding + features, _ = self.preprocessor( + input_signal=audio_signal.unsqueeze_(0).to(self.device), + length=torch.tensor([sig_len], device=self.device), + ) + + if features.shape[2] > expected_feat_len: + features = features[:, :, :expected_feat_len] + + features = features.squeeze() + right_padding = math.floor(right_padding / self.sample_rate / self.timestep_duration) + return features, right_padding + + def update(self, frame: Frame) -> None: + """ + Update the sample and feature buffers with the new frame + Args: + frame (Frame): frame to update the buffer with + """ + + # Update the sample buffer with the new frame + self.sample_buffer.update(frame) + right_padding = frame.size - frame.valid_size + + plus_one = 0 + if math.isclose(self.buffer_size_in_secs, self.chunk_size_in_secs): + # If the buffer size is equal to the chunk size, just take the whole buffer + samples = self.sample_buffer.sample_buffer.clone() + else: + # Add look_back to have context for the first feature + samples = self.sample_buffer.sample_buffer[-(self.n_chunk_look_back + self.chunk_size) :] + plus_one = 1 + + # Get the mel spectrogram + features, right_padding = self.preprocess( + samples, right_padding, expected_feat_len=self.feature_chunk_len + plus_one + ) + + # Update the feature buffer with the new features + self._update_feature_buffer(features[:, -self.feature_chunk_len :]) + self.right_padding = right_padding + + def get_feature_buffer(self) -> torch.Tensor: + """ + Get the current feature buffer + Returns: + torch.Tensor: current state of the feature buffer + """ + return self.feature_buffer.clone() + + def get_right_padding(self) -> int: + """ + Get the right padding + Returns: + int: right padding + """ + return self.right_padding + + +class BatchedCacheFeatureBufferer: + """ + Batched cache feature bufferer class + It buffers the feature chunks from multiple streams and maintains the buffers. + """ + + def __init__( + self, + sample_rate: int, + buffer_size_in_secs: float, + chunk_size_in_secs: float, + preprocessor_cfg: DictConfig, + device: torch.device, + right_padding_ratio: float = 0.8, + ): + """ + Args: + sample_rate (int): sample rate + buffer_size_in_secs (float): buffer size in seconds + chunk_size_in_secs (float): chunk size in seconds + preprocessor_cfg (DictConfig): preprocessor config + device (torch.device): device + right_padding_ratio (float): what fraction of actual right padding to use to create padding mask, + some models perform better with extra padding at the end of the audio + """ + + self.sample_rate = sample_rate + self.buffer_size_in_secs = buffer_size_in_secs + self.bufferers = {} + self.chunk_size_in_secs = chunk_size_in_secs + self.preprocessor_cfg = preprocessor_cfg + self.device = device + self.right_padding_ratio = right_padding_ratio + + def reset(self) -> None: + """ + Reset bufferers + """ + self.bufferers = {} + + def rm_bufferer(self, stream_id: int) -> None: + """ + Remove bufferer for the given stream id + Args: + stream_id (int): stream id + """ + self.bufferers.pop(stream_id, None) + + def update(self, frames: list[Frame]) -> tuple[list[torch.Tensor], list[int]]: + """ + Update the feature bufferers with the new frames. + Frames can come from different streams (audios), so we need to maintain a bufferer for each stream. + Args: + frames (list[Frame]): list of frames + Returns: + tuple[list[torch.Tensor], list[int]]: + feature_buffers: list of feature buffers, one per input frame + right_paddings: list of right paddings, one per input frame + """ + fbuffers = [] + right_paddings = [] + for frame in frames: + bufferer = self.bufferers.get(frame.stream_id, None) + + if bufferer is None: + bufferer = CacheFeatureBufferer( + sample_rate=self.sample_rate, + buffer_size_in_secs=self.buffer_size_in_secs, + chunk_size_in_secs=self.chunk_size_in_secs, + preprocessor_cfg=self.preprocessor_cfg, + device=self.device, + right_padding_ratio=self.right_padding_ratio, + ) + self.bufferers[frame.stream_id] = bufferer + + bufferer.update(frame) + fbuffers.append(bufferer.get_feature_buffer()) + right_paddings.append(bufferer.get_right_padding()) + + if frame.is_last: + self.rm_bufferer(frame.stream_id) + + return fbuffers, right_paddings diff --git a/nemo/collections/asr/inference/streaming/buffering/feature_bufferer.py b/nemo/collections/asr/inference/streaming/buffering/feature_bufferer.py new file mode 100644 index 000000000000..c1372bab61b0 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/buffering/feature_bufferer.py @@ -0,0 +1,160 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer +from nemo.collections.asr.inference.utils.constants import LOG_MEL_ZERO + + +class FeatureBufferer: + """ + Feature bufferer class + It buffers the feature chunks and maintains the buffer. + """ + + def __init__( + self, + sample_rate: int, + buffer_size_in_secs: float, + preprocessor_cfg: DictConfig, + device: torch.device, + fill_value: float = LOG_MEL_ZERO, + ): + """ + Args: + sample_rate (int): sample rate + buffer_size_in_secs (float): buffer size in seconds + preprocessor_cfg (DictConfig): preprocessor config + device (torch.device): device + fill_value (float): value to fill the feature buffer with + """ + self.sample_rate = sample_rate + self.buffer_size_in_secs = buffer_size_in_secs + self.device = device + + if hasattr(preprocessor_cfg, 'log') and preprocessor_cfg.log: + self.ZERO_LEVEL_SPEC_DB_VAL = LOG_MEL_ZERO + else: + self.ZERO_LEVEL_SPEC_DB_VAL = fill_value + + self.n_feat = preprocessor_cfg.features + self.feature_buffer_len = int(buffer_size_in_secs / preprocessor_cfg.window_stride) + self.feature_buffer = torch.full( + [self.n_feat, self.feature_buffer_len], + self.ZERO_LEVEL_SPEC_DB_VAL, + dtype=torch.float32, + device=self.device, + ) + + def reset(self) -> None: + """ + Reset the buffer to zero + """ + self.feature_buffer.fill_(self.ZERO_LEVEL_SPEC_DB_VAL) + + def update(self, fbuffer: FeatureBuffer) -> None: + """ + Replace feature buffer with new data + Args: + fbuffer (FeatureBuffer): feature buffer to update + """ + # Resize if needed (optional) + if fbuffer.size != self.feature_buffer.shape[1]: + self.feature_buffer = torch.full( + [self.n_feat, fbuffer.size], + self.ZERO_LEVEL_SPEC_DB_VAL, + dtype=torch.float32, + device=self.device, + ) + + self.feature_buffer.copy_(fbuffer.features) + + def get_feature_buffer(self) -> torch.Tensor: + """ + Get the current feature buffer + Returns: + torch.Tensor: current state of the feature buffer + """ + return self.feature_buffer.clone() + + +class BatchedFeatureBufferer: + """ + Batched feature bufferer class + It buffers the feature chunks from multiple streams and maintains the buffers. + """ + + def __init__( + self, + sample_rate: int, + buffer_size_in_secs: float, + preprocessor_cfg: DictConfig, + device: torch.device, + ): + """ + Args: + sample_rate (int): sample rate + buffer_size_in_secs (float): buffer size in seconds + preprocessor_cfg (DictConfig): preprocessor config + device (torch.device): device + """ + self.sample_rate = sample_rate + self.buffer_size_in_secs = buffer_size_in_secs + self.preprocessor_cfg = preprocessor_cfg + self.device = device + self.bufferers = {} + + def reset(self) -> None: + """Reset bufferers""" + self.bufferers = {} + + def rm_bufferer(self, stream_id: int) -> None: + """ + Remove bufferer for the given stream id + Args: + stream_id (int): stream id + """ + self.bufferers.pop(stream_id, None) + + def update(self, fbuffers: list[FeatureBuffer]) -> list[torch.Tensor]: + """ + Update the feature bufferers with the new feature buffers. + Feature buffers can come from different streams (audios), so we need to maintain a bufferer for each stream. + Args: + fbuffers (list[FeatureBuffer]): list of feature buffers + Returns: + list[torch.Tensor]: list of feature buffers, one per input frame + """ + result_buffers = [] + for fbuffer in fbuffers: + bufferer = self.bufferers.get(fbuffer.stream_id, None) + + if bufferer is None: + bufferer = FeatureBufferer( + self.sample_rate, + self.buffer_size_in_secs, + self.preprocessor_cfg, + self.device, + ) + self.bufferers[fbuffer.stream_id] = bufferer + + bufferer.update(fbuffer) + result_buffers.append(bufferer.get_feature_buffer()) + + if fbuffer.is_last: + self.rm_bufferer(fbuffer.stream_id) + + return result_buffers diff --git a/nemo/collections/asr/inference/streaming/decoders/__init__.py b/nemo/collections/asr/inference/streaming/decoders/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/decoders/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/decoders/greedy/__init__.py b/nemo/collections/asr/inference/streaming/decoders/greedy/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/decoders/greedy/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_ctc_decoder.py b/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_ctc_decoder.py new file mode 100644 index 000000000000..cc7001da9578 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_ctc_decoder.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable +import torch + +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_decoder import GreedyDecoder + + +class CTCGreedyDecoder(GreedyDecoder): + """CTC Greedy decoder class""" + + def __init__(self, vocabulary: list[str], conf_func: Callable = None): + """ + Initialize the CTCGreedyDecoder + Args: + vocabulary (list[str]): list of vocabulary tokens + conf_func (Callable): function to compute confidence + """ + + super().__init__(vocabulary, conf_func) + + @staticmethod + def get_labels(log_probs: torch.Tensor) -> list[int]: + """ + Perform greedy decoding on the log probabilities + Args: + log_probs (torch.Tensor): log probabilities + Returns: + list[int]: list of tokens + """ + if log_probs.dim() != 2: + raise ValueError("log_probs must be 2D tensor") + + labels = log_probs.argmax(dim=-1).cpu() # T + return labels.tolist() + + def __call__(self, log_probs: torch.Tensor, compute_confidence: bool = True, previous: int = None) -> dict: + """ + Greedy decode the log probabilities + Args: + log_probs (torch.Tensor): log probabilities + compute_confidence (bool): compute confidence or not + Returns: + dict: output dictionary containing tokens, timesteps, and confidences + """ + + compute_confidence = compute_confidence and self.conf_func is not None + + if log_probs.dim() != 2: + raise ValueError("log_probs must be 2D tensor") + + if compute_confidence: + # Add batch dimension + log_probs = log_probs.unsqueeze(0) # 1 x T x N + # Compute confidences + confidences = torch.zeros(log_probs.shape[0], log_probs.shape[1]) # 1 x T + confidences[0] = self.conf_func(log_probs[0], v=log_probs.shape[2]) # 1 x T + # Remove batch dimension and convert to list + confidences = confidences.squeeze(0).tolist() # T + # Remove batch dimension + log_probs = log_probs.squeeze(0) # T x N + + labels = self.get_labels(log_probs) # T + output = {"tokens": [], "timesteps": [], "confidences": []} + previous = self.blank_id if previous is None else previous + for i, p in enumerate(labels): + if p != previous and p != self.blank_id: + output["tokens"].append(p) + output["timesteps"].append(i) + if compute_confidence: + output["confidences"].append(confidences[i]) + previous = p + + output["labels"] = labels + return output + + +class ClippedCTCGreedyDecoder: + """ + Clipped CTC Greedy decoder class + Decodes the tokens within a given clip range and returns the clipped tokens and timestamps. + """ + + def __init__(self, vocabulary: list[str], tokens_per_frame: int, conf_func: Callable = None, endpointer=None): + """ + Initialize the ClippedCTCGreedyDecoder + Args: + vocabulary (list[str]): list of vocabulary tokens + tokens_per_frame (int): number of tokens per frame + conf_func (Callable): function to compute confidence + endpointer (Any): endpointer to detect EOU + """ + self.greedy_decoder = CTCGreedyDecoder(vocabulary, conf_func) + self.endpointer = endpointer + self.tokens_per_frame = tokens_per_frame + + def __call__( + self, + log_probs: torch.Tensor, + clip_start: int, + clip_end: int, + is_last: bool = False, + is_start: bool = True, + return_partial_result: bool = True, + state_start_idx: int = 0, + state_end_idx: int = 0, + stop_history_eou: int = None, + compute_confidence: bool = True, + ) -> tuple[dict, dict, bool, int, int]: + """ + Decode the log probabilities within the clip range (clip_start, clip_end) + Args: + log_probs (torch.Tensor): log probabilities + clip_start (int): start index of the clip + clip_end (int): end index of the clip + is_last (bool): is the last frame or not + is_start (bool): is the first frame for this stream or not + return_partial_result (bool): return partial result left after clip_end in the buffer + state_start_idx (int): start index from stream state + state_end_idx (int): end index from stream state + stop_history_eou (int): stop history of EOU, if None then use the default stop history + compute_confidence (bool): compute confidence or not + Returns: + tuple[dict, dict, bool, int, int]: + clipped output, tail output, is_eou, updated start_idx, updated end_idx + """ + + is_eou = is_last + eou_detected_at = len(log_probs) + # Initialize state tracking variables from input parameters + start_idx, end_idx = state_start_idx, state_end_idx + # Update indices for next processing step + if end_idx > clip_start: + end_idx -= self.tokens_per_frame + start_idx = end_idx + + if is_start or end_idx <= clip_start: + start_idx, end_idx = clip_start, clip_end + + all_output = self.greedy_decoder(log_probs, compute_confidence=compute_confidence) + + clipped_output = {"tokens": [], "timesteps": [], "confidences": [], "last_token": None, "last_token_idx": None} + tail_output = {"tokens": []} + + # check if EOU is detected or is the last frame + if not is_eou and self.endpointer is not None: + is_eou, eou_detected_at = self.endpointer.detect_eou( + log_probs, pivot_point=start_idx, search_start_point=clip_start, stop_history_eou=stop_history_eou + ) + + # if EOU is detected, and it is after the clip end, update the end index to the EOU + if is_eou and eou_detected_at > end_idx: + end_idx = eou_detected_at + + # if the end index is within the clip range, update the end index to the clip end + if clip_start <= end_idx < clip_end: + end_idx = clip_end + is_eou = False + + # clip the output within the clip range [clip_start, clip_end) + timesteps = all_output["timesteps"] + i = 0 + while i < len(timesteps): + if start_idx <= timesteps[i] < end_idx: + clipped_output["tokens"].append(all_output["tokens"][i]) + clipped_output["timesteps"].append(timesteps[i]) + if compute_confidence: + clipped_output["confidences"].append(all_output["confidences"][i]) + elif timesteps[i] >= end_idx: + break + i += 1 + + if end_idx - 1 < len(all_output["labels"]): + clipped_output["last_token"] = all_output["labels"][end_idx - 1] + clipped_output["last_token_idx"] = end_idx - 1 + + # return the partial result left after clip_end in the buffer + if return_partial_result: + while i < len(timesteps): + if timesteps[i] >= end_idx: + tail_output["tokens"] = all_output["tokens"][i:] + break + else: + i += 1 + + return clipped_output, tail_output, is_eou, start_idx, end_idx diff --git a/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_decoder.py b/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_decoder.py new file mode 100644 index 000000000000..9412330eb7a8 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_decoder.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable + +from nemo.collections.asr.inference.utils.constants import SENTENCEPIECE_UNDERSCORE + + +class GreedyDecoder: + """Base class for the greedy decoder""" + + def __init__(self, vocabulary: list[str], conf_func: Callable = None): + """ + Initialize the GreedyDecoder + Args: + vocabulary (list[str]): list of vocabulary tokens + conf_func (Callable): function to compute confidence + """ + + self.vocabulary = vocabulary + self.blank_id = len(vocabulary) + self.conf_func = conf_func + self.is_start_tokens = [token.startswith(SENTENCEPIECE_UNDERSCORE) for token in vocabulary] + + def count_silent_tokens(self, tokens: list[int], start: int, end: int) -> int: + """ + Count how many silent tokens appear in [start, end). + Args: + tokens (list[int]): list of tokens + start (int): start index + end (int): end index + Returns: + int: number of silent tokens + """ + if end <= start or start >= len(tokens): + return 0 + return sum(self.is_token_silent(tokens[i]) for i in range(start, min(end, len(tokens)))) + + def is_token_start_of_word(self, token_id: int) -> bool: + """ + Check if the token is the start of a word + Args: + token_id (int): token id + Returns: + bool: True if the token is the start of a word, False otherwise + """ + return self.is_start_tokens[token_id] + + def is_token_silent(self, token_id: int) -> bool: + """ + Check if the token is silent + Args: + token_id (int): token id + Returns: + bool: True if the token is silent, False otherwise + """ + return token_id == self.blank_id + + def first_non_silent_token(self, tokens: list[int], start: int, end: int) -> int: + """ + Return the index of the first non-silent token in [start, end). + If none found, return -1. + Args: + tokens (list[int]): list of tokens + start (int): start index + end (int): end index + Returns: + int: index of the first non-silent token + """ + for i in range(start, min(end, len(tokens))): + if not self.is_token_silent(tokens[i]): + return i + return -1 + + def count_non_silent_tokens(self, tokens: list[int], start: int, end: int) -> int: + """ + Count how many non-silent tokens appear in [start, end). + Args: + tokens (list[int]): list of tokens + start (int): start index + end (int): end index + Returns: + int: number of non-silent tokens + """ + if end <= start or start >= len(tokens): + return 0 + return sum(not self.is_token_silent(tokens[i]) for i in range(start, min(end, len(tokens)))) + + def __call__(self, *args, **kwds): + raise NotImplementedError("Subclass of GreedyDecoder should implement `__call__` method!") diff --git a/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_rnnt_decoder.py b/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_rnnt_decoder.py new file mode 100644 index 000000000000..223462409a18 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/decoders/greedy/greedy_rnnt_decoder.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable + +import torch + +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_decoder import GreedyDecoder + + +class RNNTGreedyDecoder(GreedyDecoder): + """RNNT Greedy decoder class""" + + def __init__(self, vocabulary: list[str], conf_func: Callable = None): + """ + Initialize the RNNTGreedyDecoder + Args: + vocabulary (list[str]): list of vocabulary tokens + conf_func (Callable): function to compute confidence + """ + super().__init__(vocabulary, conf_func) + + def __call__( + self, + global_timestamps: torch.Tensor | list[int], + tokens: torch.Tensor | list[int], + length: int, + offset: int = 0, + ) -> tuple[dict, list[int], int]: + """ + Decode the RNNT hypothesis using timestamps + Args: + global_timestamps (torch.Tensor | list[int]): global timestamps since the start of the stream + tokens (torch.Tensor | list[int]): tokens since the start of the stream + length (int): length of the alignment + offset (int): offset to apply to the timestamps to make them local + Returns: + tuple[dict, list[int], int]: + output: dictionary containing the decoded tokens, timestamps, and confidences + current labels: list of current labels including the blank token + new offset: new offset value for the next decoding step + """ + if isinstance(global_timestamps, list): + global_timestamps = torch.tensor(global_timestamps) + if isinstance(tokens, list): + tokens = torch.tensor(tokens) + + output = {"tokens": [], "timesteps": [], "confidences": [], "last_token": None, "last_token_idx": None} + cur_labels = [self.blank_id] * length + new_offset = len(tokens) + if offset > 0: + trimmed_tokens = tokens[offset:].tolist() + trimmed_timestamps = global_timestamps[offset:].tolist() + else: + trimmed_tokens = tokens.tolist() + trimmed_timestamps = global_timestamps.tolist() + + if len(trimmed_tokens) == 0: + return output, cur_labels, new_offset + + output["tokens"].extend(trimmed_tokens) + output["timesteps"].extend(trimmed_timestamps) + output["confidences"].extend([0.0] * len(trimmed_tokens)) + output["last_token"] = trimmed_tokens[-1] + output["last_token_idx"] = trimmed_timestamps[-1] + + for t, token in zip(trimmed_timestamps, trimmed_tokens): + cur_labels[t % length] = token + return output, cur_labels, new_offset + + +class ClippedRNNTGreedyDecoder: + """ + Clipped RNNT Greedy decoder class + Decodes the tokens within a given clip range and returns the clipped tokens and timestamps. + """ + + def __init__(self, vocabulary: list[str], tokens_per_frame: int, conf_func: Callable = None, endpointer=None): + """ + Initialize the ClippedRNNTGreedyDecoder + Args: + vocabulary (list[str]): list of vocabulary tokens + tokens_per_frame (int): number of tokens per frame + conf_func (Callable): function to compute confidence + endpointer (Any): endpointer to detect EOU + """ + self.greedy_decoder = RNNTGreedyDecoder(vocabulary, conf_func) + self.endpointer = endpointer + self.tokens_per_frame = tokens_per_frame + + @staticmethod + def extract_clipped_and_tail_single_pass( + timesteps: torch.Tensor, tokens: torch.Tensor, start_idx: int, end_idx: int, return_tail_result: bool + ) -> tuple[list[int], list[int], list[int]]: + """ + Extract clipped and tail data using tensor operations - no conversion overhead + """ + if len(timesteps) == 0: + return [], [], [] + clipped_mask = (timesteps >= start_idx) & (timesteps < end_idx) + clipped_timesteps = timesteps[clipped_mask].tolist() + clipped_tokens = tokens[clipped_mask].tolist() + tail_tokens = [] + if return_tail_result: + tail_mask = timesteps >= end_idx + if tail_mask.any(): + tail_tokens = tokens[tail_mask].tolist() + + return clipped_timesteps, clipped_tokens, tail_tokens + + def __call__( + self, + global_timesteps: torch.Tensor, + tokens: torch.Tensor, + clip_start: int, + clip_end: int, + alignment_length: int, + is_last: bool = False, + is_start: bool = True, + return_tail_result: bool = False, + state_start_idx: int = 0, + state_end_idx: int = 0, + timestamp_offset: int = 0, + vad_segments: torch.Tensor = None, + stop_history_eou: int = None, + ) -> tuple[dict, dict, bool, int, int]: + """ + Decode using timestamps instead of dense alignment + Optimized version with vectorized operations and single-pass processing + Args: + global_timesteps (torch.Tensor): global timestamps since the start of the stream + tokens (torch.Tensor): tokens + clip_start (int): start index of the clip + clip_end (int): end index of the clip + alignment_length (int): length of the alignment + is_last (bool): is the last frame or not. + is_start (bool): is the first frame for this stream or not. + return_tail_result (bool): return tail result left after clip_end in the buffer + state_start_idx (int): start index from stream state + state_end_idx (int): end index from stream state + timestamp_offset (int): offset to apply to the timestamps to make them local + vad_segments (torch.Tensor): Optional VAD segments to use for end-of-utterance detection + stop_history_eou (int): stop history of EOU, if None then use the default stop history + Returns: + tuple[dict, dict, bool, int, int]: + clipped output, tail output, is_eou, updated start_idx, updated end_idx + """ + # Initialize end-of-utterance state based on input parameters + if timestamp_offset: + timesteps = global_timesteps - timestamp_offset + else: + timesteps = global_timesteps + is_eou = is_last + eou_detected_at = alignment_length + start_idx, end_idx = state_start_idx, state_end_idx + if end_idx > clip_start: + end_idx -= self.tokens_per_frame + start_idx = end_idx + if is_start: + start_idx, end_idx = clip_start, clip_start + elif end_idx <= clip_start: + start_idx, end_idx = clip_start, clip_end + + if len(timesteps) == 0 or len(tokens) == 0: + return ( + {"tokens": [], "timesteps": [], "confidences": [], "last_token": None, "last_token_idx": None}, + {"tokens": []}, + True, + start_idx, + end_idx, + ) + + mask = timesteps >= start_idx + timesteps_trimmed = timesteps[mask] + tokens_trimmed = tokens[mask] + # If not already at end of utterance and endpointer exists, try to detect end of utterance + if not is_eou and self.endpointer is not None: + if vad_segments is not None and len(vad_segments) > 0: + if vad_segments[-1][1] != 0.0: + is_eou, eou_detected_at = self.endpointer.detect_eou_vad( + vad_segments=vad_segments, search_start_point=start_idx, stop_history_eou=stop_history_eou + ) + else: + is_eou = True + eou_detected_at = -1 + else: + is_eou, eou_detected_at = self.endpointer.detect_eou_given_timestamps( + timesteps=timesteps_trimmed, + tokens=tokens_trimmed, + alignment_length=alignment_length, + stop_history_eou=stop_history_eou, + ) + # If EOU is detected beyond current end frame, extend end frame to include it + if is_eou and eou_detected_at > end_idx: + end_idx = min(eou_detected_at, alignment_length) + + # If the end frame is within the clip range, set the end frame to the clip end + if clip_start <= end_idx < clip_end: + end_idx = clip_end + is_eou = False + clipped_timesteps, clipped_tokens, tail_tokens = self.extract_clipped_and_tail_single_pass( + timesteps, tokens, start_idx, end_idx, return_tail_result + ) + # Make timestamps global again + if timestamp_offset: + clipped_timesteps = [t + timestamp_offset for t in clipped_timesteps] + # Initialize output with last_token tracking like in __call__ method + clipped_output = { + "tokens": clipped_tokens, + "timesteps": clipped_timesteps, + "confidences": [0.0] * len(clipped_tokens) if len(clipped_tokens) > 0 else [], + "last_token": None, + "last_token_idx": None, + } + + # Set last_token and last_token_idx if there are tokens + if len(clipped_tokens) > 0: + clipped_output["last_token"] = clipped_tokens[-1] + clipped_output["last_token_idx"] = clipped_timesteps[-1] if len(clipped_timesteps) > 0 else None + + # Create tail output + tail_output = {"tokens": tail_tokens} + return clipped_output, tail_output, is_eou, start_idx, end_idx diff --git a/nemo/collections/asr/inference/streaming/endpointing/__init__.py b/nemo/collections/asr/inference/streaming/endpointing/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/endpointing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/endpointing/greedy/__init__.py b/nemo/collections/asr/inference/streaming/endpointing/greedy/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/endpointing/greedy/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_ctc_endpointing.py b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_ctc_endpointing.py new file mode 100644 index 000000000000..446c513ed49f --- /dev/null +++ b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_ctc_endpointing.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_ctc_decoder import CTCGreedyDecoder +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_endpointing import GreedyEndpointing + + +class CTCGreedyEndpointing(GreedyEndpointing): + """Greedy endpointing for the streaming CTC pipeline""" + + def __init__( + self, + vocabulary: list[str], + ms_per_timestep: int, + effective_buffer_size_in_secs: float = None, + stop_history_eou: int = -1, + residue_tokens_at_end: int = 0, + ) -> None: + """ + Initialize the CTCGreedyEndpointing class + Args: + vocabulary: (list[str]) List of vocabulary + ms_per_timestep: (int) Number of milliseconds per timestep + effective_buffer_size_in_secs: (float, optional) Effective buffer size for VAD-based EOU detection. Not used for CTC. + stop_history_eou: (int) Number of silent tokens to trigger a EOU, if -1 then it is disabled + residue_tokens_at_end: (int) Number of residue tokens at the end, if 0 then it is disabled + """ + super().__init__( + vocabulary, ms_per_timestep, effective_buffer_size_in_secs, stop_history_eou, residue_tokens_at_end + ) + self.greedy_ctc_decoder = CTCGreedyDecoder(self.vocabulary, conf_func=None) + + def detect_eou( + self, + probs_seq: torch.Tensor, + pivot_point: int, + search_start_point: int = 0, + stop_history_eou: int | None = None, + ) -> tuple[bool, int]: + """ + Detect end of utterance (EOU) given the probabilities sequence and pivot point + Args: + probs_seq (torch.Tensor): probabilities sequence + pivot_point (int): pivot point + search_start_point (int): start point for searching EOU + stop_history_eou (int | None): stop history of EOU, if None then use the stop history of EOU from the class + Returns: + bool: True if EOU is detected, False otherwise + int: index of the EOU detected at + """ + emissions = self.greedy_ctc_decoder.get_labels(probs_seq) + return self.detect_eou_given_emissions(emissions, pivot_point, search_start_point, stop_history_eou) + + def is_token_start_of_word(self, token_id: int) -> bool: + """ + Check if the token is the start of a word + Args: + token_id (int): token id + Returns: + bool: True if the token is the start of a word, False otherwise + """ + return self.greedy_ctc_decoder.is_token_start_of_word(token_id=token_id) + + def is_token_silent(self, token_id: int) -> bool: + """ + Check if the token is silent + Args: + token_id (int): token id + Returns: + bool: True if the token is silent, False otherwise + """ + return self.greedy_ctc_decoder.is_token_silent(token_id=token_id) diff --git a/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py new file mode 100644 index 000000000000..1442043ae711 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from nemo.collections.asr.inference.utils.endpointing_utils import get_custom_stop_history_eou, millisecond_to_frames + + +class GreedyEndpointing: + """Greedy endpointing for the streaming ASR pipelines""" + + def __init__( + self, + vocabulary: list[str], + ms_per_timestep: int, + effective_buffer_size_in_secs: float = None, + stop_history_eou: int = -1, + residue_tokens_at_end: int = 0, + ) -> None: + """ + Initialize the GreedyEndpointing class + Args: + vocabulary: (list[str]) List of vocabulary + ms_per_timestep: (int) Number of milliseconds per timestep + effective_buffer_size_in_secs: (float, optional) Effective buffer size for VAD-based EOU detection. + stop_history_eou: (int) Number of silent tokens to trigger a EOU, if -1 then it is disabled + residue_tokens_at_end: (int) Number of residue tokens at the end, if 0 then it is disabled + """ + + self.vocabulary = vocabulary + self.ms_per_timestep = ms_per_timestep + self.sec_per_timestep = ms_per_timestep / 1000 + self.stop_history_eou = stop_history_eou + self.stop_history_eou_ms = stop_history_eou + self.effective_buffer_size_in_secs = effective_buffer_size_in_secs + if self.stop_history_eou > 0: + self.stop_history_eou = millisecond_to_frames(self.stop_history_eou, ms_per_timestep) + self.residue_tokens_at_end = residue_tokens_at_end + + def detect_eou_given_emissions( + self, + emissions: list[int], + pivot_point: int, + search_start_point: int = 0, + stop_history_eou: int | None = None, + ) -> tuple[bool, int]: + """ + Detect end of utterance (EOU) given the emissions and pivot point + Args: + emissions (list[int]): list of emissions at each timestep + pivot_point (int): pivot point around which to detect EOU + search_start_point (int): start point for searching EOU + stop_history_eou (int | None): stop history of EOU, if None then use the stop history of EOU from the class + Returns: + Tuple[bool, int]: True if EOU is detected, False otherwise, and the point at which EOU is detected + """ + sequence_length = len(emissions) + if pivot_point < 0 or pivot_point >= sequence_length: + raise ValueError("Pivot point is out of range") + + if search_start_point > pivot_point: + raise ValueError("Search start point is greater than pivot_point") + + if self.residue_tokens_at_end > 0: + sequence_length = max(0, sequence_length - self.residue_tokens_at_end) + + stop_history_eou = get_custom_stop_history_eou(stop_history_eou, self.stop_history_eou, self.ms_per_timestep) + eou_detected, eou_detected_at = False, -1 + + if stop_history_eou > 0: + n_silent_tokens = 0 + silence_start_position = -1 + fst_non_silent_token = None + end_point = max(0, search_start_point, pivot_point - stop_history_eou) + current_position = max(0, sequence_length - 1) + while current_position >= end_point: + if self.is_token_silent(emissions[current_position]): + n_silent_tokens += 1 + eou_detected = n_silent_tokens > stop_history_eou + is_token_start_of_word = (fst_non_silent_token is None) or self.is_token_start_of_word( + fst_non_silent_token + ) + eou_detected = eou_detected and is_token_start_of_word + if eou_detected: + silence_start_position = current_position + else: + if eou_detected: + break + n_silent_tokens = 0 + eou_detected = False + silence_start_position = -1 + fst_non_silent_token = emissions[current_position] + current_position -= 1 + + eou_detected = n_silent_tokens > stop_history_eou + if eou_detected: + eou_detected_at = int(silence_start_position + stop_history_eou // 2) + + return eou_detected, eou_detected_at + + def detect_eou_given_timestamps( + self, + timesteps: torch.Tensor, + tokens: torch.Tensor, + alignment_length: int, + stop_history_eou: int | None = None, + ) -> tuple[bool, int]: + """ + Detect end of utterance (EOU) given timestamps and tokens using tensor operations. + Args: + timesteps (torch.Tensor): timestamps of the tokens + tokens (torch.Tensor): tokens + alignment_length (int): length of the alignment + stop_history_eou (int | None): stop history of EOU, if None then use the stop history of EOU from the class + Returns: + tuple[bool, int]: True if EOU is detected, False otherwise, and the point at which EOU is detected + """ + eou_detected, eou_detected_at = False, -1 + + if len(timesteps) != len(tokens): + raise ValueError("timesteps and tokens must have the same length") + + stop_history_eou = get_custom_stop_history_eou(stop_history_eou, self.stop_history_eou, self.ms_per_timestep) + + # If stop_history_eou is negative, don't detect EOU. + if len(timesteps) == 0 or stop_history_eou < 0: + return eou_detected, eou_detected_at + + # This is the condition for Riva streaming offline mode. The output of entire buffer needs to be sent as is to the client. + if stop_history_eou == 0: + return True, alignment_length + + if self.residue_tokens_at_end > 0: + alignment_length = max(0, alignment_length - self.residue_tokens_at_end) + + # Check trailing silence at the end + last_timestamp = timesteps[-1].item() + trailing_silence = max(0, alignment_length - last_timestamp - 1) + if trailing_silence > stop_history_eou: + eou_detected = True + eou_detected_at = last_timestamp + 1 + stop_history_eou // 2 + return eou_detected, eou_detected_at + + # Check gaps between consecutive non-silent tokens + if len(timesteps) > 1: + gaps = timesteps[1:] - timesteps[:-1] - 1 + large_gap_mask = gaps > stop_history_eou + if large_gap_mask.any(): + # Get the last (rightmost) large gap index for backwards compatibility + large_gap_indices = torch.where(large_gap_mask)[0] + gap_idx = large_gap_indices[-1].item() + + eou_detected = True + eou_detected_at = timesteps[gap_idx].item() + 1 + stop_history_eou // 2 + return eou_detected, eou_detected_at + return eou_detected, eou_detected_at + + def detect_eou_vad( + self, vad_segments: torch.Tensor, search_start_point: float = 0, stop_history_eou: int | None = None + ) -> tuple[bool, float]: + """ + Detect end of utterance (EOU) using VAD segments. + + Args: + vad_segments (torch.Tensor): VAD segments in format [N, 2] where each row is [start_time, end_time] + search_start_point (float): Start time for searching EOU in seconds + stop_history_eou (int | None): Stop history of EOU in milliseconds, if None then use the stop history of EOU from the class + Returns: + tuple[bool, float]: (is_eou, eou_detected_at_time) + """ + if self.effective_buffer_size_in_secs is None: + raise ValueError("Effective buffer size in seconds is required for VAD-based EOU detection") + + # Use default stop history of EOU from the class if stop_history_eou is not provided + stop_history_eou = self.stop_history_eou_ms if stop_history_eou is None else stop_history_eou + if stop_history_eou < 0: + return False, -1 + + search_start_point = search_start_point * self.sec_per_timestep + stop_history_eou_in_secs = stop_history_eou / 1000 + # Round to 4 decimal places first (vectorized) + rounded_segments = torch.round(vad_segments, decimals=4) + + # Filter segments where end_time > search_start_point + valid_mask = rounded_segments[:, 1] > search_start_point + if not valid_mask.any(): + return False, -1 + + filtered_segments = rounded_segments[valid_mask] + + # Clip start times to search_start_point + filtered_segments[:, 0] = torch.clamp(filtered_segments[:, 0], min=search_start_point) + # Initialize EOU detection variables + is_eou = False + eou_detected_at = -1 + + # Check gap to buffer end + last_segment = filtered_segments[-1] + gap_to_buffer_end = self.effective_buffer_size_in_secs - last_segment[1] + if gap_to_buffer_end > stop_history_eou_in_secs: + # EOU detected at buffer end + is_eou = True + eou_detected_at = last_segment[1] + stop_history_eou_in_secs / 2 + + elif len(filtered_segments) >= 2: + # Check gaps between segments (reverse order to find last gap) + for i in range(len(filtered_segments) - 2, -1, -1): + segment = filtered_segments[i] + next_segment = filtered_segments[i + 1] + gap = next_segment[0] - segment[1] + if gap > stop_history_eou_in_secs: + is_eou = True + eou_detected_at = segment[1] + stop_history_eou_in_secs / 2 + break + + # Convert to timesteps (only if EOU was detected) + if is_eou: + eou_detected_at = int(eou_detected_at // self.sec_per_timestep) + else: + eou_detected_at = -1 + + return is_eou, eou_detected_at + + def is_token_start_of_word(self, token_id: int) -> bool: + """Check if the token is the start of a word""" + raise NotImplementedError("Subclass of GreedyEndpointing should implement `is_token_start_of_word` method!") + + def is_token_silent(self, token_id: int) -> bool: + """Check if the token is silent""" + raise NotImplementedError("Subclass of GreedyEndpointing should implement `is_token_silent` method!") + + def detect_eou_near_pivot( + self, + emissions: list[int], + pivot_point: int, + search_start_point: int = 0, + stop_history_eou: int | None = None, + ) -> tuple[bool, int]: + """ + Detect end of utterance (EOU) given the emissions and pivot point + Args: + emissions (list[int]): list of emissions at each timestep + pivot_point (int): pivot point around which to detect EOU + search_start_point (int): start point for searching EOU + stop_history_eou (int | None): stop history of EOU, if None then use the stop history of EOU from the class + Returns: + tuple[bool, int]: True if EOU is detected, False otherwise, and the point at which EOU is detected + """ + + sequence_length = len(emissions) + + if pivot_point < 0 or pivot_point >= sequence_length: + raise ValueError("Pivot point is out of range") + + if search_start_point > pivot_point: + raise ValueError("Search start point is greater then pivot_point") + + if self.residue_tokens_at_end > 0: + sequence_length = max(0, sequence_length - self.residue_tokens_at_end) + + stop_history_eou = get_custom_stop_history_eou(stop_history_eou, self.stop_history_eou, self.ms_per_timestep) + eou_detected, eou_detected_at = False, -1 + + if stop_history_eou > 0: + + # number of silent tokens in the range [search_start_point, pivot_point) + n_silent_tokens_before = 0 + i = pivot_point - 1 + while i >= search_start_point: + if self.is_token_silent(emissions[i]): + n_silent_tokens_before += 1 + else: + break + i -= 1 + + # number of silent tokens in the range [pivot_point, sequence_length) + n_silent_tokens_after = 0 + i = pivot_point + fst_non_silent_token_after = None + while i < sequence_length: + if self.is_token_silent(emissions[i]): + n_silent_tokens_after += 1 + else: + fst_non_silent_token_after = emissions[i] + break + i += 1 + + # additional check for the first non-silent token after the pivot point + if fst_non_silent_token_after is not None: + if not self.is_token_start_of_word(fst_non_silent_token_after): + eou_detected, eou_detected_at = False, -1 + else: + # check if the number of silent tokens before and after the pivot point is greater than the threshold + val_cnt = n_silent_tokens_before + n_silent_tokens_after + eou_detected = val_cnt > stop_history_eou + eou_detected_at = ( + int(pivot_point - n_silent_tokens_before + stop_history_eou // 2) if eou_detected else -1 + ) + + return eou_detected, eou_detected_at diff --git a/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_rnnt_endpointing.py b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_rnnt_endpointing.py new file mode 100644 index 000000000000..71598fdc3809 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_rnnt_endpointing.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_rnnt_decoder import RNNTGreedyDecoder +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_endpointing import GreedyEndpointing + + +class RNNTGreedyEndpointing(GreedyEndpointing): + """Greedy endpointing for the streaming RNNT pipeline""" + + def __init__( + self, + vocabulary: list[str], + ms_per_timestep: int, + effective_buffer_size_in_secs: float = None, + stop_history_eou: int = -1, + residue_tokens_at_end: int = 0, + ) -> None: + """ + Initialize the RNNTGreedyEndpointing class + Args: + vocabulary: (list[str]) List of vocabulary + ms_per_timestep: (int) Number of milliseconds per timestep + effective_buffer_size_in_secs: (float, optional) Effective buffer size for VAD-based EOU detection for stateless and stateful RNNT. If None, VAD functionality is disabled. + stop_history_eou: (int) Number of silent tokens to trigger a EOU, if -1 then it is disabled + residue_tokens_at_end: (int) Number of residue tokens at the end, if 0 then it is disabled + """ + super().__init__( + vocabulary, ms_per_timestep, effective_buffer_size_in_secs, stop_history_eou, residue_tokens_at_end + ) + self.greedy_rnnt_decoder = RNNTGreedyDecoder(self.vocabulary, conf_func=None) + + def is_token_start_of_word(self, token_id: int) -> bool: + """ + Check if the token is the start of a word + Args: + token_id (int): token id + Returns: + bool: True if the token is the start of a word, False otherwise + """ + return self.greedy_rnnt_decoder.is_token_start_of_word(token_id=token_id) + + def is_token_silent(self, token_id: int) -> bool: + """ + Check if the token is silent + Args: + token_id (int): token id + Returns: + bool: True if the token is silent, False otherwise + """ + return self.greedy_rnnt_decoder.is_token_silent(token_id=token_id) diff --git a/nemo/collections/asr/inference/streaming/framing/__init__.py b/nemo/collections/asr/inference/streaming/framing/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/framing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/framing/mono_stream.py b/nemo/collections/asr/inference/streaming/framing/mono_stream.py new file mode 100644 index 000000000000..f942d6d57712 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/framing/mono_stream.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from nemo.collections.asr.inference.streaming.framing.request import Frame, RequestOptions +from nemo.collections.asr.inference.streaming.framing.stream import Stream +from nemo.collections.asr.inference.utils.audio_io import read_audio + + +class MonoStream(Stream): + """ + Streamer for mono wav files. + Iterates over the frames of the audio file + """ + + def __init__(self, rate: int, frame_size_in_secs: float, stream_id: int, pad_last_frame: bool = False): + """ + Initialize the MonoStream + Args: + rate (int): sampling rate + frame_size_in_secs (int): frame length in seconds + stream_id (int): stream id + """ + + self.rate = rate + self.frame_size = int(frame_size_in_secs * rate) + self.pad_last_frame = pad_last_frame + + self.samples = None + self.n_samples = None + self.options = None + super().__init__(stream_id) + + def load_audio(self, audio: str | torch.Tensor, options: RequestOptions | None = None) -> None: + """ + Load the audio file either from a file or from a torch tensor + Args: + audio (str | torch.Tensor): audio file path or torch tensor of audio samples + options (RequestOptions | None): optional options for the request + """ + if isinstance(audio, str): + # Read the audio file and convert to mono + self.samples = read_audio(audio, target_sr=self.rate, mono=True) + else: + self.samples = audio + self.n_samples = len(self.samples) + self.frame_count = 0 # Reset frame count + self.options = options + + def __iter__(self): + """Returns the frame iterator object""" + self.start = 0 + self.frame_count = 0 + return self + + def __next__(self) -> list[Frame]: + """ + Get the next frame in the stream + Returns: + list[Frame]: The next frame in the stream + """ + if self.samples is None: + raise RuntimeError("No audio samples loaded. Please call load_audio() first.") + + if self.start < self.n_samples: + + end = min(self.start + self.frame_size, self.n_samples) + + # Check if this is the last frame + is_end = False + chunk_length = end - self.start + if (end - self.start < self.frame_size) or (end == self.n_samples): + is_end = True + + # Pad the last frame if needed + if not is_end: + chunk_samples = self.samples[self.start : end] + else: + if self.pad_last_frame: + chunk_samples = torch.zeros(self.frame_size) + chunk_samples[:chunk_length] = self.samples[self.start : end] + else: + chunk_samples = self.samples[self.start : end] + + # Package the frame + is_first = self.frame_count == 0 + frame = Frame( + samples=chunk_samples, + stream_id=self.stream_id, + is_first=is_first, + is_last=is_end, + length=chunk_length, + options=self.options if is_first else None, + ) + + self.frame_count += 1 + self.start += frame.size + + return [frame] + + # End of stream + raise StopIteration diff --git a/nemo/collections/asr/inference/streaming/framing/multi_stream.py b/nemo/collections/asr/inference/streaming/framing/multi_stream.py new file mode 100644 index 000000000000..24bda4b03fde --- /dev/null +++ b/nemo/collections/asr/inference/streaming/framing/multi_stream.py @@ -0,0 +1,389 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Iterator + +import torch + +from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import BatchedAudioBufferer +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request, RequestOptions +from nemo.collections.asr.inference.streaming.framing.stream import Stream +from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.utils.progressbar import ProgressBar + + +class MultiStream: + """MultiStreamer for multiple streams""" + + def __init__(self, n_frames_per_stream: int): + """ + Args: + n_frames_per_stream (int): Number of frames per stream + """ + self.n_frames_per_stream = n_frames_per_stream + self.streams = {} + + def add_stream(self, stream: Stream, stream_id: int) -> None: + """ + Add a stream to the streamer + Args: + stream (Stream): The stream to add + stream_id (int): The id of the stream + """ + self.streams[stream_id] = iter(stream) + + def rm_stream(self, stream_id: int) -> None: + """ + Remove a stream from the streamer + Args: + stream_id (int): The id of the stream + """ + self.streams.pop(stream_id, None) + + def __len__(self) -> int: + """Number of running streams""" + return len(self.streams) + + def __iter__(self) -> Iterator: + """Returns the iterator object""" + return self + + def __next__(self) -> list[Frame]: + """ + Get the next batch of frames + Returns: + list[Frame]: The next batch of frames + """ + frame_batch = [] + ids_to_remove = [] + for stream_id, stream_iter in self.streams.items(): + # Get n_frames_per_stream frames from each stream + for _ in range(self.n_frames_per_stream): + frame = next(stream_iter)[0] + frame_batch.append(frame) + if frame.is_last: + ids_to_remove.append(stream_id) + + # Remove streams that have ended + for stream_id in ids_to_remove: + self.rm_stream(stream_id) + + # If no frames are generated, raise StopIteration + if len(frame_batch) == 0: + raise StopIteration + + return frame_batch + + +class ContinuousBatchedFrameStreamer: + """ + A class that manages continuous streaming of audio frames from multiple audio files, providing + frame generation in batches. The class supports dynamically adding audio streams, updating + a progress bar, and yielding batches of frames for further processing. + """ + + def __init__( + self, + sample_rate: int, + frame_size_in_secs: float, + batch_size: int, + n_frames_per_stream: int, + pad_last_frame: bool = False, + ): + """ + Args: + sample_rate (int): The sample rate of the audio + frame_size_in_secs (float): The size of the frame in seconds + batch_size (int): The batch size + n_frames_per_stream (int): The number of frames per stream + pad_last_frame (bool): Whether to pad the last frame + """ + + self.sample_rate = sample_rate + self.frame_size_in_secs = frame_size_in_secs + self.batch_size = batch_size + self.pad_last_frame = pad_last_frame + + self.multi_streamer = MultiStream(n_frames_per_stream=n_frames_per_stream) + self.stream_id = 0 + + self._progress_bar = None + self.processed_streams = set() + + def set_audio_filepaths(self, audio_filepaths: list[str], options: list[RequestOptions]) -> None: + """ + Set the audio filepaths + Args: + audio_filepaths (list[str]): The list of audio filepaths + options (list[RequestOptions]): The list of options + """ + if len(audio_filepaths) != len(options): + raise ValueError("audio_filepaths and options must have the same length") + + self.audio_filepaths = audio_filepaths + self.options = options + self.n_audio_files = len(audio_filepaths) + self.total_progress_steps = self.n_audio_files * 2 # One step for adding, one for processing + self.sid2filepath = {} + + def set_progress_bar(self, progress_bar: ProgressBar) -> None: + """ + Set the progress bar + Args: + progress_bar (ProgressBar): The progress bar to set + """ + self._progress_bar = progress_bar + self.restart_progress_bar() + + def restart_progress_bar(self) -> None: + """Restart the progress bar""" + if self._progress_bar: + self._progress_bar.restart() + + def update_progress_bar(self) -> None: + """Update the progress bar""" + if self._progress_bar: + self._progress_bar.update_bar(1 / self.total_progress_steps) + + def finish_progress_bar(self) -> None: + """Finish the progress bar""" + if self._progress_bar: + self._progress_bar.finish() + + def __iter__(self) -> Iterator: + """Returns the iterator object""" + return self + + def add_stream(self) -> None: + """Create a new stream and add it to the streamer""" + if self.stream_id >= self.n_audio_files: + return # No more files to add + + # Create a new stream + stream = MonoStream( + self.sample_rate, self.frame_size_in_secs, stream_id=self.stream_id, pad_last_frame=self.pad_last_frame + ) + # Load the next audio file + audio_filepath = self.audio_filepaths[self.stream_id] + options = self.options[self.stream_id] + self.sid2filepath[self.stream_id] = audio_filepath + stream.load_audio(audio_filepath, options) + + # Add the stream to the multi streamer + self.multi_streamer.add_stream(stream, stream_id=self.stream_id) + self.stream_id += 1 + + # Update the progress bar + self.update_progress_bar() + + def __next__(self) -> list[Frame]: + """ + Get the next batch of frames, continuously adding streams + Returns: + list[Frame]: The next batch of frames + """ + # If there are fewer streams than batch size, add more streams + while len(self.multi_streamer) < self.batch_size and self.stream_id < self.n_audio_files: + self.add_stream() + + try: + frames = next(self.multi_streamer) + # Update progress when a stream is fully processed + for frame in frames: + if frame.stream_id not in self.processed_streams and frame.is_last: + self.processed_streams.add(frame.stream_id) + self.update_progress_bar() + return frames + except StopIteration: + # if there are remaining streams, add them + if self.stream_id < self.n_audio_files: + return self.__next__() + + if self.stream_id == self.n_audio_files: + self.finish_progress_bar() + raise StopIteration + + raise ValueError("stream_id > self.n_audio_files unexpected") + + +class ContinuousBatchedRequestStreamer: + """ + A class that manages continuous streaming of requests from multiple audio files, providing + request generation in batches. Requests can be frames or feature buffers. + The class supports dynamically adding audio streams, updating a progress bar, + and yielding batches of requests for further processing. + """ + + def __init__( + self, + sample_rate: int, + frame_size_in_secs: float, + batch_size: int, + n_frames_per_stream: int, + request_type: RequestType = RequestType.FRAME, + preprocessor: Callable = None, + buffer_size_in_secs: float = None, + device: torch.device = None, + pad_last_frame: bool = False, + right_pad_features: bool = False, + tail_padding_in_samples: int = 0, + ): + """ + Args: + sample_rate (int): The sample rate of the audio + frame_size_in_secs (float): The size of the frame in seconds + batch_size (int): The batch size + n_frames_per_stream (int): The number of frames per stream + request_type (RequestType): The type of request + preprocessor (Callable): Preprocessor object, required for request type FEATURE_BUFFER + buffer_size_in_secs (float): The size of the buffer in seconds, required for request type FEATURE_BUFFER + device (torch.device): The device to use, required for request type FEATURE_BUFFER + pad_last_frame (bool): Whether to pad the last frame + right_pad_features (bool): Whether to right pad the features, optional for request type FEATURE_BUFFER + tail_padding_in_samples (int): The tail padding in samples, optional for request type FEATURE_BUFFER + """ + + if request_type is RequestType.FEATURE_BUFFER: + if buffer_size_in_secs is None: + raise ValueError("buffer_size_in_secs must be provided for request type FEATURE_BUFFER") + if preprocessor is None: + raise ValueError("preprocessor must be provided for request type FEATURE_BUFFER") + if device is None: + raise ValueError("device must be provided for request type FEATURE_BUFFER") + + self.request_type = request_type + self.multi_streamer = ContinuousBatchedFrameStreamer( + sample_rate=sample_rate, + frame_size_in_secs=frame_size_in_secs, + batch_size=batch_size, + n_frames_per_stream=n_frames_per_stream, + pad_last_frame=pad_last_frame, + ) + + if self.request_type is RequestType.FEATURE_BUFFER: + self.preprocessor = preprocessor + self.device = device + self.audio_bufferer = BatchedAudioBufferer( + sample_rate=sample_rate, buffer_size_in_secs=buffer_size_in_secs + ) + self.right_pad_features = right_pad_features + self.tail_padding_in_samples = tail_padding_in_samples + + def set_audio_filepaths(self, audio_filepaths: list[str], options: list[RequestOptions]) -> None: + """ + Set the audio filepaths + Args: + audio_filepaths (list[str]): The list of audio filepaths + options (list[RequestOptions]): The list of options + """ + self.multi_streamer.set_audio_filepaths(audio_filepaths, options) + + def set_progress_bar(self, progress_bar: ProgressBar) -> None: + """ + Set the progress bar + Args: + progress_bar (ProgressBar): The progress bar to set + """ + self.multi_streamer.set_progress_bar(progress_bar) + + def get_audio_filepath(self, stream_id: int) -> str: + """ + Get the audio filepath for a given stream id + Args: + stream_id (int): The id of the stream + Returns: + str: The audio filepath for the given stream id + """ + return self.multi_streamer.sid2filepath[stream_id] + + def to_feature_buffers(self, frames: list[Frame]) -> list[FeatureBuffer]: + """ + Convert frames to feature buffers + Args: + frames (list[Frame]): The list of frames + Returns: + list[FeatureBuffer]: The list of feature buffers + """ + + # Buffer input frames + buffered_frames, left_paddings = self.audio_bufferer.update(frames) + buffers = [] + + # If right padding is enabled, convert left paddings to tensor + if self.right_pad_features: + left_paddings = torch.tensor(left_paddings, dtype=torch.int64, device=self.device) + + # If right padding is enabled, roll the frames to the left + for i in range(len(buffered_frames)): + if self.right_pad_features: + lpad = left_paddings[i].item() + if lpad > 0: + buffered_frames[i] = buffered_frames[i].roll(shifts=-lpad) + buffers.append(buffered_frames[i].unsqueeze_(0)) + + buffer_lens = torch.tensor([buffers[0].size(1)] * len(buffers), device=self.device) + + # Calculate right paddings and subtract from buffer lens + # tail_padding_in_samples is used to keep some amount of padding at the end of the buffer + # some models perform better with this padding + right_paddings = torch.tensor( + [frame.size - frame.valid_size - self.tail_padding_in_samples for frame in frames], device=self.device + ).clamp(min=0) + + # Subtract right paddings from buffer lens + buffer_lens = buffer_lens - right_paddings + + # If right padding is enabled, subtract left paddings from buffer lens + # Becouse we rolled the frames to the left + if self.right_pad_features: + buffer_lens = buffer_lens - left_paddings + + # Apply preprocessor to get mel spectrograms + feature_buffers, feature_buffer_lens = self.preprocessor( + input_signal=torch.cat(buffers).to(self.device), length=buffer_lens + ) + + # Adjust left paddings after preprocessor + if self.right_pad_features: + left_paddings = left_paddings / self.preprocessor.featurizer.hop_length + left_paddings = left_paddings.to(torch.int64) + + return [ + FeatureBuffer( + features=feature_buffers[i], + is_first=frame.is_first, + is_last=frame.is_last, + stream_id=frame.stream_id, + right_pad_features=self.right_pad_features, + length=feature_buffer_lens[i].item(), + left_padding_length=left_paddings[i].item() if self.right_pad_features else 0, + options=frame.options, + ) + for i, frame in enumerate(frames) + ] + + def __iter__(self) -> Iterator: + """Returns the iterator object""" + return self + + def __next__(self) -> list[Request]: + """Get the next batch of requests. + Returns: + list of frames or feature buffers. + """ + if self.request_type is RequestType.FRAME: + return next(self.multi_streamer) + return self.to_feature_buffers(next(self.multi_streamer)) diff --git a/nemo/collections/asr/inference/streaming/framing/request.py b/nemo/collections/asr/inference/streaming/framing/request.py new file mode 100644 index 000000000000..5706bcc31232 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/framing/request.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import TypeAlias + +import torch + +from nemo.collections.asr.inference.streaming.framing.request_options import RequestOptions + + +@dataclass(frozen=True, slots=True) +class Frame: + """ + Immutable dataclass representing + + Args: + samples (torch.Tensor): The actual frame data. For audio, shape is (T,). + stream_id (int): Unique identifier for the stream this frame belongs to + is_first (bool): Flag indicating if this is the first frame in the stream + is_last (bool): Flag indicating if this is the last frame in the stream + length (int): Length of the frame without padding. + If -1, returns the size of the frame including padding. + vad_segments (torch.Tensor | None): Optional VAD segments to use for end-of-utterance detection. + Shape is [num_vad_segments, 2] where each segment contains + [start_time, end_time]. Variable for each stream. + options (RequestOptions | None): Optional options for the request + """ + + samples: torch.Tensor + stream_id: int + is_first: bool = False + is_last: bool = False + length: int = -1 + vad_segments: torch.Tensor | None = None + options: RequestOptions | None = None + + @property + def size(self) -> int: + """Returns the size of the frame including padding""" + return self.samples.shape[0] + + @property + def valid_size(self) -> int: + """Returns the size of the frame without padding""" + return self.size if self.length == -1 else self.length + + +@dataclass(frozen=True, slots=True) +class FeatureBuffer: + """ + Immutable dataclass representing a buffer of features. + Args: + features (torch.Tensor): The actual frame data. For features, shape is (feature_dim, T). + stream_id (int): Unique identifier for the stream this frame belongs to + is_first (bool): Flag indicating if this is the first frame in the stream + is_last (bool): Flag indicating if this is the last frame in the stream + right_pad_features (bool): Flag indicating if the features are right padded + length (int): Length of the valid features in the buffer + If -1, returns the size of the buffer including padding + left_padding_length (int): Length of the left padding in the buffer + It is used to roll features to the right + vad_segments (torch.Tensor | None): Optional VAD segments to use for end-of-utterance detection. + Shape is [num_vad_segments, 2] where each segment contains + [start_time, end_time]. Variable for each stream. + options (RequestOptions | None): Optional options for the request + """ + + features: torch.Tensor + stream_id: int + is_first: bool = False + is_last: bool = False + right_pad_features: bool = False + length: int = -1 + left_padding_length: int = 0 + vad_segments: torch.Tensor | None = None + options: RequestOptions | None = None + + @property + def size(self) -> int: + """Returns the number of features in the buffer including padding""" + return self.features.shape[1] + + @property + def valid_size(self) -> int: + """Returns the size of the buffer without padding. It is a actual length of the signal""" + return self.size if self.length == -1 else self.length + + @property + def roll_size(self) -> int: + """Returns the size of the buffer to roll to the right. It only makes sense for right padded feature buffers""" + return self.left_padding_length if self.right_pad_features else 0 + + +Request: TypeAlias = Frame | FeatureBuffer diff --git a/nemo/collections/asr/inference/streaming/framing/request_options.py b/nemo/collections/asr/inference/streaming/framing/request_options.py new file mode 100644 index 000000000000..fff6f7677c2a --- /dev/null +++ b/nemo/collections/asr/inference/streaming/framing/request_options.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import TypeAlias +from nemo.collections.asr.inference.utils.enums import ASROutputGranularity + + +@dataclass(slots=True) +class ASRRequestOptions: + """ + Immutable dataclass representing options for a request + None value means that the option is not set and the default value will be used + """ + + enable_itn: bool = None + enable_pnc: bool = None + stop_history_eou: int = None + asr_output_granularity: ASROutputGranularity | str = None + + def __post_init__(self) -> None: + """ + Post-init hook: + Converts the asr_output_granularity to ASROutputGranularity if it is a string + """ + if isinstance(self.asr_output_granularity, str): + self.asr_output_granularity = ASROutputGranularity.from_str(self.asr_output_granularity) + + def is_word_level_output(self) -> bool: + """ + Check if the output granularity is word level. + """ + return self.asr_output_granularity is ASROutputGranularity.WORD + + def is_segment_level_output(self) -> bool: + """ + Check if the output granularity is segment level. + """ + return self.asr_output_granularity is ASROutputGranularity.SEGMENT + + def augment_with_defaults( + self, + default_enable_itn: bool, + default_enable_pnc: bool, + default_stop_history_eou: int, + default_asr_output_granularity: ASROutputGranularity | str, + ) -> "ASRRequestOptions": + """ + Augment the options with the default values. + Args: + default_enable_itn (bool): Default enable ITN. + default_enable_pnc (bool): Default enable PNC. + default_stop_history_eou (int): Default stop history EOU. + default_asr_output_granularity (ASROutputGranularity | str): Default output granularity. + Returns: + ASRRequestOptions: Augmented options. + """ + if isinstance(default_asr_output_granularity, str): + default_asr_output_granularity = ASROutputGranularity.from_str(default_asr_output_granularity) + return ASRRequestOptions( + enable_itn=default_enable_itn if self.enable_itn is None else self.enable_itn, + enable_pnc=default_enable_pnc if self.enable_pnc is None else self.enable_pnc, + stop_history_eou=default_stop_history_eou if self.stop_history_eou is None else self.stop_history_eou, + asr_output_granularity=( + default_asr_output_granularity if self.asr_output_granularity is None else self.asr_output_granularity + ), + ) + + +RequestOptions: TypeAlias = ASRRequestOptions diff --git a/nemo/collections/asr/inference/streaming/framing/stream.py b/nemo/collections/asr/inference/streaming/framing/stream.py new file mode 100644 index 000000000000..ecf2731d3f5b --- /dev/null +++ b/nemo/collections/asr/inference/streaming/framing/stream.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Stream: + """ + Minimal interface for a stream + """ + + def __init__(self, stream_id: int): + """ + Args: + stream_id: (int) The id of the stream. + """ + self._stream_id = stream_id + self.frame_count = 0 + + def __iter__(self): + """Returns the iterator object""" + return self + + def __next__(self): + """Get the next frame in the stream""" + raise NotImplementedError("Subclasses must implement __next__ method") + + @property + def stream_id(self) -> int: + """Get the stream id""" + return self._stream_id diff --git a/nemo/collections/asr/inference/streaming/state/__init__.py b/nemo/collections/asr/inference/streaming/state/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_ctc_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_ctc_state.py new file mode 100644 index 000000000000..5b0beda8b2cb --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_ctc_state.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState + + +class CacheAwareCTCStreamingState(CacheAwareStreamingState): + """ + State of the cache aware CTC streaming pipelines + """ + + pass diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py new file mode 100644 index 000000000000..d3efd80e4396 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis + + +class CacheAwareRNNTStreamingState(CacheAwareStreamingState): + """ + State of the cache aware RNNT streaming pipelines + """ + + def __init__(self): + """ + Initialize the CacheAwareRNNTStreamingState + """ + super().__init__() + self._additional_params_reset() + + def reset(self) -> None: + """ + Reset the state + """ + super().reset() + self._additional_params_reset() + + def _additional_params_reset(self) -> None: + """ + Reset non-inherited parameters + """ + super()._additional_params_reset() + self.previous_hypothesis = None + + def set_previous_hypothesis(self, previous_hypothesis: Hypothesis) -> None: + """ + Set the previous hypothesis + Args: + previous_hypothesis: (Hypothesis) The previous hypothesis to store for the next transcribe step + """ + self.previous_hypothesis = previous_hypothesis + + def get_previous_hypothesis(self) -> Hypothesis: + """ + Get the previous hypothesis + Returns: + (Hypothesis) The previous hypothesis + """ + return self.previous_hypothesis + + def reset_previous_hypothesis(self) -> None: + """ + Reset the previous hypothesis to None + """ + self.previous_hypothesis = None diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_state.py new file mode 100644 index 000000000000..f87b18edc002 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_state.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.asr.inference.streaming.state.state import StreamingState + + +class CacheAwareStreamingState(StreamingState): + """ + State of the cache aware CTC/RNNT streaming pipelines + """ + + def __init__(self): + """ + Initialize the CacheAwareStreamingState + """ + super().__init__() + self._additional_params_reset() + + def reset(self) -> None: + """ + Reset the state + """ + super().reset() + self._additional_params_reset() + + def _additional_params_reset(self) -> None: + """ + Reset non-inherited parameters + """ + # label_buffer will be used to detect EoU + self.label_buffer = [] + self.label_buffer_size = 0 + self.offset = 0 + + def set_offset(self, offset: int) -> None: + """ + Set the offset + Args: + offset: (int) offset + """ + self.offset = offset + + def setup_label_buffer(self, label_buffer_size: int, blank_id: int) -> None: + """ + Set up the label buffer + Args: + label_buffer_size: (int) size of the label buffer + blank_id: (int) blank id + """ + self.label_buffer_size = label_buffer_size + self.label_buffer = [blank_id] * self.label_buffer_size + + def update_label_buffer(self, labels: list[int]) -> None: + """ + Update the label buffer + Args: + labels: (list[int]) list of labels + """ + shift = len(labels) + self.label_buffer[:-shift] = self.label_buffer[shift:].copy() + self.label_buffer[-shift:] = labels.copy() + + def get_label_buffer(self) -> list[int]: + """ + Get the current label buffer + Returns: + list[int]: current state of the label buffer + """ + return self.label_buffer.copy() + + def update_state(self, completed_output: dict, eou_detected: bool) -> None: + """ + Update the state with the completed output + Args: + completed_output: (dict) completed output + eou_detected: (bool) is EoU detected + """ + + if len(completed_output) == 0 or len(completed_output["tokens"]) == 0: + return + + timesteps = completed_output["timesteps"] + for i, t in enumerate(timesteps): + timesteps[i] = t + self.global_offset + + # we will not perform overlap aware merging of the tokens for CacheAware Models + # It is too error-prone to do this in the streaming mode -> skip=0 + self._update_state(completed_output, skip=0) + self.eou_detected_before = eou_detected diff --git a/nemo/collections/asr/inference/streaming/state/ctc_state.py b/nemo/collections/asr/inference/streaming/state/ctc_state.py new file mode 100644 index 000000000000..0ad5f21901d8 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/ctc_state.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.inference.streaming.state.state import StreamingState + + +class CTCStreamingState(StreamingState): + """ + State of the streaming CTC pipeline + """ + + pass diff --git a/nemo/collections/asr/inference/streaming/state/rnnt_state.py b/nemo/collections/asr/inference/streaming/state/rnnt_state.py new file mode 100644 index 000000000000..b2eade1badc5 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/rnnt_state.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.inference.streaming.state.state import StreamingState + + +class RNNTStreamingState(StreamingState): + """ + State of the streaming RNNT pipeline + """ + + def __init__(self): + """ + Initialize the RNNTStreamingState + """ + super().__init__() + self._additional_params_reset() + + def reset(self) -> None: + """ + Reset the state + """ + super().reset() + self._additional_params_reset() + + def _additional_params_reset(self) -> None: + """ + Reset non-inherited parameters + """ + self.timestamp_offset = 0 + self.hyp_decoding_state = None diff --git a/nemo/collections/asr/inference/streaming/state/state.py b/nemo/collections/asr/inference/streaming/state/state.py new file mode 100644 index 000000000000..59f5031110f4 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/state/state.py @@ -0,0 +1,334 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable + +from nemo.collections.asr.inference.streaming.framing.request import RequestOptions +from nemo.collections.asr.inference.utils.constants import POST_WORD_PUNCTUATION +from nemo.collections.asr.inference.utils.state_management_utils import ( + detect_overlap, + merge_segment_tail, + merge_timesteps, + merge_word_tail, +) +from nemo.collections.asr.inference.utils.text_segment import TextSegment, Word + +CLOSE_IN_TIME_TH = 2.0 +OVERLAP_SEARCH_TH = 3 + + +class StreamingState: + """ + Generic state for the streaming ASR pipeline + """ + + def __init__(self): + """ + Initialize the StreamingState + """ + self._reset_streaming_state() + + def reset(self) -> None: + """ + Reset the state to its initial values + """ + self._reset_streaming_state() + + def _reset_streaming_state(self) -> None: + """ + Initialize the state with default values + """ + + # Global offset is used to keep track of the timestamps + self.global_offset = 0 + + # All tokens, timestamps and conf scores that have been processed since the last EOU + self.tokens = [] + self.timesteps = [] + self.confidences = [] + + # Predicted tokens for the current step + self.current_step_tokens = [] + + # Last token and its index are used to detect overlap between the current and the previous output + self.last_token = None + self.last_token_idx = None + + # Tokens left in the right padding segment of the buffer + self.incomplete_segment_tokens = [] + + # final_transcript, partial_transcript, current_step_transcript and final_segments will be sent to the client + self.final_transcript = "" + self.partial_transcript = "" + self.current_step_transcript = "" + self.concat_with_space = True + self.final_segments = [] + + # Word-level ASR output attributes (cleared after cleanup_after_response): + # - words: Raw word-level ASR output + # - pnc_words: Words with punctuation and capitalization applied + # * When automatic punctuation is ENABLED: Contains punctuation marks and capitalization + # (from either external PnC model or built-in ASR model PnC) + # * When automatic punctuation is DISABLED: No punctuation or capitalization + # (any punctuation in raw ASR output will be removed) + # - itn_words: Words after applying both PnC and ITN (Inverse Text Normalization) + # - word_alignment: ITN word alignment + # Segment-level ASR output attributes (cleared after cleanup_after_response): + # - segments: Raw segment-level ASR output + # - processed_segment_mask: Mask indicating which segments have been processed + # - final_segments: Final segment-level ASR output + self.words = [] + self.pnc_words = [] + self.itn_words = [] + self.word_alignment = [] + self.segments = [] + self.processed_segment_mask = [] + + # Flag to indicate if EOU was detected before, used in merging logic + self.eou_detected_before = False + + # Used in EoU detection logic + self.decoder_start_idx = 0 + self.decoder_end_idx = 0 + + # Request options + self.options = None + + def set_options(self, options: RequestOptions) -> None: + """ + Set the options + Args: + options: (RequestOptions) The request options to store in the state + """ + self.options = options + + def set_incomplete_segment_tokens(self, incomplete_segment_tokens: list) -> None: + """ + Set the partial tokens + Args: + incomplete_segment_tokens: (list) The partial tokens to store in the state + """ + self.incomplete_segment_tokens = incomplete_segment_tokens + + def set_global_offset(self, start_offset: float) -> None: + """ + Set the global offset + Args: + start_offset: (float) The global offset to store in the state + """ + self.global_offset = start_offset + + def set_last_token(self, token: int | None, idx: int | None) -> None: + """ + Set the last token + Args: + token: (int | None) The last token to store in the state + idx: (int | None) The index of the last token to store in the state + """ + if None not in [token, idx]: + self.last_token_idx = idx + self.global_offset + self.last_token = token + else: + self.last_token_idx = None + self.last_token = None + + def increment_global_offset(self, shift: float) -> None: + """ + Increment the global offset by the given shift + Args: + shift: (float) The shift to increment the global offset by + """ + self.global_offset += shift + + def _update_state(self, output: dict, skip: int) -> None: + """ + Extend the tokens, timesteps and confidences, optionally skipping the first few tokens + Args: + output: (dict) The output to update the state with + skip: (int) The number of tokens to skip + """ + current_tokens = output["tokens"] + current_timesteps = output["timesteps"] + current_confidences = output["confidences"] + if skip > 0: + current_tokens = current_tokens[skip:] + current_timesteps = current_timesteps[skip:] + current_confidences = current_confidences[skip:] + + self.current_step_tokens.extend(current_tokens) + self.tokens.extend(current_tokens) + self.confidences.extend(current_confidences) + self.timesteps = merge_timesteps(self.timesteps, current_timesteps) + + def update_state(self, completed_output: dict, eou_detected: bool) -> None: + """ + Update the state with the completed output + Args: + completed_output: (dict) The completed output to update the state with + eou_detected: (bool) Whether EOU was detected + """ + + if len(completed_output) == 0 or len(completed_output["tokens"]) == 0: + self.last_token = None + self.last_token_idx = None + return + + timesteps = completed_output["timesteps"] + for i, t in enumerate(timesteps): + timesteps[i] = t + self.global_offset + + overlap = 0 + if not self.eou_detected_before: + overlap = detect_overlap( + state_tokens=self.tokens, + state_timesteps=self.timesteps, + new_tokens=completed_output["tokens"], + new_timesteps=timesteps, + overlap_search_th=OVERLAP_SEARCH_TH, + close_in_time_th=CLOSE_IN_TIME_TH, + ) + + # In case when the tokens are empty after EoU, + # we need to check if the last token is the same as the first token of the completed output + if ( + self.eou_detected_before + and self.last_token == completed_output["tokens"][0] + and self.last_token_idx is not None + and abs(self.last_token_idx - timesteps[0]) <= CLOSE_IN_TIME_TH + ): + overlap = max(overlap, 1) + + self._update_state(completed_output, overlap) + self.eou_detected_before = eou_detected + + def update_from_decoder_results(self, start_idx: int, end_idx: int) -> None: + """ + Update state based on decoder results + This is used to dynamically understand current token start and end indices + Args: + start_idx: (int) The start index of the decoder results + end_idx: (int) The end index of the decoder results + """ + self.decoder_start_idx = start_idx + self.decoder_end_idx = end_idx + + def cleanup_after_eou(self) -> None: + """ + Cleanup the state after an EOU is detected + """ + self.tokens.clear() + self.timesteps.clear() + self.confidences.clear() + + def cleanup_after_response(self) -> None: + """ + Cleanup the state after a response is sent + Specifically used to clean the state after final transcript is sent + """ + + if self.options.is_word_level_output(): + self.words.clear() + self.pnc_words.clear() + self.itn_words.clear() + self.word_alignment.clear() + else: + self.segments.clear() + self.processed_segment_mask.clear() + + self.final_transcript = "" + self.final_segments.clear() + self.current_step_transcript = "" + self.current_step_tokens.clear() + self.concat_with_space = True + + def push_back_segment( + self, + segment: TextSegment, + need_merge: bool, + conf_aggregator: Callable = None, + ) -> None: + """ + Push back the decoded segment to the state + Args: + segment: (TextSegment) The decoded segment to push back to the state + need_merge: (bool) Whether to merge the segment with the last segment in the state + conf_aggregator: (Callable) The function to aggregate the confidence + """ + + # concat_with_space is used to determine if the final transcript should be concatenated with a space + if len(self.final_segments) == 0 and need_merge: + self.concat_with_space = False + else: + self.concat_with_space = True + + if need_merge and len(self.segments) > 0: + head = merge_segment_tail( + segment_head=self.segments[-1], + segment_tail=segment, + conf_aggregator=conf_aggregator, + ) + self.segments[-1] = head + self.processed_segment_mask[-1] = False + else: + self.segments.append(segment) + self.processed_segment_mask.append(False) + + def push_back_words( + self, + decoded_words: list[Word], + merge_first_word: bool = False, + merge_first_word_punctuation: bool = True, + conf_aggregator: Callable = None, + ) -> None: + """ + Push back the decoded words to the state + Args: + decoded_words: (list[Word]) The decoded words to push back to the state + merge_first_word: (bool) Whether to merge the first word with the last word in the state + merge_first_word_punctuation: (bool) Whether to merge the first word punctuation with the last word in the state + conf_aggregator: (Callable) The function to aggregate the confidence + """ + if not decoded_words: + return + + # concat_with_space is used to determine if the final transcript should be concatenated with a space + if len(self.final_segments) == 0 and merge_first_word: + self.concat_with_space = False + else: + self.concat_with_space = True + + if ( + (fst_word_txt := decoded_words[0].text) + and fst_word_txt in POST_WORD_PUNCTUATION + and merge_first_word_punctuation + ): + # if the first word is a punctuation mark, merge it with the last word stored in the state + if len(self.words) > 0: + self.words[-1].text += fst_word_txt + decoded_words = decoded_words[1:] + + elif merge_first_word and len(self.words) > 0: + head, pnc_head = merge_word_tail( + word_head=self.words[-1], + word_tail=decoded_words[0], + pnc_word_head=self.pnc_words[-1] if len(self.pnc_words) > 0 else None, + conf_aggregator=conf_aggregator, + ) + self.words[-1] = head + if pnc_head is not None: + self.pnc_words[-1] = pnc_head + decoded_words = decoded_words[1:] + + self.words.extend(decoded_words) diff --git a/nemo/collections/asr/inference/streaming/text/__init__.py b/nemo/collections/asr/inference/streaming/text/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/streaming/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/streaming/text/text_processing.py b/nemo/collections/asr/inference/streaming/text/text_processing.py new file mode 100644 index 000000000000..13f7c259ce1b --- /dev/null +++ b/nemo/collections/asr/inference/streaming/text/text_processing.py @@ -0,0 +1,414 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from functools import partial +from typing import TYPE_CHECKING, Callable + +from omegaconf import DictConfig + +from nemo.collections.asr.inference.streaming.state.state import StreamingState +from nemo.collections.asr.inference.utils.constants import POST_WORD_PUNCTUATION +from nemo.collections.asr.inference.utils.pipeline_utils import ( + get_leading_punctuation_regex_pattern, + get_repeated_punctuation_regex_pattern, +) +from nemo.collections.asr.inference.utils.text_segment import Word, normalize_segments_inplace + +if TYPE_CHECKING: + from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +class StreamingTextProcessor: + """ + A streaming text post-processing module that applies punctuation & capitalization (PnC) and + inverse text normalization (ITN) to ASR transcriptions in real-time. + + This class supports configurable pipelines where PnC and ITN can be enabled/disabled dynamically. + It ensures that the final output adheres to proper punctuation, capitalization, and normalized text. + """ + + def __init__( + self, + itn_cfg: DictConfig, + itn_model: AlignmentPreservingInverseNormalizer | None, + asr_supported_puncts: set, + asr_supports_punctuation: bool, + confidence_aggregator: Callable, + sep: str, + enable_pnc: bool = False, + enable_itn: bool = False, + ): + """ + Initialize the streaming text processor. + + Args: + itn_cfg (DictConfig): ITN parameters. + itn_model (AlignmentPreservingInverseNormalizer | None): Model for inverse text normalization (ITN). + asr_supported_puncts (set): Set of punctuation marks recognized by the ASR model. + asr_supports_punctuation (bool): Boolean indicating if the ASR model outputs punctuation. + confidence_aggregator (Callable): Function for aggregating confidence scores. + sep (str): String separator used in ASR output processing. + enable_pnc (bool): Boolean to enable PnC. Default is False. + enable_itn (bool): Boolean to enable ITN. Default is False. + """ + + self.pnc_enabled = enable_pnc and asr_supports_punctuation + self.supports_punctuation = asr_supports_punctuation + + self.itn_model = itn_model + self.itn_enabled = False + if enable_itn: + self.itn_enabled = itn_model is not None + + self.itn_runtime_params = { + "batch_size": itn_cfg.batch_size, + "n_jobs": itn_cfg.n_jobs, + } + self.itn_left_padding_size = itn_cfg.left_padding_size + + self.asr_supported_puncts = asr_supported_puncts + self.asr_supported_puncts_str = ''.join(self.asr_supported_puncts) + self.sep = sep + self.rm_punctuation_capitalization_from_segments_fn = partial( + normalize_segments_inplace, punct_marks=self.asr_supported_puncts, sep=self.sep + ) + + puncts_to_process = self.asr_supported_puncts + self.leading_punctuation_regex_pattern = get_leading_punctuation_regex_pattern(puncts_to_process) + self.repeated_punctuation_regex_pattern = get_repeated_punctuation_regex_pattern(puncts_to_process) + + self.alignment_aware_itn_model = None + if self.itn_enabled: + from nemo.collections.asr.inference.itn.batch_inverse_normalizer import ( + BatchAlignmentPreservingInverseNormalizer, + ) + + self.alignment_aware_itn_model = BatchAlignmentPreservingInverseNormalizer( + itn_model=self.itn_model, + sep=self.sep, + asr_supported_puncts=self.asr_supported_puncts, + post_word_punctuation=POST_WORD_PUNCTUATION, + conf_aggregate_fn=confidence_aggregator, + ) + + def is_itn_enabled(self) -> bool: + """Check if ITN is enabled""" + return self.itn_enabled + + def is_pnc_enabled(self) -> bool: + """Check if PnC is enabled""" + return self.pnc_enabled + + def is_enabled(self) -> bool: + """Check if PnC or ITN is enabled""" + return self.is_pnc_enabled() or self.is_itn_enabled() + + def process(self, states: list[StreamingState]) -> None: + """ + Apply PnC and ITN on the states. + Args: + states: (list[StreamingState]) List of StreamingState objects + """ + word_boundary_states, segment_boundary_states = [], [] + for state in states: + if state.options.is_word_level_output(): + word_boundary_states.append(state) + else: + segment_boundary_states.append(state) + + # Process states with word boundaries + if word_boundary_states: + self.process_states_with_word_boundaries(word_boundary_states) + + # Process states with segment boundaries + if segment_boundary_states: + self.process_states_with_segment_boundaries(segment_boundary_states) + + # Generate final transcript + self.generate_final_transcript(word_boundary_states, segment_boundary_states) + + def process_states_with_segment_boundaries(self, states: list[StreamingState]) -> None: + """ + Process states with segment boundaries. + Args: + states (list[StreamingState]): List of StreamingState objects that have segments + """ + states_with_text = [state for state in states if len(state.segments) > 0] + if len(states_with_text) == 0: + return + + # if PnC & ITN DISABLED globally, remove PnC from the words if ASR supports punctuation + if not self.is_enabled(): + if self.supports_punctuation: + segments = [] + for state in states_with_text: + for i, seg in enumerate(state.segments): + if not state.processed_segment_mask[i]: + segments.append(seg) + state.processed_segment_mask[i] = True + self.rm_punctuation_capitalization_from_segments_fn(segments) + return + + # Remove PnC from states where PnC is disabled + for state in states_with_text: + if (not state.options.enable_pnc) or (not self.is_pnc_enabled()): + if self.supports_punctuation: + self.rm_punctuation_capitalization_from_segments_fn(state.segments) + + # Apply ITN + if self.is_itn_enabled(): # If ITN ENABLED globally + # collect texts + texts = [] + for i, state in enumerate(states_with_text): + # if ITN is disabled for this state + if not state.options.enable_itn: + continue + + for j, seg in enumerate(state.segments): + if state.processed_segment_mask[j]: # if the segment is already processed, skip it + continue + texts.append((i, j, seg.text)) + + if len(texts) > 0: + # apply ITN + processed_texts = self.itn_model.inverse_normalize_list( + texts=[text for _, _, text in texts], params=self.itn_runtime_params + ) + # update states with ITN-processed texts + for (i, j, _), processed_text in zip(texts, processed_texts): + states_with_text[i].segments[j].text = processed_text + + # --> Apply External PnC here (if needed) + + # mark all segments as processed + for state in states_with_text: + if state.options.enable_pnc: + for seg in state.segments: + if self.leading_punctuation_regex_pattern: + seg.text = re.sub(self.leading_punctuation_regex_pattern, r'\1', seg.text) + if self.repeated_punctuation_regex_pattern: + seg.text = re.sub(self.repeated_punctuation_regex_pattern, r'\1', seg.text) + state.processed_segment_mask = [True] * len(state.segments) + + def process_states_with_word_boundaries(self, states: list[StreamingState]) -> None: + """ + Apply PnC and ITN on the states. + Args: + states: (list[StreamingState]) List of StreamingState objects + """ + # Get the indices of the states that have new words to process + indices, asr_words_list = self.prepare_asr_words(states) + + # If PnC & ITN DISABLED globally, remove PnC from the words + # Does not matter that individual request has enabled itn or pnc + if not self.is_enabled(): + self.handle_plain_asr_transcriptions(states, indices, asr_words_list) + return + + # Keep or remove PnC from the words + for idx, jdx, z in indices: + if not states[idx].options.enable_pnc and self.supports_punctuation: + self.rm_punctuation_capitalization_from_segments_fn(asr_words_list[jdx]) + states[idx].pnc_words[-z:] = asr_words_list[jdx][-z:] + + # If ITN is disabled globally, do nothing + if not self.itn_enabled: + return + + # Apply Inverse Text Normalization (ITN) + self.apply_itn(states, indices) + self.realign_punctuated_words(states, indices) + + def realign_punctuated_words(self, states: list[StreamingState], indices: list[tuple]) -> None: + """ + Realign punctuation and capitalization after applying ITN. + Ensures that capitalization and punctuation marks from the original ASR output + are properly reflected in the final ITN-processed text. + + Args: + states (list[StreamingState]): List of StreamingState objects to be updated. + indices (list[tuple]): Indices of words within states that need realignment. + """ + for idx, _, z in indices: + state = states[idx] + if not state.options.enable_itn: + continue + + z_idx = len(state.words) - z + + itn_idx = len(state.itn_words) + for sids, _, _ in reversed(state.word_alignment): + st, et = sids[0], sids[-1] + itn_idx -= 1 + if st < z_idx and et < z_idx: + break + + last_char = state.pnc_words[et].text[-1] + first_char = state.pnc_words[st].text[0] + + itn_word_orig = state.itn_words[itn_idx] + itn_word_copy = itn_word_orig.copy() + itn_word_text = itn_word_copy.text.lower() + + # preserve the first char capitalization + first_word = state.pnc_words[st].copy() + first_char_is_upper = first_word.text[0].isupper() + first_word.normalize_text_inplace(self.asr_supported_puncts, self.sep) + if first_char_is_upper and itn_word_text.startswith(first_word.text): + itn_word_orig.capitalize() + + # preserve the last punctuation mark + if last_char in self.asr_supported_puncts: + itn_word_orig.text = itn_word_orig.text.rstrip(self.asr_supported_puncts_str) + last_char + + # preserve the first punctuation mark + if first_char in self.asr_supported_puncts: + itn_word_orig.text = first_char + itn_word_orig.text.lstrip(self.asr_supported_puncts_str) + + state.itn_words[itn_idx] = itn_word_orig + + def prepare_asr_words(self, states: list[StreamingState]) -> tuple[list[tuple], list[list[Word]]]: + """ + Find the indices of the states that have words to process. + Args: + states: (list[StreamingState]) List of StreamingState objects + Returns: + tuple[list[tuple], list[list[Word]]]: + indices: list of indices of the states that have words to process + asr_words_list: list of words to process + """ + indices, asr_words_list = [], [] + + jdx = 0 + for idx, state in enumerate(states): + if (n_not_punctuated_words := len(state.words) - len(state.pnc_words)) == 0: + continue + + words_list = [word.copy() for word in state.words[-n_not_punctuated_words:]] + asr_words_list.append(words_list) + state.pnc_words.extend([None] * n_not_punctuated_words) + indices.append((idx, jdx, len(words_list))) + jdx += 1 + + return indices, asr_words_list + + def handle_plain_asr_transcriptions( + self, states: list[StreamingState], indices: list[tuple], asr_words_list: list[list[Word]] + ) -> None: + """ + Handle scenarios where PnC and ITN are disabled. + In such cases, remove Punctuation and Capitalization from the words. + Args: + states: (list[StreamingState]) List of StreamingState objects + indices: (list[tuple]) List of indices of the states that have words to process + asr_words_list: (list[list[Word]]) List of words + """ + if self.supports_punctuation: + self.rm_punctuation_capitalization_from_segments_fn(asr_words_list) + + for idx, jdx, z in indices: + states[idx].pnc_words[-z:] = asr_words_list[jdx][-z:] + + def apply_itn(self, states: list[StreamingState], indices: list[tuple]) -> None: + """ + Apply Inverse Text Normalization (ITN) on the states. + Calculates the lookback for ITN and updates the states with the ITN results. + Args: + states: (list[StreamingState]) List of StreamingState objects + indices: (list[tuple]) List of indices of the states that have words to process + """ + itn_indices, asr_words_list, pnc_words_list = [], [], [] + jdx = 0 + for state_idx, _, _ in indices: + state = states[state_idx] + if not state.options.enable_itn: + continue + s, t, cut_point = self.calculate_itn_lookback(state) + asr_words_list.append([word.copy() for word in state.words[s:]]) + pnc_words_list.append([word.copy() for word in state.pnc_words[s:]]) + itn_indices.append((state_idx, jdx, s, t, cut_point)) + jdx += 1 + output = self.alignment_aware_itn_model( + asr_words_list, pnc_words_list, self.itn_runtime_params, return_alignment=True + ) + self.update_itn_words(states, output, itn_indices) + + def calculate_itn_lookback(self, state: StreamingState) -> tuple[int, int, int]: + """ + Calculate the lookback for ITN. + Args: + state: (StreamingState) StreamingState object + Returns: + Start index (int): Start index of the source (non itn-ed) words + Target index (int): Start index of the target (itn-ed) words + Cut point (int): Index to cut the source words + """ + s, t, cut_point = 0, 0, len(state.itn_words) + word_alignment = list(reversed(state.word_alignment)) + for idx, (sidx, tidx, _) in enumerate(word_alignment, start=1): + s, t = sidx[0], tidx[0] + state.word_alignment.pop() + cut_point -= 1 + if idx == self.itn_left_padding_size: + break + return s, t, cut_point + + @staticmethod + def update_itn_words(states: list[StreamingState], output: list[tuple], indices: list[tuple]) -> None: + """ + Update the states with the ITN results. + Updates the word_alignment and itn_words in the states. + Args: + states: (list[StreamingState]) List of StreamingState objects + output: (list[tuple]) List of output tuples containing the spans and alignment + indices: (list[tuple]) List of indices of the states that have words to process + """ + for state_idx, jdx, s, t, cut_point in indices: + state = states[state_idx] + spans, alignment = output[jdx] + for sidx, tidx, sclass in alignment: + sidx = [k + s for k in sidx] + tidx = [k + t for k in tidx] + state.word_alignment.append((sidx, tidx, sclass)) + + state.itn_words = state.itn_words[:cut_point] + spans + assert len(state.word_alignment) == len(state.itn_words) + + def generate_final_transcript( + self, word_boundary_states: list[StreamingState], segment_boundary_states: list[StreamingState] + ) -> None: + """ + Generate final transcript based on enabled features and word count. + Args: + word_boundary_states (list[StreamingState]): The streaming state containing words + segment_boundary_states (list[StreamingState]): The streaming state containing segments + """ + # Generate final transcript for word boundary states + for state in word_boundary_states: + attr_name = "itn_words" if state.options.enable_itn else "pnc_words" + words = getattr(state, attr_name) + for word in words: + state.final_segments.append(word.copy()) + state.final_transcript += word.text + self.sep + state.final_transcript = state.final_transcript.rstrip(self.sep) + + # Generate final transcript for segment boundary states + for state in segment_boundary_states: + for segment in state.segments: + state.final_segments.append(segment.copy()) + state.final_transcript += segment.text + self.sep + state.final_transcript = state.final_transcript.rstrip(self.sep) diff --git a/nemo/collections/asr/inference/utils/__init__.py b/nemo/collections/asr/inference/utils/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/collections/asr/inference/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/inference/utils/audio_io.py b/nemo/collections/asr/inference/utils/audio_io.py new file mode 100644 index 000000000000..17fc9d1a9a3f --- /dev/null +++ b/nemo/collections/asr/inference/utils/audio_io.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import librosa +import torch + + +def read_audio(audio_file: str, target_sr: int, mono: bool = True) -> torch.Tensor: + """ + Read audio file and return samples with target sampling rate + Args: + audio_file: (str) audio file path + target_sr: (int) target sampling rate + mono: (bool) whether to convert to mono + Returns: + (torch.Tensor) audio samples + """ + return torch.tensor(librosa.load(audio_file, sr=target_sr, mono=mono)[0]).float() diff --git a/nemo/collections/asr/inference/utils/bpe_decoder.py b/nemo/collections/asr/inference/utils/bpe_decoder.py new file mode 100644 index 000000000000..18148801c64e --- /dev/null +++ b/nemo/collections/asr/inference/utils/bpe_decoder.py @@ -0,0 +1,269 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import lru_cache +from typing import Callable + +import numpy as np + +from nemo.collections.asr.inference.streaming.state.state import StreamingState +from nemo.collections.asr.inference.utils.constants import ( + POST_WORD_PUNCTUATION, + ROUND_PRECISION, + SENTENCEPIECE_UNDERSCORE, +) +from nemo.collections.asr.inference.utils.text_segment import TextSegment, Word +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +class BPEDecoder: + """ + BPEDecoder class for decoding BPE (Byte Pair Encoding) tokens into words and segments by preserving timestamps and confidence scores + """ + + def __init__( + self, + vocabulary: list[str], + tokenizer: TokenizerSpec, + confidence_aggregator: Callable, + asr_supported_puncts: set, + word_boundary_tolerance: float, + token_duration_in_secs: float, + ): + """ + Initialize the BPEDecoder. + Args: + vocabulary (list[str]): List of vocabulary tokens. + tokenizer (TokenizerSpec): Tokenizer object. + confidence_aggregator (Callable): Confidence aggregator function. + asr_supported_puncts (set): Set of supported punctuation symbols. + word_boundary_tolerance (float): Word boundary tolerance for timestamp refinement. + token_duration_in_secs (float): Token duration in seconds. + """ + + self.vocabulary = vocabulary + self.tokenizer = tokenizer + self.confidence_aggregator = confidence_aggregator + self.asr_supported_puncts = asr_supported_puncts + self.punct_marks_with_underscore = asr_supported_puncts.union({SENTENCEPIECE_UNDERSCORE}) + self.word_boundary_tolerance = word_boundary_tolerance + self.token_duration_in_secs = token_duration_in_secs + self.start_of_word_cache = { + token_id: token.startswith(SENTENCEPIECE_UNDERSCORE) for token_id, token in enumerate(self.vocabulary) + } + self.punct_cache = { + token_id: (token in self.asr_supported_puncts, token in self.punct_marks_with_underscore) + for token_id, token in enumerate(self.vocabulary) + } + + @lru_cache(maxsize=10000) + def cached_ids_to_text(self, tokens_slice: tuple[int]) -> str: + """ + Cached tokenizer output to avoid repeated calls to the tokenizer. + Args: + tokens_slice (tuple): Tuple of token indices to be detokenized. + Returns: + str: Detokenized text. + """ + word_text = self.tokenizer.ids_to_text(list(tokens_slice)).strip() + return word_text + + def decode_bpe_tokens(self, state: StreamingState) -> None: + """ + Decodes BPE tokens into words or segments with timestamps and confidence scores. + Args: + state (StreamingState): The state object containing the BPE tokens, timestamps, and confidence scores. + """ + if state.options.is_word_level_output(): + # Form words and push them to the state + decoded_words, need_merge = self.group_tokens_into_words(state.tokens, state.timesteps, state.confidences) + state.push_back_words(decoded_words, need_merge, self.confidence_aggregator) + elif state.options.is_segment_level_output(): + # Form text segment and push it to the state + if state.tokens: + decoded_segment, need_merge = self.group_tokens_into_segment( + state.tokens, state.timesteps, state.confidences + ) + state.push_back_segment(decoded_segment, need_merge, self.confidence_aggregator) + else: + raise ValueError(f"Invalid output granularity: {state.options.asr_output_granularity}") + + def group_tokens_into_segment( + self, tokens: list, timesteps: list, confidences: list + ) -> tuple[TextSegment | None, bool]: + """ + Group tokens into a text segment with timestamps and confidence scores. + Args: + tokens (list): List of token indices. + timesteps (list): List of token timestamps. + confidences (list): List of token confidence scores. + Returns: + (tuple[TextSegment | None, bool]) Text segment with text, start time, end time, and confidence score. + Also returns a boolean to indicate if the text segment should be merged with the last segment stored in the state + """ + n_tokens = len(tokens) + + if n_tokens != len(timesteps) or n_tokens != len(confidences): + raise ValueError("tokens, timesteps and confidences must have the same length") + + if n_tokens == 0: + return None, False + + need_merge = not bool(self.start_of_word_cache[tokens[0]]) + + # Get the segment text + segment_text = self.tokenizer.ids_to_text(tokens).strip() + + # Refine the start and end timestamps of the text segment + start, end = self.refine_text_segment_timestamp(tokens, timesteps) + + # Convert token timestamps to seconds + start = round(start * self.token_duration_in_secs, ROUND_PRECISION) + end = round(end * self.token_duration_in_secs, ROUND_PRECISION) + + # Aggregate the confidence score of the text segment + conf = self.confidence_aggregator(confidences) + + # Create a text segment + return TextSegment(text=segment_text, start=start, end=end, conf=conf), need_merge + + def group_tokens_into_words(self, tokens: list, timesteps: list, confidences: list) -> tuple[list[Word], bool]: + """ + Decodes BPE tokens into words with timestamps and confidence scores. + Args: + tokens (list): List of token indices. + timesteps (list): List of token timesteps. + confidences (list): List of token confidence scores. + Returns: + (tuple[list[Word], bool]) List of decoded words with text, start time, end time, and confidence score. + Also returns a boolean to indicate if the first word should be merged with the last word stored in the state + """ + n_tokens = len(tokens) + + if n_tokens != len(timesteps) or n_tokens != len(confidences): + raise ValueError("tokens, timesteps and confidences must have the same length") + + if n_tokens == 0: + return [], False + + # Group tokens into words + is_start_mask = np.fromiter((self.start_of_word_cache[tok] for tok in tokens), dtype=np.int32) + word_ids = np.cumsum(is_start_mask) + + start_indices = np.nonzero(np.diff(word_ids, prepend=word_ids[0] - 1))[0] + end_indices = np.append(start_indices[1:], n_tokens) + + decoded_words, prev_word_end = [], None + + # If the first word is the start of a word, we need to merge it with the last word stored in the state + need_merge = not bool(is_start_mask[0]) + + for start_idx, end_idx in zip(start_indices, end_indices): + + tokens_slice = tokens[start_idx:end_idx] + time_slice = timesteps[start_idx:end_idx] + conf_slice = confidences[start_idx:end_idx] + + word_text = self.cached_ids_to_text(tuple(tokens_slice)) + + # Ignore empty text + if not word_text: + continue + + # Append the post word punctuation to the previous word + if word_text in POST_WORD_PUNCTUATION and len(decoded_words) > 0: + prev_word = decoded_words[-1] + prev_word.text += word_text + continue + + # Refine timestamps + word_start_tms, word_end_tms = self.refine_text_segment_timestamp( + current_tokens=tokens_slice, + current_timesteps=time_slice, + next_segment_start_timestep=timesteps[end_idx] if end_idx < n_tokens else None, + need_merge_with_next_segment=( + self.start_of_word_cache[tokens[end_idx]] if end_idx < n_tokens else None + ), + prev_segment_end=prev_word_end, + ) + prev_word_end = word_end_tms + + # Aggregate confidence + word_conf = self.confidence_aggregator(conf_slice) + + # Convert token timestamps to seconds + start_sec = round(word_start_tms * self.token_duration_in_secs, ROUND_PRECISION) + end_sec = round(word_end_tms * self.token_duration_in_secs, ROUND_PRECISION) + + decoded_words.append(Word(text=word_text, start=start_sec, end=end_sec, conf=word_conf)) + + return decoded_words, need_merge + + def refine_text_segment_timestamp( + self, + current_tokens: list[int], + current_timesteps: list[float], + next_segment_start_timestep: float | None = None, + need_merge_with_next_segment: bool | None = None, + prev_segment_end: float | None = None, + ) -> tuple[float, float]: + """ + Refines the text segment timestamp based on the current tokens, timestamps, and the next segment start timestamp. + Args: + current_tokens (list[int]): List of token indices. + current_timesteps (list[float]): List of token timestamps. + next_segment_start_timestep (float | None): The start timestamp of the next segment. + need_merge_with_next_segment (bool | None): True if the current segment should be merged with the next segment. + prev_segment_end (float | None): The end timestamp of the previous segment. + Returns: + tuple(float, float): The refined start and end timestamps. + """ + + start, end = current_timesteps[0], current_timesteps[-1] + + # --- Correct the start timestamp if the first token is underscore or punctuation --- + first_token = current_tokens[0] + if self.punct_cache[first_token][1]: + start = next( + (tms for tms, token in zip(current_timesteps, current_tokens) if not self.punct_cache[token][1]), + start, + ) + + # --- Correct the end timestamp if the last token is punctuation --- + last_token = current_tokens[-1] + if self.punct_cache[last_token][0]: + end = next( + ( + current_timesteps[i] + for i in reversed(range(len(current_tokens))) + if not self.punct_cache[current_tokens[i]][0] + ), + end, + ) + + # --- If the next segment is close to the end of the current segment, merge timestamps --- + if next_segment_start_timestep is not None and need_merge_with_next_segment: + if next_segment_start_timestep - end <= self.word_boundary_tolerance: + end = next_segment_start_timestep + + # --- Adjust the start and end timestamps based on the previous segment end --- + delta = 0 + if prev_segment_end is not None: + if prev_segment_end > start: + delta = prev_segment_end - start + + start = start + delta + end = end + delta + return start, end + (1 if start == end else 0) diff --git a/nemo/collections/asr/inference/utils/config_io.py b/nemo/collections/asr/inference/utils/config_io.py new file mode 100644 index 000000000000..e4e2fea982a5 --- /dev/null +++ b/nemo/collections/asr/inference/utils/config_io.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from hydra import compose, initialize +from hydra.core.global_hydra import GlobalHydra +from omegaconf import DictConfig + + +def read_config(config_path: str, config_name: str) -> DictConfig: + """ + Read configuration file + Args: + config_path: (str) Absolute path to the configuration file + config_name: (str) Name of the configuration file + Returns: + (DictConfig) Configuration object + """ + + rel_to = os.path.dirname(os.path.realpath(__file__)) + + # Reset the global Hydra instance if already initialized to prevent duplicate initialization errors + if GlobalHydra.instance().is_initialized(): + GlobalHydra.instance().clear() + + with initialize(version_base=None, config_path=os.path.relpath(config_path, rel_to)): + cfg = compose(config_name=config_name) + return cfg diff --git a/nemo/collections/asr/inference/utils/constants.py b/nemo/collections/asr/inference/utils/constants.py new file mode 100644 index 000000000000..c67488a96f76 --- /dev/null +++ b/nemo/collections/asr/inference/utils/constants.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Precision related constants +BIG_EPSILON = 1e-5 +SMALL_EPSILON = 1e-10 +ROUND_PRECISION = 9 + +# ASR Preprocessing related constants +LOG_MEL_ZERO = -16.635 + +# Punctuation related constants +POST_WORD_PUNCTUATION = set(".,?") +PRE_WORD_PUNCTUATION = set("¿") +SEP_REPLACEABLE_PUNCTUATION = set("-_") +SENTENCEPIECE_UNDERSCORE = "▁" + +# ITN related constants +DEFAULT_SEMIOTIC_CLASS = "name" + +# Default output directory name +DEFAULT_OUTPUT_DIR_NAME = "jsons" diff --git a/nemo/collections/asr/inference/utils/context_manager.py b/nemo/collections/asr/inference/utils/context_manager.py new file mode 100644 index 000000000000..5da474f1b598 --- /dev/null +++ b/nemo/collections/asr/inference/utils/context_manager.py @@ -0,0 +1,196 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from queue import Queue +from typing import Any + +import torch +from torch import Tensor + + +class CacheAwareContext: + """ + Stores the cache state for the Cache-Aware models. + """ + + def __init__( + self, + cache_last_channel: Tensor | None = None, + cache_last_time: Tensor | None = None, + cache_last_channel_len: Tensor | None = None, + ): + """ + Args: + cache_last_channel (Tensor | None): Last channel of the cache. + cache_last_time (Tensor | None): Last time of the cache. + cache_last_channel_len (Tensor | None): Last channel length of the cache. + """ + self.cache_last_channel = cache_last_channel + self.cache_last_time = cache_last_time + self.cache_last_channel_len = cache_last_channel_len + + +class CacheAwareContextManager: + """ + Manager class to manipulate the cached states for the Cache-Aware models. + """ + + def __init__( + self, + cache_aware_model: Any, + num_slots: int, + use_cache: bool = True, + ): + """ + Initialize the CacheAwareContextManager. + Args: + cache_aware_model (Any): Cache-Aware model object. It should have the get_initial_cache_state method. + num_slots (int): Number of slots to use for the cache. It should be greater than or equal to the batch size. + use_cache (bool): Whether to use the cache. Default is True. If False, the cache is disabled. + """ + self.cache_aware_model = cache_aware_model + # Cache aware model should have the following methods: + if not hasattr(self.cache_aware_model, "get_initial_cache_state"): + raise ValueError("Cache aware model should have the get_initial_cache_state method") + + self.num_slots = num_slots + self.cache_disabled = not use_cache + self.cache_last_channel = None + self.cache_last_time = None + self.cache_last_channel_len = None + self.reset() + + def reset(self) -> None: + """Resets the context manager""" + if self.cache_disabled: + return + + self.streamidx2slotidx = {} + self.slotidx2streamidx = {} + self.free_slots = Queue(self.num_slots) + for i in range(self.num_slots): + self.free_slots.put(i) + ( + self.cache_last_channel, # [17, B, 70, 512] + self.cache_last_time, # [17, B, 512, 8] + self.cache_last_channel_len, # B + ) = self.cache_aware_model.get_initial_cache_state(self.num_slots) + + def reset_slot(self, slot_idx: int) -> None: + """ + Resets particular slot + Args: + slot_idx: slot index to reset + """ + if self.cache_disabled: + return + + # iterate over the layers + for i in range(self.cache_last_channel.size(0)): + self.cache_last_channel[i][slot_idx] = torch.zeros_like(self.cache_last_channel[i][slot_idx]) + self.cache_last_time[i][slot_idx] = torch.zeros_like(self.cache_last_time[i][slot_idx]) + self.cache_last_channel_len[slot_idx] = 0 + + # free the slot, so that it can be used by other streams + # remove the stream from the mappings + self.free_slots.put(slot_idx) + stream_id = self.slotidx2streamidx[slot_idx] + del self.slotidx2streamidx[slot_idx] + del self.streamidx2slotidx[stream_id] + + def update_cache(self, stream_ids: list[int], new_context: CacheAwareContext, mapping: dict) -> None: + """ + Updates the cache for the given stream_ids with the new_context + Args: + stream_ids (list[int]): list of stream ids + new_context (CacheAwareContext): new context to update corresponding to the stream_ids + mapping (dict): mapping between the old and new slots + """ + if self.cache_disabled: + return + + for stream_id in stream_ids: + slot_idx = self.streamidx2slotidx.get(stream_id, None) + if slot_idx is None: + raise RuntimeError(f"Stream {stream_id} is not registered in the context manager") + + # iterate over layers + tgt_slot_idx = mapping[slot_idx] + for i in range(self.cache_last_channel.size(0)): + self.cache_last_channel[i][slot_idx] = new_context.cache_last_channel[i][tgt_slot_idx].clone() + self.cache_last_time[i][slot_idx] = new_context.cache_last_time[i][tgt_slot_idx].clone() + self.cache_last_channel_len[slot_idx] = new_context.cache_last_channel_len[tgt_slot_idx] + + def reset_slots(self, stream_ids: list[int], eos_flags: list[bool]) -> None: + """ + Resets the slots for the finished streams + Args: + stream_ids (list[int]): list of stream ids + eos_flags (list[bool]): list of eos flags indicating whether the stream has finished + """ + if self.cache_disabled: + return + + if len(stream_ids) != len(eos_flags): + raise ValueError("stream_ids and eos_flags must have the same length") + + if len(stream_ids) == 0: + return + + # reset the slots for finished streams + for stream_id, eos_flag in zip(stream_ids, eos_flags): + if eos_flag: + slot_idx = self.streamidx2slotidx[stream_id] + self.reset_slot(slot_idx) + + def get_context(self, stream_ids: list[int]) -> tuple[CacheAwareContext, dict]: + """ + Retrieves the context from the cache for the given stream_ids + Args: + stream_ids (list[int]): list of stream ids + Returns: + context (CacheAwareContext): context for the given stream_ids + mapping (dict): mapping between the cache and retrieved context + """ + + if len(stream_ids) == 0 or self.cache_disabled: + # Create a dummy context with None values + return CacheAwareContext(), {} + + # if the stream_id is new, we need to assign a slot to it + for stream_id in stream_ids: + if stream_id not in self.streamidx2slotidx: + if self.free_slots.empty(): + raise RuntimeError("No free slots available") + slot_idx = self.free_slots.get() + self.streamidx2slotidx[stream_id] = slot_idx + self.slotidx2streamidx[slot_idx] = stream_id + + # get the cache for the particular stream_ids + slot_ids = [self.streamidx2slotidx[stream_id] for stream_id in stream_ids] + cache_last_channel = self.cache_last_channel[:, slot_ids, :, :] + cache_last_time = self.cache_last_time[:, slot_ids, :, :] + cache_last_channel_len = self.cache_last_channel_len[slot_ids] + + # create a context object + context = CacheAwareContext( + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + # mapping between cache and context + mapping = dict(zip(slot_ids, range(len(slot_ids)))) + return context, mapping diff --git a/nemo/collections/asr/inference/utils/device_utils.py b/nemo/collections/asr/inference/utils/device_utils.py new file mode 100644 index 000000000000..320a6da5c838 --- /dev/null +++ b/nemo/collections/asr/inference/utils/device_utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from nemo.utils import logging + +COMPUTE_DTYPE_MAP = { + 'bfloat16': torch.bfloat16, + 'float16': torch.float16, + 'float32': torch.float32, +} + +DEVICE_TYPES = ["cuda", "mps", "cpu"] + + +def setup_device(device: str, device_id: int | None, compute_dtype: str) -> tuple[str, int, torch.dtype]: + """ + Set up the compute device for the model. + + Args: + device (str): Requested device type ('cuda', 'mps' or 'cpu'). + device_id (int | None): Requested CUDA device ID (None for CPU or MPS). + compute_dtype (str): Requested compute dtype. + + Returns: + tuple(str, int, torch.dtype): Tuple of (device_string, device_id, compute_dtype) for model initialization. + """ + device = device.strip() + if device not in DEVICE_TYPES: + raise ValueError(f"Invalid device type: {device}. Must be one of {DEVICE_TYPES}") + + device_id = int(device_id) if device_id is not None else 0 + + # Handle CUDA devices + if torch.cuda.is_available() and device == "cuda": + if device_id >= torch.cuda.device_count(): + logging.warning(f"Device ID {device_id} is not available. Using GPU 0 instead.") + device_id = 0 + + compute_dtype = COMPUTE_DTYPE_MAP.get(compute_dtype, None) + if compute_dtype is None: + raise ValueError( + f"Invalid compute dtype: {compute_dtype}. Must be one of {list(COMPUTE_DTYPE_MAP.keys())}" + ) + + device_str = f"cuda:{device_id}" + return device_str, device_id, compute_dtype + + # Handle MPS devices + if torch.backends.mps.is_available() and device == "mps": + return "mps", -1, torch.float32 + + # Handle CPU devices + if device == "cpu": + return "cpu", -1, torch.float32 + + raise ValueError(f"Device {device} is not available.") diff --git a/nemo/collections/asr/inference/utils/endpointing_utils.py b/nemo/collections/asr/inference/utils/endpointing_utils.py new file mode 100644 index 000000000000..c72a6230164c --- /dev/null +++ b/nemo/collections/asr/inference/utils/endpointing_utils.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def millisecond_to_frames(millisecond: int, ms_per_timestep: int) -> int: + """ + Convert milliseconds to frames + Args: + millisecond (int): milliseconds to convert + ms_per_timestep (int): milliseconds per timestep + Returns: + int: number of frames + """ + return (millisecond + ms_per_timestep - 1) // ms_per_timestep + + +def get_custom_stop_history_eou( + stop_history_eou: int | None, default_stop_history_eou: int, ms_per_timestep: int +) -> int: + """ + Get the custom stop history of EOU + Args: + stop_history_eou (int): stop history of EOU + default_stop_history_eou (int): default stop history of EOU + ms_per_timestep (int): milliseconds per timestep + Returns: + int: custom stop history of EOU + """ + if stop_history_eou is None: + return default_stop_history_eou + if stop_history_eou > 0: + return millisecond_to_frames(stop_history_eou, ms_per_timestep) + return 0 if stop_history_eou == 0 else -1 diff --git a/nemo/collections/asr/inference/utils/enums.py b/nemo/collections/asr/inference/utils/enums.py new file mode 100644 index 000000000000..f9d80164d40d --- /dev/null +++ b/nemo/collections/asr/inference/utils/enums.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum, auto + + +class StrEnumMixin: + @classmethod + def from_str(cls, name: str): + """Convert a string to an Enum value (case-insensitive).""" + normalized = name.lower() + for member in cls: + if member.name.lower() == normalized or str(member.value).lower() == normalized: + return member + + choices = [member.name.lower() for member in cls] + raise ValueError(f"Invalid {cls.__name__} `{name}`: must be one of {choices}") + + +class ASRDecodingType(StrEnumMixin, Enum): + """ + Enumeration of the ASR decoding types. + """ + + CTC = auto() + RNNT = auto() + + +class ASROutputGranularity(StrEnumMixin, Enum): + """ + Enumeration of the ASR output granularity. + """ + + WORD = auto() + SEGMENT = auto() + + +class PipelineType(StrEnumMixin, Enum): + """ + Enumeration of the pipeline types. + """ + + BUFFERED = auto() + CACHE_AWARE = auto() + + +class RequestType(StrEnumMixin, Enum): + """ + Enumeration of the request types. + """ + + FRAME = auto() + FEATURE_BUFFER = auto() + + +class FeatureBufferPaddingMode(StrEnumMixin, Enum): + """ + Enumeration of the feature buffer padding modes. + """ + + LEFT = auto() + RIGHT = auto() diff --git a/nemo/collections/asr/inference/utils/itn_utils.py b/nemo/collections/asr/inference/utils/itn_utils.py new file mode 100644 index 000000000000..1316e099b771 --- /dev/null +++ b/nemo/collections/asr/inference/utils/itn_utils.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +from collections import OrderedDict + +from nemo.collections.asr.inference.utils.constants import DEFAULT_SEMIOTIC_CLASS + + +# Compile regex pattern once at module level for better performance +TOKEN_PATTERN = re.compile(r'tokens \{.*?(?=tokens \{|$)', re.DOTALL) + + +def get_semiotic_class(tokens: list[OrderedDict]) -> str: + """ + Returns the semiotic class of the given tokens. + """ + return list(tokens[0]["tokens"].keys())[0] + + +def split_text(text: str, sep: str = " ") -> tuple[list, int]: + """ + Splits the text into words based on the separator. + Args: + text: (str) input text + sep: (str) separator to split the text + Returns: + words: (list) list of words + n_words: (int) number of words + """ + words = [w for w in text.split(sep) if w] + return words, len(words) + + +def find_tokens(text: str) -> list[str]: + """ + Find the start and end positions of token blocks in the given text. + Args: + text: (str) input text containing token blocks + Returns: + token_blocks: (list[str]) list of token blocks + """ + + # Use compiled regex to find all token blocks in a single pass + token_blocks = TOKEN_PATTERN.findall(text) + + # Strip whitespace from each block + return [block.strip() for block in token_blocks] + + +def get_trivial_alignment(N: int, i_shift: int = 0, o_shift: int = 0) -> list[tuple]: + """ + Returns a trivial word alignment for N input words. + Args: + N: (int) number of input words + i_shift: (int) input shift + o_shift: (int) output shift + Returns: + (list) Returns a trivial word alignment + """ + return [([i + i_shift], [i + o_shift], DEFAULT_SEMIOTIC_CLASS) for i in range(N)] + + +def fallback_to_trivial_alignment( + input_words: list[str], i_shift: int = 0, o_shift: int = 0 +) -> tuple[list[str], list[str], list[tuple]]: + """ + Returns a trivial word alignment for the input words. + Args: + input_words: (list[str]) list of input words + i_shift: (int) input shift + o_shift: (int) output shift + Returns: + (tuple) Returns a tuple of input words, output words, and a trivial word alignment + """ + return input_words, input_words.copy(), get_trivial_alignment(N=len(input_words), i_shift=i_shift, o_shift=o_shift) diff --git a/nemo/collections/asr/inference/utils/manifest_io.py b/nemo/collections/asr/inference/utils/manifest_io.py new file mode 100644 index 000000000000..3a733d8e8c02 --- /dev/null +++ b/nemo/collections/asr/inference/utils/manifest_io.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import librosa + +from nemo.collections.asr.inference.utils.constants import DEFAULT_OUTPUT_DIR_NAME +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.common.parts.preprocessing.manifest import get_full_path + + +def make_abs_path(path: str) -> str: + """ + Make a path absolute + Args: + path: (str) Path to the file or folder + Returns: + (str) Absolute path + """ + path = path.strip() + if not path: + raise ValueError("Path cannot be empty") + if not os.path.isabs(path): + path = os.path.abspath(path) + return path + + +def get_audio_filepaths(audio_file: str, sort_by_duration: bool = True) -> list[str]: + """ + Get audio filepaths from a folder or a single audio file + Args: + audio_file: (str) Path to the audio file, folder or manifest file + sort_by_duration: (bool) If True, sort the audio files by duration from shortest to longest + Returns: + (list[str]) List of audio filepaths + """ + audio_file = audio_file.strip() + audio_file = make_abs_path(audio_file) + if os.path.isdir(audio_file): + filepaths = filter(lambda x: x.endswith(".wav"), os.listdir(audio_file)) + filepaths = [os.path.join(audio_file, x) for x in filepaths] + elif audio_file.endswith(".wav"): + filepaths = [audio_file] + elif audio_file.endswith((".json", ".jsonl")): + manifest = read_manifest(audio_file) + filepaths = [get_full_path(entry["audio_filepath"], audio_file) for entry in manifest] + else: + raise ValueError(f"audio_file `{audio_file}` need to be folder, audio file or manifest file") + + if sort_by_duration: + durations = [librosa.get_duration(path=audio_filepath) for audio_filepath in filepaths] + filepaths_with_durations = list(zip(filepaths, durations)) + filepaths_with_durations.sort(key=lambda x: x[1]) + filepaths = [x[0] for x in filepaths_with_durations] + return filepaths + + +def get_stem(file_path: str) -> str: + """ + Get the stem of a file path + Args: + file_path: (str) Path to the file + Returns: + (str) Filename with extension + """ + return file_path.split('/')[-1] + + +def dump_output(output: dict, output_filename: str, output_dir: str | None = None) -> None: + """ + Dump the transcriptions to a output file + Args: + output (dict): Pipeline output, structured as {stream_id: {"text": str, "segments": list}} + output_filename: (str) Path to the output file + output_dir: (str | None) Path to the output directory, if None, will write at the same level as the output file + """ + if output_dir is None: + # Create default output directory, if not provided + output_dir = os.path.dirname(output_filename) + output_dir = os.path.join(output_dir, DEFAULT_OUTPUT_DIR_NAME) + + os.makedirs(output_dir, exist_ok=True) + with open(output_filename, 'w') as fout: + for stream_id, data in sorted(output.items(), key=lambda x: x[0]): + audio_filepath = data["audio_filepath"] + text = data["text"] + segments = data["segments"] + stem = get_stem(audio_filepath) + stem = os.path.splitext(stem)[0] + json_filepath = os.path.join(output_dir, f"{stem}.json") + json_filepath = make_abs_path(json_filepath) + with open(json_filepath, 'w') as json_fout: + for segment in segments: + json_line = json.dumps(segment.to_dict(), ensure_ascii=False) + json_fout.write(f"{json_line}\n") + + item = {"audio_filepath": audio_filepath, "text": text, "json_filepath": json_filepath} + json.dump(item, fout, ensure_ascii=False) + fout.write('\n') + fout.flush() + + +def calculate_duration(audio_filepaths: list[str]) -> float: + """ + Calculate the duration of the audio files + Args: + audio_filepaths: (list[str]) List of audio filepaths + Returns: + (float) Total duration of the audio files + """ + total_duration = 0 + for audio_filepath in audio_filepaths: + duration = librosa.get_duration(path=audio_filepath) + total_duration += duration + return total_duration diff --git a/nemo/collections/asr/inference/utils/pipeline_utils.py b/nemo/collections/asr/inference/utils/pipeline_utils.py new file mode 100644 index 000000000000..dcdcaf2f7ab6 --- /dev/null +++ b/nemo/collections/asr/inference/utils/pipeline_utils.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +from functools import partial, wraps +from typing import Iterable + +import torch +from omegaconf import DictConfig, open_dict +from torch import Tensor + +from nemo.collections.asr.inference.utils.constants import BIG_EPSILON, SENTENCEPIECE_UNDERSCORE, SMALL_EPSILON +from nemo.collections.asr.parts.preprocessing.features import normalize_batch +from nemo.collections.asr.parts.utils.asr_confidence_utils import ( + get_confidence_aggregation_bank, + get_confidence_measure_bank, +) +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +def check_existance_of_required_attributes(obj: object, required_args: list[str]) -> None: + """ + Check if the required attributes exist in the object + Args: + obj: (object) Object to check the attributes of + required_args: (list[str]) List of required attributes + """ + not_found_args = [] + for arg in required_args: + if not hasattr(obj, arg): + not_found_args.append(arg) + if not_found_args: + raise ValueError(f"Required attributes not found: {not_found_args}") + + +def normalize_features(features: Tensor, feature_lens: Tensor = None) -> Tensor: + """Normalize the features. + Args: + features: (Tensor) features. Shape is torch.Size([B, C, T]). + feature_lens: (Tensor) feature lengths. Shape is torch.Size([B]). + Returns: + (Tensor) normalized features. Shape is torch.Size([B, C, T]). + """ + return normalize_batch(features, feature_lens, "per_feature")[0] + + +def ids_to_text_without_stripping(tokens: list[int], tokenizer: TokenizerSpec, sep: str = ' ') -> str: + """ + Convert a list of token IDs to text without stripping. + Args: + tokens: (list[int]) List of token IDs. + tokenizer: (TokenizerSpec) Tokenizer. + sep: (str) Separator between words. Default is ' '. + Returns: + (str) Text. + """ + pieces = tokenizer.ids_to_tokens(tokens) + text = "".join( + [(p.replace(SENTENCEPIECE_UNDERSCORE, sep) if p.startswith(SENTENCEPIECE_UNDERSCORE) else p) for p in pieces] + ) + return text + + +def memoize_normalization_mode(): + """ + Decorator to memoize the normalization mode. + In the first call, the normalization mode is detected and cached. + In the subsequent calls, the cached normalization mode is used. + """ + + def decorator(func): + mode = None # Cache the detected format + + @wraps(func) + def wrapper(log_probs: torch.Tensor) -> torch.Tensor: + nonlocal mode + + if mode is None: + ONE = torch.tensor(1.0, dtype=log_probs.dtype) + if torch.allclose(log_probs[0][0].sum(), ONE, atol=BIG_EPSILON): + # assume that softmax is already applied + mode = 'prob' + else: + if not torch.allclose(log_probs[0][0].exp().sum(), ONE, atol=BIG_EPSILON): + # It's neither prob nor log-softmax, need to apply log_softmax + mode = "logits" + else: + # It's already in log-softmax form + mode = "log_softmax" + + # Fast-path execution + if mode == "prob": + return torch.log(log_probs + SMALL_EPSILON) + elif mode == 'logits': + return torch.log_softmax(log_probs, dim=-1) + else: + return log_probs + + return wrapper + + return decorator + + +@memoize_normalization_mode() +def normalize_log_probs(log_probs: torch.Tensor) -> torch.Tensor: + """ + log_probs: (B, T, vocab_size) log probabilities + Returns: + (Tensor) normalized log probabilities. Shape is torch.Size([B, T, vocab_size]). + """ + # Ensure log_probs are normalized + return log_probs + + +def drop_trailing_features(features: Tensor, expected_feature_buffer_len: int) -> Tensor: + """Drop the trailing features if the number of features is greater than the expected feature buffer length. + Args: + features: (Tensor) features. Shape is torch.Size([B, C, T1]). + expected_feature_buffer_len: (int) Expected feature buffer length. + Returns: + (Tensor) features. Shape is torch.Size([B, C, T2]). + """ + if features.shape[2] > expected_feature_buffer_len: + features = features[:, :, :expected_feature_buffer_len] + return features + + +def make_preprocessor_deterministic(asr_model_cfg: DictConfig, disable_normalization: bool = True) -> DictConfig: + """ + Make the preprocessor deterministic by disabling normalization, dither and padding + Args: + asr_model_cfg: (DictConfig) ASR model configuration. + disable_normalization: (bool) Whether to disable normalization. Default is True. + Returns: + (DictConfig) ASR model configuration with deterministic preprocessor. + """ + # Enable config overwriting + with open_dict(asr_model_cfg): + # Normalization will be done per buffer in frame_bufferer + # Do not normalize whatever the model's preprocessor setting is + asr_model_cfg.preprocessor.dither = 0.0 + asr_model_cfg.preprocessor.pad_to = 0 + + if disable_normalization: + asr_model_cfg.preprocessor.normalize = "None" + + return asr_model_cfg + + +def get_confidence_utils(confidence_cfg: DictConfig) -> tuple: + """ + Get the confidence function and the confidence aggregator + Args: + confidence_cfg: (DictConfig) Confidence configuration. + Returns: + (tuple) Confidence function and the confidence aggregator. + """ + if confidence_cfg.method_cfg.name == "max_prob": + conf_type = "max_prob" + conf_alpha = 1.0 + else: + conf_type = f"entropy_{confidence_cfg.method_cfg.entropy_type}_{confidence_cfg.method_cfg.entropy_norm}" + conf_alpha = confidence_cfg.method_cfg.alpha + + conf_func = get_confidence_measure_bank()[conf_type] + conf_func = partial(conf_func, t=conf_alpha) + confidence_aggregator = get_confidence_aggregation_bank()[confidence_cfg.aggregation] + return conf_func, confidence_aggregator + + +def get_leading_punctuation_regex_pattern(puncts: set[str]) -> str: + """ + Get the regex pattern for the punctuation marks. + Args: + puncts (set[str]): Set of punctuation marks. + Returns: + (str) Regex pattern for the punctuation marks. + """ + if not puncts: + return "" + escaped_puncts = '|'.join(re.escape(punct) for punct in puncts) + return r'\s+(' + escaped_puncts + ')' + + +def get_repeated_punctuation_regex_pattern(puncts: set[str]) -> str: + """ + Get the regex pattern for the repeated punctuation marks. + Args: + puncts (set[str]): Set of punctuation marks. + Returns: + (str) Regex pattern for the repeated punctuation marks. + """ + if not puncts: + return "" + escaped_puncts = ''.join(re.escape(p) for p in puncts) + return r'([' + escaped_puncts + r']){2,}' + + +def update_punctuation_and_language_tokens_timestamps( + tokens: Tensor, timestamp: Tensor, tokens_to_move: set[int], underscore_id: int +) -> Tensor: + """ + RNNT models predict punctuations and language tokens at the end of the sequence. + Due to this, it appears as if there's a silence between the last word and the punctuation. + This function moves the tokens close to preceding word in the list. + Args: + tokens: (Tensor) Tokens tensor. + timestamp: (Tensor) Timestamps tensor. + tokens_to_move: (set[int]) Set of tokens to move. + underscore_id: (int) ID of the underscore token. + Returns: + (Tensor) Updated timestamps tensor. + """ + + n_tokens = tokens.shape[0] + if n_tokens != timestamp.shape[0]: + raise ValueError("Tokens and timestamps must have the same length") + + tokens_to_move_with_underscore = tokens_to_move.union({underscore_id}) + # If all tokens need moving, don't change timestamps (no content words to attach to) + only_special_tokens = all(token.item() in tokens_to_move_with_underscore for token in tokens) + if only_special_tokens: + return timestamp + + groups = [] + i = 0 + while i < n_tokens: + if tokens[i].item() in tokens_to_move_with_underscore: + start_idx = i + end_idx = i + j = i + 1 + while j < n_tokens and (tokens[j].item() in tokens_to_move_with_underscore): + if tokens[j].item() != underscore_id: + end_idx = j + j += 1 + if j > start_idx and end_idx >= start_idx: + left_timestamp = int(timestamp[start_idx - 1]) if start_idx > 0 else 0 + if start_idx == end_idx: + if tokens[start_idx].item() in tokens_to_move: + groups.append((start_idx, end_idx + 1, left_timestamp)) + else: + groups.append((start_idx, end_idx + 1, left_timestamp)) + i = j + else: + i += 1 + + updated_timestamps = timestamp.clone() + for start_idx, end_idx, left_timestamp in groups: + for k in range(start_idx, end_idx): + # Give all tokens_to_move the same timestamp as the preceding word + updated_timestamps[k] = left_timestamp + + return updated_timestamps + + +def adjust_vad_segments(vad_segments: Tensor, left_padding_size: float) -> Tensor | None: + """ + Adjust VAD segments for stateful mode by subtracting left_padding and applying clipping rules. + Args: + vad_segments: (Tensor) VAD segments tensor with shape [num_segments, 2] (start_time, end_time) + left_padding_size: (float) Amount of left padding in seconds to subtract from segments + Returns: + (Tensor | None) Adjusted VAD segments tensor or None if no valid segments are left. + """ + if vad_segments is None or len(vad_segments) == 0: + return vad_segments + + # Vectorized operations on the entire tensor + adjusted_segments = vad_segments - left_padding_size + + # Filter out segments that end before or at 0 + valid_mask = adjusted_segments[:, 1] > 0 + + if not valid_mask.any(): + return None + + adjusted_segments = adjusted_segments[valid_mask] + + # Clip start times to 0 + adjusted_segments[:, 0] = torch.clamp(adjusted_segments[:, 0], min=0.0) + + return adjusted_segments + + +def seconds_to_frames(seconds: float | int | Iterable[float | int], model_stride_in_secs: float) -> int | list[int]: + """ + Convert seconds to frames. + Args: + seconds: (float | int | Iterable[float | int]) Time in seconds + model_stride_in_secs: (float) Stride of the model in seconds + Returns: + (int | list[int]) Number of frames + """ + if isinstance(seconds, (float, int)): + return int(seconds / model_stride_in_secs) + + if isinstance(seconds, Iterable): + return [int(s / model_stride_in_secs) for s in seconds] + + raise ValueError(f"Invalid type for seconds: {type(seconds)}") diff --git a/nemo/collections/asr/inference/utils/progressbar.py b/nemo/collections/asr/inference/utils/progressbar.py new file mode 100644 index 000000000000..ef6663bbdaa5 --- /dev/null +++ b/nemo/collections/asr/inference/utils/progressbar.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import tqdm + + +class ProgressBar: + """ + Base class for progress bars. + """ + + def __init__(self, value: float = 0.0, total: float = 1.0): + """ + Initialize the ProgressBar. + Args: + value: (float) Initial value. + total: (float) Total value. Must be greater than zero. + """ + if total <= 0: + raise ValueError("Total must be greater than zero.") + if value < 0 or value > total: + raise ValueError("Initial value must be between 0 and total.") + + self.value = value + self.total = total + self.start_value = value + + def restart(self) -> None: + """Restart progress from the initial value.""" + self.value = self.start_value + + def increment(self, by: float) -> None: + """ + Increase progress but do not exceed total. + Args: + by: (float) Amount to increment. + """ + self.value = min(self.value + by, self.total) + + def update_bar(self, by: float) -> None: + """ + Update progress and call update. + Args: + by: (float) Amount to increment. + """ + self.increment(by) + self.update() + + def finish(self) -> None: + """Complete progress bar.""" + self.value = self.total + self.update(True) + + def update(self, is_end: bool = False) -> None: + """ + Abstract method for updating the progress bar. + Args: + is_end: (bool) Whether the progress bar is at the end. + """ + raise NotImplementedError("Subclasses must implement update method.") + + +class TQDMProgressBar(ProgressBar): + """TQDM progress bar wrapper.""" + + def __init__(self, value: float = 0.0, total: float = 1.0): + """ + Initialize the TQDMProgressBar. + Args: + value: (float) Initial value. + total: (float) Total value. + """ + super().__init__(value, total) + self.bar = tqdm(total=self.total, bar_format='{l_bar}{bar}') + self.prev_value = value + + def update(self, is_end: bool = False) -> None: + """ + Update tqdm progress bar. + Args: + is_end: (bool) Whether the progress bar is at the end. + """ + increment = self.value - self.prev_value + if increment > 0: + self.bar.update(increment) + self.prev_value = self.value + + if is_end: + self.bar.close() diff --git a/nemo/collections/asr/inference/utils/state_management_utils.py b/nemo/collections/asr/inference/utils/state_management_utils.py new file mode 100644 index 000000000000..cabec9a9a54d --- /dev/null +++ b/nemo/collections/asr/inference/utils/state_management_utils.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable + +from nemo.collections.asr.inference.utils.constants import POST_WORD_PUNCTUATION, PRE_WORD_PUNCTUATION +from nemo.collections.asr.inference.utils.text_segment import TextSegment, Word + + +def merge_timesteps(timesteps1: list, timesteps2: list) -> list: + """ + Merge two lists of timesteps by preserving the order and ensuring that the timesteps are in increasing order + Args: + timesteps1: (list) The first list of timesteps + timesteps2: (list) The second list of timesteps + Returns: + (list) The merged list of timesteps + """ + # If both lists are empty, return an empty list + if not timesteps1 and not timesteps2: + return [] + + # If timesteps1 is not empty and the first timestep is negative, + # shift all the timesteps by the absolute value of the first timestep + if timesteps1: + if (first := timesteps1[0]) < 0: # Assigns and checks in the same line + for i, t in enumerate(timesteps1): + timesteps1[i] = t - first + + # If timesteps2 is not empty and the first timestep is negative, + # shift all the timesteps by the absolute value of the first timestep + if timesteps2: + if (first := timesteps2[0]) < 0: + for i, t in enumerate(timesteps2): + timesteps2[i] = t - first + + # If the first list is empty, return the second list + if not timesteps1: + return timesteps2 + + # If the second list is empty, return the first list + if not timesteps2: + return timesteps1 + + # If the last timestep of the first list is greater than the first timestep of the second list, + # calculate the gap between the two timesteps and shift all the timesteps of the second list by the gap + if (gap := timesteps2[0] - timesteps1[-1]) <= 0: + return timesteps1 + [t + abs(gap) + 1 for t in timesteps2] + return timesteps1 + timesteps2 + + +def merge_segment_tail( + segment_head: TextSegment, + segment_tail: TextSegment, + conf_aggregator: Callable = None, +) -> TextSegment: + """ + Merge the segment_tail into the segment_head + Args: + segment_head: (TextSegment) The head segment + segment_tail: (TextSegment) The tail segment + conf_aggregator: (Callable) The function to aggregate the confidence + Returns: + (TextSegment) The merged segment + """ + head = segment_head.copy() + + # for models that have built-in punctuation, we need to rm the last punctuation before merging + if head.text and (last_char := head.text[-1]) and last_char in POST_WORD_PUNCTUATION: + head.text = head.text.rstrip(last_char) + + # merge the segment_tail text + head.text += segment_tail.text + + # update the end timestep + head.end = segment_tail.end + + # update the confidence + if conf_aggregator is not None: + head.conf = conf_aggregator([head.conf, segment_tail.conf]) + + return head + + +def merge_word_tail( + word_head: Word, word_tail: Word, pnc_word_head: Word = None, conf_aggregator: Callable = None +) -> tuple[Word, Word]: + """ + Merge the word_tail into the word_head + Args: + word_head: (Word) The head word + word_tail: (Word) The tail word + pnc_word_head: (Word) The head word with punctuation/capitalization + conf_aggregator: (Callable) The function to aggregate the confidence + Returns: + (tuple[Word, Word]) The merged word and the head word with punctuation/capitalization + """ + + head = word_head.copy() + head_text = head.text + + # for models that have built-in punctuation, we need to rm the last punctuation before merging + if head_text and (last_char := head_text[-1]) and last_char in POST_WORD_PUNCTUATION: + head.text = head_text.rstrip(last_char) + + # merge the word_tail text + head.text += word_tail.text + + # update the end timestep + head.end = word_tail.end + + # update the confidence + if conf_aggregator is not None: + head.conf = conf_aggregator([head.conf, word_tail.conf]) + + pnc_head = None + if pnc_word_head is not None: + + last_char = pnc_word_head.text[-1] if pnc_word_head.text else None + first_char = pnc_word_head.text[0] if pnc_word_head.text else None + + pnc_head = head.copy() + + if last_char in POST_WORD_PUNCTUATION: + if pnc_head.text and pnc_head.text[-1] not in POST_WORD_PUNCTUATION: + pnc_head.text = pnc_head.text + last_char + + if first_char in PRE_WORD_PUNCTUATION: + if pnc_head.text and pnc_head.text[0] not in PRE_WORD_PUNCTUATION: + pnc_head.text = first_char + pnc_head.text + + if first_char and first_char.isupper(): + pnc_head.capitalize() + + return head, pnc_head + + +def find_max_overlap(state_tokens: list, new_tokens: list, limit: int) -> int: + """ + Finds the maximum overlap between the state_tokens suffix and the new_tokens prefix + Args: + state_tokens: (list) The list of state tokens + new_tokens: (list) The list of new tokens + limit: (int) The limit on the overlap + Returns: + (int) The maximum overlap within the limit + """ + max_overlap = 0 + for k in range(1, min(len(state_tokens), len(new_tokens), limit) + 1): + if state_tokens[-k:] == new_tokens[:k]: + max_overlap = k + return max_overlap + + +def detect_overlap( + state_tokens: list[int], + state_timesteps: list[float], + new_tokens: list[int], + new_timesteps: list[float], + overlap_search_th: int = 3, + close_in_time_th: float = 2.0, +) -> int: + """ + Detect the overlap between state_tokens and new_tokens + Args: + state_tokens: (list[int]) The list of state tokens + state_timesteps: (list[float]) The list of state timesteps + new_tokens: (list[int]) The list of new tokens + new_timesteps: (list[float]) The list of new timesteps + overlap_search_th: (int) The threshold on the overlap + close_in_time_th: (float) The threshold on the close in time + Returns: + (int) The overlap between the state_tokens and the new_tokens + """ + overlap = 0 + if state_tokens: + overlap = find_max_overlap(state_tokens, new_tokens, overlap_search_th) + if overlap > 0: + close_in_time = (new_timesteps[overlap - 1] - state_timesteps[-overlap]) <= close_in_time_th + overlap = overlap if close_in_time else 0 + return overlap diff --git a/nemo/collections/asr/inference/utils/text_segment.py b/nemo/collections/asr/inference/utils/text_segment.py new file mode 100644 index 000000000000..77bc1c50e5e8 --- /dev/null +++ b/nemo/collections/asr/inference/utils/text_segment.py @@ -0,0 +1,319 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import lru_cache + +from nemo.collections.asr.inference.utils.constants import DEFAULT_SEMIOTIC_CLASS, SEP_REPLACEABLE_PUNCTUATION + + +@lru_cache(maxsize=5) +def get_translation_table(punct_marks_frozen: frozenset[str], sep: str) -> dict: + """ + Create and cache translation table for text normalization. + + Args: + punct_marks_frozen (frozenset[str]): Frozen set of punctuation marks to process + sep (str): Separator to replace certain punctuation marks + + Returns: + (dict) Translation table for str.translate() + """ + replace_map = {mark: sep if mark in SEP_REPLACEABLE_PUNCTUATION else "" for mark in punct_marks_frozen} + return str.maketrans(replace_map) + + +def normalize_text(text: str, punct_marks: set[str], sep: str) -> str: + """ + Helper to normalize text by removing/replacing punctuation and lowercasing. + + Args: + text (str): Text to normalize + punct_marks (set[str]): Set of punctuation marks to process + sep (str): Separator to replace certain punctuation marks + + Returns: + (str) Normalized text + """ + trans_table = get_translation_table(frozenset(punct_marks), sep) + return text.translate(trans_table).lower() + + +def validate_init_params( + text: str, start: float, end: float, conf: float, semiotic_class: str = None, strict: bool = False +) -> None: + """ + Validate initialization parameters. + Args: + text: (str) Text to validate + start: (float) Start time + end: (float) End time + conf: (float) Confidence score + semiotic_class: (str) Semiotic class + strict: (bool) Whether to strict validation + """ + if not isinstance(text, str): + raise TypeError(f"text must be a string, got {type(text).__name__}") + if not isinstance(start, (int, float)): + raise TypeError(f"start must be numeric, got {type(start).__name__}") + if not isinstance(end, (int, float)): + raise TypeError(f"end must be numeric, got {type(end).__name__}") + if not isinstance(conf, (int, float)): + raise TypeError(f"conf must be numeric, got {type(conf).__name__}") + + if semiotic_class is not None and not isinstance(semiotic_class, str): + raise TypeError(f"semiotic_class must be a string, got {type(semiotic_class).__name__}") + + if strict: + if start >= end: + raise ValueError(f"start time ({start}) must be less than end time ({end})") + if conf < 0 or conf > 1: + raise ValueError(f"confidence ({conf}) must be between 0 and 1") + + +class TextSegment: + """ + Text segment class. + Represents a continuous text segment with a start time, end time, and confidence score. + """ + + __slots__ = ['_text', '_start', '_end', '_conf'] + + def __init__(self, text: str, start: float, end: float, conf: float) -> None: + """ + Initialize a TextSegment instance. + + Args: + text: The content of the text segment + start: Start time in seconds + end: End time in seconds + conf: Confidence score [0.0, 1.0] + Raises: + ValueError: If start >= end or if confidence is negative + TypeError: If text is not a string + """ + validate_init_params(text, start, end, conf, strict=True) + + self._text = text + self._start = start + self._end = end + self._conf = conf + + @property + def text(self) -> str: + """The content of the text segment.""" + return self._text + + @property + def start(self) -> float: + """Start time of the text segment in seconds.""" + return self._start + + @property + def end(self) -> float: + """End time of the text segment in seconds.""" + return self._end + + @property + def duration(self) -> float: + """Duration of the text segment in seconds.""" + return self._end - self._start + + @property + def conf(self) -> float: + """Confidence score of the text segment.""" + return self._conf + + @text.setter + def text(self, value: str) -> None: + """Set the content of the text segment.""" + if not isinstance(value, str): + raise TypeError(f"text must be a string, got {type(value).__name__}") + self._text = value + + @start.setter + def start(self, value: float) -> None: + """Set the start time.""" + if not isinstance(value, (int, float)): + raise TypeError(f"start time must be numeric, got {type(value).__name__}") + self._start = value + + @end.setter + def end(self, value: float) -> None: + """Set the end time.""" + if not isinstance(value, (int, float)): + raise TypeError(f"end must be numeric, got {type(value).__name__}") + self._end = value + + @conf.setter + def conf(self, value: float) -> None: + """Set the confidence score.""" + if not isinstance(value, (int, float)): + raise TypeError(f"conf must be numeric, got {type(value).__name__}") + if value < 0 or value > 1: + raise ValueError(f"confidence ({value}) must be between 0 and 1") + self._conf = value + + def copy(self) -> 'TextSegment': + """ + Create a deep copy of this TextSegment instance. + + Returns: + A new TextSegment instance with identical properties + """ + return TextSegment(text=self.text, start=self.start, end=self.end, conf=self.conf) + + def capitalize(self) -> None: + """Capitalize first letter of the text segment.""" + self._text = self._text.capitalize() + + def with_normalized_text(self, punct_marks: set[str], sep: str = "") -> 'TextSegment': + """ + Create a new TextSegment with normalized text (punctuation removed/replaced and lowercased). + + Args: + punct_marks (set[str]): Set of punctuation marks to process + sep: Separator to replace certain punctuation marks + + Returns: + New TextSegment instance with normalized text + """ + # Return new instance instead of modifying in place + obj_copy = self.copy() + obj_copy._text = normalize_text(self._text, punct_marks, sep) # Direct access + return obj_copy + + def normalize_text_inplace(self, punct_marks: set[str], sep: str = "") -> None: + """ + Normalize text in place (punctuation removed/replaced and lowercased). + + Args: + punct_marks (set[str]): Set of punctuation marks to process + sep (str): Separator to replace certain punctuation marks + + Note: + This method modifies the current instance. Consider using + with_normalized_text() for a functional approach. + """ + self._text = normalize_text(self._text, punct_marks, sep) # Direct access + + def to_dict(self) -> dict: + """ + Convert the TextSegment to a JSON-compatible dictionary. + """ + return { + "text": self.text, + "start": self.start, + "end": self.end, + "conf": self.conf, + } + + +class Word(TextSegment): + """ + Word class. + Represents a word with a text, start time, end time, confidence score, and semiotic class. + """ + + __slots__ = ['_semiotic_class'] + + def __init__( + self, text: str, start: float, end: float, conf: float, semiotic_class: str = DEFAULT_SEMIOTIC_CLASS + ) -> None: + """ + Initialize a Word instance. + + Args: + text: The text content of the word + start: Start time in seconds + end: End time in seconds + conf: Confidence score [0.0, 1.0] + semiotic_class: Semiotic class of the word + + Raises: + ValueError: If start >= end or if confidence is negative + TypeError: If text is not a string + """ + validate_init_params(text, start, end, conf, semiotic_class, strict=True) + super().__init__(text, start, end, conf) + self._semiotic_class = semiotic_class + + @property + def semiotic_class(self) -> str: + """Semiotic class of the word.""" + return self._semiotic_class + + @semiotic_class.setter + def semiotic_class(self, value: str) -> None: + """Set the semiotic class.""" + if not isinstance(value, str): + raise TypeError(f"semiotic_class must be a string, got {type(value).__name__}") + self._semiotic_class = value + + def copy(self) -> 'Word': + """ + Create a deep copy of this Word instance. + + Returns: + A new Word instance with identical properties + """ + return Word(text=self.text, start=self.start, end=self.end, conf=self.conf, semiotic_class=self.semiotic_class) + + def to_dict(self) -> dict: + """ + Convert the Word to a JSON-compatible dictionary. + """ + return super().to_dict() | {"semiotic_class": self.semiotic_class} + + +def join_segments(segments: list[list[TextSegment]], sep: str) -> list[str]: + """ + Join the text segments to form transcriptions. + + Args: + segments (list[list[TextSegment]]): List of text segment sequences to join + sep (str): Separator to use when joining text segments + + Returns: + List of transcriptions, one for each text segment sequence + """ + return [sep.join([s.text for s in items]) for items in segments] + + +def normalize_segments_inplace( + segments: list[TextSegment] | list[list[TextSegment]], punct_marks: set[str], sep: str = ' ' +) -> None: + """ + Normalize text in text segments by removing punctuation and converting to lowercase. + + This function modifies the text segments in-place by calling normalize_text_inplace + on each TextSegment object. It handles both flat lists of text segments and nested lists. + + Args: + segments (list[TextSegment] | list[list[TextSegment]]): List of TextSegment objects or list of lists of TextSegment objects + punct_marks (set[str]): Set of punctuation marks to be processed + sep (str): Separator to replace certain punctuation marks (default: ' ') + + Note: + This function modifies the input text segments in-place. The original text + content of the text segments will be permanently changed. + """ + for item in segments: + if isinstance(item, list): + for segment in item: + segment.normalize_text_inplace(punct_marks, sep) + elif isinstance(item, TextSegment): + item.normalize_text_inplace(punct_marks, sep) + else: + raise ValueError(f"Invalid item type: {type(item)}. Expected `TextSegment` or `List[TextSegment]`.") diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 1ff21cece1e2..cffc94d276e3 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -303,7 +303,6 @@ def __init__( f"{self} got an invalid value for either n_window_size or " f"n_window_stride. Both must be positive ints." ) - logging.info(f"PADDING: {pad_to}") self.sample_rate = sample_rate self.win_length = n_window_size diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index 5e49843ba4be..db55c673653e 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -21,3 +21,4 @@ soundfile sox<=1.5.0 kaldialign<=0.9.1 whisper_normalizer +diskcache diff --git a/tests/collections/asr/inference/__init__.py b/tests/collections/asr/inference/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/collections/asr/inference/test_audio_bufferer.py b/tests/collections/asr/inference/test_audio_bufferer.py new file mode 100644 index 000000000000..249b1ca6b420 --- /dev/null +++ b/tests/collections/asr/inference/test_audio_bufferer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from nemo.collections.asr.inference.streaming.buffering.audio_bufferer import AudioBufferer, BatchedAudioBufferer +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.multi_stream import MultiStream + + +@pytest.fixture(scope="module") +def test_audios(): + return torch.ones(83200), torch.ones(118960) + + +class TestAudioBufferer: + + @pytest.mark.unit + def test_audio_bufferer(self, test_audios): + for audio in test_audios: + stream = MonoStream(16000, frame_size_in_secs=2.5, stream_id=0, pad_last_frame=False) + stream.load_audio(audio, options=None) + + frame_bufferer = AudioBufferer(16000, buffer_size_in_secs=5.0) + + for frame in iter(stream): + frame = frame[0] + frame_bufferer.update(frame) + buffer = frame_bufferer.get_buffer() + + assert len(buffer) == frame_bufferer.buffer_size + assert torch.allclose(buffer[-frame.size :], frame.samples, atol=1e-5) + + +class TestBatchedAudioBufferer: + + @pytest.mark.unit + def test_batched_audio_bufferer(self, test_audios): + + multi_stream = MultiStream(n_frames_per_stream=1) + for stream_id, audio in enumerate(test_audios): + stream = MonoStream(16000, 2.5, stream_id=stream_id, pad_last_frame=False) + stream.load_audio(audio, options=None) + multi_stream.add_stream(stream, stream_id=stream_id) + + batched_audio_bufferer = BatchedAudioBufferer(16000, buffer_size_in_secs=5.0) + + for frames in iter(multi_stream): + buffered_frames, left_paddings = batched_audio_bufferer.update(frames) + for idx, frame in enumerate(frames): + frame_buffer = buffered_frames[idx] + assert torch.allclose(frame_buffer[-frame.size :], frame.samples, atol=1e-5) diff --git a/tests/collections/asr/inference/test_bpe_decoder.py b/tests/collections/asr/inference/test_bpe_decoder.py new file mode 100644 index 000000000000..42bafd45a392 --- /dev/null +++ b/tests/collections/asr/inference/test_bpe_decoder.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.asr.inference.model_wrappers.ctc_inference_wrapper import CTCInferenceWrapper +from nemo.collections.asr.inference.utils.bpe_decoder import BPEDecoder +from nemo.collections.asr.inference.utils.text_segment import TextSegment, Word +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig + + +@pytest.fixture(scope="module") +def bpe_decoder(): + asr_model = CTCInferenceWrapper( + model_name="stt_en_conformer_ctc_small", + decoding_cfg=CTCDecodingConfig(), + device="cuda" if torch.cuda.is_available() else "cpu", + ) + return BPEDecoder( + vocabulary=asr_model.get_vocabulary(), + tokenizer=asr_model.tokenizer, + confidence_aggregator=min, + asr_supported_puncts=asr_model.supported_punctuation(), + word_boundary_tolerance=0.0, # Set to 0.0 for easy testing + token_duration_in_secs=asr_model.get_model_stride(in_secs=True), + ) + + +class TestBPEDecoder: + + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "text", + [ + "the quick brown fox jumps over the lazy dog", + "lorem ipsum dolor sit amet", + "this a test sentence", + ], + ) + def test_group_tokens_into_words(self, bpe_decoder, text): + ground_truth_words = text.split() + tokens = bpe_decoder.tokenizer.text_to_ids(text) + n_tokens = len(tokens) + timestamps = [float(i) for i in range(n_tokens)] + confidences = [0.1] * n_tokens + + words, need_merge = bpe_decoder.group_tokens_into_words(tokens, timestamps, confidences) + assert len(words) == len(ground_truth_words) + prev_word_end = -1 + for word, ground_truth_word in zip(words, ground_truth_words): + assert isinstance(word, Word) + assert word.text == ground_truth_word + assert word.conf == 0.1 + assert word.end > word.start and word.start >= prev_word_end + prev_word_end = word.end + assert need_merge == False + + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "text", + [ + "the quick brown fox jumps over the lazy dog", + "lorem ipsum dolor sit amet", + "this a test sentence", + ], + ) + def test_group_tokens_into_segment(self, bpe_decoder, text): + tokens = bpe_decoder.tokenizer.text_to_ids(text) + n_tokens = len(tokens) + timestamps = [float(i) for i in range(n_tokens)] + confidences = [0.1] * n_tokens + + segment, need_merge = bpe_decoder.group_tokens_into_segment(tokens, timestamps, confidences) + assert isinstance(segment, TextSegment) + assert need_merge == False + assert segment.text == text + assert segment.start == 0.0 + assert segment.end == (n_tokens - 1) * bpe_decoder.token_duration_in_secs + assert segment.conf == 0.1 diff --git a/tests/collections/asr/inference/test_enums.py b/tests/collections/asr/inference/test_enums.py new file mode 100644 index 000000000000..e1c1f5945496 --- /dev/null +++ b/tests/collections/asr/inference/test_enums.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo.collections.asr.inference.utils.enums import ( + ASRDecodingType, + ASROutputGranularity, + FeatureBufferPaddingMode, + PipelineType, + RequestType, +) + + +class TestEnums: + + @pytest.mark.unit + def test_ASRDecodingType(self): + assert ASRDecodingType.from_str("ctc") == ASRDecodingType.CTC + assert ASRDecodingType.from_str("RNNT") == ASRDecodingType.RNNT + with pytest.raises(ValueError): + ASRDecodingType.from_str("invalid") + + @pytest.mark.unit + def test_ASROutputGranularity(self): + assert ASROutputGranularity.from_str("word") == ASROutputGranularity.WORD + assert ASROutputGranularity.from_str("segment") == ASROutputGranularity.SEGMENT + with pytest.raises(ValueError): + ASROutputGranularity.from_str("invalid") + + @pytest.mark.unit + def test_PipelineType(self): + assert PipelineType.from_str("buffered") == PipelineType.BUFFERED + assert PipelineType.from_str("cache_aware") == PipelineType.CACHE_AWARE + with pytest.raises(ValueError): + PipelineType.from_str("invalid") + + @pytest.mark.unit + def test_RequestType(self): + assert RequestType.from_str("frame") == RequestType.FRAME + assert RequestType.from_str("feature_buffer") == RequestType.FEATURE_BUFFER + with pytest.raises(ValueError): + RequestType.from_str("invalid") + + @pytest.mark.unit + def test_FeatureBufferPaddingMode(self): + assert FeatureBufferPaddingMode.from_str("left") == FeatureBufferPaddingMode.LEFT + assert FeatureBufferPaddingMode.from_str("right") == FeatureBufferPaddingMode.RIGHT + with pytest.raises(ValueError): + FeatureBufferPaddingMode.from_str("invalid") diff --git a/tests/collections/asr/inference/test_framing.py b/tests/collections/asr/inference/test_framing.py new file mode 100644 index 000000000000..e6bf6db2bfda --- /dev/null +++ b/tests/collections/asr/inference/test_framing.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.asr.inference.streaming.framing.mono_stream import MonoStream +from nemo.collections.asr.inference.streaming.framing.multi_stream import MultiStream + + +@pytest.fixture(scope="module") +def test_audios(): + return torch.ones(83200), torch.ones(118960) + + +class TestMonoWavStream: + + @pytest.mark.unit + def test_mono_wav_stream_no_pad(self, test_audios): + for audio in test_audios: + stream = MonoStream(16000, 2.5, stream_id=0, pad_last_frame=False) + stream.load_audio(audio, options=None) + audio_len_in_samples = stream.samples.shape[0] + i = 0 + total_samples = 0 + for frame in iter(stream): + total_samples += len(frame[0].samples) + i += 1 + assert total_samples == audio_len_in_samples + assert frame[0].is_last == True + + @pytest.mark.unit + def test_mono_wav_stream_with_pad(self, test_audios): + for audio in test_audios: + stream = MonoStream(16000, 2.5, stream_id=0, pad_last_frame=True) + stream.load_audio(audio, options=None) + for frame in iter(stream): + last_frame_size = frame[0].size + assert last_frame_size == stream.frame_size + + +class TestMultiStream: + + @pytest.mark.unit + def test_multi_stream(self, test_audios): + multi_stream = MultiStream(n_frames_per_stream=1) + audio_len_in_samples = {} + for stream_id, audio in enumerate(test_audios): + stream = MonoStream(16000, 2.5, stream_id=stream_id, pad_last_frame=False) + stream.load_audio(audio, options=None) + multi_stream.add_stream(stream, stream_id=stream_id) + audio_len_in_samples[stream_id] = stream.samples.shape[0] + + total_samples = {} + for frames in iter(multi_stream): + for frame in frames: + total_samples[frame.stream_id] = total_samples.get(frame.stream_id, 0) + frame.size + + assert total_samples == audio_len_in_samples diff --git a/tests/collections/asr/inference/test_greedy_decoder.py b/tests/collections/asr/inference/test_greedy_decoder.py new file mode 100644 index 000000000000..f2aa65494fff --- /dev/null +++ b/tests/collections/asr/inference/test_greedy_decoder.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_ctc_decoder import CTCGreedyDecoder +from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_rnnt_decoder import RNNTGreedyDecoder + + +class TestCTCGreedyDecoder: + + @pytest.mark.unit + def test_ctc_greedy_decoder(self): + + vocab = ["a", "b", "c", "d"] + decoder = CTCGreedyDecoder(vocabulary=vocab) + + assert decoder.blank_id == len(vocab) + assert decoder.is_token_silent(len(vocab)) == True + + for i in range(len(vocab)): + assert decoder.is_token_silent(i) == False + + for i in range(len(vocab)): + assert decoder.is_token_start_of_word(i) == False + + assert decoder.count_silent_tokens([0, 1, 2, 3, 4], 0, 5) == 1 + assert decoder.count_silent_tokens([0, 1, 2, 3, 4], 0, 3) == 0 + assert decoder.first_non_silent_token([1, 2, 3, 4], 0, 5) == 0 + + log_probs = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.05], [0.4, 0.3, 0.2, 0.1, 0.05]]) + assert decoder.get_labels(log_probs) == log_probs.argmax(dim=-1).tolist() + + @pytest.mark.unit + def test_ctc_greedy_decoder_with_previous_token(self): + vocab = ["a", "b", "c", "d"] + decoder = CTCGreedyDecoder(vocabulary=vocab) + + log_probs = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.05], [0.1, 0.2, 0.3, 0.4, 0.05], [0.4, 0.3, 0.2, 0.1, 0.05]]) + last_token_id = 3 + output = decoder(log_probs, compute_confidence=False, previous=last_token_id) + assert output["tokens"] == [0] + assert output["timesteps"] == [2] + + output = decoder(log_probs, compute_confidence=False, previous=None) + assert output["tokens"] == [3, 0] + assert output["timesteps"] == [0, 2] + + +class TestRNNTGreedyDecoder: + + @pytest.mark.unit + def test_rnnt_greedy_decoder(self): + + vocab = ["a", "b", "c", "d"] + decoder = RNNTGreedyDecoder(vocab) + + blank_id = len(vocab) + assert decoder.blank_id == blank_id + assert decoder.is_token_silent(blank_id) == True + + for i in range(len(vocab)): + assert decoder.is_token_silent(i) == False + + for i in range(len(vocab)): + assert decoder.is_token_start_of_word(i) == False + + assert decoder.count_silent_tokens([0, 1, 2, 3, 4], 0, 5) == 1 + assert decoder.count_silent_tokens([0, 1, 2, 3, 4], 0, 3) == 0 + assert decoder.first_non_silent_token([1, 2, 3, 4], 0, 5) == 0 diff --git a/tests/collections/asr/inference/test_greedy_endpointing.py b/tests/collections/asr/inference/test_greedy_endpointing.py new file mode 100644 index 000000000000..4527a62fe97e --- /dev/null +++ b/tests/collections/asr/inference/test_greedy_endpointing.py @@ -0,0 +1,231 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_ctc_endpointing import CTCGreedyEndpointing +from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_rnnt_endpointing import RNNTGreedyEndpointing +from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames + + +class TestGreedyEndpointing: + + @pytest.mark.unit + @pytest.mark.parametrize( + "inputs, expected", + [ + ((100, 80), 2), + ((100, 100), 1), + ((100, 40), 3), + ], + ) + def test_millisecond_to_frames(self, inputs, expected): + assert millisecond_to_frames(*inputs) == expected + + @pytest.mark.unit + def test_endpointing_with_negative_stop_history_eou(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls(vocabulary=["a", "b", "c"], ms_per_timestep=100, stop_history_eou=-1) + if isinstance(greedy_endpointing, CTCGreedyEndpointing): + b = len(greedy_endpointing.greedy_ctc_decoder.vocabulary) + else: + b = len(greedy_endpointing.greedy_rnnt_decoder.vocabulary) + emissions = [0, 1, 2, b, b, b, b, b, b, b, b, b] + + # False case, because stop_history_eou = -1 + assert greedy_endpointing.detect_eou_given_emissions(emissions, 3) == (False, -1) + + @pytest.mark.unit + def test_endpointing_with_positive_stop_history_eou(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls( + vocabulary=["a", "b", "c"], ms_per_timestep=20, stop_history_eou=100, residue_tokens_at_end=0 + ) + if isinstance(greedy_endpointing, CTCGreedyEndpointing): + b = len(greedy_endpointing.greedy_ctc_decoder.vocabulary) + else: + b = len(greedy_endpointing.greedy_rnnt_decoder.vocabulary) + emissions = [0, 1, 2, b, b, b, b, b, b, b, b, b] + + for pivot_point in range(len(emissions)): + eou_detected, eou_detected_at = greedy_endpointing.detect_eou_given_emissions(emissions, pivot_point) + assert eou_detected == True + + @pytest.mark.unit + def test_detect_eou_given_timestamps_empty_inputs(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls( + vocabulary=["a", "b", "c"], ms_per_timestep=80, stop_history_eou=100, residue_tokens_at_end=0 + ) + + # Test with empty timesteps and tokens + timesteps = torch.tensor([]) + tokens = torch.tensor([]) + alignment_length = 10 + + eou_detected, eou_detected_at = greedy_endpointing.detect_eou_given_timestamps( + timesteps, tokens, alignment_length + ) + assert eou_detected == False + assert eou_detected_at == -1 + + @pytest.mark.unit + def test_detect_eou_given_timestamps_disabled_stop_history(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls( + vocabulary=["a", "b", "c"], + ms_per_timestep=80, + stop_history_eou=-1, # Disabled + residue_tokens_at_end=0, + ) + + timesteps = torch.tensor([0, 2, 4, 6]) + tokens = torch.tensor([0, 1, 2, 3]) + alignment_length = 10 + + eou_detected, eou_detected_at = greedy_endpointing.detect_eou_given_timestamps( + timesteps, tokens, alignment_length + ) + assert eou_detected == False + assert eou_detected_at == -1 + + @pytest.mark.unit + def test_detect_eou_given_timestamps_trailing_silence(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls( + vocabulary=["a", "b", "c"], ms_per_timestep=20, stop_history_eou=80, residue_tokens_at_end=0 + ) + + # Last token at position 5, alignment_length is 10 + # Trailing silence = 10 - 4 - 1 = 5 frames > stop_history_eou (4) + timesteps = torch.tensor([0, 1, 2, 3, 4]) + tokens = torch.tensor([0, 1, 2, 3, 4]) + alignment_length = 10 + + eou_detected, eou_detected_at = greedy_endpointing.detect_eou_given_timestamps( + timesteps, tokens, alignment_length + ) + assert eou_detected == True + # eou_detected_at = 4 + 1 + 4//2 = 7 + assert eou_detected_at == 7 + + @pytest.mark.unit + def test_detect_eou_given_timestamps_no_trailing_silence(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls( + vocabulary=["a", "b", "c"], ms_per_timestep=20, stop_history_eou=80, residue_tokens_at_end=0 + ) + + # Last token at position 8, alignment_length is 10 + # Trailing silence = 10 - 8 - 1 = 1 frame < stop_history_eou (4) + timesteps = torch.tensor([0, 1, 2, 3, 8]) + tokens = torch.tensor([0, 1, 2, 3, 4]) + alignment_length = 10 + + eou_detected, eou_detected_at = greedy_endpointing.detect_eou_given_timestamps( + timesteps, tokens, alignment_length + ) + assert eou_detected == False + assert eou_detected_at == -1 + + @pytest.mark.unit + def test_detect_eou_given_timestamps_gap_detection(self): + for endpointing_cls in [CTCGreedyEndpointing, RNNTGreedyEndpointing]: + greedy_endpointing = endpointing_cls( + vocabulary=["a", "b", "c"], ms_per_timestep=20, stop_history_eou=80, residue_tokens_at_end=0 + ) + + # Large gap between tokens: 8 - 2 - 1 = 5 frames > stop_history_eou (4) + timesteps = torch.tensor([0, 2, 8, 9]) + tokens = torch.tensor([0, 1, 2, 3]) + alignment_length = 10 + + eou_detected, eou_detected_at = greedy_endpointing.detect_eou_given_timestamps( + timesteps, tokens, alignment_length + ) + assert eou_detected == True + # eou_detected_at = 2 + 1 + 4//2 = 5 + assert eou_detected_at == 5 + + @pytest.mark.unit + def test_rnnt_vad_endpointing_disabled(self): + rnnt_endpointing = RNNTGreedyEndpointing( + vocabulary=["a", "b", "c"], + ms_per_timestep=100, + effective_buffer_size_in_secs=None, # VAD disabled + stop_history_eou=100, + ) + + # Test with VAD segments - should raise ValueError since VAD is disabled + vad_segments = torch.tensor([[0.0, 1.0], [1.5, 2.5]]) + + with pytest.raises( + ValueError, match="Effective buffer size in seconds is required for VAD-based EOU detection" + ): + rnnt_endpointing.detect_eou_vad(vad_segments) + + @pytest.mark.unit + def test_rnnt_vad_endpointing_enabled_no_eou(self): + rnnt_endpointing = RNNTGreedyEndpointing( + vocabulary=["a", "b", "c"], + ms_per_timestep=100, + effective_buffer_size_in_secs=2.0, # VAD enabled + stop_history_eou=100, + ) + + # Test with VAD segments that don't trigger EOU + vad_segments = torch.tensor([[0.0, 1.45], [1.5, 2.0]]) + eou_detected, eou_detected_at = rnnt_endpointing.detect_eou_vad(vad_segments, stop_history_eou=100) + + assert eou_detected == False + assert eou_detected_at == -1 + + @pytest.mark.unit + def test_rnnt_vad_endpointing_enabled_with_eou(self): + rnnt_endpointing = RNNTGreedyEndpointing( + vocabulary=["a", "b", "c"], + ms_per_timestep=100, + effective_buffer_size_in_secs=2.0, # VAD enabled + stop_history_eou=100, + ) + + # Test with VAD segments that should trigger EOU + # Create segments with enough silence to trigger EOU + vad_segments = torch.tensor([[0.0, 0.5], [1.0, 2.0]]) # Gap of 0.5s between segments + eou_detected, eou_detected_at = rnnt_endpointing.detect_eou_vad(vad_segments, stop_history_eou=100) + + # This should detect EOU if the silence gap is sufficient + # The exact behavior depends on the VAD logic implementation + assert eou_detected == True + assert eou_detected_at == 5 + + @pytest.mark.unit + def test_rnnt_vad_endpointing_enabled_with_eou_at_end(self): + rnnt_endpointing = RNNTGreedyEndpointing( + vocabulary=["a", "b", "c"], + ms_per_timestep=100, + effective_buffer_size_in_secs=2.0, # VAD enabled + stop_history_eou=100, + ) + + # Test with VAD segments that should trigger EOU + # Create segments with enough silence to trigger EOU + vad_segments = torch.tensor([[0.0, 0.5], [1.0, 1.8]]) # Gap of 0.5s between segments + eou_detected, eou_detected_at = rnnt_endpointing.detect_eou_vad(vad_segments, stop_history_eou=100) + + # This should detect EOU if the silence gap is sufficient + # The exact behavior depends on the VAD logic implementation + assert eou_detected == True + assert eou_detected_at == 18 diff --git a/tests/collections/asr/inference/test_itn.py b/tests/collections/asr/inference/test_itn.py new file mode 100644 index 000000000000..a0cb0d29cd8a --- /dev/null +++ b/tests/collections/asr/inference/test_itn.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer + + +@pytest.fixture(scope="module") +def en_itn_model(): + return AlignmentPreservingInverseNormalizer( + lang="en", input_case=AlignmentPreservingInverseNormalizer.LOWER_CASED, cache_dir=None + ) + + +@pytest.fixture(scope="module") +def de_itn_model(): + return AlignmentPreservingInverseNormalizer( + lang="de", input_case=AlignmentPreservingInverseNormalizer.LOWER_CASED, cache_dir=None + ) + + +@pytest.fixture(scope="module") +def es_itn_model(): + return AlignmentPreservingInverseNormalizer( + lang="es", input_case=AlignmentPreservingInverseNormalizer.LOWER_CASED, cache_dir=None + ) + + +class TestAlignmentPreservingInverseNormalizer: + + @pytest.mark.unit + def test_word_alignment_cardinal_en(self, en_itn_model): + text = "zzz minus twenty five thousand thirty seven zzz" + iwords, owords, alignment = en_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["zzz", "minus", "twenty", "five", "thousand", "thirty", "seven", "zzz"] + assert owords == ["zzz", "-25037", "zzz"] + assert alignment == [([0], [0], "name"), ([1, 2, 3, 4, 5, 6], [1], "cardinal"), ([7], [2], "name")] + + @pytest.mark.unit + def test_word_alignment_time_en(self, en_itn_model): + text = "zzz eleven fifty five p m zzz" + iwords, owords, alignment = en_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["zzz", "eleven", "fifty", "five", "p", "m", "zzz"] + assert owords == ["zzz", "11:55", "p.m.", "zzz"] + assert alignment == [([0], [0], "name"), ([1, 2, 3, 4, 5], [1, 2], "time"), ([6], [3], "name")] + + @pytest.mark.unit + def test_word_alignment_money_en(self, en_itn_model): + text = "zzz two hundred fifty dollars zzz" + iwords, owords, alignment = en_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["zzz", "two", "hundred", "fifty", "dollars", "zzz"] + assert owords == ["zzz", "$250", "zzz"] + assert alignment == [([0], [0], "name"), ([1, 2, 3, 4], [1], "money"), ([5], [2], "name")] + + @pytest.mark.unit + def test_word_alignment_combo_en(self, en_itn_model): + text = "eleven twenty seven fifty seven october twenty fourth nineteen seventy" + iwords, owords, alignment = en_itn_model.get_word_alignment(text, sep=" ") + assert iwords == [ + "eleven", + "twenty", + "seven", + "fifty", + "seven", + "october", + "twenty", + "fourth", + "nineteen", + "seventy", + ] + assert owords == ["1120", "07:57", "october", "24", "1970"] + assert alignment == [([0, 1], [0], "date"), ([2, 3, 4], [1], "time"), ([5, 6, 7, 8, 9], [2, 3, 4], "date")] + + @pytest.mark.unit + def test_word_alignment_measure_en(self, en_itn_model): + text = "it is two hundred fifty meters long" + iwords, owords, alignment = en_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["it", "is", "two", "hundred", "fifty", "meters", "long"] + assert owords == ["it", "is", "250", "m", "long"] + assert alignment == [ + ([0], [0], "name"), + ([1], [1], "name"), + ([2, 3, 4, 5], [2, 3], "measure"), + ([6], [4], "name"), + ] + + @pytest.mark.unit + def test_word_alignment_sterling_en(self, en_itn_model): + text = "trade turnover of three million pounds sterling" + iwords, owords, alignment = en_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["trade", "turnover", "of", "three", "million", "pounds", "sterling"] + assert owords == ["trade", "turnover", "of", "£3", "million"] + assert alignment == [ + ([0], [0], "name"), + ([1], [1], "name"), + ([2], [2], "name"), + ([3, 4, 5, 6], [3, 4], "money"), + ] + + @pytest.mark.unit + def test_word_alignment_time_de(self, de_itn_model): + text = "zzz drei uhr zwanzig zzz" + iwords, owords, alignment = de_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["zzz", "drei", "uhr", "zwanzig", "zzz"] + assert owords == ['zzz', '03:20', 'Uhr', 'zzz'] + assert alignment == [([0], [0], "name"), ([1, 2, 3], [1, 2], "time"), ([4], [3], "name")] + + @pytest.mark.unit + def test_word_alignment_money_de(self, de_itn_model): + text = "zzz zwei hundert fünfzig dollar zzz" + iwords, owords, alignment = de_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["zzz", "zwei", "hundert", "fünfzig", "dollar", "zzz"] + assert owords == ["zzz", "$250", "zzz"] + assert alignment == [([0], [0], "name"), ([1, 2, 3, 4], [1], "money"), ([5], [2], "name")] + + @pytest.mark.unit + def test_word_alignment_cardinal_de(self, de_itn_model): + text = "zzz minus fünfundzwanzigtausendsiebenunddreißig zzz" + iwords, owords, alignment = de_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["zzz", "minus", "fünfundzwanzigtausendsiebenunddreißig", "zzz"] + assert owords == ["zzz", "-25037", "zzz"] + assert alignment == [([0], [0], "name"), ([1, 2], [1], "cardinal"), ([3], [2], "name")] + + @pytest.mark.unit + def test_word_alignment_measure_de(self, de_itn_model): + text = "es ist zweihundertfünfzig meter lang" + iwords, owords, alignment = de_itn_model.get_word_alignment(text, sep=" ") + assert iwords == ["es", "ist", "zweihundertfünfzig", "meter", "lang"] + assert owords == ["es", "ist", "250", "m", "lang"] + assert alignment == [([0], [0], "name"), ([1], [1], "name"), ([2, 3], [2, 3], "measure"), ([4], [4], "name")] + + @pytest.mark.unit + def test_word_alignment_combo_es(self, es_itn_model): + text = "un mil intereses al diez por ciento a la semana estándar derecho diez por ciento" + iwords, owords, alignment = es_itn_model.get_word_alignment(text, sep=" ") + assert iwords == [ + 'un', + 'mil', + 'intereses', + 'al', + 'diez', + 'por', + 'ciento', + 'a', + 'la', + 'semana', + 'estándar', + 'derecho', + 'diez', + 'por', + 'ciento', + ] + assert owords == ['1000', 'intereses', 'al', '10', '%', 'a', 'la', 'semana', 'estándar', 'derecho', '10', '%'] + assert alignment == [ + ([0, 1], [0], 'cardinal'), + ([2], [1], 'name'), + ([3], [2], 'name'), + ([4, 5, 6], [3, 4], 'measure'), + ([7], [5], 'name'), + ([8], [6], 'name'), + ([9], [7], 'name'), + ([10], [8], 'name'), + ([11], [9], 'name'), + ([12, 13, 14], [10, 11], 'measure'), + ] diff --git a/tests/collections/asr/inference/test_itn_utils.py b/tests/collections/asr/inference/test_itn_utils.py new file mode 100644 index 000000000000..911ffa6316a3 --- /dev/null +++ b/tests/collections/asr/inference/test_itn_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from nemo.collections.asr.inference.utils.itn_utils import ( + fallback_to_trivial_alignment, + find_tokens, + get_semiotic_class, + get_trivial_alignment, + split_text, +) + + +class TestItnUtils: + + @pytest.mark.unit + @pytest.mark.parametrize( + "text, expected_words, expected_n", + [ + ("hello world how are you", ["hello", "world", "how", "are", "you"], 5), + ("hello", ["hello"], 1), + ("a hello world b ccc d e", ["a", "hello", "world", "b", "ccc", "d", "e"], 7), + (" a hello world b ccc d e", ["a", "hello", "world", "b", "ccc", "d", "e"], 7), + ("a hello world b ccc d e ", ["a", "hello", "world", "b", "ccc", "d", "e"], 7), + (" a hello world b ccc d e ", ["a", "hello", "world", "b", "ccc", "d", "e"], 7), + (" a hello world b ccc d e ", ["a", "hello", "world", "b", "ccc", "d", "e"], 7), + ], + ) + def test_split_text(self, text, expected_words, expected_n): + words, n = split_text(text) + assert words == expected_words + assert n == expected_n + + @pytest.mark.unit + def test_get_semiotic_class(self): + tokens = [{"tokens": {"name": "hello"}}] + semiotic_class = get_semiotic_class(tokens) + assert semiotic_class == "name" + + @pytest.mark.unit + def test_find_tokens(self): + text = "tokens {name: hello} tokens {name: world} tokens {name: how} tokens {name: are} tokens {name: you}" + tokens = find_tokens(text) + assert tokens == [ + "tokens {name: hello}", + "tokens {name: world}", + "tokens {name: how}", + "tokens {name: are}", + "tokens {name: you}", + ] + + @pytest.mark.unit + def test_get_trivial_alignment(self): + N = 5 + i_shift = 1 + o_shift = 2 + alignment = get_trivial_alignment(N, i_shift, o_shift) + assert alignment == [ + ([1], [2], "name"), + ([2], [3], "name"), + ([3], [4], "name"), + ([4], [5], "name"), + ([5], [6], "name"), + ] + + @pytest.mark.unit + def test_fallback_to_trivial_alignment(self): + input_words = ["hello", "world", "how", "are", "you"] + input_words, output_words, word_alignment = fallback_to_trivial_alignment(input_words) + assert input_words == ["hello", "world", "how", "are", "you"] + assert output_words == ["hello", "world", "how", "are", "you"] + assert word_alignment == [ + ([0], [0], "name"), + ([1], [1], "name"), + ([2], [2], "name"), + ([3], [3], "name"), + ([4], [4], "name"), + ] diff --git a/tests/collections/asr/inference/test_pipeline_utils.py b/tests/collections/asr/inference/test_pipeline_utils.py new file mode 100644 index 000000000000..a4042064ee5a --- /dev/null +++ b/tests/collections/asr/inference/test_pipeline_utils.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re + +import pytest +import torch + +from nemo.collections.asr.inference.utils.pipeline_utils import ( + check_existance_of_required_attributes, + drop_trailing_features, + get_leading_punctuation_regex_pattern, + get_repeated_punctuation_regex_pattern, +) + + +class TestPipelineUtils: + + @pytest.mark.unit + def test_drop_trailing_features(self): + x = torch.randn(10, 10, 20) + expected_feature_buffer_len = 15 + x_dropped = drop_trailing_features(x, expected_feature_buffer_len) + assert x_dropped.shape == (10, 10, 15) + assert x_dropped.allclose(x[:, :, :15]) + + @pytest.mark.unit + @pytest.mark.parametrize( + "text, expected_text", + [ + ("", ""), + (" ", " "), + ("simple text", "simple text"), + ("just a 2nd . Yeah, I hope", "just a 2nd. Yeah, I hope"), + ("Hello , world ! How are you ?", "Hello, world! How are you?"), + ("The quick, brown fox jumps ? over the lazy ! dog.", "The quick, brown fox jumps? over the lazy! dog."), + ], + ) + def test_remove_leading_punctuation_spaces(self, text, expected_text): + puncts = {"!", "?", ".", ","} + pattern = get_leading_punctuation_regex_pattern(puncts) + assert re.sub(pattern, r'\1', text) == expected_text + + @pytest.mark.unit + @pytest.mark.parametrize( + "text, expected_text", + [ + ("", ""), + (" ", " "), + ("simple text", "simple text"), + ("Hello, world!! How are you???", "Hello, world! How are you?"), + ("The quick,, brown fox jumps? over the lazy! dog..", "The quick, brown fox jumps? over the lazy! dog."), + ], + ) + def test_remove_repeated_punctuation(self, text, expected_text): + puncts = {"!", "?", ".", ","} + pattern = get_repeated_punctuation_regex_pattern(puncts) + assert re.sub(pattern, r'\1', text) == expected_text + + @pytest.mark.unit + def test_check_existance_of_required_attributes(self): + class TestClass: + pass + + with pytest.raises(ValueError): + check_existance_of_required_attributes(TestClass, ['test_attr']) diff --git a/tests/collections/asr/inference/test_state_management_utils.py b/tests/collections/asr/inference/test_state_management_utils.py new file mode 100644 index 000000000000..5b3f9db6626c --- /dev/null +++ b/tests/collections/asr/inference/test_state_management_utils.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from nemo.collections.asr.inference.utils.state_management_utils import ( + detect_overlap, + find_max_overlap, + merge_segment_tail, + merge_timesteps, + merge_word_tail, +) +from nemo.collections.asr.inference.utils.text_segment import TextSegment, Word + + +class TestStateManagementUtils: + + @pytest.mark.unit + @pytest.mark.parametrize( + "timesteps1, timesteps2, expected_merged_timesteps", + [ + ([0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7]), + ([0, 1, 2, 3], [], [0, 1, 2, 3]), + ([], [4, 5, 6, 7], [4, 5, 6, 7]), + ([-1, 0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7]), + ([-3, 1, 2, 3], [], [0, 4, 5, 6]), + ([], [-3, 1, 2, 3], [0, 4, 5, 6]), + ], + ) + def test_merge_timesteps(self, timesteps1, timesteps2, expected_merged_timesteps): + merged_timesteps = merge_timesteps(timesteps1, timesteps2) + assert merged_timesteps == expected_merged_timesteps + + @pytest.mark.unit + @pytest.mark.parametrize( + "state_tokens, new_tokens, limit, expected_max_overlap", + [ + ([0, 1, 2, 3], [2, 3, 4, 5], 4, 2), + ([0, 2, 3, 4], [2, 3, 4, 5], 4, 3), + ([0, 0, 0, 1], [2, 3, 4, 5], 4, 0), + ], + ) + def test_find_max_overlap(self, state_tokens, new_tokens, limit, expected_max_overlap): + max_overlap = find_max_overlap(state_tokens, new_tokens, limit) + assert max_overlap == expected_max_overlap + + @pytest.mark.unit + @pytest.mark.parametrize( + "state_tokens, state_timesteps, new_tokens, new_timesteps, expected_overlap", + [ + ([0, 1, 2, 3], [0.0, 1.0, 2.0, 3.0], [2, 3, 4, 5], [2.0, 3.0, 4.0, 5.0], 2), + ([0, 1, 2, 3], [0.0, 1.0, 2.0, 3.0], [2, 3, 4, 5], [1.0, 2.0, 4.0, 5.0], 2), + ([0, 1, 2, 3], [0.0, 1.0, 2.0, 3.0], [2, 3, 4, 5], [5.0, 7.0, 8.0, 9.0], 0), + ], + ) + def test_detect_overlap(self, state_tokens, state_timesteps, new_tokens, new_timesteps, expected_overlap): + overlap = detect_overlap(state_tokens, state_timesteps, new_tokens, new_timesteps) + assert overlap == expected_overlap + + @pytest.mark.unit + def test_merge_word_tail_without_pnc(self): + word_head = Word(text="meaning", start=0.0, end=1.0, conf=0.5) + word_tail = Word(text="ful", start=1.0, end=2.0, conf=0.6) + head, _ = merge_word_tail(word_head, word_tail, conf_aggregator=min) + + assert head.text == "meaningful" + assert head.start == 0.0 + assert head.end == 2.0 + assert head.conf == 0.5 + + @pytest.mark.unit + def test_merge_word_tail_with_pnc(self): + + word_head = Word(text="meaning", start=0.0, end=1.0, conf=0.5) + word_tail = Word(text="s", start=1.0, end=2.0, conf=0.6) + pnc_head = Word(text="Meaning?", start=0.0, end=1.0, conf=0.5) + new_head, new_pnc_head = merge_word_tail(word_head, word_tail, conf_aggregator=min, pnc_word_head=pnc_head) + + assert new_head.text == "meanings" + assert new_head.start == 0.0 + assert new_head.end == 2.0 + assert new_head.conf == 0.5 + assert new_pnc_head.text == "Meanings?" + assert new_pnc_head.start == 0.0 + assert new_pnc_head.end == 2.0 + assert new_pnc_head.conf == 0.5 + + @pytest.mark.unit + def test_merge_segment_tail(self): + seg1 = TextSegment(text="Good morn", start=0.0, end=1.0, conf=0.5) + seg2 = TextSegment(text="ing", start=1.0, end=2.0, conf=0.6) + merged_seg = merge_segment_tail(seg1, seg2, conf_aggregator=min) + + assert merged_seg.text == "Good morning" + assert merged_seg.start == 0.0 + assert merged_seg.end == 2.0 + assert merged_seg.conf == 0.5 diff --git a/tests/collections/asr/inference/test_text_segment.py b/tests/collections/asr/inference/test_text_segment.py new file mode 100644 index 000000000000..49f9da5d5fb8 --- /dev/null +++ b/tests/collections/asr/inference/test_text_segment.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from nemo.collections.asr.inference.utils.text_segment import ( + TextSegment, + Word, + join_segments, + normalize_segments_inplace, +) + + +class TestTextSegment: + + @pytest.mark.unit + @pytest.mark.parametrize("text, expected_text", [("Hello!", "hello"), ("HeLLo!", "hello")]) + def test_normalize_text_inplace(self, text, expected_text): + for cls in [Word, TextSegment]: + text_segment = cls(text, 0, 1, 0.5) + text_segment.normalize_text_inplace(punct_marks='!', sep=' ') + assert text_segment.text == expected_text + + @pytest.mark.unit + @pytest.mark.parametrize("text, expected_text", [("Hello!", "hello"), ("HeLLo!", "hello")]) + def test_with_normalized_text(self, text, expected_text): + for cls in [Word, TextSegment]: + text_segment = cls(text, 0, 1, 0.5) + text_segment_copy = text_segment.with_normalized_text(punct_marks='!', sep=' ') + assert text_segment_copy.text == expected_text + assert text_segment.text == text + + @pytest.mark.unit + def test_join_segments(self): + for cls in [Word, TextSegment]: + segments = [ + [cls('hello', 0, 1, 0.5), cls('world', 1, 2, 0.5)], + [cls('how', 2, 3, 0.5), cls('are', 3, 4, 0.5), cls('you', 4, 5, 0.5)], + ] + transcriptions = join_segments(segments, sep=' ') + assert transcriptions == ['hello world', 'how are you'] + + @pytest.mark.unit + def test_normalize_segments_inplace(self): + for cls in [Word, TextSegment]: + segments = [cls('Hello!', 0, 1, 0.5), cls('world?', 1, 2, 0.5)] + normalize_segments_inplace(segments, punct_marks=set("!?"), sep=' ') + assert segments[0].text == 'hello' + assert segments[1].text == 'world' + + @pytest.mark.unit + @pytest.mark.parametrize("text, expected_text", [("hello", "Hello"), ("World!", "World!")]) + def test_capitalize(self, text, expected_text): + for cls in [Word, TextSegment]: + text_segment = cls(text, 0, 1, 0.5) + text_segment.capitalize() + assert text_segment.text == expected_text diff --git a/tests/functional_tests/L2_Speech_Transcription_Streaming_Inference.sh b/tests/functional_tests/L2_Speech_Transcription_Streaming_Inference.sh new file mode 100644 index 000000000000..cb118359366c --- /dev/null +++ b/tests/functional_tests/L2_Speech_Transcription_Streaming_Inference.sh @@ -0,0 +1,34 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/asr/asr_streaming_inference/asr_streaming_infer.py \ + --config-path="../conf/asr_streaming_inference/" \ + --config-name=buffered_ctc.yaml \ + audio_file="/home/TestData/an4_transcribe/test_subset/" \ + output_filename="/tmp/buffered_ctc_test_res.json" \ + output_dir="/tmp/buffered_ctc_test_dir" \ + lang=en \ + enable_pnc=False \ + enable_itn=False \ + asr_output_granularity=segment + +coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/asr/asr_streaming_inference/asr_streaming_infer.py \ + --config-path="../conf/asr_streaming_inference/" \ + --config-name=buffered_rnnt.yaml \ + audio_file="/home/TestData/an4_transcribe/test_subset/" \ + output_filename="/tmp/buffered_rnnt_test_res.json" \ + output_dir="/tmp/buffered_rnnt_test_dir" \ + lang=en \ + enable_pnc=False \ + enable_itn=False \ + asr_output_granularity=segment