Skip to content

Commit 0d3283b

Browse files
Inception pool arg
1 parent 8c9edfc commit 0d3283b

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

Inception.lua

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ function Inception:__init(config)
4949
self.kernelSize = config.kernelSize or {5,3}
5050
-- The stride (height=width) of the convolution.
5151
self.kernelStride = config.kernelStride or {1,1}
52+
-- The size (height=width) of the spatial max pooling used
53+
-- in the next-to-last column.
54+
self.poolSize = config.poolSize or 3
55+
-- The stride (height=width) of the spatial max pooling.
56+
self.poolStride = config.poolStride or 1
5257
-- The pooling layer.
53-
self.pool = config.pool or nn.SpatialMaxPooling(3, 3, 1, 1)
58+
self.pool = config.pool or nn.SpatialMaxPooling(self.poolSize, self.poolSize, self.poolStride, self.poolStride)
5459

5560
-- [[ Module Construction ]]--
5661
local depthConcat = nn.DepthConcat(2) -- concat on 'c' dimension

test/test.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ function dpnntest.ReinforceNormal()
591591
local stdev2 = torch.cmul(stdev,stdev)
592592
gradStdev:add(-1,stdev2)
593593
stdev2:cmul(stdev)
594-
gradStdev:cdiv(stdev2)
594+
gradStdev:cdiv(stdev2):mul(-1)
595595
mytester:assertTensorEq(gradInput[2], gradStdev, 0.000001, "ReinforceNormal backward table input - gradStdev err")
596596
end
597597

0 commit comments

Comments
 (0)