|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | from difflib import SequenceMatcher |
16 | | -from typing import List, Optional, Tuple |
| 16 | +from typing import List, Optional, Tuple, Union |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 |
|
@@ -395,10 +395,14 @@ def ensure_char_token(entry): |
395 | 395 | # When return_hypotheses is True, y_sequence contains logits (2D: [T, V]). |
396 | 396 | if return_hypotheses and hasattr(hypotheses[0], 'token_sequence') and hypotheses[0].token_sequence is not None: |
397 | 397 | merged_hypotheses.y_sequence = torch.cat([hyp.y_sequence for hyp in hypotheses], dim=0) |
398 | | - merged_hypotheses.token_sequence = torch.tensor(merged_tokens, dtype=torch.long) |
| 398 | + merged_hypotheses.token_sequence = torch.tensor(merged_tokens) |
399 | 399 | else: |
400 | 400 | merged_hypotheses.y_sequence = torch.tensor(merged_tokens) |
401 | 401 |
|
| 402 | + merged_alignments = join_alignments(hypotheses) |
| 403 | + if merged_alignments is not None: |
| 404 | + merged_hypotheses.alignments = merged_alignments |
| 405 | + |
402 | 406 | merged_hypotheses = join_confidence_values(merged_hypotheses, hypotheses) |
403 | 407 | merged_hypotheses.text = final_text |
404 | 408 | # Set score to minimum of all chunk scores, length to sum of all chunk lengths |
@@ -504,6 +508,48 @@ def update_timestamps(hypotheses, tokenizer=None, timestamps_type=None, lang_id= |
504 | 508 | return hypotheses |
505 | 509 |
|
506 | 510 |
|
| 511 | +def join_alignments( |
| 512 | + hypotheses: List[Hypothesis], |
| 513 | +) -> Optional[Union[torch.Tensor, List]]: |
| 514 | + """ |
| 515 | + Concatenate alignments from multiple chunk hypotheses into a single sequence. |
| 516 | +
|
| 517 | + Supports both CTC alignments (1D: list of ints or tensor) and RNNT alignments |
| 518 | + (2D: list of lists, one inner list per time step). If any hypothesis has no |
| 519 | + alignments, returns None and the caller should leave merged alignments unset. |
| 520 | +
|
| 521 | + Args: |
| 522 | + hypotheses: List of Hypothesis objects, each possibly having an alignments attribute. |
| 523 | +
|
| 524 | + Returns: |
| 525 | + Concatenated alignments (tensor or list), or None if any hypothesis has no alignments. |
| 526 | + """ |
| 527 | + if not hypotheses: |
| 528 | + return None |
| 529 | + if not all(getattr(h, 'alignments', None) is not None for h in hypotheses): |
| 530 | + return None |
| 531 | + |
| 532 | + alignments_list = [h.alignments for h in hypotheses] |
| 533 | + |
| 534 | + # CTC: alignments are a 1D tensor |
| 535 | + if isinstance(alignments_list[0], torch.Tensor): |
| 536 | + return torch.cat(alignments_list, dim=0) |
| 537 | + |
| 538 | + # RNNT: alignments are list of lists (one list per time step T) |
| 539 | + first_nonempty = next((a for a in alignments_list if len(a) > 0), None) |
| 540 | + if first_nonempty is not None and isinstance(first_nonempty[0], (list, tuple)): |
| 541 | + result = [] |
| 542 | + for a in alignments_list: |
| 543 | + result.extend(a) |
| 544 | + return result |
| 545 | + |
| 546 | + # CTC: alignments are a flat list of ints |
| 547 | + result = [] |
| 548 | + for a in alignments_list: |
| 549 | + result.extend(a.tolist() if isinstance(a, torch.Tensor) else a) |
| 550 | + return result |
| 551 | + |
| 552 | + |
507 | 553 | def join_confidence_values(merged_hypothesis, hypotheses): |
508 | 554 | """ |
509 | 555 | Concatenate confidence values from multiple hypotheses into a single sequence. |
@@ -1116,6 +1162,11 @@ def merge_hypotheses_of_same_audio( |
1116 | 1162 | if valid_y_sequences: |
1117 | 1163 | merged_hypothesis.y_sequence = torch.cat(valid_y_sequences) |
1118 | 1164 |
|
| 1165 | + # Merge alignments from all hypotheses (CTC: 1D tensor/list; RNNT: list of lists) |
| 1166 | + merged_alignments = join_alignments(hypotheses_list) |
| 1167 | + if merged_alignments is not None: |
| 1168 | + merged_hypothesis.alignments = merged_alignments |
| 1169 | + |
1119 | 1170 | # Merge confidence values from all hypotheses |
1120 | 1171 | merged_hypothesis = join_confidence_values(merged_hypothesis, hypotheses_list) |
1121 | 1172 | # Set score to minimum of all chunk scores, length to sum of all chunk lengths |
|
0 commit comments