forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOneHot.lua
42 lines (33 loc) · 1.16 KB
/
OneHot.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
local OneHot, parent = torch.class('nn.OneHot', 'nn.Module')
-- adapted from https://github.com/karpathy/char-rnn
-- and https://github.com/hughperkins/char-rnn-er
function OneHot:__init(outputSize)
parent.__init(self)
self.outputSize = outputSize
end
function OneHot:updateOutput(input)
local size = input:size():totable()
table.insert(size, self.outputSize)
self.output:resize(unpack(size)):zero()
size[input:dim()+1] = 1
local input_ = input:view(unpack(size))
if torch.type(input) == 'torch.CudaTensor' or torch.type(input) == 'torch.ClTensor' then
self.output:scatter(self.output:dim(), input_, 1)
else
if torch.type(input) ~= 'torch.LongTensor' then
self._input = self._input or torch.LongTensor()
self._input:resize(input_:size()):copy(input_)
input_ = self._input
end
self.output:scatter(self.output:dim(), input_, 1)
end
return self.output
end
function OneHot:updateGradInput(input, gradOutput)
self.gradInput:resize(input:size()):zero()
return self.gradInput
end
function OneHot:type(type, typecache)
self._input = nil
return parent.type(self, type, typecache)
end