Skip to content

Commit ee72ecb

Browse files
committed
implement variational weight noise
Fix #240
1 parent b14fe8a commit ee72ecb

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

nn/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .label_smoothing import *
1010
from .stochastic_depth import *
1111
from .targets import *
12+
from .variational_weight_noise import *
1213
from .weight_norm import weight_norm, remove_weight_norm

nn/utils/variational_weight_noise.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
Variational weight noise
3+
4+
https://github.com/rwth-i6/returnn_common/issues/240
5+
"""
6+
7+
from __future__ import annotations
8+
from typing import TypeVar
9+
from ... import nn
10+
11+
12+
T_module = TypeVar('T_module', bound=nn.Module)
13+
14+
15+
def variational_weight_noise(module: T_module, name: str, weight_noise_std: float) -> T_module:
16+
"""
17+
:param module: module
18+
:param name: name of the weight parameter
19+
:param weight_noise_std: standard deviation of the weight noise
20+
"""
21+
assert weight_noise_std > 0
22+
assert hasattr(module, name)
23+
weight = getattr(module, name)
24+
assert isinstance(weight, nn.Parameter)
25+
26+
setattr(module, f"{name}_raw", weight)
27+
28+
with nn.Cond(nn.train_flag()) as cond:
29+
weight_noise = nn.random_normal(weight.shape_ordered, weight.dtype, stddev=weight_noise_std)
30+
cond.true = weight + weight_noise
31+
cond.false = weight
32+
setattr(module, name, cond.result)
33+
return module

tests/test_nn_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,19 @@ def test_weight_norm():
4040
y.mark_as_default_output()
4141
config_str = nn.get_returnn_config().get_complete_py_code_str(net)
4242
dummy_run_net_single_custom(config_str, eval_flag=True)
43+
44+
45+
def test_variational_weight_noise():
46+
nn.reset_default_root_name_ctx()
47+
time_dim = nn.SpatialDim("time")
48+
in_dim = nn.FeatureDim("in", 3)
49+
x = nn.Data("data", dim_tags=[nn.batch_dim, time_dim, in_dim])
50+
x = nn.get_extern_data(x)
51+
net = nn.Linear(in_dim, nn.FeatureDim("out", 5))
52+
assert isinstance(net.weight, nn.Parameter)
53+
nn.variational_weight_noise(net, "weight", 0.075)
54+
assert not isinstance(net.weight, nn.Parameter) and isinstance(net.weight, nn.Tensor)
55+
y = net(x)
56+
y.mark_as_default_output()
57+
config_str = nn.get_returnn_config().get_complete_py_code_str(net)
58+
dummy_run_net_single_custom(config_str, train_flag=True)

0 commit comments

Comments
 (0)