diff --git a/ReinforceCategorical.lua b/ReinforceCategorical.lua index 7f66e21..15abf2e 100644 --- a/ReinforceCategorical.lua +++ b/ReinforceCategorical.lua @@ -11,7 +11,7 @@ local ReinforceCategorical, parent = torch.class("nn.ReinforceCategorical", "nn. function ReinforceCategorical:updateOutput(input) self.output:resizeAs(input) - self._index = self._index or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor()) + self._index = self._index or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaLongTensor() or torch.LongTensor()) if self.stochastic or self.train ~= false then -- sample from categorical with p = input self._input = self._input or input.new()