Skip to content

Commit 0ee15d8

Browse files
committed
implement weight dropout
Fix #100
1 parent 97b5f99 commit 0ee15d8

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

nn/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
from .stochastic_depth import *
1111
from .targets import *
1212
from .variational_weight_noise import *
13+
from .weight_dropout import *
1314
from .weight_norm import weight_norm, remove_weight_norm

nn/utils/weight_dropout.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
Weight dropout.
3+
4+
Also known as "variational dropout" or "Bayesian dropout",
5+
sometimes applied for LSTM weights,
6+
but this can also be applied to any other weights.
7+
8+
https://github.com/rwth-i6/returnn_common/issues/100
9+
"""
10+
11+
from __future__ import annotations
12+
from typing import Optional, Union, Sequence, TypeVar
13+
from ... import nn
14+
15+
16+
T_module = TypeVar('T_module', bound=nn.Module)
17+
18+
19+
def weight_dropout(
20+
module: T_module, name: str, dropout: float,
21+
*,
22+
axis: Optional[Union[nn.Dim, Sequence[nn.Dim]]] = None,
23+
) -> T_module:
24+
"""
25+
:param module: module
26+
:param name: name of the weight parameter
27+
:param dropout: dropout probability
28+
:param axis: axis to apply dropout on. see :func:`nn.dropout`
29+
"""
30+
assert hasattr(module, name)
31+
weight = getattr(module, name)
32+
assert isinstance(weight, nn.Parameter)
33+
if not axis:
34+
axis = weight.shape_ordered
35+
36+
assert not hasattr(module, f"{name}_raw")
37+
setattr(module, f"{name}_raw", weight)
38+
weight = nn.dropout(weight, dropout, axis=axis)
39+
setattr(module, name, weight)
40+
return module

tests/test_nn_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,19 @@ def test_variational_weight_noise():
5656
y.mark_as_default_output()
5757
config_str = nn.get_returnn_config().get_complete_py_code_str(net)
5858
dummy_run_net_single_custom(config_str, train_flag=True)
59+
60+
61+
def test_weight_dropout():
62+
nn.reset_default_root_name_ctx()
63+
time_dim = nn.SpatialDim("time")
64+
in_dim = nn.FeatureDim("in", 3)
65+
x = nn.Data("data", dim_tags=[nn.batch_dim, time_dim, in_dim])
66+
x = nn.get_extern_data(x)
67+
net = nn.Linear(in_dim, nn.FeatureDim("out", 5))
68+
assert isinstance(net.weight, nn.Parameter)
69+
nn.weight_dropout(net, "weight", 0.3)
70+
assert not isinstance(net.weight, nn.Parameter) and isinstance(net.weight, nn.Tensor)
71+
y = net(x)
72+
y.mark_as_default_output()
73+
config_str = nn.get_returnn_config().get_complete_py_code_str(net)
74+
dummy_run_net_single_custom(config_str, train_flag=True)

0 commit comments

Comments
 (0)