Skip to content

Commit fe13830

Browse files
Updates
Signed-off-by: Nune <[email protected]>
1 parent 101e8da commit fe13830

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

nemo/collections/asr/parts/utils/chunking_utils.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from difflib import SequenceMatcher
16-
from typing import List, Optional, Tuple
16+
from typing import List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -395,10 +395,14 @@ def ensure_char_token(entry):
395395
# When return_hypotheses is True, y_sequence contains logits (2D: [T, V]).
396396
if return_hypotheses and hasattr(hypotheses[0], 'token_sequence') and hypotheses[0].token_sequence is not None:
397397
merged_hypotheses.y_sequence = torch.cat([hyp.y_sequence for hyp in hypotheses], dim=0)
398-
merged_hypotheses.token_sequence = torch.tensor(merged_tokens, dtype=torch.long)
398+
merged_hypotheses.token_sequence = torch.tensor(merged_tokens)
399399
else:
400400
merged_hypotheses.y_sequence = torch.tensor(merged_tokens)
401401

402+
merged_alignments = join_alignments(hypotheses)
403+
if merged_alignments is not None:
404+
merged_hypotheses.alignments = merged_alignments
405+
402406
merged_hypotheses = join_confidence_values(merged_hypotheses, hypotheses)
403407
merged_hypotheses.text = final_text
404408
# Set score to minimum of all chunk scores, length to sum of all chunk lengths
@@ -504,6 +508,48 @@ def update_timestamps(hypotheses, tokenizer=None, timestamps_type=None, lang_id=
504508
return hypotheses
505509

506510

511+
def join_alignments(
512+
hypotheses: List[Hypothesis],
513+
) -> Optional[Union[torch.Tensor, List]]:
514+
"""
515+
Concatenate alignments from multiple chunk hypotheses into a single sequence.
516+
517+
Supports both CTC alignments (1D: list of ints or tensor) and RNNT alignments
518+
(2D: list of lists, one inner list per time step). If any hypothesis has no
519+
alignments, returns None and the caller should leave merged alignments unset.
520+
521+
Args:
522+
hypotheses: List of Hypothesis objects, each possibly having an alignments attribute.
523+
524+
Returns:
525+
Concatenated alignments (tensor or list), or None if any hypothesis has no alignments.
526+
"""
527+
if not hypotheses:
528+
return None
529+
if not all(getattr(h, 'alignments', None) is not None for h in hypotheses):
530+
return None
531+
532+
alignments_list = [h.alignments for h in hypotheses]
533+
534+
# CTC: alignments are a 1D tensor
535+
if isinstance(alignments_list[0], torch.Tensor):
536+
return torch.cat(alignments_list, dim=0)
537+
538+
# RNNT: alignments are list of lists (one list per time step T)
539+
first_nonempty = next((a for a in alignments_list if len(a) > 0), None)
540+
if first_nonempty is not None and isinstance(first_nonempty[0], (list, tuple)):
541+
result = []
542+
for a in alignments_list:
543+
result.extend(a)
544+
return result
545+
546+
# CTC: alignments are a flat list of ints
547+
result = []
548+
for a in alignments_list:
549+
result.extend(a.tolist() if isinstance(a, torch.Tensor) else a)
550+
return result
551+
552+
507553
def join_confidence_values(merged_hypothesis, hypotheses):
508554
"""
509555
Concatenate confidence values from multiple hypotheses into a single sequence.
@@ -1116,6 +1162,11 @@ def merge_hypotheses_of_same_audio(
11161162
if valid_y_sequences:
11171163
merged_hypothesis.y_sequence = torch.cat(valid_y_sequences)
11181164

1165+
# Merge alignments from all hypotheses (CTC: 1D tensor/list; RNNT: list of lists)
1166+
merged_alignments = join_alignments(hypotheses_list)
1167+
if merged_alignments is not None:
1168+
merged_hypothesis.alignments = merged_alignments
1169+
11191170
# Merge confidence values from all hypotheses
11201171
merged_hypothesis = join_confidence_values(merged_hypothesis, hypotheses_list)
11211172
# Set score to minimum of all chunk scores, length to sum of all chunk lengths

tests/collections/asr/mixins/test_transcription.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def test_transcribe_return_hypothesis(self, test_data_dir, fast_conformer_ctc_mo
321321

322322
# Audio file test
323323
#setting enable_chunking False for alignment check
324-
outputs = fast_conformer_ctc_model.transcribe(audio_file, batch_size=1, return_hypotheses=True, enable_chunking=False)
324+
outputs = fast_conformer_ctc_model.transcribe(audio_file, batch_size=1, return_hypotheses=True)
325325
assert len(outputs) == 1
326326
assert isinstance(outputs[0], Hypothesis)
327327

tests/collections/asr/test_asr_context_biasing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_run_word_spotter(self, test_data_dir, conformer_ctc_bpe_model):
5959
target_text = "nineteen"
6060
target_tokenization = asr_model.tokenizer.text_to_ids(target_text)
6161
ctc_logprobs = (
62-
asr_model.transcribe([audio_file_path], batch_size=1, return_hypotheses=True, enable_chunking=False)[0].alignments.cpu().numpy()
62+
asr_model.transcribe([audio_file_path], batch_size=1, return_hypotheses=True)[0].alignments.cpu().numpy()
6363
)
6464
context_biasing_list = [[target_text, [target_tokenization]]]
6565
context_graph = context_biasing.ContextGraphCTC(blank_id=asr_model.decoding.blank_id)

0 commit comments

Comments
 (0)