Skip to content

Commit 1bb638c

Browse files
committed
Fix device bug in IndependentNonlinearitiesLayer
1 parent e259e39 commit 1bb638c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

neuralfields/custom_layers.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(
115115
in_features: Number of dimensions of each input sample.
116116
nonlin: The nonlinear function to apply.
117117
bias: If `True`, a learnable bias is subtracted, else no bias is used.
118-
weight: If `True`, the input is multiplied with a learnable scaling factor.
118+
weight: If `True`, the input is multiplied with a learnable scaling factor, else no weighting is used.
119119
"""
120120
if not callable(nonlin):
121121
if len(nonlin) != in_features:
@@ -131,11 +131,11 @@ def __init__(
131131
if weight:
132132
self.weight = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()))
133133
else:
134-
self.weight = torch.ones(in_features, dtype=torch.get_default_dtype())
134+
self.register_buffer("weight", torch.ones(in_features, dtype=torch.get_default_dtype()))
135135
if bias:
136136
self.bias = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()))
137137
else:
138-
self.bias = torch.zeros(in_features, dtype=torch.get_default_dtype())
138+
self.register_buffer("bias", torch.zeros(in_features, dtype=torch.get_default_dtype()))
139139

140140
init_param_(self)
141141

@@ -153,8 +153,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
153153
Returns:
154154
Output tensor.
155155
"""
156-
tmp = inp + self.bias
157-
tmp = self.weight * tmp
156+
tmp = self.weight * (inp + self.bias)
158157

159158
# Every dimension runs through an individual nonlinearity.
160159
if _is_iterable(self.nonlin):

0 commit comments

Comments
 (0)