Skip to content

Commit 3b9a2b4

Browse files
committed
Add functional conversion of HardShrink
1 parent d2e3243 commit 3b9a2b4

File tree

6 files changed

+55
-44
lines changed

6 files changed

+55
-44
lines changed

HardShrink.lua

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,20 @@ function HardShrink:__init(lam)
66
end
77

88
function HardShrink:updateOutput(input)
9-
input.nn.HardShrink_updateOutput(self, input)
9+
input.THNN.HardShrink_updateOutput(
10+
input:cdata(),
11+
self.output:cdata(),
12+
self.lambda
13+
)
1014
return self.output
1115
end
1216

1317
function HardShrink:updateGradInput(input, gradOutput)
14-
input.nn.HardShrink_updateGradInput(self, input, gradOutput)
18+
input.THNN.HardShrink_updateGradInput(
19+
input:cdata(),
20+
gradOutput:cdata(),
21+
self.gradInput:cdata(),
22+
self.lambda
23+
)
1524
return self.gradInput
1625
end

THNN.lua

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ TH_API void THNN_(DistKLDivCriterion_updateGradInput)(
5555
THTensor *target,
5656
THTensor *gradInput,
5757
bool sizeAverage);
58+
59+
TH_API void THNN_(HardShrink_updateOutput)(
60+
THNNState *state,
61+
THTensor *input,
62+
THTensor *output,
63+
real lambda);
64+
TH_API void THNN_(HardShrink_updateGradInput)(
65+
THNNState *state,
66+
THTensor *input,
67+
THTensor *gradOutput,
68+
THTensor *gradInput,
69+
real lambda);
5870
]]
5971

6072
-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h

init.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
#include "generic/Tanh.c"
3030
#include "THGenerateFloatTypes.h"
3131

32-
#include "generic/HardShrink.c"
33-
#include "THGenerateFloatTypes.h"
34-
3532
#include "generic/SoftShrink.c"
3633
#include "THGenerateFloatTypes.h"
3734

@@ -157,7 +154,6 @@ int luaopen_libnn(lua_State *L)
157154
nn_FloatSoftMax_init(L);
158155
nn_FloatSoftPlus_init(L);
159156
nn_FloatTanh_init(L);
160-
nn_FloatHardShrink_init(L);
161157
nn_FloatSoftShrink_init(L);
162158
nn_FloatThreshold_init(L);
163159
nn_FloatPReLU_init(L);
@@ -202,7 +198,6 @@ int luaopen_libnn(lua_State *L)
202198
nn_DoubleSoftMax_init(L);
203199
nn_DoubleSoftPlus_init(L);
204200
nn_DoubleTanh_init(L);
205-
nn_DoubleHardShrink_init(L);
206201
nn_DoubleSoftShrink_init(L);
207202
nn_DoubleThreshold_init(L);
208203
nn_DoublePReLU_init(L);

lib/THNN/generic/HardShrink.c

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,29 @@
22
#define TH_GENERIC_FILE "generic/HardShrink.c"
33
#else
44

5-
static int nn_(HardShrink_updateOutput)(lua_State *L)
5+
void THNN_(HardShrink_updateOutput)(THNNState *state, THTensor *input, THTensor *output, real lambda)
66
{
7-
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
8-
real lambda = luaT_getfieldchecknumber(L, 1, "lambda");
9-
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
10-
117
THTensor_(resizeAs)(output, input);
12-
13-
TH_TENSOR_APPLY2(real, output, real, input, \
14-
if ((*input_data) > lambda) *output_data = *input_data; \
15-
else if ((*input_data) < -lambda) *output_data = *input_data; \
16-
else *output_data = 0;);
17-
return 1;
18-
}
19-
20-
static int nn_(HardShrink_updateGradInput)(lua_State *L)
21-
{
22-
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
23-
real lambda = luaT_getfieldchecknumber(L, 1, "lambda");
24-
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
25-
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
268

27-
THTensor_(resizeAs)(gradInput, input);
28-
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
29-
if ((*input_data) > lambda || (*input_data) < -lambda) \
30-
*gradInput_data = (*gradOutput_data); \
31-
else \
32-
*gradInput_data = 0; \
33-
);
34-
return 1;
9+
TH_TENSOR_APPLY2(real, output, real, input,
10+
if ((*input_data) > lambda)
11+
*output_data = *input_data;
12+
else if ((*input_data) < -lambda)
13+
*output_data = *input_data;
14+
else
15+
*output_data = 0;
16+
);
3517
}
3618

37-
static const struct luaL_Reg nn_(HardShrink__) [] = {
38-
{"HardShrink_updateOutput", nn_(HardShrink_updateOutput)},
39-
{"HardShrink_updateGradInput", nn_(HardShrink_updateGradInput)},
40-
{NULL, NULL}
41-
};
42-
43-
static void nn_(HardShrink_init)(lua_State *L)
19+
void THNN_(HardShrink_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, real lambda)
4420
{
45-
luaT_pushmetatable(L, torch_Tensor);
46-
luaT_registeratname(L, nn_(HardShrink__), "nn");
47-
lua_pop(L,1);
21+
THTensor_(resizeAs)(gradInput, input);
22+
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
23+
if ((*input_data) > lambda || (*input_data) < -lambda)
24+
*gradInput_data = (*gradOutput_data);
25+
else
26+
*gradInput_data = 0;
27+
);
4828
}
4929

5030
#endif

lib/THNN/generic/THNN.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,16 @@ TH_API void THNN_(DistKLDivCriterion_updateGradInput)(
5555
THTensor *gradInput,
5656
bool sizeAverage);
5757

58+
TH_API void THNN_(HardShrink_updateOutput)(
59+
THNNState *state,
60+
THTensor *input,
61+
THTensor *output,
62+
real lambda);
63+
TH_API void THNN_(HardShrink_updateGradInput)(
64+
THNNState *state,
65+
THTensor *input,
66+
THTensor *gradOutput,
67+
THTensor *gradInput,
68+
real lambda);
69+
5870
#endif

lib/THNN/init.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@
1515

1616
#include "generic/DistKLDivCriterion.c"
1717
#include "THGenerateFloatTypes.h"
18+
19+
#include "generic/HardShrink.c"
20+
#include "THGenerateFloatTypes.h"

0 commit comments

Comments
 (0)