Skip to content

Commit 67d32eb

Browse files
committed
Enhance inline documentation. Set compute_dtype=float32 by default.
Signed-off-by: Vladimir Bataev <[email protected]>
1 parent 4a05ae7 commit 67d32eb

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
2424
## To evaluate a model in cache-aware streaming mode on a single audio file:
2525
26-
python speech_to_text_streaming_infer.py \
26+
python speech_to_text_cache_aware_streaming_infer.py \
2727
model_path=asr_model.nemo \
2828
audio_file=audio_file.wav \
2929
compare_vs_offline=true \
@@ -32,20 +32,37 @@
3232
3333
## To evaluate a model in cache-aware streaming mode on a manifest file:
3434
35-
python speech_to_text_streaming_infer.py \
35+
python speech_to_text_cache_aware_streaming_infer.py \
3636
model_path=asr_model.nemo \
3737
dataset_manifest=manifest_file.json \
3838
batch_size=16 \
3939
compare_vs_offline=true \
4040
amp=true \
4141
debug_mode=true
4242
43+
## It is also possible to use phrase boosting or external LM with cache-aware models:
44+
45+
python speech_to_text_cache_aware_streaming_infer.py \
46+
model_path=asr_model.nemo \
47+
dataset_manifest=manifest_file.json \
48+
batch_size=16 \
49+
rnnt_decoding.greedy.boosting_tree.key_phrases_file=key_words_list.txt \
50+
rnnt_decoding.greedy.boosting_tree_alpha=1.0 \
51+
rnnt_decoding.greedy.ngram_lm_model=lm_model.nemo \
52+
rnnt_decoding.greedy.ngram_lm_model=0.5 \
53+
compare_vs_offline=true \
54+
amp=true \
55+
debug_mode=true
56+
4357
You may drop the 'debug_mode' and 'compare_vs_offline' to speedup the streaming evaluation.
4458
If compare_vs_offline is not used, then significantly larger batch_size can be used.
4559
Setting `pad_and_drop_preencoded` would perform the caching for all steps including the first step.
4660
It may result in slightly different outputs from the sub-sampling module compared to offline mode for some techniques like striding and sw_striding.
4761
Enabling it would make it easier to export the model to ONNX.
4862
63+
For customization details (phrases list, n-gram LM) see details in the documentation:
64+
https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/asr_language_modeling_and_customization.html
65+
4966
## Hybrid ASR models
5067
For Hybrid ASR models which have two decoders, you may select the decoder by decoder_type DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt".
5168
If decoder is not set, then the default decoder would be used which is the RNNT decoder for Hybrid ASR models.
@@ -66,7 +83,7 @@
6683
The following command would simulate cache-aware streaming on a pretrained model from NGC with chunk_size of 100, shift_size of 50 and 2 left chunks as left context.
6784
The chunk_size of 100 would be 100*4*10=4000ms for a model with 4x downsampling and 10ms shift in feature extraction.
6885
69-
python speech_to_text_streaming_infer.py \
86+
python speech_to_text_cache_aware_streaming_infer.py \
7087
pretrained_name=stt_en_conformer_ctc_large \
7188
chunk_size=100 \
7289
shift_size=50 \
@@ -147,8 +164,9 @@ class TranscriptionConfig:
147164
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
148165
amp: bool = False
149166
amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp
167+
# NB: default compute_dtype is float32 since currently cache-aware models do not work with different dtype
150168
compute_dtype: Optional[str] = (
151-
None # "float32", "bfloat16" or "float16"; if None (default): bfloat16 if available else float32
169+
"float32" # "float32" (default), "bfloat16" or "float16"; if None: bfloat16 if available else float32
152170
)
153171
matmul_precision: str = "high" # Literal["highest", "high", "medium"]
154172

@@ -306,9 +324,9 @@ def main(cfg: TranscriptionConfig):
306324
if compute_dtype != torch.float32:
307325
# NB: cache-aware models do not currently work with compute_dtype != float32
308326
# since in some layers output is force-casted to float32
309-
# TODO(vbataev): implement support in future
327+
# TODO(vbataev): implement support in future; set `compute_dtype` in config to None by default
310328
raise NotImplementedError(
311-
f"Compute dtype {cfg.compute_dtype} is not yet supported for cache-aware models, use float32 instead"
329+
f"Compute dtype {compute_dtype} is not yet supported for cache-aware models, use float32 instead"
312330
)
313331

314332
if (cfg.audio_file is None and cfg.dataset_manifest is None) or (

0 commit comments

Comments
 (0)