@@ -115,7 +115,7 @@ def __init__(
115
115
in_features: Number of dimensions of each input sample.
116
116
nonlin: The nonlinear function to apply.
117
117
bias: If `True`, a learnable bias is subtracted, else no bias is used.
118
- weight: If `True`, the input is multiplied with a learnable scaling factor.
118
+ weight: If `True`, the input is multiplied with a learnable scaling factor, else no weighting is used .
119
119
"""
120
120
if not callable (nonlin ):
121
121
if len (nonlin ) != in_features :
@@ -131,11 +131,11 @@ def __init__(
131
131
if weight :
132
132
self .weight = nn .Parameter (torch .empty (in_features , dtype = torch .get_default_dtype ()))
133
133
else :
134
- self .weight = torch .ones (in_features , dtype = torch .get_default_dtype ())
134
+ self .register_buffer ( " weight" , torch .ones (in_features , dtype = torch .get_default_dtype () ))
135
135
if bias :
136
136
self .bias = nn .Parameter (torch .empty (in_features , dtype = torch .get_default_dtype ()))
137
137
else :
138
- self .bias = torch .zeros (in_features , dtype = torch .get_default_dtype ())
138
+ self .register_buffer ( " bias" , torch .zeros (in_features , dtype = torch .get_default_dtype () ))
139
139
140
140
init_param_ (self )
141
141
@@ -153,8 +153,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
153
153
Returns:
154
154
Output tensor.
155
155
"""
156
- tmp = inp + self .bias
157
- tmp = self .weight * tmp
156
+ tmp = self .weight * (inp + self .bias )
158
157
159
158
# Every dimension runs through an individual nonlinearity.
160
159
if _is_iterable (self .nonlin ):
0 commit comments