Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit e07df3c

Browse files
James Reedfacebook-github-bot
authored andcommitted
Some fixes for running on CPU (#221)
Summary: Pull Request resolved: #221 Running the benchmark with CUDA_VISIBLE_DEVICES='' brought up some issues in the code when running on CPU. These are the fixes Reviewed By: jhcross Differential Revision: D9996497 fbshipit-source-id: e418f0d5f9bd98b97fb149d689f3571fb4b8f14d
1 parent eab1d26 commit e07df3c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytorch_translate/beam_decode.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def generate_batched_itr(
9999
for sample in data_itr:
100100
if cuda:
101101
s = utils.move_to_cuda(sample)
102+
else:
103+
s = sample
102104
input = s["net_input"]
103105
srclen = input["src_tokens"].size(1)
104106
if self.use_char_source:
@@ -608,8 +610,9 @@ def gather_probs(all_translation_tokens, all_probs):
608610
# The corresponding model did not use vocab reduction if
609611
# possible_translation_tokens is None.
610612
mapped_probs = torch.zeros(
611-
(probs.size(0), possible_translation_tokens.size(0))
612-
).cuda()
613+
(probs.size(0), possible_translation_tokens.size(0)),
614+
device=probs.device,
615+
)
613616

614617
mapped_probs[:, inv_ind] = probs
615618
if avg_probs is None:

0 commit comments

Comments
 (0)