Skip to content

Commit 01e2392

Browse files
authored
Two fixes to SampleK (#1086)
* Fix: set device for best_hyp_indices in SampleK. * Fix: Take top-k values. * Changelog.
1 parent d912554 commit 01e2392

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa
1111

1212
Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.
1313

14+
## [3.1.33]
15+
16+
### Fixed
17+
- Two small fixes to SampleK. Before the device was not set correctly leading to issues when running sampling on GPUs. Furthermore, SampleK did not return the top-k values correctly.
18+
1419
## [3.1.32]
1520

1621
### Added

sockeye/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
__version__ = '3.1.32'
14+
__version__ = '3.1.33'

sockeye/beam_search.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,9 @@ def forward(self, scores, target_dists, finished):
475475
# n == 0 means sample from the full vocabulary. Otherwise, we sample from the top n.
476476
if self.n != 0:
477477
# select the top n in each row, via a mask
478-
_, indices = pt.topk(target_dists, k=self.n, dim=1, largest=True, sorted=True)
478+
values, indices = pt.topk(target_dists, k=self.n, dim=1, largest=True, sorted=True)
479479
# set items not chosen by topk to 0
480-
target_dists = pt.scatter(pt.zeros_like(target_dists), 1, indices, target_dists)
480+
target_dists = pt.scatter(pt.zeros_like(target_dists), 1, indices, values)
481481
# renormalize
482482
target_dists = target_dists / target_dists.sum(1, keepdim=True)
483483

@@ -489,7 +489,7 @@ def forward(self, scores, target_dists, finished):
489489
# (batch, 1)
490490
values = scores.gather(dim=1, index=best_word_indices.long().unsqueeze(1))
491491
# (batch,)
492-
best_hyp_indices = pt.arange(0, best_word_indices.size()[0])
492+
best_hyp_indices = pt.arange(0, best_word_indices.size()[0], device=best_word_indices.device)
493493

494494
return best_hyp_indices, best_word_indices, values
495495

0 commit comments

Comments
 (0)