Commit 09e5724
authored
[CUDA] Fix beam search of num_beams > 32 (#23599)
### Description
* Pass topk_scores to beam scorer in slow topk path.
* Add an env variable `ORT_BEAM_SEARCH_USE_FAST_TOPK` to enable/disable fast topk.
* Add a test case for slow topk path.
### Motivation and Context
This bug was introduced in
#16272
Beam search uses fast cuda kernel when number of beams <= 32. When beam
size is larger than that threshold, we use another code path (slower
cuda kernel) to get topk. In such `slow topk path`, topk_scores shall be
passed to beam scorer but it is not.
This bug will cause incorrect result when num_beams > 32. It was not
found previously since such large beam size is rarely used.1 parent 82840f6 commit 09e5724
File tree
4 files changed
+35
-15
lines changed- onnxruntime
- contrib_ops
- cpu/transformers
- cuda/transformers
- test/contrib_ops
4 files changed
+35
-15
lines changedLines changed: 5 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
136 | 137 | | |
137 | 138 | | |
138 | 139 | | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
139 | 143 | | |
| 144 | + | |
140 | 145 | | |
141 | 146 | | |
142 | 147 | | |
| |||
Lines changed: 6 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
199 | 199 | | |
200 | 200 | | |
201 | 201 | | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
202 | 205 | | |
203 | 206 | | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
204 | 210 | | |
205 | 211 | | |
206 | 212 | | |
Lines changed: 11 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
524 | 524 | | |
525 | 525 | | |
526 | 526 | | |
527 | | - | |
| 527 | + | |
| 528 | + | |
528 | 529 | | |
529 | 530 | | |
530 | 531 | | |
| |||
546 | 547 | | |
547 | 548 | | |
548 | 549 | | |
549 | | - | |
550 | | - | |
551 | | - | |
552 | | - | |
553 | | - | |
554 | | - | |
555 | | - | |
556 | 550 | | |
557 | 551 | | |
558 | 552 | | |
| |||
588 | 582 | | |
589 | 583 | | |
590 | 584 | | |
591 | | - | |
592 | | - | |
593 | | - | |
594 | | - | |
595 | | - | |
| 585 | + | |
596 | 586 | | |
597 | 587 | | |
598 | 588 | | |
599 | | - | |
| 589 | + | |
600 | 590 | | |
601 | 591 | | |
602 | 592 | | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
603 | 599 | | |
604 | 600 | | |
605 | 601 | | |
| |||
735 | 731 | | |
736 | 732 | | |
737 | 733 | | |
| 734 | + | |
738 | 735 | | |
739 | 736 | | |
740 | 737 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
| 13 | + | |
12 | 14 | | |
13 | 15 | | |
14 | 16 | | |
| |||
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
22 | | - | |
| 24 | + | |
23 | 25 | | |
24 | 26 | | |
25 | 27 | | |
| |||
107 | 109 | | |
108 | 110 | | |
109 | 111 | | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
110 | 122 | | |
111 | 123 | | |
112 | 124 | | |
| |||
0 commit comments